Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #include <cstddef>
7 : #include <limits>
8 : #include <optional>
9 : #include <tuple>
10 : #include <utility>
11 :
12 : #include "DataStructures/DataBox/DataBox.hpp"
13 : #include "DataStructures/DataBox/PrefixHelpers.hpp"
14 : #include "NumericalAlgorithms/Convergence/Tags.hpp"
15 : #include "NumericalAlgorithms/LinearSolver/InnerProduct.hpp"
16 : #include "Parallel/AlgorithmExecution.hpp"
17 : #include "Parallel/GlobalCache.hpp"
18 : #include "Parallel/Invoke.hpp"
19 : #include "Parallel/Reduction.hpp"
20 : #include "ParallelAlgorithms/LinearSolver/ConjugateGradient/ResidualMonitorActions.hpp"
21 : #include "ParallelAlgorithms/LinearSolver/ConjugateGradient/Tags/InboxTags.hpp"
22 : #include "ParallelAlgorithms/LinearSolver/Tags.hpp"
23 : #include "Utilities/Functional.hpp"
24 : #include "Utilities/Gsl.hpp"
25 : #include "Utilities/Requires.hpp"
26 : #include "Utilities/TMPL.hpp"
27 :
28 : /// \cond
29 : namespace Convergence {
30 : struct HasConverged;
31 : } // namespace Convergence
32 : namespace tuples {
33 : template <typename...>
34 : class TaggedTuple;
35 : } // namespace tuples
36 : namespace LinearSolver::cg::detail {
37 : template <typename Metavariables, typename FieldsTag, typename OptionsGroup>
38 : struct ResidualMonitor;
39 : template <typename FieldsTag, typename OptionsGroup, typename Label>
40 : struct UpdateOperand;
41 : } // namespace LinearSolver::cg::detail
42 : /// \endcond
43 :
44 : namespace LinearSolver::cg::detail {
45 :
46 : template <typename FieldsTag, typename OptionsGroup, typename Label,
47 : typename SourceTag>
48 : struct PrepareSolve {
49 : private:
50 : using fields_tag = FieldsTag;
51 : using source_tag = SourceTag;
52 : using operator_applied_to_fields_tag =
53 : db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, fields_tag>;
54 : using operand_tag =
55 : db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
56 : using residual_tag =
57 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>;
58 :
59 : public:
60 : template <typename DbTagsList, typename... InboxTags, typename Metavariables,
61 : typename ArrayIndex, typename ActionList,
62 : typename ParallelComponent>
63 : static Parallel::iterable_action_return_t apply(
64 : db::DataBox<DbTagsList>& box,
65 : const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
66 : Parallel::GlobalCache<Metavariables>& cache,
67 : const ArrayIndex& array_index, const ActionList /*meta*/,
68 : const ParallelComponent* const /*meta*/) {
69 : db::mutate<Convergence::Tags::IterationId<OptionsGroup>, operand_tag,
70 : residual_tag>(
71 : [](const gsl::not_null<size_t*> iteration_id, const auto operand,
72 : const auto residual, const auto& source,
73 : const auto& operator_applied_to_fields) {
74 : *iteration_id = 0;
75 : *operand = source - operator_applied_to_fields;
76 : *residual = *operand;
77 : },
78 : make_not_null(&box), get<source_tag>(box),
79 : get<operator_applied_to_fields_tag>(box));
80 :
81 : // Perform global reduction to compute initial residual magnitude square for
82 : // residual monitor
83 : const auto& residual = get<residual_tag>(box);
84 : Parallel::contribute_to_reduction<
85 : InitializeResidual<FieldsTag, OptionsGroup, ParallelComponent>>(
86 : Parallel::ReductionData<
87 : Parallel::ReductionDatum<double, funcl::Plus<>>>{
88 : inner_product(residual, residual)},
89 : Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
90 : Parallel::get_parallel_component<
91 : ResidualMonitor<Metavariables, FieldsTag, OptionsGroup>>(cache));
92 :
93 : return {Parallel::AlgorithmExecution::Continue, std::nullopt};
94 : }
95 : };
96 :
97 : template <typename FieldsTag, typename OptionsGroup, typename Label>
98 : struct InitializeHasConverged {
99 : using inbox_tags = tmpl::list<Tags::InitialHasConverged<OptionsGroup>>;
100 :
101 : template <typename DbTagsList, typename... InboxTags, typename Metavariables,
102 : typename ArrayIndex, typename ActionList,
103 : typename ParallelComponent>
104 : static Parallel::iterable_action_return_t apply(
105 : db::DataBox<DbTagsList>& box, tuples::TaggedTuple<InboxTags...>& inboxes,
106 : const Parallel::GlobalCache<Metavariables>& /*cache*/,
107 : const ArrayIndex& /*array_index*/, const ActionList /*meta*/,
108 : const ParallelComponent* const /*meta*/) {
109 : auto& inbox = get<Tags::InitialHasConverged<OptionsGroup>>(inboxes);
110 : const auto& iteration_id =
111 : db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
112 : if (inbox.find(iteration_id) == inbox.end()) {
113 : return {Parallel::AlgorithmExecution::Retry, std::nullopt};
114 : }
115 :
116 : auto has_converged = std::move(inbox.extract(iteration_id).mapped());
117 :
118 : db::mutate<Convergence::Tags::HasConverged<OptionsGroup>>(
119 : [&has_converged](const gsl::not_null<Convergence::HasConverged*>
120 : local_has_converged) {
121 : *local_has_converged = std::move(has_converged);
122 : },
123 : make_not_null(&box));
124 :
125 : // Skip steps entirely if the solve has already converged
126 : constexpr size_t step_end_index =
127 : tmpl::index_of<ActionList,
128 : UpdateOperand<FieldsTag, OptionsGroup, Label>>::value;
129 : constexpr size_t this_action_index =
130 : tmpl::index_of<ActionList, InitializeHasConverged>::value;
131 : return {Parallel::AlgorithmExecution::Continue,
132 : has_converged ? (step_end_index + 1) : (this_action_index + 1)};
133 : }
134 : };
135 :
136 : template <typename FieldsTag, typename OptionsGroup, typename Label>
137 : struct PerformStep {
138 : template <typename DbTagsList, typename... InboxTags, typename Metavariables,
139 : typename ArrayIndex, typename ActionList,
140 : typename ParallelComponent>
141 : static Parallel::iterable_action_return_t apply(
142 : db::DataBox<DbTagsList>& box,
143 : const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
144 : Parallel::GlobalCache<Metavariables>& cache,
145 : const ArrayIndex& array_index, const ActionList /*meta*/,
146 : const ParallelComponent* const /*meta*/) {
147 : using fields_tag = FieldsTag;
148 : using operand_tag =
149 : db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
150 : using operator_tag =
151 : db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, operand_tag>;
152 :
153 : // At this point Ap must have been computed in a previous action
154 : // We compute the inner product <p,p> w.r.t A. This requires a global
155 : // reduction.
156 : const double local_conj_grad_inner_product =
157 : inner_product(get<operand_tag>(box), get<operator_tag>(box));
158 :
159 : Parallel::contribute_to_reduction<
160 : ComputeAlpha<FieldsTag, OptionsGroup, ParallelComponent>>(
161 : Parallel::ReductionData<
162 : Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
163 : Parallel::ReductionDatum<double, funcl::Plus<>>>{
164 : get<Convergence::Tags::IterationId<OptionsGroup>>(box),
165 : local_conj_grad_inner_product},
166 : Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
167 : Parallel::get_parallel_component<
168 : ResidualMonitor<Metavariables, FieldsTag, OptionsGroup>>(cache));
169 :
170 : return {Parallel::AlgorithmExecution::Continue, std::nullopt};
171 : }
172 : };
173 :
174 : template <typename FieldsTag, typename OptionsGroup, typename Label>
175 : struct UpdateFieldValues {
176 : private:
177 : using fields_tag = FieldsTag;
178 : using operand_tag =
179 : db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
180 : using operator_tag =
181 : db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, operand_tag>;
182 : using residual_tag =
183 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>;
184 :
185 : public:
186 : using inbox_tags = tmpl::list<Tags::Alpha<OptionsGroup>>;
187 :
188 : template <typename DbTagsList, typename... InboxTags, typename Metavariables,
189 : typename ArrayIndex, typename ActionList,
190 : typename ParallelComponent>
191 : static Parallel::iterable_action_return_t apply(
192 : db::DataBox<DbTagsList>& box, tuples::TaggedTuple<InboxTags...>& inboxes,
193 : Parallel::GlobalCache<Metavariables>& cache,
194 : const ArrayIndex& array_index, const ActionList /*meta*/,
195 : const ParallelComponent* const /*meta*/) {
196 : auto& inbox = get<Tags::Alpha<OptionsGroup>>(inboxes);
197 : const auto& iteration_id =
198 : db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
199 : if (inbox.find(iteration_id) == inbox.end()) {
200 : return {Parallel::AlgorithmExecution::Retry, std::nullopt};
201 : }
202 :
203 : const double alpha = std::move(inbox.extract(iteration_id).mapped());
204 :
205 : db::mutate<residual_tag, fields_tag>(
206 : [alpha](const auto residual, const auto fields, const auto& operand,
207 : const auto& operator_applied_to_operand) {
208 : *fields += alpha * operand;
209 : *residual -= alpha * operator_applied_to_operand;
210 : },
211 : make_not_null(&box), get<operand_tag>(box), get<operator_tag>(box));
212 :
213 : // Compute new residual norm in a second global reduction
214 : const auto& residual = get<residual_tag>(box);
215 : const double local_residual_magnitude_square =
216 : inner_product(residual, residual);
217 :
218 : Parallel::contribute_to_reduction<
219 : UpdateResidual<FieldsTag, OptionsGroup, ParallelComponent>>(
220 : Parallel::ReductionData<
221 : Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
222 : Parallel::ReductionDatum<double, funcl::Plus<>>>{
223 : get<Convergence::Tags::IterationId<OptionsGroup>>(box),
224 : local_residual_magnitude_square},
225 : Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
226 : Parallel::get_parallel_component<
227 : ResidualMonitor<Metavariables, FieldsTag, OptionsGroup>>(cache));
228 :
229 : return {Parallel::AlgorithmExecution::Continue, std::nullopt};
230 : }
231 : };
232 :
233 : template <typename FieldsTag, typename OptionsGroup, typename Label>
234 : struct UpdateOperand {
235 : private:
236 : using fields_tag = FieldsTag;
237 : using operand_tag =
238 : db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
239 : using residual_tag =
240 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>;
241 :
242 : public:
243 : using inbox_tags =
244 : tmpl::list<Tags::ResidualRatioAndHasConverged<OptionsGroup>>;
245 :
246 : template <typename DbTagsList, typename... InboxTags, typename Metavariables,
247 : typename ArrayIndex, typename ActionList,
248 : typename ParallelComponent>
249 : static Parallel::iterable_action_return_t apply(
250 : db::DataBox<DbTagsList>& box, tuples::TaggedTuple<InboxTags...>& inboxes,
251 : const Parallel::GlobalCache<Metavariables>& /*cache*/,
252 : const ArrayIndex& /*array_index*/, const ActionList /*meta*/,
253 : const ParallelComponent* const /*meta*/) {
254 : auto& inbox =
255 : get<Tags::ResidualRatioAndHasConverged<OptionsGroup>>(inboxes);
256 : const auto& iteration_id =
257 : db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
258 : if (inbox.find(iteration_id) == inbox.end()) {
259 : return {Parallel::AlgorithmExecution::Retry, std::nullopt};
260 : }
261 :
262 : auto received_data = std::move(inbox.extract(iteration_id).mapped());
263 : const double res_ratio = get<0>(received_data);
264 : auto& has_converged = get<1>(received_data);
265 :
266 : db::mutate<operand_tag, Convergence::Tags::IterationId<OptionsGroup>,
267 : Convergence::Tags::HasConverged<OptionsGroup>>(
268 : [res_ratio, &has_converged](
269 : const auto operand, const gsl::not_null<size_t*> local_iteration_id,
270 : const gsl::not_null<Convergence::HasConverged*> local_has_converged,
271 : const auto& residual) {
272 : *operand = residual + res_ratio * *operand;
273 : ++(*local_iteration_id);
274 : *local_has_converged = std::move(has_converged);
275 : },
276 : make_not_null(&box), get<residual_tag>(box));
277 :
278 : // Repeat steps until the solve has converged
279 : constexpr size_t this_action_index =
280 : tmpl::index_of<ActionList, UpdateOperand>::value;
281 : constexpr size_t prepare_step_index =
282 : tmpl::index_of<ActionList, InitializeHasConverged<
283 : FieldsTag, OptionsGroup, Label>>::value +
284 : 1;
285 : return {Parallel::AlgorithmExecution::Continue,
286 : get<Convergence::Tags::HasConverged<OptionsGroup>>(box)
287 : ? (this_action_index + 1)
288 : : prepare_step_index};
289 : }
290 : };
291 :
292 : } // namespace LinearSolver::cg::detail
|