SpECTRE Documentation Coverage Report
Current view: top level - ParallelAlgorithms/LinearSolver/AsynchronousSolvers - ElementActions.hpp Hit Total Coverage
Commit: a6a8ee404306bec9d92da8ab89f636b037aefc25 Lines: 3 37 8.1 %
Date: 2024-07-26 22:35:59
Legend: Lines: hit not hit

          Line data    Source code
       1           0 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : #pragma once
       5             : 
       6             : #include <cstddef>
       7             : #include <limits>
       8             : #include <optional>
       9             : #include <string>
      10             : #include <tuple>
      11             : #include <utility>
      12             : #include <vector>
      13             : 
      14             : #include "DataStructures/DataBox/DataBox.hpp"
      15             : #include "DataStructures/DataBox/PrefixHelpers.hpp"
      16             : #include "IO/Logging/Tags.hpp"
      17             : #include "IO/Logging/Verbosity.hpp"
      18             : #include "IO/Observer/Actions/RegisterWithObservers.hpp"
      19             : #include "IO/Observer/GetSectionObservationKey.hpp"
      20             : #include "IO/Observer/Helpers.hpp"
      21             : #include "IO/Observer/ObservationId.hpp"
      22             : #include "IO/Observer/ObserverComponent.hpp"
      23             : #include "IO/Observer/Protocols/ReductionDataFormatter.hpp"
      24             : #include "IO/Observer/ReductionActions.hpp"
      25             : #include "IO/Observer/Tags.hpp"
      26             : #include "IO/Observer/TypeOfObservation.hpp"
      27             : #include "NumericalAlgorithms/Convergence/HasConverged.hpp"
      28             : #include "NumericalAlgorithms/Convergence/Tags.hpp"
      29             : #include "NumericalAlgorithms/LinearSolver/Gmres.hpp"
      30             : #include "NumericalAlgorithms/LinearSolver/InnerProduct.hpp"
      31             : #include "Parallel/AlgorithmExecution.hpp"
      32             : #include "Parallel/ArrayComponentId.hpp"
      33             : #include "Parallel/GlobalCache.hpp"
      34             : #include "Parallel/Invoke.hpp"
      35             : #include "Parallel/Local.hpp"
      36             : #include "Parallel/Printf/Printf.hpp"
      37             : #include "Parallel/Reduction.hpp"
      38             : #include "ParallelAlgorithms/Amr/Protocols/Projector.hpp"
      39             : #include "ParallelAlgorithms/Initialization/MutateAssign.hpp"
      40             : #include "ParallelAlgorithms/LinearSolver/Tags.hpp"
      41             : #include "Utilities/Functional.hpp"
      42             : #include "Utilities/GetOutput.hpp"
      43             : #include "Utilities/Gsl.hpp"
      44             : #include "Utilities/PrettyType.hpp"
      45             : #include "Utilities/ProtocolHelpers.hpp"
      46             : #include "Utilities/Requires.hpp"
      47             : #include "Utilities/TMPL.hpp"
      48             : 
      49             : /// \cond
      50             : namespace tuples {
      51             : template <typename...>
      52             : class TaggedTuple;
      53             : }  // namespace tuples
      54             : namespace LinearSolver::async_solvers {
      55             : template <typename FieldsTag, typename OptionsGroup, typename SourceTag,
      56             :           typename Label, typename ArraySectionIdTag,
      57             :           bool ObserveInitialResidual>
      58             : struct CompleteStep;
      59             : }  // namespace LinearSolver::async_solvers
      60             : /// \endcond
      61             : 
      62             : /// Functionality shared between parallel linear solvers that have no global
      63             : /// synchronization points
      64           1 : namespace LinearSolver::async_solvers {
      65             : 
      66           0 : using reduction_data = Parallel::ReductionData<
      67             :     // Iteration
      68             :     Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
      69             :     // Residual
      70             :     Parallel::ReductionDatum<double, funcl::Plus<>, funcl::Sqrt<>>>;
      71             : 
      72             : template <typename OptionsGroup>
      73           0 : struct ResidualReductionFormatter
      74             :     : tt::ConformsTo<observers::protocols::ReductionDataFormatter> {
      75           0 :   using reduction_data = async_solvers::reduction_data;
      76           0 :   ResidualReductionFormatter() = default;
      77           0 :   ResidualReductionFormatter(std::string local_section_observation_key)
      78             :       : section_observation_key(std::move(local_section_observation_key)) {}
      79           0 :   std::string operator()(const size_t iteration_id,
      80             :                          const double residual) const {
      81             :     if (iteration_id == 0) {
      82             :       return pretty_type::name<OptionsGroup>() + section_observation_key +
      83             :              " initialized with residual: " + get_output(residual);
      84             :     } else {
      85             :       return pretty_type::name<OptionsGroup>() + section_observation_key + "(" +
      86             :              get_output(iteration_id) +
      87             :              ") iteration complete. Remaining residual: " +
      88             :              get_output(residual);
      89             :     }
      90             :   }
      91             :   // NOLINTNEXTLINE(google-runtime-references)
      92           0 :   void pup(PUP::er& p) { p | section_observation_key; }
      93           0 :   std::string section_observation_key{};
      94             : };
      95             : 
      96             : template <typename OptionsGroup, typename ParallelComponent,
      97             :           typename Metavariables, typename ArrayIndex>
      98           0 : void contribute_to_residual_observation(
      99             :     const size_t iteration_id, const double residual_magnitude_square,
     100             :     Parallel::GlobalCache<Metavariables>& cache, const ArrayIndex& array_index,
     101             :     const std::string& section_observation_key) {
     102             :   auto& local_observer = *Parallel::local_branch(
     103             :       Parallel::get_parallel_component<observers::Observer<Metavariables>>(
     104             :           cache));
     105             :   auto formatter =
     106             :       UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(cache) >=
     107             :                ::Verbosity::Quiet)
     108             :           ? std::make_optional(ResidualReductionFormatter<OptionsGroup>{
     109             :                 section_observation_key})
     110             :           : std::nullopt;
     111             :   Parallel::simple_action<observers::Actions::ContributeReductionData>(
     112             :       local_observer,
     113             :       observers::ObservationId(
     114             :           iteration_id,
     115             :           pretty_type::get_name<OptionsGroup>() + section_observation_key),
     116             :       Parallel::make_array_component_id<ParallelComponent>(array_index),
     117             :       std::string{"/" + pretty_type::name<OptionsGroup>() +
     118             :                   section_observation_key + "Residuals"},
     119             :       std::vector<std::string>{"Iteration", "Residual"},
     120             :       reduction_data{iteration_id, residual_magnitude_square},
     121             :       std::move(formatter));
     122             :   if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(cache) >=
     123             :                ::Verbosity::Debug)) {
     124             :     if (iteration_id == 0) {
     125             :       Parallel::printf(
     126             :           "%s %s initialized with local residual: %e\n",
     127             :           get_output(array_index),
     128             :           pretty_type::name<OptionsGroup>() + section_observation_key,
     129             :           sqrt(residual_magnitude_square));
     130             :     } else {
     131             :       Parallel::printf(
     132             :           "%s %s(%zu) iteration complete. Remaining local residual: %e\n",
     133             :           get_output(array_index),
     134             :           pretty_type::name<OptionsGroup>() + section_observation_key,
     135             :           iteration_id, sqrt(residual_magnitude_square));
     136             :     }
     137             :   }
     138             : }
     139             : 
     140             : template <typename FieldsTag, typename OptionsGroup, typename SourceTag>
     141           0 : struct InitializeElement : tt::ConformsTo<amr::protocols::Projector> {
     142             :  private:
     143           0 :   using fields_tag = FieldsTag;
     144           0 :   using operator_applied_to_fields_tag =
     145             :       db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, fields_tag>;
     146           0 :   using source_tag = SourceTag;
     147           0 :   using residual_tag =
     148             :       db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>;
     149           0 :   using residual_magnitude_square_tag =
     150             :       LinearSolver::Tags::MagnitudeSquare<residual_tag>;
     151             : 
     152             :  public:  // Iterable action
     153           0 :   using const_global_cache_tags =
     154             :       tmpl::list<Convergence::Tags::Iterations<OptionsGroup>>;
     155             : 
     156           0 :   using simple_tags = tmpl::list<Convergence::Tags::IterationId<OptionsGroup>,
     157             :                                  Convergence::Tags::HasConverged<OptionsGroup>,
     158             :                                  operator_applied_to_fields_tag>;
     159           0 :   using compute_tags =
     160             :       tmpl::list<LinearSolver::Tags::ResidualCompute<fields_tag, source_tag>>;
     161             : 
     162             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     163             :             typename ArrayIndex, typename ActionList,
     164             :             typename ParallelComponent>
     165           0 :   static Parallel::iterable_action_return_t apply(
     166             :       db::DataBox<DbTagsList>& box,
     167             :       const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
     168             :       const Parallel::GlobalCache<Metavariables>& /*cache*/,
     169             :       const ArrayIndex& /*array_index*/, const ActionList /*meta*/,
     170             :       const ParallelComponent* const /*meta*/) {
     171             :     // The `PrepareSolve` action populates these tags with initial
     172             :     // values, except for `operator_applied_to_fields_tag` which is
     173             :     // expected to be updated in every iteration of the algorithm
     174             :     Initialization::mutate_assign<
     175             :         tmpl::list<Convergence::Tags::IterationId<OptionsGroup>>>(
     176             :         make_not_null(&box), std::numeric_limits<size_t>::max());
     177             :     return {Parallel::AlgorithmExecution::Continue, std::nullopt};
     178             :   }
     179             : 
     180             :  public:  // amr::protocols::Projector
     181           0 :   using argument_tags = tmpl::list<>;
     182           0 :   using return_tags = simple_tags;
     183             : 
     184             :   template <typename... AmrData>
     185           0 :   static void apply(const gsl::not_null<size_t*> /*unused*/,
     186             :                     const AmrData&... /*all_items*/) {
     187             :     // No need to reset or initialize any of the items during AMR because they
     188             :     // will be set in `PrepareSolve`. AMR can't happen _during_ a solve.
     189             :   }
     190             : };
     191             : 
     192             : template <typename OptionsGroup, typename ArraySectionIdTag = void>
     193           0 : struct RegisterObservers {
     194             :   template <typename ParallelComponent, typename DbTagsList,
     195             :             typename ArrayIndex>
     196             :   static std::pair<observers::TypeOfObservation, observers::ObservationKey>
     197           0 :   register_info(const db::DataBox<DbTagsList>& box,
     198             :                 const ArrayIndex& /*array_index*/) {
     199             :     // Get the observation key, or "Unused" if the element does not belong
     200             :     // to a section with this tag. In the latter case, no observations will
     201             :     // ever be contributed.
     202             :     const std::optional<std::string> section_observation_key =
     203             :         observers::get_section_observation_key<ArraySectionIdTag>(box);
     204             :     ASSERT(section_observation_key != "Unused",
     205             :            "The identifier 'Unused' is reserved to indicate that no "
     206             :            "observations with this key will be contributed. Use a different "
     207             :            "key, or change the identifier 'Unused' to something else.");
     208             :     return {
     209             :         observers::TypeOfObservation::Reduction,
     210             :         observers::ObservationKey{pretty_type::get_name<OptionsGroup>() +
     211             :                                   section_observation_key.value_or("Unused")}};
     212             :   }
     213             : };
     214             : 
     215             : template <typename FieldsTag, typename OptionsGroup, typename SourceTag,
     216             :           typename ArraySectionIdTag = void>
     217           0 : using RegisterElement = observers::Actions::RegisterWithObservers<
     218             :     RegisterObservers<OptionsGroup, ArraySectionIdTag>>;
     219             : 
     220             : /*!
     221             :  * \brief Prepare the asynchronous linear solver for a solve
     222             :  *
     223             :  * This action resets the asynchronous linear solver to its initial state, and
     224             :  * optionally observes the initial residual. If the initial residual should be
     225             :  * observed, both the `SourceTag` as well as the
     226             :  * `db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, FieldsTag>`
     227             :  * must be up-to-date at the time this action is invoked.
     228             :  *
     229             :  * This action also provides an anchor point in the action list for looping the
     230             :  * linear solver, in the sense that the algorithm jumps back to the action
     231             :  * immediately following this one when a step is complete and the solver hasn't
     232             :  * yet converged (see `LinearSolver::async_solvers::CompleteStep`).
     233             :  *
     234             :  * \tparam FieldsTag The data `x` in the linear equation `Ax = b` to be solved.
     235             :  * Should hold the initial guess `x_0` at this point in the algorithm.
     236             :  * \tparam OptionsGroup An options group identifying the linear solver
     237             :  * \tparam SourceTag The data `b` in `Ax = b`
     238             :  * \tparam Label An optional compile-time label for the solver to distinguish
     239             :  * different solves with the same solver in the action list
     240             :  * \tparam ArraySectionIdTag Observe the residual norm separately for each
     241             :  * array section identified by this tag (see `Parallel::Section`). Set to `void`
     242             :  * to observe the residual norm over all elements of the array (default). The
     243             :  * section only affects observations of residuals and has no effect on the
     244             :  * solver algorithm.
     245             :  * \tparam ObserveInitialResidual Whether or not to observe the initial residual
     246             :  * `b - A x_0` (default: `true`). Disable when `b` or `A x_0` are not yet
     247             :  * available at the preparation stage.
     248             :  */
     249             : template <typename FieldsTag, typename OptionsGroup, typename SourceTag,
     250             :           typename Label = OptionsGroup, typename ArraySectionIdTag = void,
     251             :           bool ObserveInitialResidual = true>
     252           1 : struct PrepareSolve {
     253             :  private:
     254           0 :   using fields_tag = FieldsTag;
     255           0 :   using residual_tag =
     256             :       db::add_tag_prefix<LinearSolver::Tags::Residual, FieldsTag>;
     257             : 
     258             :  public:
     259           0 :   using const_global_cache_tags =
     260             :       tmpl::list<logging::Tags::Verbosity<OptionsGroup>>;
     261             : 
     262             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     263             :             typename ArrayIndex, typename ActionList,
     264             :             typename ParallelComponent>
     265           0 :   static Parallel::iterable_action_return_t apply(
     266             :       db::DataBox<DbTagsList>& box,
     267             :       const tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
     268             :       Parallel::GlobalCache<Metavariables>& cache,
     269             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     270             :       const ParallelComponent* const /*meta*/) {
     271             :     constexpr size_t iteration_id = 0;
     272             : 
     273             :     if (UNLIKELY(get<logging::Tags::Verbosity<OptionsGroup>>(box) >=
     274             :                  ::Verbosity::Debug)) {
     275             :       Parallel::printf("%s %s: Prepare solve\n", get_output(array_index),
     276             :                        pretty_type::name<OptionsGroup>());
     277             :     }
     278             : 
     279             :     db::mutate<Convergence::Tags::IterationId<OptionsGroup>,
     280             :                Convergence::Tags::HasConverged<OptionsGroup>>(
     281             :         [](const gsl::not_null<size_t*> local_iteration_id,
     282             :            const gsl::not_null<Convergence::HasConverged*> has_converged,
     283             :            const size_t num_iterations) {
     284             :           *local_iteration_id = iteration_id;
     285             :           *has_converged =
     286             :               Convergence::HasConverged{num_iterations, iteration_id};
     287             :         },
     288             :         make_not_null(&box),
     289             :         get<Convergence::Tags::Iterations<OptionsGroup>>(box));
     290             : 
     291             :     if constexpr (ObserveInitialResidual) {
     292             :       // Observe the initial residual even if no steps are going to be performed
     293             :       const std::optional<std::string> section_observation_key =
     294             :           observers::get_section_observation_key<ArraySectionIdTag>(box);
     295             :       if (section_observation_key.has_value()) {
     296             :         const auto& residual = get<residual_tag>(box);
     297             :         const double residual_magnitude_square = magnitude_square(residual);
     298             :         contribute_to_residual_observation<OptionsGroup, ParallelComponent>(
     299             :             iteration_id, residual_magnitude_square, cache, array_index,
     300             :             *section_observation_key);
     301             :       }
     302             :     }
     303             : 
     304             :     // Skip steps entirely if the solve has already converged
     305             :     constexpr size_t step_end_index = tmpl::index_of<
     306             :         ActionList,
     307             :         CompleteStep<FieldsTag, OptionsGroup, SourceTag, Label,
     308             :                      ArraySectionIdTag, ObserveInitialResidual>>::value;
     309             :     constexpr size_t this_action_index =
     310             :         tmpl::index_of<ActionList, PrepareSolve>::value;
     311             :     return {Parallel::AlgorithmExecution::Continue,
     312             :             get<Convergence::Tags::HasConverged<OptionsGroup>>(box)
     313             :                 ? (step_end_index + 1)
     314             :                 : (this_action_index + 1)};
     315             :   }
     316             : };
     317             : 
     318             : /*!
     319             :  * \brief Complete a step of the asynchronous linear solver
     320             :  *
     321             :  * This action prepares the next step of the asynchronous linear solver, and
     322             :  * observes the residual. To observe the correct residual, make sure the
     323             :  * `db::add_tag_prefix<LinearSolver::Tags::OperatorAppliedTo, FieldsTag>` is
     324             :  * up-to-date at the time this action is invoked.
     325             :  *
     326             :  * This action checks if the algorithm has converged, i.e. it has completed the
     327             :  * requested number of steps. If it hasn't, the algorithm jumps back to the
     328             :  * action immediately following the `LinearSolver::async_solvers::PrepareSolve`
     329             :  * to perform another iteration. Make sure both actions use the same template
     330             :  * parameters.
     331             :  *
     332             :  * \tparam FieldsTag The data `x` in the linear equation `Ax = b` to be solved.
     333             :  * \tparam OptionsGroup An options group identifying the linear solver
     334             :  * \tparam SourceTag The data `b` in `Ax = b`
     335             :  * \tparam Label An optional compile-time label for the solver to distinguish
     336             :  * different solves with the same solver in the action list
     337             :  * \tparam ArraySectionIdTag Observe the residual norm separately for each
     338             :  * array section identified by this tag (see `Parallel::Section`). Set to `void`
     339             :  * to observe the residual norm over all elements of the array (default). The
     340             :  * section only affects observations of residuals and has no effect on the
     341             :  * solver algorithm.
     342             :  * \tparam ObserveInitialResidual Whether or not to observe the _initial_
     343             :  * residual `b - A x_0`. This parameter should match the one passed to
     344             :  * `PrepareSolve`.
     345             :  */
     346             : template <typename FieldsTag, typename OptionsGroup, typename SourceTag,
     347             :           typename Label = OptionsGroup, typename ArraySectionIdTag = void,
     348             :           bool ObserveInitialResidual = true>
     349           1 : struct CompleteStep {
     350             :  private:
     351           0 :   using fields_tag = FieldsTag;
     352           0 :   using residual_tag =
     353             :       db::add_tag_prefix<LinearSolver::Tags::Residual, fields_tag>;
     354             : 
     355             :  public:
     356           0 :   using const_global_cache_tags =
     357             :       tmpl::list<logging::Tags::Verbosity<OptionsGroup>>;
     358             : 
     359             :   template <typename DbTagsList, typename... InboxTags, typename Metavariables,
     360             :             typename ArrayIndex, typename ActionList,
     361             :             typename ParallelComponent>
     362           0 :   static Parallel::iterable_action_return_t apply(
     363             :       db::DataBox<DbTagsList>& box,
     364             :       tuples::TaggedTuple<InboxTags...>& /*inboxes*/,
     365             :       Parallel::GlobalCache<Metavariables>& cache,
     366             :       const ArrayIndex& array_index, const ActionList /*meta*/,
     367             :       const ParallelComponent* const /*meta*/) {
     368             :     // Prepare for next iteration
     369             :     db::mutate<Convergence::Tags::IterationId<OptionsGroup>,
     370             :                Convergence::Tags::HasConverged<OptionsGroup>>(
     371             :         [](const gsl::not_null<size_t*> iteration_id,
     372             :            const gsl::not_null<Convergence::HasConverged*> has_converged,
     373             :            const size_t num_iterations) {
     374             :           ++(*iteration_id);
     375             :           *has_converged =
     376             :               Convergence::HasConverged{num_iterations, *iteration_id};
     377             :         },
     378             :         make_not_null(&box),
     379             :         get<Convergence::Tags::Iterations<OptionsGroup>>(box));
     380             : 
     381             :     // Observe element-local residual magnitude
     382             :     const std::optional<std::string> section_observation_key =
     383             :         observers::get_section_observation_key<ArraySectionIdTag>(box);
     384             :     if (section_observation_key.has_value()) {
     385             :       const size_t completed_iterations =
     386             :           get<Convergence::Tags::IterationId<OptionsGroup>>(box);
     387             :       const auto& residual = get<residual_tag>(box);
     388             :       const double residual_magnitude_square = magnitude_square(residual);
     389             :       contribute_to_residual_observation<OptionsGroup, ParallelComponent>(
     390             :           completed_iterations, residual_magnitude_square, cache, array_index,
     391             :           *section_observation_key);
     392             :     }
     393             : 
     394             :     // Repeat steps until the solve has converged
     395             :     constexpr size_t step_begin_index =
     396             :         tmpl::index_of<
     397             :             ActionList,
     398             :             PrepareSolve<FieldsTag, OptionsGroup, SourceTag, Label,
     399             :                          ArraySectionIdTag, ObserveInitialResidual>>::value +
     400             :         1;
     401             :     constexpr size_t this_action_index =
     402             :         tmpl::index_of<ActionList, CompleteStep>::value;
     403             :     return {Parallel::AlgorithmExecution::Continue,
     404             :             get<Convergence::Tags::HasConverged<OptionsGroup>>(box)
     405             :                 ? (this_action_index + 1)
     406             :                 : step_begin_index};
     407             :   }
     408             : };
     409             : 
     410             : }  // namespace LinearSolver::async_solvers

Generated by: LCOV version 1.14