SpECTRE Documentation Coverage Report
Current view: top level - NumericalAlgorithms/LinearSolver - LinearSolver.hpp Hit Total Coverage
Commit: b5f497991094937944b0a3f519166bb54739d08a Lines: 11 31 35.5 %
Date: 2024-03-28 18:20:13
Legend: Lines: hit not hit

          Line data    Source code
       1           0 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : #pragma once
       5             : 
       6             : #include <memory>
       7             : #include <optional>
       8             : #include <pup.h>
       9             : #include <pup_stl.h>
      10             : #include <type_traits>
      11             : #include <utility>
      12             : 
      13             : #include "NumericalAlgorithms/Convergence/HasConverged.hpp"
      14             : #include "Options/Auto.hpp"
      15             : #include "Options/String.hpp"
      16             : #include "Utilities/CallWithDynamicType.hpp"
      17             : #include "Utilities/ErrorHandling/Assert.hpp"
      18             : #include "Utilities/Gsl.hpp"
      19             : #include "Utilities/Registration.hpp"
      20             : #include "Utilities/Requires.hpp"
      21             : #include "Utilities/Serialization/CharmPupable.hpp"
      22             : #include "Utilities/Serialization/PupStlCpp17.hpp"
      23             : 
      24             : namespace LinearSolver::Serial {
      25             : 
      26             : /// Registrars for linear solvers
      27             : namespace Registrars {}
      28             : 
      29             : /*!
      30             :  * \brief Base class for serial linear solvers that supports factory-creation.
      31             :  *
      32             :  * Derive linear solvers from this class so they can be factory-created. If your
      33             :  * linear solver supports preconditioning, derive from
      34             :  * `PreconditionedLinearSolver` instead to inherit utility that allows using any
      35             :  * other factor-creatable linear solver as preconditioner.
      36             :  */
      37             : template <typename LinearSolverRegistrars>
      38           1 : class LinearSolver : public PUP::able {
      39             :  protected:
      40             :   /// \cond
      41             :   LinearSolver() = default;
      42             :   LinearSolver(const LinearSolver&) = default;
      43             :   LinearSolver(LinearSolver&&) = default;
      44             :   LinearSolver& operator=(const LinearSolver&) = default;
      45             :   LinearSolver& operator=(LinearSolver&&) = default;
      46             :   /// \endcond
      47             : 
      48             :  public:
      49           0 :   ~LinearSolver() override = default;
      50             : 
      51             :   /// \cond
      52             :   explicit LinearSolver(CkMigrateMessage* m);
      53             :   WRAPPED_PUPable_abstract(LinearSolver);  // NOLINT
      54             :   /// \endcond
      55             : 
      56           0 :   using registrars = LinearSolverRegistrars;
      57           0 :   using creatable_classes = Registration::registrants<LinearSolverRegistrars>;
      58             : 
      59           0 :   virtual std::unique_ptr<LinearSolver<LinearSolverRegistrars>> get_clone()
      60             :       const = 0;
      61             : 
      62             :   /*!
      63             :    * \brief Solve the linear equation \f$Ax=b\f$ where \f$A\f$ is the
      64             :    * `linear_operator` and \f$b\f$ is the `source`.
      65             :    *
      66             :    * - The (approximate) solution \f$x\f$ is returned in the
      67             :    *   `initial_guess_in_solution_out` buffer, which also serves to provide an
      68             :    *   initial guess for \f$x\f$. Not all solvers take the initial guess into
      69             :    *   account, but all expect the buffer is sized correctly.
      70             :    * - The `linear_operator` must be an invocable that takes a `VarsType` as
      71             :    *   const-ref argument and returns a `SourceType` by reference. It also takes
      72             :    *   all `OperatorArgs` as const-ref arguments.
      73             :    *
      74             :    * Each solve may mutate the private state of the solver, for example to cache
      75             :    * quantities to accelerate successive solves for the same operator. Invoke
      76             :    * `reset` to discard these caches.
      77             :    */
      78             :   template <typename LinearOperator, typename VarsType, typename SourceType,
      79             :             typename... OperatorArgs, typename... Args>
      80           1 :   Convergence::HasConverged solve(
      81             :       gsl::not_null<VarsType*> initial_guess_in_solution_out,
      82             :       const LinearOperator& linear_operator, const SourceType& source,
      83             :       const std::tuple<OperatorArgs...>& operator_args, Args&&... args) const;
      84             : 
      85             :   /// Discard caches from previous solves. Use before solving a different linear
      86             :   /// operator.
      87           1 :   virtual void reset() = 0;
      88             : };
      89             : 
      90             : /// \cond
      91             : template <typename LinearSolverRegistrars>
      92             : LinearSolver<LinearSolverRegistrars>::LinearSolver(CkMigrateMessage* m)
      93             :     : PUP::able(m) {}
      94             : /// \endcond
      95             : 
      96             : template <typename LinearSolverRegistrars>
      97             : template <typename LinearOperator, typename VarsType, typename SourceType,
      98             :           typename... OperatorArgs, typename... Args>
      99           0 : Convergence::HasConverged LinearSolver<LinearSolverRegistrars>::solve(
     100             :     const gsl::not_null<VarsType*> initial_guess_in_solution_out,
     101             :     const LinearOperator& linear_operator, const SourceType& source,
     102             :     const std::tuple<OperatorArgs...>& operator_args, Args&&... args) const {
     103             :   return call_with_dynamic_type<Convergence::HasConverged, creatable_classes>(
     104             :       this, [&initial_guess_in_solution_out, &linear_operator, &source,
     105             :              &operator_args, &args...](auto* const linear_solver) {
     106             :         return linear_solver->solve(initial_guess_in_solution_out,
     107             :                                     linear_operator, source, operator_args,
     108             :                                     std::forward<Args>(args)...);
     109             :       });
     110             : }
     111             : 
     112             : /// Indicates the linear solver uses no preconditioner. It may perform
     113             : /// compile-time optimization for this case.
     114           1 : struct NoPreconditioner {};
     115             : 
     116             : /*!
     117             :  * \brief Base class for serial linear solvers that supports factory-creation
     118             :  * and nested preconditioning.
     119             :  *
     120             :  * To enable support for preconditioning in your derived linear solver class,
     121             :  * pass any type that has a `solve` and a `reset` function as the
     122             :  * `Preconditioner` template parameter. It can also be an abstract
     123             :  * `LinearSolver` type, which means that any other linear solver can be used as
     124             :  * preconditioner. Pass `NoPreconditioner` to disable support for
     125             :  * preconditioning.
     126             :  */
     127             : template <typename Preconditioner, typename LinearSolverRegistrars>
     128           1 : class PreconditionedLinearSolver : public LinearSolver<LinearSolverRegistrars> {
     129             :  private:
     130           0 :   using Base = LinearSolver<LinearSolverRegistrars>;
     131             : 
     132             :  public:
     133           0 :   using PreconditionerType =
     134             :       tmpl::conditional_t<std::is_abstract_v<Preconditioner>,
     135             :                           std::unique_ptr<Preconditioner>, Preconditioner>;
     136           0 :   struct PreconditionerOption {
     137           0 :     static std::string name() { return "Preconditioner"; }
     138             :     // Support factory-creatable preconditioners by storing them as unique-ptrs
     139           0 :     using type = Options::Auto<PreconditionerType, Options::AutoLabel::None>;
     140           0 :     static constexpr Options::String help =
     141             :         "An approximate linear solve in every iteration that helps the "
     142             :         "algorithm converge.";
     143             :   };
     144             : 
     145             :  protected:
     146           0 :   PreconditionedLinearSolver() = default;
     147           0 :   PreconditionedLinearSolver(PreconditionedLinearSolver&&) = default;
     148           0 :   PreconditionedLinearSolver& operator=(PreconditionedLinearSolver&&) = default;
     149             : 
     150           0 :   explicit PreconditionedLinearSolver(
     151             :       std::optional<PreconditionerType> local_preconditioner);
     152             : 
     153           0 :   PreconditionedLinearSolver(const PreconditionedLinearSolver& rhs);
     154           0 :   PreconditionedLinearSolver& operator=(const PreconditionedLinearSolver& rhs);
     155             : 
     156             :  public:
     157           0 :   ~PreconditionedLinearSolver() override = default;
     158             : 
     159             :   /// \cond
     160             :   explicit PreconditionedLinearSolver(CkMigrateMessage* m);
     161             :   /// \endcond
     162             : 
     163           0 :   void pup(PUP::er& p) override {  // NOLINT
     164             :     PUP::able::pup(p);
     165             :     if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
     166             :       p | preconditioner_;
     167             :     }
     168             :   }
     169             : 
     170             :   /// Whether or not a preconditioner is set
     171           1 :   bool has_preconditioner() const {
     172             :     if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
     173             :       return preconditioner_.has_value();
     174             :     } else {
     175             :       return false;
     176             :     }
     177             :   }
     178             : 
     179             :   /// @{
     180             :   /// Access to the preconditioner. Check `has_preconditioner()` before calling
     181             :   /// this function. Calling this function when `has_preconditioner()` returns
     182             :   /// `false` is an error.
     183             :   template <
     184             :       bool Enabled = not std::is_same_v<Preconditioner, NoPreconditioner>,
     185             :       Requires<Enabled and
     186             :                not std::is_same_v<Preconditioner, NoPreconditioner>> = nullptr>
     187           1 :   const Preconditioner& preconditioner() const {
     188             :     ASSERT(has_preconditioner(),
     189             :            "No preconditioner is set. Please use `has_preconditioner()` to "
     190             :            "check before trying to retrieve it.");
     191             :     if constexpr (std::is_abstract_v<Preconditioner>) {
     192             :       return **preconditioner_;
     193             :     } else {
     194             :       return *preconditioner_;
     195             :     }
     196             :   }
     197             : 
     198             :   template <
     199             :       bool Enabled = not std::is_same_v<Preconditioner, NoPreconditioner>,
     200             :       Requires<Enabled and
     201             :                not std::is_same_v<Preconditioner, NoPreconditioner>> = nullptr>
     202           1 :   Preconditioner& preconditioner() {
     203             :     ASSERT(has_preconditioner(),
     204             :            "No preconditioner is set. Please use `has_preconditioner()` to "
     205             :            "check before trying to retrieve it.");
     206             :     if constexpr (std::is_abstract_v<Preconditioner>) {
     207             :       return **preconditioner_;
     208             :     } else {
     209             :       return *preconditioner_;
     210             :     }
     211             :   }
     212             : 
     213             :   // Keep the function virtual so derived classes must provide an
     214             :   // implementation, but also provide an implementation below that derived
     215             :   // classes can use to reset the preconditioner
     216           1 :   void reset() override = 0;
     217             : 
     218             :  protected:
     219             :   /// Copy the preconditioner. Useful to implement `get_clone` when the
     220             :   /// preconditioner has an abstract type.
     221             :   template <
     222             :       bool Enabled = not std::is_same_v<Preconditioner, NoPreconditioner>,
     223             :       Requires<Enabled and
     224             :                not std::is_same_v<Preconditioner, NoPreconditioner>> = nullptr>
     225           1 :   std::optional<PreconditionerType> clone_preconditioner() const {
     226             :     if constexpr (std::is_abstract_v<Preconditioner>) {
     227             :       return has_preconditioner()
     228             :                  ? std::optional((*preconditioner_)->get_clone())
     229             :                  : std::nullopt;
     230             :     } else {
     231             :       return preconditioner_;
     232             :     }
     233             :   }
     234             : 
     235             :  private:
     236             :   // Only needed when preconditioning is enabled, but current C++ can't remove
     237             :   // this variable at compile-time. Keeping the variable shouldn't have any
     238             :   // noticeable overhead though.
     239           1 :   std::optional<PreconditionerType> preconditioner_{};
     240             : };
     241             : 
     242             : template <typename Preconditioner, typename LinearSolverRegistrars>
     243             : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>::
     244             :     PreconditionedLinearSolver(
     245             :         std::optional<PreconditionerType> local_preconditioner)
     246             :     : preconditioner_(std::move(local_preconditioner)) {}
     247             : 
     248             : // Override copy constructors so they can clone abstract preconditioners
     249             : template <typename Preconditioner, typename LinearSolverRegistrars>
     250             : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>::
     251             :     PreconditionedLinearSolver(const PreconditionedLinearSolver& rhs)
     252             :     : Base(rhs) {
     253             :   if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
     254             :     preconditioner_ = rhs.clone_preconditioner();
     255             :   }
     256             : }
     257             : template <typename Preconditioner, typename LinearSolverRegistrars>
     258             : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>&
     259             : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>::operator=(
     260             :     const PreconditionedLinearSolver& rhs) {
     261             :   Base::operator=(rhs);
     262             :   if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
     263             :     preconditioner_ = rhs.clone_preconditioner();
     264             :   }
     265             :   return *this;
     266             : }
     267             : 
     268             : /// \cond
     269             : template <typename Preconditioner, typename LinearSolverRegistrars>
     270             : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>::
     271             :     PreconditionedLinearSolver(CkMigrateMessage* m)
     272             :     : Base(m) {}
     273             : /// \endcond
     274             : 
     275             : template <typename Preconditioner, typename LinearSolverRegistrars>
     276             : void PreconditionedLinearSolver<Preconditioner,
     277             :                                 LinearSolverRegistrars>::reset() {
     278             :   if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
     279             :     if (has_preconditioner()) {
     280             :       preconditioner().reset();
     281             :     }
     282             :   }
     283             : }
     284             : 
     285             : }  // namespace LinearSolver::Serial

Generated by: LCOV version 1.14