Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #include <cstddef>
7 : #include <tuple>
8 : #include <utility>
9 :
10 : #include "DataStructures/DataBox/DataBox.hpp"
11 : #include "DataStructures/DataBox/PrefixHelpers.hpp"
12 : #include "IO/Logging/Tags.hpp"
13 : #include "IO/Logging/Verbosity.hpp"
14 : #include "NumericalAlgorithms/Convergence/HasConverged.hpp"
15 : #include "NumericalAlgorithms/Convergence/Tags.hpp"
16 : #include "Parallel/GlobalCache.hpp"
17 : #include "Parallel/Invoke.hpp"
18 : #include "Parallel/Printf/Printf.hpp"
19 : #include "ParallelAlgorithms/LinearSolver/ConjugateGradient/Tags/InboxTags.hpp"
20 : #include "ParallelAlgorithms/LinearSolver/Observe.hpp"
21 : #include "ParallelAlgorithms/LinearSolver/Tags.hpp"
22 : #include "Utilities/EqualWithinRoundoff.hpp"
23 : #include "Utilities/Functional.hpp"
24 : #include "Utilities/Gsl.hpp"
25 : #include "Utilities/PrettyType.hpp"
26 : #include "Utilities/Requires.hpp"
27 :
28 : /// \cond
29 : namespace tuples {
30 : template <typename...>
31 : class TaggedTuple;
32 : } // namespace tuples
33 : /// \endcond
34 :
35 : namespace LinearSolver::cg::detail {
36 :
37 : template <typename FieldsTag, typename OptionsGroup, typename BroadcastTarget>
38 : struct InitializeResidual {
39 : private:
40 : using fields_tag = FieldsTag;
41 : using residual_square_tag = LinearSolver::Tags::MagnitudeSquare<
42 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>>;
43 : using initial_residual_magnitude_tag =
44 : ::Tags::Initial<LinearSolver::Tags::Magnitude<
45 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>>>;
46 :
47 : public:
48 : template <typename ParallelComponent, typename DbTagsList,
49 : typename Metavariables, typename ArrayIndex,
50 : typename DataBox = db::DataBox<DbTagsList>>
51 : static void apply(db::DataBox<DbTagsList>& box,
52 : Parallel::GlobalCache<Metavariables>& cache,
53 : const ArrayIndex& /*array_index*/,
54 : const double residual_square) {
55 : constexpr size_t iteration_id = 0;
56 : const double residual_magnitude = sqrt(residual_square);
57 :
58 : db::mutate<residual_square_tag, initial_residual_magnitude_tag>(
59 : [residual_square, residual_magnitude](
60 : const gsl::not_null<double*> local_residual_square,
61 : const gsl::not_null<double*> initial_residual_magnitude) {
62 : *local_residual_square = residual_square;
63 : *initial_residual_magnitude = residual_magnitude;
64 : },
65 : make_not_null(&box));
66 :
67 : LinearSolver::observe_detail::contribute_to_reduction_observer<
68 : OptionsGroup, ParallelComponent>(iteration_id, residual_magnitude,
69 : cache);
70 :
71 : // Determine whether the linear solver has converged
72 : Convergence::HasConverged has_converged{
73 : get<Convergence::Tags::Criteria<OptionsGroup>>(box), iteration_id,
74 : residual_magnitude, residual_magnitude};
75 :
76 : // Do some logging
77 : if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(cache) >=
78 : ::Verbosity::Quiet)) {
79 : Parallel::printf("%s initialized with residual: %e\n",
80 : pretty_type::name<OptionsGroup>(), residual_magnitude);
81 : }
82 : if (UNLIKELY(has_converged and get<logging::Tags::Verbosity<OptionsGroup>>(
83 : cache) >= ::Verbosity::Quiet)) {
84 : Parallel::printf("%s has converged without any iterations: %s\n",
85 : pretty_type::name<OptionsGroup>(), has_converged);
86 : }
87 :
88 : Parallel::receive_data<Tags::InitialHasConverged<OptionsGroup>>(
89 : Parallel::get_parallel_component<BroadcastTarget>(cache), iteration_id,
90 : // NOLINTNEXTLINE(performance-move-const-arg)
91 : std::move(has_converged));
92 : }
93 : };
94 :
95 : template <typename FieldsTag, typename OptionsGroup, typename BroadcastTarget>
96 : struct ComputeAlpha {
97 : private:
98 : using fields_tag = FieldsTag;
99 : using residual_square_tag = LinearSolver::Tags::MagnitudeSquare<
100 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>>;
101 :
102 : public:
103 : template <typename ParallelComponent, typename DbTagsList,
104 : typename Metavariables, typename ArrayIndex,
105 : typename DataBox = db::DataBox<DbTagsList>>
106 : static void apply(db::DataBox<DbTagsList>& box,
107 : Parallel::GlobalCache<Metavariables>& cache,
108 : const ArrayIndex& /*array_index*/,
109 : const size_t iteration_id,
110 : const double conj_grad_inner_product) {
111 : Parallel::receive_data<Tags::Alpha<OptionsGroup>>(
112 : Parallel::get_parallel_component<BroadcastTarget>(cache), iteration_id,
113 : get<residual_square_tag>(box) / conj_grad_inner_product);
114 : }
115 : };
116 :
117 : template <typename FieldsTag, typename OptionsGroup, typename BroadcastTarget>
118 : struct UpdateResidual {
119 : private:
120 : using fields_tag = FieldsTag;
121 : using residual_square_tag = LinearSolver::Tags::MagnitudeSquare<
122 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>>;
123 : using initial_residual_magnitude_tag =
124 : ::Tags::Initial<LinearSolver::Tags::Magnitude<
125 : db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>>>;
126 :
127 : public:
128 : template <typename ParallelComponent, typename DbTagsList,
129 : typename Metavariables, typename ArrayIndex,
130 : typename DataBox = db::DataBox<DbTagsList>>
131 : static void apply(db::DataBox<DbTagsList>& box,
132 : Parallel::GlobalCache<Metavariables>& cache,
133 : const ArrayIndex& /*array_index*/,
134 : const size_t iteration_id, const double residual_square) {
135 : // Compute the residual ratio before mutating the DataBox
136 : const double res_ratio = residual_square / get<residual_square_tag>(box);
137 :
138 : db::mutate<residual_square_tag>(
139 : [residual_square](const gsl::not_null<double*> local_residual_square) {
140 : *local_residual_square = residual_square;
141 : },
142 : make_not_null(&box));
143 :
144 : // At this point, the iteration is complete. We proceed with observing,
145 : // logging and checking convergence before broadcasting back to the
146 : // elements.
147 :
148 : const size_t completed_iterations = iteration_id + 1;
149 : const double residual_magnitude = sqrt(residual_square);
150 : LinearSolver::observe_detail::contribute_to_reduction_observer<
151 : OptionsGroup, ParallelComponent>(completed_iterations,
152 : residual_magnitude, cache);
153 :
154 : // Determine whether the linear solver has converged
155 : Convergence::HasConverged has_converged{
156 : get<Convergence::Tags::Criteria<OptionsGroup>>(box),
157 : completed_iterations, residual_magnitude,
158 : get<initial_residual_magnitude_tag>(box)};
159 :
160 : // Do some logging
161 : if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(cache) >=
162 : ::Verbosity::Quiet)) {
163 : Parallel::printf("%s(%zu) iteration complete. Remaining residual: %e\n",
164 : pretty_type::name<OptionsGroup>(), completed_iterations,
165 : residual_magnitude);
166 : }
167 : if (UNLIKELY(has_converged and get<logging::Tags::Verbosity<OptionsGroup>>(
168 : cache) >= ::Verbosity::Quiet)) {
169 : Parallel::printf("%s has converged in %zu iterations: %s\n",
170 : pretty_type::name<OptionsGroup>(), completed_iterations,
171 : has_converged);
172 : }
173 :
174 : Parallel::receive_data<Tags::ResidualRatioAndHasConverged<OptionsGroup>>(
175 : Parallel::get_parallel_component<BroadcastTarget>(cache), iteration_id,
176 : // NOLINTNEXTLINE(performance-move-const-arg)
177 : std::make_tuple(res_ratio, std::move(has_converged)));
178 : }
179 : };
180 :
181 : } // namespace LinearSolver::cg::detail
|