Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #include <memory>
7 : #include <optional>
8 : #include <pup.h>
9 : #include <pup_stl.h>
10 : #include <type_traits>
11 : #include <utility>
12 :
13 : #include "NumericalAlgorithms/Convergence/HasConverged.hpp"
14 : #include "Options/Auto.hpp"
15 : #include "Options/String.hpp"
16 : #include "Utilities/CallWithDynamicType.hpp"
17 : #include "Utilities/ErrorHandling/Assert.hpp"
18 : #include "Utilities/Gsl.hpp"
19 : #include "Utilities/Registration.hpp"
20 : #include "Utilities/Requires.hpp"
21 : #include "Utilities/Serialization/CharmPupable.hpp"
22 : #include "Utilities/Serialization/PupStlCpp17.hpp"
23 :
24 : namespace LinearSolver::Serial {
25 :
26 : /// Registrars for linear solvers
27 : namespace Registrars {}
28 :
29 : /*!
30 : * \brief Base class for serial linear solvers that supports factory-creation.
31 : *
32 : * Derive linear solvers from this class so they can be factory-created. If your
33 : * linear solver supports preconditioning, derive from
34 : * `PreconditionedLinearSolver` instead to inherit utility that allows using any
35 : * other factor-creatable linear solver as preconditioner.
36 : */
37 : template <typename LinearSolverRegistrars>
38 1 : class LinearSolver : public PUP::able {
39 : protected:
40 : /// \cond
41 : LinearSolver() = default;
42 : LinearSolver(const LinearSolver&) = default;
43 : LinearSolver(LinearSolver&&) = default;
44 : LinearSolver& operator=(const LinearSolver&) = default;
45 : LinearSolver& operator=(LinearSolver&&) = default;
46 : /// \endcond
47 :
48 : public:
49 0 : ~LinearSolver() override = default;
50 :
51 : /// \cond
52 : explicit LinearSolver(CkMigrateMessage* m);
53 : WRAPPED_PUPable_abstract(LinearSolver); // NOLINT
54 : /// \endcond
55 :
56 0 : using registrars = LinearSolverRegistrars;
57 0 : using creatable_classes = Registration::registrants<LinearSolverRegistrars>;
58 :
59 0 : virtual std::unique_ptr<LinearSolver<LinearSolverRegistrars>> get_clone()
60 : const = 0;
61 :
62 : /*!
63 : * \brief Solve the linear equation \f$Ax=b\f$ where \f$A\f$ is the
64 : * `linear_operator` and \f$b\f$ is the `source`.
65 : *
66 : * - The (approximate) solution \f$x\f$ is returned in the
67 : * `initial_guess_in_solution_out` buffer, which also serves to provide an
68 : * initial guess for \f$x\f$. Not all solvers take the initial guess into
69 : * account, but all expect the buffer is sized correctly.
70 : * - The `linear_operator` must be an invocable that takes a `VarsType` as
71 : * const-ref argument and returns a `SourceType` by reference. It also takes
72 : * all `OperatorArgs` as const-ref arguments.
73 : *
74 : * Each solve may mutate the private state of the solver, for example to cache
75 : * quantities to accelerate successive solves for the same operator. Invoke
76 : * `reset` to discard these caches.
77 : */
78 : template <typename LinearOperator, typename VarsType, typename SourceType,
79 : typename... OperatorArgs, typename... Args>
80 1 : Convergence::HasConverged solve(
81 : gsl::not_null<VarsType*> initial_guess_in_solution_out,
82 : const LinearOperator& linear_operator, const SourceType& source,
83 : const std::tuple<OperatorArgs...>& operator_args, Args&&... args) const;
84 :
85 : /// Discard caches from previous solves. Use before solving a different linear
86 : /// operator.
87 1 : virtual void reset() = 0;
88 : };
89 :
90 : /// \cond
91 : template <typename LinearSolverRegistrars>
92 : LinearSolver<LinearSolverRegistrars>::LinearSolver(CkMigrateMessage* m)
93 : : PUP::able(m) {}
94 : /// \endcond
95 :
96 : template <typename LinearSolverRegistrars>
97 : template <typename LinearOperator, typename VarsType, typename SourceType,
98 : typename... OperatorArgs, typename... Args>
99 0 : Convergence::HasConverged LinearSolver<LinearSolverRegistrars>::solve(
100 : const gsl::not_null<VarsType*> initial_guess_in_solution_out,
101 : const LinearOperator& linear_operator, const SourceType& source,
102 : const std::tuple<OperatorArgs...>& operator_args, Args&&... args) const {
103 : return call_with_dynamic_type<Convergence::HasConverged, creatable_classes>(
104 : this, [&initial_guess_in_solution_out, &linear_operator, &source,
105 : &operator_args, &args...](auto* const linear_solver) {
106 : return linear_solver->solve(initial_guess_in_solution_out,
107 : linear_operator, source, operator_args,
108 : std::forward<Args>(args)...);
109 : });
110 : }
111 :
112 : /// Indicates the linear solver uses no preconditioner. It may perform
113 : /// compile-time optimization for this case.
114 1 : struct NoPreconditioner {};
115 :
116 : /*!
117 : * \brief Base class for serial linear solvers that supports factory-creation
118 : * and nested preconditioning.
119 : *
120 : * To enable support for preconditioning in your derived linear solver class,
121 : * pass any type that has a `solve` and a `reset` function as the
122 : * `Preconditioner` template parameter. It can also be an abstract
123 : * `LinearSolver` type, which means that any other linear solver can be used as
124 : * preconditioner. Pass `NoPreconditioner` to disable support for
125 : * preconditioning.
126 : */
127 : template <typename Preconditioner, typename LinearSolverRegistrars>
128 1 : class PreconditionedLinearSolver : public LinearSolver<LinearSolverRegistrars> {
129 : private:
130 0 : using Base = LinearSolver<LinearSolverRegistrars>;
131 :
132 : public:
133 0 : using PreconditionerType =
134 : tmpl::conditional_t<std::is_abstract_v<Preconditioner>,
135 : std::unique_ptr<Preconditioner>, Preconditioner>;
136 0 : struct PreconditionerOption {
137 0 : static std::string name() { return "Preconditioner"; }
138 : // Support factory-creatable preconditioners by storing them as unique-ptrs
139 0 : using type = Options::Auto<PreconditionerType, Options::AutoLabel::None>;
140 0 : static constexpr Options::String help =
141 : "An approximate linear solve in every iteration that helps the "
142 : "algorithm converge.";
143 : };
144 :
145 : protected:
146 0 : PreconditionedLinearSolver() = default;
147 0 : PreconditionedLinearSolver(PreconditionedLinearSolver&&) = default;
148 0 : PreconditionedLinearSolver& operator=(PreconditionedLinearSolver&&) = default;
149 :
150 0 : explicit PreconditionedLinearSolver(
151 : std::optional<PreconditionerType> local_preconditioner);
152 :
153 0 : PreconditionedLinearSolver(const PreconditionedLinearSolver& rhs);
154 0 : PreconditionedLinearSolver& operator=(const PreconditionedLinearSolver& rhs);
155 :
156 : public:
157 0 : ~PreconditionedLinearSolver() override = default;
158 :
159 : /// \cond
160 : explicit PreconditionedLinearSolver(CkMigrateMessage* m);
161 : /// \endcond
162 :
163 0 : void pup(PUP::er& p) override { // NOLINT
164 : PUP::able::pup(p);
165 : if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
166 : p | preconditioner_;
167 : }
168 : }
169 :
170 : /// Whether or not a preconditioner is set
171 1 : bool has_preconditioner() const {
172 : if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
173 : return preconditioner_.has_value();
174 : } else {
175 : return false;
176 : }
177 : }
178 :
179 : /// @{
180 : /// Access to the preconditioner. Check `has_preconditioner()` before calling
181 : /// this function. Calling this function when `has_preconditioner()` returns
182 : /// `false` is an error.
183 : template <
184 : bool Enabled = not std::is_same_v<Preconditioner, NoPreconditioner>,
185 : Requires<Enabled and
186 : not std::is_same_v<Preconditioner, NoPreconditioner>> = nullptr>
187 1 : const Preconditioner& preconditioner() const {
188 : ASSERT(has_preconditioner(),
189 : "No preconditioner is set. Please use `has_preconditioner()` to "
190 : "check before trying to retrieve it.");
191 : if constexpr (std::is_abstract_v<Preconditioner>) {
192 : return **preconditioner_;
193 : } else {
194 : return *preconditioner_;
195 : }
196 : }
197 :
198 : template <
199 : bool Enabled = not std::is_same_v<Preconditioner, NoPreconditioner>,
200 : Requires<Enabled and
201 : not std::is_same_v<Preconditioner, NoPreconditioner>> = nullptr>
202 1 : Preconditioner& preconditioner() {
203 : ASSERT(has_preconditioner(),
204 : "No preconditioner is set. Please use `has_preconditioner()` to "
205 : "check before trying to retrieve it.");
206 : if constexpr (std::is_abstract_v<Preconditioner>) {
207 : return **preconditioner_;
208 : } else {
209 : return *preconditioner_;
210 : }
211 : }
212 :
213 : // Keep the function virtual so derived classes must provide an
214 : // implementation, but also provide an implementation below that derived
215 : // classes can use to reset the preconditioner
216 1 : void reset() override = 0;
217 :
218 : protected:
219 : /// Copy the preconditioner. Useful to implement `get_clone` when the
220 : /// preconditioner has an abstract type.
221 : template <
222 : bool Enabled = not std::is_same_v<Preconditioner, NoPreconditioner>,
223 : Requires<Enabled and
224 : not std::is_same_v<Preconditioner, NoPreconditioner>> = nullptr>
225 1 : std::optional<PreconditionerType> clone_preconditioner() const {
226 : if constexpr (std::is_abstract_v<Preconditioner>) {
227 : return has_preconditioner()
228 : ? std::optional((*preconditioner_)->get_clone())
229 : : std::nullopt;
230 : } else {
231 : return preconditioner_;
232 : }
233 : }
234 :
235 : private:
236 : // Only needed when preconditioning is enabled, but current C++ can't remove
237 : // this variable at compile-time. Keeping the variable shouldn't have any
238 : // noticeable overhead though.
239 1 : std::optional<PreconditionerType> preconditioner_{};
240 : };
241 :
242 : template <typename Preconditioner, typename LinearSolverRegistrars>
243 : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>::
244 : PreconditionedLinearSolver(
245 : std::optional<PreconditionerType> local_preconditioner)
246 : : preconditioner_(std::move(local_preconditioner)) {}
247 :
248 : // Override copy constructors so they can clone abstract preconditioners
249 : template <typename Preconditioner, typename LinearSolverRegistrars>
250 : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>::
251 : PreconditionedLinearSolver(const PreconditionedLinearSolver& rhs)
252 : : Base(rhs) {
253 : if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
254 : preconditioner_ = rhs.clone_preconditioner();
255 : }
256 : }
257 : template <typename Preconditioner, typename LinearSolverRegistrars>
258 : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>&
259 : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>::operator=(
260 : const PreconditionedLinearSolver& rhs) {
261 : Base::operator=(rhs);
262 : if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
263 : preconditioner_ = rhs.clone_preconditioner();
264 : }
265 : return *this;
266 : }
267 :
268 : /// \cond
269 : template <typename Preconditioner, typename LinearSolverRegistrars>
270 : PreconditionedLinearSolver<Preconditioner, LinearSolverRegistrars>::
271 : PreconditionedLinearSolver(CkMigrateMessage* m)
272 : : Base(m) {}
273 : /// \endcond
274 :
275 : template <typename Preconditioner, typename LinearSolverRegistrars>
276 : void PreconditionedLinearSolver<Preconditioner,
277 : LinearSolverRegistrars>::reset() {
278 : if constexpr (not std::is_same_v<Preconditioner, NoPreconditioner>) {
279 : if (has_preconditioner()) {
280 : preconditioner().reset();
281 : }
282 : }
283 : }
284 :
285 : } // namespace LinearSolver::Serial
|