SpECTRE Documentation Coverage Report
Current view: top level - ParallelAlgorithms/LinearSolver/ConjugateGradient - ElementActions.hpp Hit Total Coverage
Commit: aabde07399ba7837e5db64eedfd0a21f31f96922 Lines: 0 1 0.0 %
Date: 2024-04-26 02:38:13
Legend: Lines: hit not hit

          Line data    Source code
       1           0 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : #pragma once
       5             : 
       6             : #include <cstddef>
       7             : #include <limits>
       8             : #include <optional>
       9             : #include <tuple>
      10             : #include <utility>
      11             : 
      12             : #include "DataStructures/DataBox/DataBox.hpp"
      13             : #include "DataStructures/DataBox/PrefixHelpers.hpp"
      14             : #include "NumericalAlgorithms/Convergence/Tags.hpp"
      15             : #include "NumericalAlgorithms/LinearSolver/InnerProduct.hpp"
      16             : #include "Parallel/AlgorithmExecution.hpp"
      17             : #include "Parallel/GlobalCache.hpp"
      18             : #include "Parallel/Invoke.hpp"
      19             : #include "Parallel/Reduction.hpp"
      20             : #include "ParallelAlgorithms/LinearSolver/ConjugateGradient/ResidualMonitorActions.hpp"
      21             : #include "ParallelAlgorithms/LinearSolver/ConjugateGradient/Tags/InboxTags.hpp"
      22             : #include "ParallelAlgorithms/LinearSolver/Tags.hpp"
      23             : #include "Utilities/Functional.hpp"
      24             : #include "Utilities/Gsl.hpp"
      25             : #include "Utilities/Requires.hpp"
      26             : #include "Utilities/TMPL.hpp"
      27             : 
      28             : /// \cond
      29             : namespace Convergence {
      30             : struct HasConverged;
      31             : }  // namespace Convergence
      32             : namespace tuples {
      33             : template <typename...>
      34             : class TaggedTuple;
      35             : }  // namespace tuples
      36             : namespace LinearSolver::cg::detail {
      37             : template <typename Metavariables, typename FieldsTag, typename OptionsGroup>
      38             : struct ResidualMonitor;
      39             : template <typename FieldsTag, typename OptionsGroup, typename Label>
      40             : struct UpdateOperand;
      41             : }  // namespace LinearSolver::cg::detail
      42             : /// \endcond
      43             : 
      44             : namespace LinearSolver::cg::detail {
      45             : 
      46             : template <typename FieldsTag, typename OptionsGroup, typename Label,
      47             :           typename SourceTag>
      48             : struct PrepareSolve {
      49             :  private:
      50             :   using fields_tag = FieldsTag;
      51             :   using source_tag = SourceTag;
      52             :   using operator_applied_to_fields_tag =
      53             :       db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, fields_tag>;
      54             :   using operand_tag =
      55             :       db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
      56             :   using residual_tag =
      57             :       db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>;
      58             : 
      59             :  public:
      60             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
      61             :             typename ArrayIndex, typename ActionList,
      62             :             typename ParallelComponent>
      63             :   static Parallel::iterable_action_return_t apply(
      64             :       db::DataBox<DbTagsList>& box,
      65             :       const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
      66             :       Parallel::GlobalCache<Metavariables>& cache,
      67             :       const ArrayIndex& array_index, const ActionList /*meta*/,
      68             :       const ParallelComponent* const /*meta*/) {
      69             :     db::mutate<Convergence::Tags::IterationId<OptionsGroup>, operand_tag,
      70             :                residual_tag>(
      71             :         [](const gsl::not_null<size_t*> iteration_id, const auto operand,
      72             :            const auto residual, const auto& source,
      73             :            const auto& operator_applied_to_fields) {
      74             :           *iteration_id = 0;
      75             :           *operand = source - operator_applied_to_fields;
      76             :           *residual = *operand;
      77             :         },
      78             :         make_not_null(&box), get<source_tag>(box),
      79             :         get<operator_applied_to_fields_tag>(box));
      80             : 
      81             :     // Perform global reduction to compute initial residual magnitude square for
      82             :     // residual monitor
      83             :     const auto& residual = get<residual_tag>(box);
      84             :     Parallel::contribute_to_reduction<
      85             :         InitializeResidual<FieldsTag, OptionsGroup, ParallelComponent>>(
      86             :         Parallel::ReductionData<
      87             :             Parallel::ReductionDatum<double, funcl::Plus<>>>{
      88             :             inner_product(residual, residual)},
      89             :         Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
      90             :         Parallel::get_parallel_component<
      91             :             ResidualMonitor<Metavariables, FieldsTag, OptionsGroup>>(cache));
      92             : 
      93             :     return {Parallel::AlgorithmExecution::Continue, std::nullopt};
      94             :   }
      95             : };
      96             : 
      97             : template <typename FieldsTag, typename OptionsGroup, typename Label>
      98             : struct InitializeHasConverged {
      99             :   using inbox_tags = tmpl::list<Tags::InitialHasConverged<OptionsGroup>>;
     100             : 
     101             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     102             :             typename ArrayIndex, typename ActionList,
     103             :             typename ParallelComponent>
     104             :   static Parallel::iterable_action_return_t apply(
     105             :       db::DataBox<DbTagsList>& box, tuples::TaggedTuple<InboxTags...>& inboxes,
     106             :       const Parallel::GlobalCache<Metavariables>& /*cache*/,
     107             :       const ArrayIndex& /*array_index*/, const ActionList /*meta*/,
     108             :       const ParallelComponent* const /*meta*/) {
     109             :     auto& inbox = get<Tags::InitialHasConverged<OptionsGroup>>(inboxes);
     110             :     const auto& iteration_id =
     111             :         db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
     112             :     if (inbox.find(iteration_id) == inbox.end()) {
     113             :       return {Parallel::AlgorithmExecution::Retry, std::nullopt};
     114             :     }
     115             : 
     116             :     auto has_converged = std::move(inbox.extract(iteration_id).mapped());
     117             : 
     118             :     db::mutate<Convergence::Tags::HasConverged<OptionsGroup>>(
     119             :         [&has_converged](const gsl::not_null<Convergence::HasConverged*>
     120             :                              local_has_converged) {
     121             :           *local_has_converged = std::move(has_converged);
     122             :         },
     123             :         make_not_null(&box));
     124             : 
     125             :     // Skip steps entirely if the solve has already converged
     126             :     constexpr size_t step_end_index =
     127             :         tmpl::index_of<ActionList,
     128             :                        UpdateOperand<FieldsTag, OptionsGroup, Label>>::value;
     129             :     constexpr size_t this_action_index =
     130             :         tmpl::index_of<ActionList, InitializeHasConverged>::value;
     131             :     return {Parallel::AlgorithmExecution::Continue,
     132             :             has_converged ? (step_end_index + 1) : (this_action_index + 1)};
     133             :   }
     134             : };
     135             : 
     136             : template <typename FieldsTag, typename OptionsGroup, typename Label>
     137             : struct PerformStep {
     138             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     139             :             typename ArrayIndex, typename ActionList,
     140             :             typename ParallelComponent>
     141             :   static Parallel::iterable_action_return_t apply(
     142             :       db::DataBox<DbTagsList>& box,
     143             :       const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
     144             :       Parallel::GlobalCache<Metavariables>& cache,
     145             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     146             :       const ParallelComponent* const /*meta*/) {
     147             :     using fields_tag = FieldsTag;
     148             :     using operand_tag =
     149             :         db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
     150             :     using operator_tag =
     151             :         db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, operand_tag>;
     152             : 
     153             :     // At this point Ap must have been computed in a previous action
     154             :     // We compute the inner product <p,p> w.r.t A. This requires a global
     155             :     // reduction.
     156             :     const double local_conj_grad_inner_product =
     157             :         inner_product(get<operand_tag>(box), get<operator_tag>(box));
     158             : 
     159             :     Parallel::contribute_to_reduction<
     160             :         ComputeAlpha<FieldsTag, OptionsGroup, ParallelComponent>>(
     161             :         Parallel::ReductionData<
     162             :             Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
     163             :             Parallel::ReductionDatum<double, funcl::Plus<>>>{
     164             :             get<Convergence::Tags::IterationId<OptionsGroup>>(box),
     165             :             local_conj_grad_inner_product},
     166             :         Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
     167             :         Parallel::get_parallel_component<
     168             :             ResidualMonitor<Metavariables, FieldsTag, OptionsGroup>>(cache));
     169             : 
     170             :     return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     171             :   }
     172             : };
     173             : 
     174             : template <typename FieldsTag, typename OptionsGroup, typename Label>
     175             : struct UpdateFieldValues {
     176             :  private:
     177             :   using fields_tag = FieldsTag;
     178             :   using operand_tag =
     179             :       db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
     180             :   using operator_tag =
     181             :       db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, operand_tag>;
     182             :   using residual_tag =
     183             :       db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>;
     184             : 
     185             :  public:
     186             :   using inbox_tags = tmpl::list<Tags::Alpha<OptionsGroup>>;
     187             : 
     188             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     189             :             typename ArrayIndex, typename ActionList,
     190             :             typename ParallelComponent>
     191             :   static Parallel::iterable_action_return_t apply(
     192             :       db::DataBox<DbTagsList>& box, tuples::TaggedTuple<InboxTags...>& inboxes,
     193             :       Parallel::GlobalCache<Metavariables>& cache,
     194             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     195             :       const ParallelComponent* const /*meta*/) {
     196             :     auto& inbox = get<Tags::Alpha<OptionsGroup>>(inboxes);
     197             :     const auto& iteration_id =
     198             :         db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
     199             :     if (inbox.find(iteration_id) == inbox.end()) {
     200             :       return {Parallel::AlgorithmExecution::Retry, std::nullopt};
     201             :     }
     202             : 
     203             :     const double alpha = std::move(inbox.extract(iteration_id).mapped());
     204             : 
     205             :     db::mutate<residual_tag, fields_tag>(
     206             :         [alpha](const auto residual, const auto fields, const auto& operand,
     207             :                 const auto& operator_applied_to_operand) {
     208             :           *fields += alpha * operand;
     209             :           *residual -= alpha * operator_applied_to_operand;
     210             :         },
     211             :         make_not_null(&box), get<operand_tag>(box), get<operator_tag>(box));
     212             : 
     213             :     // Compute new residual norm in a second global reduction
     214             :     const auto& residual = get<residual_tag>(box);
     215             :     const double local_residual_magnitude_square =
     216             :         inner_product(residual, residual);
     217             : 
     218             :     Parallel::contribute_to_reduction<
     219             :         UpdateResidual<FieldsTag, OptionsGroup, ParallelComponent>>(
     220             :         Parallel::ReductionData<
     221             :             Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
     222             :             Parallel::ReductionDatum<double, funcl::Plus<>>>{
     223             :             get<Convergence::Tags::IterationId<OptionsGroup>>(box),
     224             :             local_residual_magnitude_square},
     225             :         Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
     226             :         Parallel::get_parallel_component<
     227             :             ResidualMonitor<Metavariables, FieldsTag, OptionsGroup>>(cache));
     228             : 
     229             :     return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     230             :   }
     231             : };
     232             : 
     233             : template <typename FieldsTag, typename OptionsGroup, typename Label>
     234             : struct UpdateOperand {
     235             :  private:
     236             :   using fields_tag = FieldsTag;
     237             :   using operand_tag =
     238             :       db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
     239             :   using residual_tag =
     240             :       db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>;
     241             : 
     242             :  public:
     243             :   using inbox_tags =
     244             :       tmpl::list<Tags::ResidualRatioAndHasConverged<OptionsGroup>>;
     245             : 
     246             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     247             :             typename ArrayIndex, typename ActionList,
     248             :             typename ParallelComponent>
     249             :   static Parallel::iterable_action_return_t apply(
     250             :       db::DataBox<DbTagsList>& box, tuples::TaggedTuple<InboxTags...>& inboxes,
     251             :       const Parallel::GlobalCache<Metavariables>& /*cache*/,
     252             :       const ArrayIndex& /*array_index*/, const ActionList /*meta*/,
     253             :       const ParallelComponent* const /*meta*/) {
     254             :     auto& inbox =
     255             :         get<Tags::ResidualRatioAndHasConverged<OptionsGroup>>(inboxes);
     256             :     const auto& iteration_id =
     257             :         db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
     258             :     if (inbox.find(iteration_id) == inbox.end()) {
     259             :       return {Parallel::AlgorithmExecution::Retry, std::nullopt};
     260             :     }
     261             : 
     262             :     auto received_data = std::move(inbox.extract(iteration_id).mapped());
     263             :     const double res_ratio = get<0>(received_data);
     264             :     auto& has_converged = get<1>(received_data);
     265             : 
     266             :     db::mutate<operand_tag, Convergence::Tags::IterationId<OptionsGroup>,
     267             :                Convergence::Tags::HasConverged<OptionsGroup>>(
     268             :         [res_ratio, &has_converged](
     269             :             const auto operand, const gsl::not_null<size_t*> local_iteration_id,
     270             :             const gsl::not_null<Convergence::HasConverged*> local_has_converged,
     271             :             const auto& residual) {
     272             :           *operand = residual + res_ratio * *operand;
     273             :           ++(*local_iteration_id);
     274             :           *local_has_converged = std::move(has_converged);
     275             :         },
     276             :         make_not_null(&box), get<residual_tag>(box));
     277             : 
     278             :     // Repeat steps until the solve has converged
     279             :     constexpr size_t this_action_index =
     280             :         tmpl::index_of<ActionList, UpdateOperand>::value;
     281             :     constexpr size_t prepare_step_index =
     282             :         tmpl::index_of<ActionList, InitializeHasConverged<
     283             :                                        FieldsTag, OptionsGroup, Label>>::value +
     284             :         1;
     285             :     return {Parallel::AlgorithmExecution::Continue,
     286             :             get<Convergence::Tags::HasConverged<OptionsGroup>>(box)
     287             :                 ? (this_action_index + 1)
     288             :                 : prepare_step_index};
     289             :   }
     290             : };
     291             : 
     292             : }  // namespace LinearSolver::cg::detail

Generated by: LCOV version 1.14