ElementActions.hpp
1 // Distributed under the MIT License.
2 // See LICENSE.txt for details.
3 
4 #pragma once
5 
6 #include <tuple>
7 
11 #include "Parallel/Info.hpp"
12 #include "Parallel/Invoke.hpp"
13 #include "Parallel/Reduction.hpp"
14 #include "ParallelAlgorithms/LinearSolver/ConjugateGradient/ResidualMonitorActions.hpp"
17 #include "Utilities/Functional.hpp"
18 #include "Utilities/Gsl.hpp"
19 #include "Utilities/Requires.hpp"
20 
21 /// \cond
22 namespace tuples {
23 template <typename...>
24 class TaggedTuple;
25 } // namespace tuples
26 namespace LinearSolver {
27 namespace cg_detail {
28 template <typename Metavariables, typename FieldsTag>
29 struct ResidualMonitor;
30 } // namespace cg_detail
31 } // namespace LinearSolver
32 /// \endcond
33 
34 namespace LinearSolver {
35 namespace cg_detail {
36 
37 struct PrepareStep {
38  template <typename DbTagsList, typename... InboxTags, typename Metavariables,
39  typename ArrayIndex, typename ActionList,
40  typename ParallelComponent>
43  const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
45  const ArrayIndex& /*array_index*/, const ActionList /*meta*/,
46  const ParallelComponent* const /*meta*/) noexcept {
47  db::mutate<LinearSolver::Tags::IterationId>(
48  make_not_null(&box),
50  iteration_id,
52  LinearSolver::Tags::IterationId>>& next_iteration_id) noexcept {
53  *iteration_id = next_iteration_id;
54  },
55  get<::Tags::Next<LinearSolver::Tags::IterationId>>(box));
56  return {std::move(box)};
57  }
58 };
59 
60 template <typename FieldsTag>
61 struct PerformStep {
62  template <typename DbTagsList, typename... InboxTags, typename Metavariables,
63  typename ArrayIndex, typename ActionList,
64  typename ParallelComponent>
67  const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
69  const ArrayIndex& array_index,
70  // NOLINTNEXTLINE(readability-avoid-const-params-in-decls)
71  const ActionList /*meta*/,
72  // NOLINTNEXTLINE(readability-avoid-const-params-in-decls)
73  const ParallelComponent* const /*meta*/) noexcept {
74  using fields_tag = FieldsTag;
75  using operand_tag =
77  using operator_tag =
79 
80  ASSERT(get<LinearSolver::Tags::IterationId>(box) !=
82  "Linear solve iteration ID is at initial state. Did you forget to "
83  "invoke 'PrepareStep'?");
84 
85  // At this point Ap must have been computed in a previous action
86  // We compute the inner product <p,p> w.r.t A. This requires a global
87  // reduction.
88  const double local_conj_grad_inner_product =
89  inner_product(get<operand_tag>(box), get<operator_tag>(box));
90 
92  ComputeAlpha<FieldsTag, ParallelComponent>>(
93  Parallel::ReductionData<
95  local_conj_grad_inner_product},
96  Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
98  ResidualMonitor<Metavariables, FieldsTag>>(cache));
99 
100  // Terminate algorithm for now. The `ResidualMonitor` will receive the
101  // reduction that is performed above and then broadcast to the following
102  // action, which is responsible for restarting the algorithm.
103  return {std::move(box), true};
104  }
105 };
106 
107 template <typename FieldsTag>
108 struct UpdateFieldValues {
109  private:
110  using fields_tag = FieldsTag;
111  using operand_tag =
113  using operator_tag =
115  using residual_tag =
117 
118  public:
119  template <typename ParallelComponent, typename DbTagsList,
120  typename Metavariables, typename ArrayIndex,
121  typename DataBox = db::DataBox<DbTagsList>,
123  db::tag_is_retrievable_v<fields_tag, DataBox> and
124  db::tag_is_retrievable_v<operand_tag, DataBox> and
125  db::tag_is_retrievable_v<operator_tag, DataBox>> = nullptr>
126  static auto apply(db::DataBox<DbTagsList>& box,
128  const ArrayIndex& array_index,
129  const double alpha) noexcept {
130  // Received global reduction result, proceed with conjugate gradient.
131  db::mutate<residual_tag, fields_tag>(
132  make_not_null(&box),
136  const db::const_item_type<operator_tag>& Ap) noexcept {
137  *x += alpha * p;
138  *r -= alpha * Ap;
139  },
140  get<operand_tag>(box), get<operator_tag>(box));
141 
142  // Compute new residual norm in a second global reduction
143  const auto& r = get<residual_tag>(box);
144  const double local_residual_magnitude_square = inner_product(r, r);
145 
147  UpdateResidual<FieldsTag, ParallelComponent>>(
148  Parallel::ReductionData<
150  local_residual_magnitude_square},
151  Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
153  ResidualMonitor<Metavariables, FieldsTag>>(cache));
154  }
155 };
156 
157 template <typename FieldsTag>
158 struct UpdateOperand {
159  private:
160  using fields_tag = FieldsTag;
161  using operand_tag =
163  using residual_tag =
165 
166  public:
167  template <typename ParallelComponent, typename DbTagsList,
168  typename Metavariables, typename ArrayIndex,
169  typename DataBox = db::DataBox<DbTagsList>,
172  DataBox> and
173  db::tag_is_retrievable_v<residual_tag, DataBox>> = nullptr>
174  static auto apply(db::DataBox<DbTagsList>& box,
176  const ArrayIndex& array_index, const double res_ratio,
178  has_converged) noexcept {
179  db::mutate<operand_tag, LinearSolver::Tags::HasConverged>(
180  make_not_null(&box),
181  [
182  res_ratio, &has_converged
185  local_has_converged,
186  const db::const_item_type<residual_tag>& r) noexcept {
187  *p = r + res_ratio * *p;
188  *local_has_converged = has_converged;
189  },
190  get<residual_tag>(box));
191 
192  // Proceed with algorithm
193  Parallel::get_parallel_component<ParallelComponent>(cache)[array_index]
194  .perform_algorithm(true);
195  }
196 };
197 
198 } // namespace cg_detail
199 } // namespace LinearSolver
Prefix< DataBox_detail::dispatch_add_tag_prefix_impl< Prefix, Tag, Args... >, Args... > add_tag_prefix
Wrap Tag in Prefix<_, Args...>, also wrapping variables tags if Tag is a Tags::Variables.
Definition: DataBoxTag.hpp:630
Defines functions for interfacing with the parallelization framework.
Functionality for solving linear systems of equations.
Definition: TerminateIfConverged.hpp:22
Definition: TaggedTuple.hpp:27
Define prefixes for DataBox tags.
Holds a Convergence::HasConverged flag that signals the linear solver has converged, along with the reason for convergence.
Definition: Tags.hpp:206
void contribute_to_reduction(ReductionData< Ts... > reduction_data, const SenderProxy &sender_component, const TargetProxy &target_component) noexcept
Perform a reduction from the sender_component (typically your own parallel component) to the target_c...
Definition: Reduction.hpp:243
#define ASSERT(a, m)
Assert that an expression should be true.
Definition: Assert.hpp:51
Defines the type alias Requires.
Holds an IterationId that identifies a step in the linear solver algorithm.
Definition: Tags.hpp:76
An associative container that is indexed by structs.
Definition: TaggedTuple.hpp:273
Defines classes and functions used for manipulating DataBox&#39;s.
constexpr bool tag_is_retrievable_v
Equal to true if Tag can be retrieved from a DataBox of type DataBoxType.
Definition: DataBox.hpp:72
Definition: InterpolationTargetWedgeSectionTorus.hpp:24
double inner_product(const Lhs &lhs, const Rhs &rhs) noexcept
The local part of the Euclidean inner product on the vector space w.r.t. which the addition and scala...
Definition: InnerProduct.hpp:71
A Charm++ chare that caches constant data once per Charm++ node.
Definition: ConstGlobalCache.hpp:135
constexpr auto apply(F &&f, const DataBox< BoxTags > &box, Args &&... args) noexcept
Apply the invokable f with argument Tags TagsList from DataBox box
Definition: DataBox.hpp:1628
Defines an inner product for the linear solver.
auto get_parallel_component(ConstGlobalCache< Metavariables > &cache) noexcept -> Parallel::proxy_from_parallel_component< ConstGlobalCache_detail::get_component_if_mocked< typename Metavariables::component_list, ParallelComponentTag >> &
Access the Charm++ proxy associated with a ParallelComponent.
Definition: ConstGlobalCache.hpp:222
Prefix indicating the value a quantity will take on the next iteration of the algorithm.
Definition: Prefixes.hpp:150
Defines functions and classes from the GSL.
gsl::not_null< T * > make_not_null(T *ptr) noexcept
Construct a not_null from a pointer. Often this will be done as an implicit conversion, but it may be necessary to perform the conversion explicitly when type deduction is desired.
Definition: Gsl.hpp:879
typename Requires_detail::requires_impl< B >::template_error_type_failed_to_meet_requirements_on_template_parameters Requires
Express requirements on the template parameters of a function or class, replaces std::enable_if_t ...
Definition: Requires.hpp:67
Defines class template ConstGlobalCache.
Defines DataBox tags for the linear solver.
The data to be reduced, and invokables to be called whenever two reduction messages are combined and ...
Definition: Reduction.hpp:64
Require a pointer to not be a nullptr
Definition: Gsl.hpp:182