ResidualMonitorActions.hpp
1 // Distributed under the MIT License.
2 // See LICENSE.txt for details.
3 
4 #pragma once
5 
8 #include "Informer/Tags.hpp"
9 #include "Informer/Verbosity.hpp"
11 #include "NumericalAlgorithms/LinearSolver/Observe.hpp"
14 #include "Parallel/Info.hpp"
15 #include "Parallel/Invoke.hpp"
16 #include "Parallel/Printf.hpp"
17 #include "Utilities/EqualWithinRoundoff.hpp"
18 #include "Utilities/Functional.hpp"
19 #include "Utilities/Requires.hpp"
20 
21 /// \cond
22 namespace tuples {
23 template <typename...>
24 class TaggedTuple;
25 } // namespace tuples
26 namespace LinearSolver {
27 namespace cg_detail {
28 struct InitializeHasConverged;
29 struct UpdateFieldValues;
30 struct UpdateOperand;
31 } // namespace cg_detail
32 } // namespace LinearSolver
33 /// \endcond
34 
35 namespace LinearSolver {
36 namespace cg_detail {
37 
38 template <typename BroadcastTarget>
39 struct InitializeResidual {
40  template <typename... DbTags, typename... InboxTags, typename Metavariables,
41  typename ArrayIndex, typename ActionList,
42  typename ParallelComponent,
43  Requires<sizeof...(DbTags) != 0> = nullptr>
44  static auto apply(db::DataBox<tmpl::list<DbTags...>>& box,
47  const ArrayIndex& /*array_index*/,
48  const ActionList /*meta*/,
49  const ParallelComponent* const /*meta*/,
50  const double residual_square) noexcept {
51  using fields_tag = typename Metavariables::system::fields_tag;
52  using residual_square_tag = db::add_tag_prefix<
55  using residual_magnitude_tag = db::add_tag_prefix<
57  db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>>;
58  using initial_residual_magnitude_tag =
60 
61  db::mutate<residual_square_tag>(
62  make_not_null(&box), [residual_square](
64  local_residual_square) noexcept {
65  *local_residual_square = residual_square;
66  });
67  // Perform a separate `db::mutate` so that we can retrieve the
68  // `residual_magnitude_tag` from the compute item
69  db::mutate<initial_residual_magnitude_tag>(
70  make_not_null(&box),
71  [](const gsl::not_null<double*> local_initial_residual_magnitude,
72  const double& initial_residual_magnitude) noexcept {
73  *local_initial_residual_magnitude = initial_residual_magnitude;
74  },
75  get<residual_magnitude_tag>(box));
76 
77  LinearSolver::observe_detail::contribute_to_reduction_observer(box, cache);
78 
79  // Determine whether the linear solver has converged. This invokes the
80  // compute item.
81  const auto& has_converged = db::get<LinearSolver::Tags::HasConverged>(box);
82 
83  if (UNLIKELY(has_converged and
84  static_cast<int>(get<::Tags::Verbosity>(box)) >=
85  static_cast<int>(::Verbosity::Quiet))) {
86  Parallel::printf("%s", has_converged);
87  }
88 
89  Parallel::simple_action<InitializeHasConverged>(
90  Parallel::get_parallel_component<BroadcastTarget>(cache),
91  has_converged);
92  }
93 };
94 
95 template <typename BroadcastTarget>
96 struct ComputeAlpha {
97  template <typename... DbTags, typename... InboxTags, typename Metavariables,
98  typename ArrayIndex, typename ActionList,
99  typename ParallelComponent,
100  Requires<sizeof...(DbTags) != 0> = nullptr>
101  static auto apply(db::DataBox<tmpl::list<DbTags...>>& box,
104  const ArrayIndex& /*array_index*/,
105  const ActionList /*meta*/,
106  const ParallelComponent* const /*meta*/,
107  const double conj_grad_inner_product) noexcept {
108  using fields_tag = typename Metavariables::system::fields_tag;
109  using residual_square_tag = db::add_tag_prefix<
112 
113  Parallel::simple_action<UpdateFieldValues>(
114  Parallel::get_parallel_component<BroadcastTarget>(cache),
115  get<residual_square_tag>(box) / conj_grad_inner_product);
116  }
117 };
118 
119 template <typename BroadcastTarget>
120 struct UpdateResidual {
121  template <typename... DbTags, typename... InboxTags, typename Metavariables,
122  typename ArrayIndex, typename ActionList,
123  typename ParallelComponent,
124  Requires<sizeof...(DbTags) != 0> = nullptr>
125  static auto apply(db::DataBox<tmpl::list<DbTags...>>& box,
128  const ArrayIndex& /*array_index*/,
129  const ActionList /*meta*/,
130  const ParallelComponent* const /*meta*/,
131  const double residual_square) noexcept {
132  using fields_tag = typename Metavariables::system::fields_tag;
133  using residual_square_tag = db::add_tag_prefix<
136  using residual_magnitude_tag = db::add_tag_prefix<
138  db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>>;
139 
140  // Compute the residual ratio before mutating the DataBox
141  const double res_ratio = residual_square / get<residual_square_tag>(box);
142 
143  db::mutate<residual_square_tag, LinearSolver::Tags::IterationId>(
144  make_not_null(&box), [residual_square](const gsl::not_null<double*>
145  local_residual_square,
147  iteration_id) noexcept {
148  *local_residual_square = residual_square;
149  // Prepare for the next iteration
150  iteration_id->step_number++;
151  });
152 
153  // At this point, the iteration is complete. We proceed with observing,
154  // logging and checking convergence before broadcasting back to the
155  // elements.
156 
157  LinearSolver::observe_detail::contribute_to_reduction_observer(box, cache);
158 
159  // Determine whether the linear solver has converged. This invokes the
160  // compute item.
161  const auto& has_converged = get<LinearSolver::Tags::HasConverged>(box);
162 
163  // Do some logging
164  if (UNLIKELY(static_cast<int>(get<::Tags::Verbosity>(box)) >=
165  static_cast<int>(::Verbosity::Verbose))) {
167  "Linear solver iteration %d done. Remaining residual: %e\n",
168  get<LinearSolver::Tags::IterationId>(box).step_number,
169  get<residual_magnitude_tag>(box));
170  }
171  if (UNLIKELY(has_converged and
172  static_cast<int>(get<::Tags::Verbosity>(box)) >=
173  static_cast<int>(::Verbosity::Quiet))) {
174  Parallel::printf("%s", has_converged);
175  }
176 
177  Parallel::simple_action<UpdateOperand>(
178  Parallel::get_parallel_component<BroadcastTarget>(cache), res_ratio,
179  has_converged);
180  }
181 };
182 
183 } // namespace cg_detail
184 } // namespace LinearSolver
The magnitude w.r.t. the LinearSolver::inner_product
Definition: Tags.hpp:115
Prefix< DataBox_detail::dispatch_add_tag_prefix_impl< Prefix, Tag, Args... >, Args... > add_tag_prefix
Wrap Tag in Prefix<_, Args...>, also wrapping variables tags if Tag is a Tags::Variables.
Definition: DataBoxTag.hpp:533
Definition: Variables.hpp:46
Defines functions for interfacing with the parallelization framework.
Defines DataBox tags for the linear solver.
Functionality for solving linear systems of equations.
Definition: TerminateIfConverged.hpp:22
#define UNLIKELY(x)
Definition: Gsl.hpp:72
Definition: TaggedTuple.hpp:25
The magnitude square w.r.t. the LinearSolver::inner_product
Definition: Tags.hpp:101
Define prefixes for DataBox tags.
Defines the type alias Requires.
constexpr auto apply(F &&f, const DataBox< BoxTags > &box, Args &&... args)
Apply the function f with argument Tags TagsList from DataBox box
Definition: DataBox.hpp:1595
An associative container that is indexed by structs.
Definition: TaggedTuple.hpp:272
Defines classes and functions used for manipulating DataBox&#39;s.
Defines Parallel::printf for writing to stdout.
Definition: InterpolationTargetWedgeSectionTorus.hpp:24
A Charm++ chare that caches constant data once per Charm++ node.
Definition: ConstGlobalCache.hpp:76
gsl::not_null< T * > make_not_null(T *ptr) noexcept
Construct a not_null from a pointer. Often this will be done as an implicit conversion, but it may be necessary to perform the conversion explicitly when type deduction is desired.
Definition: Gsl.hpp:863
typename Requires_detail::requires_impl< B >::template_error_type_failed_to_meet_requirements_on_template_parameters Requires
Express requirements on the template parameters of a function or class, replaces std::enable_if_t ...
Definition: Requires.hpp:67
Definition: SolvePoissonProblem.hpp:38
Defines class template ConstGlobalCache.
void printf(const std::string &format, Args &&... args)
Print an atomic message to stdout with C printf usage.
Definition: Printf.hpp:100
Require a pointer to not be a nullptr
Definition: ConservativeFromPrimitive.hpp:12
Defines class IterationId.