Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #include <blaze/math/DynamicMatrix.h>
7 : #include <blaze/math/DynamicVector.h>
8 : #include <cstddef>
9 : #include <tuple>
10 : #include <utility>
11 :
12 : #include "DataStructures/DataBox/DataBox.hpp"
13 : #include "DataStructures/DataBox/PrefixHelpers.hpp"
14 : #include "IO/Logging/Tags.hpp"
15 : #include "IO/Logging/Verbosity.hpp"
16 : #include "NumericalAlgorithms/Convergence/Tags.hpp"
17 : #include "NumericalAlgorithms/LinearSolver/InnerProduct.hpp"
18 : #include "Parallel/GlobalCache.hpp"
19 : #include "Parallel/Invoke.hpp"
20 : #include "Parallel/Printf/Printf.hpp"
21 : #include "ParallelAlgorithms/LinearSolver/Gmres/Tags/InboxTags.hpp"
22 : #include "ParallelAlgorithms/LinearSolver/Observe.hpp"
23 : #include "ParallelAlgorithms/LinearSolver/Tags.hpp"
24 : #include "Utilities/EqualWithinRoundoff.hpp"
25 : #include "Utilities/Gsl.hpp"
26 : #include "Utilities/PrettyType.hpp"
27 : #include "Utilities/Requires.hpp"
28 :
29 : /// \cond
30 : namespace tuples {
31 : template <typename...>
32 : class TaggedTuple;
33 : } // namespace tuples
34 : /// \endcond
35 :
36 : namespace LinearSolver::gmres::detail {
37 :
38 : template <typename FieldsTag, typename OptionsGroup, typename BroadcastTarget>
39 : struct InitializeResidualMagnitude {
40 : private:
41 : using fields_tag = FieldsTag;
42 : using residual_magnitude_tag = LinearSolver::Tags::Magnitude<
43 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>>;
44 : using initial_residual_magnitude_tag =
45 : ::Tags::Initial<residual_magnitude_tag>;
46 : using previous_residual_magnitude_tag =
47 : ::Tags::Previous<residual_magnitude_tag>;
48 : using orthogonalization_history_tag =
49 : LinearSolver::Tags::OrthogonalizationHistory<fields_tag>;
50 :
51 : public:
52 : template <typename ParallelComponent, typename DbTagsList,
53 : typename Metavariables, typename ArrayIndex,
54 : typename DataBox = db::DataBox<DbTagsList>>
55 : static void apply(db::DataBox<DbTagsList>& box,
56 : Parallel::GlobalCache<Metavariables>& cache,
57 : const ArrayIndex& /*array_index*/,
58 : const double residual_magnitude) {
59 : constexpr size_t iteration_id = 0;
60 :
61 : db::mutate<initial_residual_magnitude_tag, previous_residual_magnitude_tag>(
62 : [residual_magnitude](
63 : const gsl::not_null<double*> initial_residual_magnitude,
64 : const gsl::not_null<double*> previous_residual_magnitude) {
65 : *initial_residual_magnitude = residual_magnitude;
66 : *previous_residual_magnitude = residual_magnitude;
67 : },
68 : make_not_null(&box));
69 :
70 : LinearSolver::observe_detail::contribute_to_reduction_observer<
71 : OptionsGroup, ParallelComponent>(iteration_id, residual_magnitude,
72 : cache);
73 :
74 : // Determine whether the linear solver has already converged
75 : Convergence::HasConverged has_converged{
76 : get<Convergence::Tags::Criteria<OptionsGroup>>(box), iteration_id,
77 : residual_magnitude, residual_magnitude};
78 :
79 : // Do some logging
80 : if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(cache) >=
81 : ::Verbosity::Quiet)) {
82 : Parallel::printf("%s initialized with residual: %e\n",
83 : pretty_type::name<OptionsGroup>(), residual_magnitude);
84 : }
85 : if (UNLIKELY(has_converged and get<logging::Tags::Verbosity<OptionsGroup>>(
86 : cache) >= ::Verbosity::Quiet)) {
87 : Parallel::printf("%s has converged without any iterations: %s\n",
88 : pretty_type::name<OptionsGroup>(), has_converged);
89 : }
90 :
91 : Parallel::receive_data<Tags::InitialOrthogonalization<OptionsGroup>>(
92 : Parallel::get_parallel_component<BroadcastTarget>(cache), iteration_id,
93 : // NOLINTNEXTLINE(performance-move-const-arg)
94 : std::make_tuple(residual_magnitude, std::move(has_converged)));
95 : }
96 : };
97 :
98 : template <typename FieldsTag, typename OptionsGroup, typename BroadcastTarget>
99 : struct StoreOrthogonalization {
100 : private:
101 : using fields_tag = FieldsTag;
102 : using residual_magnitude_tag = LinearSolver::Tags::Magnitude<
103 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>>;
104 : using initial_residual_magnitude_tag =
105 : ::Tags::Initial<residual_magnitude_tag>;
106 : using previous_residual_magnitude_tag =
107 : ::Tags::Previous<residual_magnitude_tag>;
108 : using orthogonalization_history_tag =
109 : LinearSolver::Tags::OrthogonalizationHistory<fields_tag>;
110 : using ValueType =
111 : tt::get_complex_or_fundamental_type_t<typename fields_tag::type>;
112 :
113 : public:
114 : template <typename ParallelComponent, typename DbTagsList,
115 : typename Metavariables, typename ArrayIndex,
116 : typename DataBox = db::DataBox<DbTagsList>>
117 : static void apply(db::DataBox<DbTagsList>& box,
118 : Parallel::GlobalCache<Metavariables>& cache,
119 : const ArrayIndex& /*array_index*/,
120 : const size_t iteration_id,
121 : const size_t orthogonalization_iteration_id,
122 : const ValueType orthogonalization) {
123 : if (UNLIKELY(orthogonalization_iteration_id == 0)) {
124 : // Append a row and a column to the orthogonalization history. Zero the
125 : // entries that won't be set during the orthogonalization procedure below.
126 : db::mutate<orthogonalization_history_tag>(
127 : [iteration_id](const auto orthogonalization_history) {
128 : orthogonalization_history->resize(iteration_id + 1, iteration_id);
129 : for (size_t j = 0; j < orthogonalization_history->columns() - 1;
130 : ++j) {
131 : (*orthogonalization_history)(
132 : orthogonalization_history->rows() - 1, j) = 0.;
133 : }
134 : },
135 : make_not_null(&box));
136 : }
137 :
138 : // While the orthogonalization procedure is not complete, store the
139 : // orthogonalization, broadcast it back to all elements and return early
140 : if (orthogonalization_iteration_id < iteration_id) {
141 : db::mutate<orthogonalization_history_tag>(
142 : [orthogonalization, iteration_id, orthogonalization_iteration_id](
143 : const auto orthogonalization_history) {
144 : (*orthogonalization_history)(orthogonalization_iteration_id,
145 : iteration_id - 1) = orthogonalization;
146 : },
147 : make_not_null(&box));
148 :
149 : Parallel::receive_data<Tags::Orthogonalization<OptionsGroup, ValueType>>(
150 : Parallel::get_parallel_component<BroadcastTarget>(cache),
151 : iteration_id, orthogonalization);
152 : return;
153 : }
154 :
155 : // At this point, the orthogonalization procedure is complete.
156 : ASSERT(equal_within_roundoff(imag(orthogonalization), 0.0),
157 : "Normalization is not real: " << orthogonalization);
158 : const double normalization = sqrt(real(orthogonalization));
159 : db::mutate<orthogonalization_history_tag>(
160 : [normalization, iteration_id,
161 : orthogonalization_iteration_id](const auto orthogonalization_history) {
162 : (*orthogonalization_history)(orthogonalization_iteration_id,
163 : iteration_id - 1) = normalization;
164 : },
165 : make_not_null(&box));
166 :
167 : // Perform a QR decomposition of the Hessenberg matrix that was built during
168 : // the orthogonalization
169 : const auto& orthogonalization_history =
170 : get<orthogonalization_history_tag>(box);
171 : const auto num_rows = orthogonalization_iteration_id + 1;
172 : blaze::DynamicMatrix<ValueType> qr_Q;
173 : blaze::DynamicMatrix<ValueType> qr_R;
174 : blaze::qr(orthogonalization_history, qr_Q, qr_R);
175 : // Compute the residual vector from the QR decomposition
176 : blaze::DynamicVector<double> beta(num_rows, 0.);
177 : const double initial_residual_magnitude =
178 : get<initial_residual_magnitude_tag>(box);
179 : beta[0] = initial_residual_magnitude;
180 : blaze::DynamicVector<ValueType> minres =
181 : blaze::inv(qr_R) * blaze::ctrans(qr_Q) * beta;
182 : blaze::DynamicVector<ValueType> res =
183 : beta - orthogonalization_history * minres;
184 : const double residual_magnitude = sqrt(magnitude_square(res));
185 :
186 : // At this point, the iteration is complete. We proceed with observing,
187 : // logging and checking convergence before broadcasting back to the
188 : // elements.
189 :
190 : LinearSolver::observe_detail::contribute_to_reduction_observer<
191 : OptionsGroup, ParallelComponent>(iteration_id, residual_magnitude,
192 : cache);
193 :
194 : // Determine whether the linear solver has converged.
195 : // GMRES is guaranteed to decrease the residual monotonically, so an
196 : // increase in the residual is an error.
197 : const auto& convergence_criteria =
198 : get<Convergence::Tags::Criteria<OptionsGroup>>(box);
199 : const double previous_residual_magnitude =
200 : get<previous_residual_magnitude_tag>(box);
201 : auto has_converged =
202 : residual_magnitude < previous_residual_magnitude
203 : ? Convergence::HasConverged{convergence_criteria, iteration_id,
204 : residual_magnitude,
205 : initial_residual_magnitude}
206 : : Convergence::HasConverged{
207 : Convergence::Reason::Error,
208 : MakeString{} << std::scientific
209 : << "Residual should decrease monotonically, but "
210 : "increased from "
211 : << previous_residual_magnitude << " to "
212 : << residual_magnitude << ".",
213 : iteration_id};
214 :
215 : db::mutate<previous_residual_magnitude_tag>(
216 : [residual_magnitude](
217 : const gsl::not_null<double*> stored_previous_residual_magnitude) {
218 : *stored_previous_residual_magnitude = residual_magnitude;
219 : },
220 : make_not_null(&box));
221 :
222 : // Do some logging
223 : if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(cache) >=
224 : ::Verbosity::Quiet)) {
225 : Parallel::printf("%s(%zu) iteration complete. Remaining residual: %e\n",
226 : pretty_type::name<OptionsGroup>(), iteration_id,
227 : residual_magnitude);
228 : }
229 : if (UNLIKELY(has_converged and get<logging::Tags::Verbosity<OptionsGroup>>(
230 : cache) >= ::Verbosity::Quiet)) {
231 : if (has_converged.reason() == Convergence::Reason::Error) {
232 : Parallel::printf("%s has encountered an error in iteration %zu: %s\n",
233 : pretty_type::name<OptionsGroup>(), iteration_id,
234 : has_converged.error_message());
235 : } else {
236 : Parallel::printf("%s has converged in %zu iterations: %s\n",
237 : pretty_type::name<OptionsGroup>(), iteration_id,
238 : has_converged);
239 : }
240 : }
241 :
242 : Parallel::receive_data<
243 : Tags::FinalOrthogonalization<OptionsGroup, ValueType>>(
244 : Parallel::get_parallel_component<BroadcastTarget>(cache), iteration_id,
245 : std::make_tuple(normalization, std::move(minres),
246 : // NOLINTNEXTLINE(performance-move-const-arg)
247 : std::move(has_converged)));
248 : }
249 : };
250 :
251 : } // namespace LinearSolver::gmres::detail
|