SpECTRE Documentation Coverage Report
Current view: top level - ParallelAlgorithms/LinearSolver/Gmres - ElementActions.hpp Hit Total Coverage
Commit: 1f2210958b4f38fdc0400907ee7c6d5af5111418 Lines: 0 1 0.0 %
Date: 2025-12-05 05:03:31
Legend: Lines: hit not hit

          Line data    Source code
       1           0 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : #pragma once
       5             : 
       6             : #include <cstddef>
       7             : #include <limits>
       8             : #include <optional>
       9             : #include <tuple>
      10             : #include <type_traits>
      11             : #include <utility>
      12             : 
      13             : #include "DataStructures/DataBox/DataBox.hpp"
      14             : #include "DataStructures/DataBox/PrefixHelpers.hpp"
      15             : #include "NumericalAlgorithms/Convergence/HasConverged.hpp"
      16             : #include "NumericalAlgorithms/Convergence/Tags.hpp"
      17             : #include "NumericalAlgorithms/LinearSolver/InnerProduct.hpp"
      18             : #include "Parallel/AlgorithmExecution.hpp"
      19             : #include "Parallel/GetSection.hpp"
      20             : #include "Parallel/GlobalCache.hpp"
      21             : #include "Parallel/Invoke.hpp"
      22             : #include "Parallel/Printf/Printf.hpp"
      23             : #include "Parallel/Reduction.hpp"
      24             : #include "Parallel/Tags/Section.hpp"
      25             : #include "ParallelAlgorithms/LinearSolver/Gmres/ResidualMonitorActions.hpp"
      26             : #include "ParallelAlgorithms/LinearSolver/Gmres/Tags/InboxTags.hpp"
      27             : #include "ParallelAlgorithms/LinearSolver/Tags.hpp"
      28             : #include "Utilities/ErrorHandling/Assert.hpp"
      29             : #include "Utilities/Functional.hpp"
      30             : #include "Utilities/GetOutput.hpp"
      31             : #include "Utilities/Gsl.hpp"
      32             : #include "Utilities/MakeWithValue.hpp"
      33             : #include "Utilities/PrettyType.hpp"
      34             : #include "Utilities/Requires.hpp"
      35             : #include "Utilities/TMPL.hpp"
      36             : #include "Utilities/TypeTraits/GetFundamentalType.hpp"
      37             : 
      38             : /// \cond
      39             : namespace tuples {
      40             : template <typename...>
      41             : class TaggedTuple;
      42             : }  // namespace tuples
      43             : namespace LinearSolver::gmres::detail {
      44             : template <typename Metavariables, typename FieldsTag, typename OptionsGroup>
      45             : struct ResidualMonitor;
      46             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
      47             :           typename Label, typename ArraySectionIdTag>
      48             : struct PrepareStep;
      49             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
      50             :           typename Label, typename ArraySectionIdTag>
      51             : struct NormalizeOperandAndUpdateField;
      52             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
      53             :           typename Label, typename ArraySectionIdTag>
      54             : struct CompleteStep;
      55             : }  // namespace LinearSolver::gmres::detail
      56             : /// \endcond
      57             : 
      58             : namespace LinearSolver::gmres::detail {
      59             : 
      60             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
      61             :           typename Label, typename SourceTag, typename ArraySectionIdTag>
      62             : struct PrepareSolve {
      63             :  private:
      64             :   using fields_tag = FieldsTag;
      65             :   using initial_fields_tag = db::add_tag_prefix<::Tags::Initial, fields_tag>;
      66             :   using source_tag = SourceTag;
      67             :   using operator_applied_to_fields_tag =
      68             :       db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, fields_tag>;
      69             :   using operand_tag =
      70             :       db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
      71             :   using basis_history_tag =
      72             :       LinearSolver::Tags::KrylovSubspaceBasis<operand_tag>;
      73             : 
      74             :  public:
      75             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
      76             :             typename ArrayIndex, typename ActionList,
      77             :             typename ParallelComponent>
      78             :   static Parallel::iterable_action_return_t apply(
      79             :       db::DataBox<DbTagsList>& box,
      80             :       const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
      81             :       Parallel::GlobalCache<Metavariables>& cache,
      82             :       const ArrayIndex& array_index, const ActionList /*meta*/,
      83             :       const ParallelComponent* const /*meta*/) {
      84             :     db::mutate<Convergence::Tags::IterationId<OptionsGroup>>(
      85             :         [](const gsl::not_null<size_t*> iteration_id) { *iteration_id = 0; },
      86             :         make_not_null(&box));
      87             : 
      88             :     // Skip the initial reduction on elements that are not part of the section
      89             :     if constexpr (not std::is_same_v<ArraySectionIdTag, void>) {
      90             :       if (not db::get<Parallel::Tags::Section<ParallelComponent,
      91             :                                               ArraySectionIdTag>>(box)
      92             :                   .has_value()) {
      93             :         return {Parallel::AlgorithmExecution::Continue, std::nullopt};
      94             :       }
      95             :     }
      96             : 
      97             :     if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(box) >=
      98             :                  ::Verbosity::Debug)) {
      99             :       Parallel::printf("%s %s: Prepare solve\n", get_output(array_index),
     100             :                        pretty_type::name<OptionsGroup>());
     101             :     }
     102             : 
     103             :     db::mutate<operand_tag, initial_fields_tag, basis_history_tag>(
     104             :         [](const auto operand, const auto initial_fields,
     105             :            const auto basis_history, const auto& source,
     106             :            const auto& operator_applied_to_fields, const auto& fields) {
     107             :           *operand = source - operator_applied_to_fields;
     108             :           *initial_fields = fields;
     109             :           *basis_history = typename basis_history_tag::type{};
     110             :         },
     111             :         make_not_null(&box), get<source_tag>(box),
     112             :         get<operator_applied_to_fields_tag>(box), get<fields_tag>(box));
     113             : 
     114             :     auto& section = Parallel::get_section<ParallelComponent, ArraySectionIdTag>(
     115             :         make_not_null(&box));
     116             :     Parallel::contribute_to_reduction<InitializeResidualMagnitude<
     117             :         FieldsTag, OptionsGroup, ParallelComponent>>(
     118             :         Parallel::ReductionData<
     119             :             Parallel::ReductionDatum<double, funcl::Plus<>, funcl::Sqrt<>>>{
     120             :             magnitude_square(get<operand_tag>(box))},
     121             :         Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
     122             :         Parallel::get_parallel_component<
     123             :             ResidualMonitor<Metavariables, FieldsTag, OptionsGroup>>(cache),
     124             :         make_not_null(&section));
     125             : 
     126             :     if constexpr (Preconditioned) {
     127             :       using preconditioned_operand_tag =
     128             :           db::add_tag_prefix<LinearSolver::Tags::Preconditioned, operand_tag>;
     129             :       using preconditioned_basis_history_tag =
     130             :           LinearSolver::Tags::KrylovSubspaceBasis<preconditioned_operand_tag>;
     131             : 
     132             :       db::mutate<preconditioned_basis_history_tag>(
     133             :           [](const auto preconditioned_basis_history) {
     134             :             *preconditioned_basis_history =
     135             :                 typename preconditioned_basis_history_tag::type{};
     136             :           },
     137             :           make_not_null(&box));
     138             :     }
     139             : 
     140             :     return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     141             :   }
     142             : };
     143             : 
     144             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
     145             :           typename Label, typename ArraySectionIdTag>
     146             : struct NormalizeInitialOperand {
     147             :  private:
     148             :   using fields_tag = FieldsTag;
     149             :   using operand_tag =
     150             :       db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
     151             :   using basis_history_tag =
     152             :       LinearSolver::Tags::KrylovSubspaceBasis<operand_tag>;
     153             : 
     154             :  public:
     155             :   using const_global_cache_tags =
     156             :       tmpl::list<logging::Tags::Verbosity<OptionsGroup>>;
     157             :   using inbox_tags = tmpl::list<Tags::InitialOrthogonalization<OptionsGroup>>;
     158             : 
     159             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     160             :             typename ArrayIndex, typename ActionList,
     161             :             typename ParallelComponent>
     162             :   static Parallel::iterable_action_return_t apply(
     163             :       db::DataBox<DbTagsList>& box, tuples::TaggedTuple<InboxTags...>& inboxes,
     164             :       const Parallel::GlobalCache<Metavariables>& /*cache*/,
     165             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     166             :       const ParallelComponent* const /*meta*/) {
     167             :     const size_t iteration_id =
     168             :         db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
     169             :     auto& inbox = get<Tags::InitialOrthogonalization<OptionsGroup>>(inboxes);
     170             :     if (inbox.find(iteration_id) == inbox.end()) {
     171             :       return {Parallel::AlgorithmExecution::Retry, std::nullopt};
     172             :     }
     173             : 
     174             :     auto received_data = std::move(inbox.extract(iteration_id).mapped());
     175             :     const double residual_magnitude = get<0>(received_data);
     176             :     auto& has_converged = get<1>(received_data);
     177             :     db::mutate<Convergence::Tags::HasConverged<OptionsGroup>>(
     178             :         [&has_converged](const gsl::not_null<Convergence::HasConverged*>
     179             :                              local_has_converged) {
     180             :           *local_has_converged = std::move(has_converged);
     181             :         },
     182             :         make_not_null(&box));
     183             : 
     184             :     // Skip steps entirely if the solve has already converged
     185             :     constexpr size_t step_end_index =
     186             :         tmpl::index_of<ActionList,
     187             :                        CompleteStep<FieldsTag, OptionsGroup, Preconditioned,
     188             :                                     Label, ArraySectionIdTag>>::value;
     189             :     if (get<Convergence::Tags::HasConverged<OptionsGroup>>(box)) {
     190             :       return {Parallel::AlgorithmExecution::Continue, step_end_index + 1};
     191             :     }
     192             : 
     193             :     // Skip the solve entirely on elements that are not part of the section. To
     194             :     // do so, we skip ahead to the `ApplyOperatorActions` action list between
     195             :     // `PrepareStep` and `PerformStep`. Those actions need a chance to run even
     196             :     // on elements that are not part of the section, because they may take part
     197             :     // in preconditioning (see Multigrid preconditioner).
     198             :     if constexpr (not std::is_same_v<ArraySectionIdTag, void>) {
     199             :       if (not db::get<Parallel::Tags::Section<ParallelComponent,
     200             :                                               ArraySectionIdTag>>(box)
     201             :                   .has_value()) {
     202             :         return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     203             :       }
     204             :     }
     205             : 
     206             :     if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(box) >=
     207             :                  ::Verbosity::Debug)) {
     208             :       Parallel::printf("%s %s(%zu): Normalize initial operand\n",
     209             :                        get_output(array_index),
     210             :                        pretty_type::name<OptionsGroup>(), iteration_id);
     211             :     }
     212             : 
     213             :     db::mutate<operand_tag, basis_history_tag>(
     214             :         [residual_magnitude](const auto operand, const auto basis_history) {
     215             :           *operand /= residual_magnitude;
     216             :           basis_history->push_back(*operand);
     217             :         },
     218             :         make_not_null(&box));
     219             : 
     220             :     return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     221             :   }
     222             : };
     223             : 
     224             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
     225             :           typename Label, typename ArraySectionIdTag>
     226             : struct PrepareStep {
     227             :   using const_global_cache_tags =
     228             :       tmpl::list<logging::Tags::Verbosity<OptionsGroup>>;
     229             : 
     230             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     231             :             typename ArrayIndex, typename ActionList,
     232             :             typename ParallelComponent>
     233             :   static Parallel::iterable_action_return_t apply(
     234             :       db::DataBox<DbTagsList>& box,
     235             :       const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
     236             :       const Parallel::GlobalCache<Metavariables>& /*cache*/,
     237             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     238             :       const ParallelComponent* const /*meta*/) {
     239             :     db::mutate<Convergence::Tags::IterationId<OptionsGroup>>(
     240             :         [](const gsl::not_null<size_t*> iteration_id) { ++(*iteration_id); },
     241             :         make_not_null(&box));
     242             : 
     243             :     if constexpr (not std::is_same_v<ArraySectionIdTag, void>) {
     244             :       if (not db::get<Parallel::Tags::Section<ParallelComponent,
     245             :                                               ArraySectionIdTag>>(box)
     246             :                   .has_value()) {
     247             :         return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     248             :       }
     249             :     }
     250             : 
     251             :     if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(box) >=
     252             :                  ::Verbosity::Debug)) {
     253             :       Parallel::printf(
     254             :           "%s %s(%zu): Prepare step\n", get_output(array_index),
     255             :           pretty_type::name<OptionsGroup>(),
     256             :           db::get<Convergence::Tags::IterationId<OptionsGroup>>(box));
     257             :     }
     258             : 
     259             :     if constexpr (Preconditioned) {
     260             :       using fields_tag = FieldsTag;
     261             :       using operand_tag =
     262             :           db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
     263             :       using preconditioned_operand_tag =
     264             :           db::add_tag_prefix<LinearSolver::Tags::Preconditioned, operand_tag>;
     265             :       using operator_tag = db::add_tag_prefix<
     266             :           LinearSolver::Tags::OperatorAppliedTo,
     267             :           std::conditional_t<Preconditioned, preconditioned_operand_tag,
     268             :                              operand_tag>>;
     269             : 
     270             :       db::mutate<preconditioned_operand_tag, operator_tag>(
     271             :           [](const auto preconditioned_operand,
     272             :              const auto operator_applied_to_operand, const auto& operand) {
     273             :             // Start the preconditioner at zero because we have no reason to
     274             :             // expect the remaining residual to have a particular form.
     275             :             // Another possibility would be to start the preconditioner with an
     276             :             // initial guess equal to its source, so not running the
     277             :             // preconditioner at all means it is the identity, but that approach
     278             :             // appears to yield worse results.
     279             :             *preconditioned_operand =
     280             :                 make_with_value<typename preconditioned_operand_tag::type>(
     281             :                     operand, 0.);
     282             :             // Also set the operator applied to the initial preconditioned
     283             :             // operand to zero because it's linear. This may save the
     284             :             // preconditioner an operator application if it's optimized for
     285             :             // this.
     286             :             *operator_applied_to_operand =
     287             :                 make_with_value<typename operator_tag::type>(operand, 0.);
     288             :           },
     289             :           make_not_null(&box), get<operand_tag>(box));
     290             :     }
     291             :     return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     292             :   }
     293             : };
     294             : 
     295             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
     296             :           typename Label, typename ArraySectionIdTag>
     297             : struct PerformStep {
     298             :  private:
     299             :   using fields_tag = FieldsTag;
     300             :   using operand_tag =
     301             :       db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
     302             :   using preconditioned_operand_tag =
     303             :       db::add_tag_prefix<LinearSolver::Tags::Preconditioned, operand_tag>;
     304             :   using ValueType =
     305             :       tt::get_complex_or_fundamental_type_t<typename fields_tag::type>;
     306             : 
     307             :  public:
     308             :   using const_global_cache_tags =
     309             :       tmpl::list<logging::Tags::Verbosity<OptionsGroup>>;
     310             : 
     311             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     312             :             typename ArrayIndex, typename ActionList,
     313             :             typename ParallelComponent>
     314             :   static Parallel::iterable_action_return_t apply(
     315             :       db::DataBox<DbTagsList>& box,
     316             :       const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
     317             :       Parallel::GlobalCache<Metavariables>& cache,
     318             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     319             :       const ParallelComponent* const /*meta*/) {
     320             :     // Skip to the end of the step on elements that are not part of the section
     321             :     if constexpr (not std::is_same_v<ArraySectionIdTag, void>) {
     322             :       if (not db::get<Parallel::Tags::Section<ParallelComponent,
     323             :                                               ArraySectionIdTag>>(box)
     324             :                   .has_value()) {
     325             :         constexpr size_t step_end_index =
     326             :             tmpl::index_of<ActionList,
     327             :                            NormalizeOperandAndUpdateField<
     328             :                                FieldsTag, OptionsGroup, Preconditioned, Label,
     329             :                                ArraySectionIdTag>>::value;
     330             :         return {Parallel::AlgorithmExecution::Continue, step_end_index};
     331             :       }
     332             :     }
     333             : 
     334             :     const size_t iteration_id =
     335             :         db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
     336             :     if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(box) >=
     337             :                  ::Verbosity::Debug)) {
     338             :       Parallel::printf("%s %s(%zu): Perform step\n", get_output(array_index),
     339             :                        pretty_type::name<OptionsGroup>(), iteration_id);
     340             :     }
     341             : 
     342             :     using operator_tag = db::add_tag_prefix<
     343             :         LinearSolver::Tags::OperatorAppliedTo,
     344             :         std::conditional_t<Preconditioned, preconditioned_operand_tag,
     345             :                            operand_tag>>;
     346             :     using orthogonalization_iteration_id_tag =
     347             :         LinearSolver::Tags::Orthogonalization<
     348             :             Convergence::Tags::IterationId<OptionsGroup>>;
     349             :     using basis_history_tag =
     350             :         LinearSolver::Tags::KrylovSubspaceBasis<operand_tag>;
     351             : 
     352             :     if constexpr (Preconditioned) {
     353             :       using preconditioned_basis_history_tag =
     354             :           LinearSolver::Tags::KrylovSubspaceBasis<preconditioned_operand_tag>;
     355             : 
     356             :       db::mutate<preconditioned_basis_history_tag>(
     357             :           [](const auto preconditioned_basis_history,
     358             :              const auto& preconditioned_operand) {
     359             :             preconditioned_basis_history->push_back(preconditioned_operand);
     360             :           },
     361             :           make_not_null(&box), get<preconditioned_operand_tag>(box));
     362             :     }
     363             : 
     364             :     db::mutate<operand_tag, orthogonalization_iteration_id_tag>(
     365             :         [](const auto operand,
     366             :            const gsl::not_null<size_t*> orthogonalization_iteration_id,
     367             :            const auto& operator_action) {
     368             :           *operand = typename operand_tag::type(operator_action);
     369             :           *orthogonalization_iteration_id = 0;
     370             :         },
     371             :         make_not_null(&box), get<operator_tag>(box));
     372             : 
     373             :     auto& section = Parallel::get_section<ParallelComponent, ArraySectionIdTag>(
     374             :         make_not_null(&box));
     375             :     Parallel::contribute_to_reduction<
     376             :         StoreOrthogonalization<FieldsTag, OptionsGroup, ParallelComponent>>(
     377             :         Parallel::ReductionData<
     378             :             Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
     379             :             Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
     380             :             Parallel::ReductionDatum<ValueType, funcl::Plus<>>>{
     381             :             get<Convergence::Tags::IterationId<OptionsGroup>>(box),
     382             :             get<orthogonalization_iteration_id_tag>(box),
     383             :             inner_product(get<basis_history_tag>(box)[0],
     384             :                           get<operand_tag>(box))},
     385             :         Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
     386             :         Parallel::get_parallel_component<
     387             :             ResidualMonitor<Metavariables, FieldsTag, OptionsGroup>>(cache),
     388             :         make_not_null(&section));
     389             : 
     390             :     return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     391             :   }
     392             : };
     393             : 
     394             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
     395             :           typename Label, typename ArraySectionIdTag>
     396             : struct OrthogonalizeOperand {
     397             :  private:
     398             :   using fields_tag = FieldsTag;
     399             :   using operand_tag =
     400             :       db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
     401             :   using orthogonalization_iteration_id_tag =
     402             :       LinearSolver::Tags::Orthogonalization<
     403             :           Convergence::Tags::IterationId<OptionsGroup>>;
     404             :   using basis_history_tag =
     405             :       LinearSolver::Tags::KrylovSubspaceBasis<operand_tag>;
     406             :   using ValueType =
     407             :       tt::get_complex_or_fundamental_type_t<typename fields_tag::type>;
     408             : 
     409             :  public:
     410             :   using inbox_tags =
     411             :       tmpl::list<Tags::Orthogonalization<OptionsGroup, ValueType>>;
     412             : 
     413             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     414             :             typename ArrayIndex, typename ActionList,
     415             :             typename ParallelComponent>
     416             :   static Parallel::iterable_action_return_t apply(
     417             :       db::DataBox<DbTagsList>& box, tuples::TaggedTuple<InboxTags...>& inboxes,
     418             :       Parallel::GlobalCache<Metavariables>& cache,
     419             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     420             :       const ParallelComponent* const /*meta*/) {
     421             :     const size_t iteration_id =
     422             :         db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
     423             :     auto& inbox =
     424             :         get<Tags::Orthogonalization<OptionsGroup, ValueType>>(inboxes);
     425             :     if (inbox.find(iteration_id) == inbox.end()) {
     426             :       return {Parallel::AlgorithmExecution::Retry, std::nullopt};
     427             :     }
     428             : 
     429             :     const ValueType orthogonalization =
     430             :         std::move(inbox.extract(iteration_id).mapped());
     431             : 
     432             :     db::mutate<operand_tag, orthogonalization_iteration_id_tag>(
     433             :         [orthogonalization](
     434             :             const auto operand,
     435             :             const gsl::not_null<size_t*> orthogonalization_iteration_id,
     436             :             const auto& basis_history) {
     437             :           *operand -= orthogonalization *
     438             :                       gsl::at(basis_history, *orthogonalization_iteration_id);
     439             :           ++(*orthogonalization_iteration_id);
     440             :         },
     441             :         make_not_null(&box), get<basis_history_tag>(box));
     442             : 
     443             :     const auto& next_orthogonalization_iteration_id =
     444             :         get<orthogonalization_iteration_id_tag>(box);
     445             :     const bool orthogonalization_complete =
     446             :         next_orthogonalization_iteration_id == iteration_id;
     447             :     const ValueType local_orthogonalization =
     448             :         inner_product(orthogonalization_complete
     449             :                           ? get<operand_tag>(box)
     450             :                           : gsl::at(get<basis_history_tag>(box),
     451             :                                     next_orthogonalization_iteration_id),
     452             :                       get<operand_tag>(box));
     453             : 
     454             :     auto& section = Parallel::get_section<ParallelComponent, ArraySectionIdTag>(
     455             :         make_not_null(&box));
     456             :     Parallel::contribute_to_reduction<
     457             :         StoreOrthogonalization<FieldsTag, OptionsGroup, ParallelComponent>>(
     458             :         Parallel::ReductionData<
     459             :             Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
     460             :             Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
     461             :             Parallel::ReductionDatum<ValueType, funcl::Plus<>>>{
     462             :             iteration_id, next_orthogonalization_iteration_id,
     463             :             local_orthogonalization},
     464             :         Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
     465             :         Parallel::get_parallel_component<
     466             :             ResidualMonitor<Metavariables, FieldsTag, OptionsGroup>>(cache),
     467             :         make_not_null(&section));
     468             : 
     469             :     // Repeat this action until orthogonalization is complete
     470             :     constexpr size_t this_action_index =
     471             :         tmpl::index_of<ActionList, OrthogonalizeOperand>::value;
     472             :     return {Parallel::AlgorithmExecution::Continue,
     473             :             orthogonalization_complete ? (this_action_index + 1)
     474             :                                        : this_action_index};
     475             :   }
     476             : };
     477             : 
     478             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
     479             :           typename Label, typename ArraySectionIdTag>
     480             : struct NormalizeOperandAndUpdateField {
     481             :  private:
     482             :   using fields_tag = FieldsTag;
     483             :   using initial_fields_tag = db::add_tag_prefix<::Tags::Initial, fields_tag>;
     484             :   using operand_tag =
     485             :       db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
     486             :   using preconditioned_operand_tag =
     487             :       db::add_tag_prefix<LinearSolver::Tags::Preconditioned, operand_tag>;
     488             :   using basis_history_tag =
     489             :       LinearSolver::Tags::KrylovSubspaceBasis<operand_tag>;
     490             :   using preconditioned_basis_history_tag =
     491             :       LinearSolver::Tags::KrylovSubspaceBasis<std::conditional_t<
     492             :           Preconditioned, preconditioned_operand_tag, operand_tag>>;
     493             :   using ValueType =
     494             :       tt::get_complex_or_fundamental_type_t<typename fields_tag::type>;
     495             : 
     496             :  public:
     497             :   using const_global_cache_tags =
     498             :       tmpl::list<logging::Tags::Verbosity<OptionsGroup>>;
     499             :   using inbox_tags =
     500             :       tmpl::list<Tags::FinalOrthogonalization<OptionsGroup, ValueType>>;
     501             : 
     502             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     503             :             typename ArrayIndex, typename ActionList,
     504             :             typename ParallelComponent>
     505             :   static Parallel::iterable_action_return_t apply(
     506             :       db::DataBox<DbTagsList>& box, tuples::TaggedTuple<InboxTags...>& inboxes,
     507             :       const Parallel::GlobalCache<Metavariables>& /*cache*/,
     508             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     509             :       const ParallelComponent* const /*meta*/) {
     510             :     const size_t iteration_id =
     511             :         db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
     512             :     auto& inbox =
     513             :         get<Tags::FinalOrthogonalization<OptionsGroup, ValueType>>(inboxes);
     514             :     if (inbox.find(iteration_id) == inbox.end()) {
     515             :       return {Parallel::AlgorithmExecution::Retry, std::nullopt};
     516             :     }
     517             : 
     518             :     // Retrieve reduction data from inbox
     519             :     auto received_data = std::move(inbox.extract(iteration_id).mapped());
     520             :     const double normalization = get<0>(received_data);
     521             :     const auto& minres = get<1>(received_data);
     522             :     db::mutate<Convergence::Tags::HasConverged<OptionsGroup>>(
     523             :         [&received_data](
     524             :             const gsl::not_null<Convergence::HasConverged*> has_converged) {
     525             :           *has_converged = std::move(get<2>(received_data));
     526             :         },
     527             :         make_not_null(&box));
     528             : 
     529             :     // Elements that are not part of the section jump directly to the
     530             :     // `ApplyOperationActions` for the next step.
     531             :     if constexpr (not std::is_same_v<ArraySectionIdTag, void>) {
     532             :       constexpr size_t complete_step_index =
     533             :           tmpl::index_of<ActionList,
     534             :                          CompleteStep<FieldsTag, OptionsGroup, Preconditioned,
     535             :                                       Label, ArraySectionIdTag>>::value;
     536             :       constexpr size_t prepare_step_index =
     537             :           tmpl::index_of<ActionList,
     538             :                          PrepareStep<FieldsTag, OptionsGroup, Preconditioned,
     539             :                                      Label, ArraySectionIdTag>>::value;
     540             :       if (not db::get<Parallel::Tags::Section<ParallelComponent,
     541             :                                               ArraySectionIdTag>>(box)
     542             :                   .has_value()) {
     543             :         return {Parallel::AlgorithmExecution::Continue,
     544             :                 get<Convergence::Tags::HasConverged<OptionsGroup>>(box)
     545             :                     ? (complete_step_index + 1)
     546             :                     : prepare_step_index};
     547             :       }
     548             :     }
     549             : 
     550             :     if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(box) >=
     551             :                  ::Verbosity::Debug)) {
     552             :       Parallel::printf("%s %s(%zu): Update field\n", get_output(array_index),
     553             :                        pretty_type::name<OptionsGroup>(), iteration_id);
     554             :     }
     555             : 
     556             :     db::mutate<operand_tag, basis_history_tag, fields_tag>(
     557             :         [normalization, &minres](const auto operand, const auto basis_history,
     558             :                                  const auto field, const auto& initial_field,
     559             :                                  const auto& preconditioned_basis_history,
     560             :                                  const auto& has_converged) {
     561             :           // Avoid an FPE if the new operand norm is exactly zero. In that case
     562             :           // the problem is solved and the algorithm will terminate (see
     563             :           // Proposition 9.3 in \cite Saad2003). Since there will be no next
     564             :           // iteration we don't need to normalize the operand.
     565             :           if (LIKELY(normalization > 0.)) {
     566             :             *operand /= normalization;
     567             :           }
     568             :           basis_history->push_back(*operand);
     569             :           // Don't update the solution if an error occurred
     570             :           if (not(has_converged and
     571             :                   has_converged.reason() == Convergence::Reason::Error)) {
     572             :             *field = initial_field;
     573             :             for (size_t i = 0; i < minres.size(); i++) {
     574             :               *field += minres[i] * gsl::at(preconditioned_basis_history, i);
     575             :             }
     576             :           }
     577             :         },
     578             :         make_not_null(&box), get<initial_fields_tag>(box),
     579             :         get<preconditioned_basis_history_tag>(box),
     580             :         get<Convergence::Tags::HasConverged<OptionsGroup>>(box));
     581             : 
     582             :     return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     583             :   }
     584             : };
     585             : 
     586             : // Jump back to `PrepareStep` to continue iterating if the algorithm has not yet
     587             : // converged, or complete the solve and proceed with the action list if it has
     588             : // converged. This is a separate action because the user has the opportunity to
     589             : // insert actions before the step completes, for example to do observations.
     590             : template <typename FieldsTag, typename OptionsGroup, bool Preconditioned,
     591             :           typename Label, typename ArraySectionIdTag>
     592             : struct CompleteStep {
     593             :   using const_global_cache_tags =
     594             :       tmpl::list<logging::Tags::Verbosity<OptionsGroup>>;
     595             : 
     596             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     597             :             typename ArrayIndex, typename ActionList,
     598             :             typename ParallelComponent>
     599             :   static Parallel::iterable_action_return_t apply(
     600             :       db::DataBox<DbTagsList>& box,
     601             :       tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
     602             :       const Parallel::GlobalCache<Metavariables>& /*cache*/,
     603             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     604             :       const ParallelComponent* const /*meta*/) {
     605             :     if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(box) >=
     606             :                  ::Verbosity::Debug)) {
     607             :       Parallel::printf(
     608             :           "%s %s(%zu): Complete step\n", get_output(array_index),
     609             :           pretty_type::name<OptionsGroup>(),
     610             :           db::get<Convergence::Tags::IterationId<OptionsGroup>>(box));
     611             :     }
     612             : 
     613             :     // Repeat steps until the solve has converged
     614             :     constexpr size_t prepare_step_index =
     615             :         tmpl::index_of<ActionList,
     616             :                        PrepareStep<FieldsTag, OptionsGroup, Preconditioned,
     617             :                                    Label, ArraySectionIdTag>>::value;
     618             :     constexpr size_t this_action_index =
     619             :         tmpl::index_of<ActionList, CompleteStep>::value;
     620             :     return {Parallel::AlgorithmExecution::Continue,
     621             :             get<Convergence::Tags::HasConverged<OptionsGroup>>(box)
     622             :                 ? (this_action_index + 1)
     623             :                 : prepare_step_index};
     624             :   }
     625             : };
     626             : 
     627             : }  // namespace LinearSolver::gmres::detail

Generated by: LCOV version 1.14