Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #include <array>
7 : #include <cstddef>
8 : #include <ostream>
9 :
10 : #include "DataStructures/Matrix.hpp"
11 : #include "DataStructures/Variables.hpp"
12 : #include "Utilities/DereferenceWrapper.hpp"
13 : #include "Utilities/ErrorHandling/Assert.hpp"
14 : #include "Utilities/Gsl.hpp"
15 :
16 : /// \cond
17 : template <size_t Dim>
18 : class Index;
19 : // IWYU pragma: no_forward_declare Variables
20 : /// \endcond
21 :
22 : namespace apply_matrices_detail {
23 : template <typename ElementType, size_t Dim, bool... DimensionIsIdentity>
24 : struct Impl {
25 : template <typename MatrixType>
26 : static void apply(gsl::not_null<ElementType*> result,
27 : const std::array<MatrixType, Dim>& matrices,
28 : const ElementType* data, const Index<Dim>& extents,
29 : size_t number_of_independent_components);
30 : };
31 :
32 : template <typename MatrixType, size_t Dim>
33 : size_t result_size(const std::array<MatrixType, Dim>& matrices,
34 : const Index<Dim>& extents) {
35 : size_t num_points_result = 1;
36 : for (size_t d = 0; d < Dim; ++d) {
37 : const size_t cols = dereference_wrapper(gsl::at(matrices, d)).columns();
38 : if (cols == 0) {
39 : // An empty matrix is treated as the identity.
40 : num_points_result *= extents[d];
41 : } else {
42 : ASSERT(cols == extents[d],
43 : "Matrix " << d << " has wrong number of columns: " << cols
44 : << " (expected " << extents[d] << ")");
45 : num_points_result *= dereference_wrapper(gsl::at(matrices, d)).rows();
46 : }
47 : }
48 : return num_points_result;
49 : }
50 : } // namespace apply_matrices_detail
51 :
52 : /// @{
53 : /// \ingroup NumericalAlgorithmsGroup
54 : /// \brief Multiply by matrices in each dimension
55 : ///
56 : /// Multiplies each stripe in the first dimension of `u` by
57 : /// `matrices[0]`, each stripe in the second dimension of `u` by
58 : /// `matrices[1]`, and so on. If any of the matrices are empty they
59 : /// will be treated as the identity, but the matrix multiplications
60 : /// will be skipped for increased efficiency.
61 : ///
62 : /// \note The element type stored in the vectors to be transformed may be either
63 : /// `double` or `std::complex<double>`. The matrix, however, must be real. In
64 : /// the case of acting on a vector of complex values, the matrix is treated as
65 : /// having zero imaginary part. This is chosen for efficiency in all
66 : /// use-cases for spectral matrix arithmetic so far encountered.
67 : template <typename VariableTags, typename MatrixType, size_t Dim>
68 1 : void apply_matrices(const gsl::not_null<Variables<VariableTags>*> result,
69 : const std::array<MatrixType, Dim>& matrices,
70 : const Variables<VariableTags>& u,
71 : const Index<Dim>& extents) {
72 : ASSERT(u.number_of_grid_points() == extents.product(),
73 : "Mismatch between extents (" << extents.product()
74 : << ") and variables ("
75 : << u.number_of_grid_points() << ").");
76 : ASSERT(result->number_of_grid_points() ==
77 : apply_matrices_detail::result_size(matrices, extents),
78 : "result has wrong size. Expected "
79 : << apply_matrices_detail::result_size(matrices, extents)
80 : << ", received " << result->number_of_grid_points());
81 : apply_matrices_detail::Impl<typename Variables<VariableTags>::value_type,
82 : Dim>::apply(result->data(), matrices, u.data(),
83 : extents,
84 : u.number_of_independent_components);
85 : }
86 :
87 : template <typename VariableTags, typename MatrixType, size_t Dim>
88 1 : Variables<VariableTags> apply_matrices(
89 : const std::array<MatrixType, Dim>& matrices,
90 : const Variables<VariableTags>& u, const Index<Dim>& extents) {
91 : Variables<VariableTags> result(
92 : apply_matrices_detail::result_size(matrices, extents));
93 : apply_matrices(make_not_null(&result), matrices, u, extents);
94 : return result;
95 : }
96 :
97 : // clang tidy mistakenly fails to identify this as a function definition
98 : template <typename ResultType, typename MatrixType, typename VectorType,
99 : size_t Dim>
100 1 : void apply_matrices(const gsl::not_null<ResultType*> result, // NOLINT
101 : const std::array<MatrixType, Dim>& matrices,
102 : const VectorType& u, const Index<Dim>& extents) {
103 : const size_t number_of_independent_components = u.size() / extents.product();
104 : ASSERT(u.size() == number_of_independent_components * extents.product(),
105 : "The size of the vector u ("
106 : << u.size()
107 : << ") must be a multiple of the number of grid points ("
108 : << extents.product() << ").");
109 : ASSERT(result->size() ==
110 : number_of_independent_components *
111 : apply_matrices_detail::result_size(matrices, extents),
112 : "result has wrong size. Expected "
113 : << number_of_independent_components *
114 : apply_matrices_detail::result_size(matrices, extents)
115 : << ", received " << result->size());
116 : apply_matrices_detail::Impl<typename VectorType::ElementType, Dim>::apply(
117 : result->data(), matrices, u.data(), extents,
118 : number_of_independent_components);
119 : }
120 :
121 : template <typename MatrixType, typename VectorType, size_t Dim>
122 1 : VectorType apply_matrices(const std::array<MatrixType, Dim>& matrices,
123 : const VectorType& u, const Index<Dim>& extents) {
124 : const size_t number_of_independent_components = u.size() / extents.product();
125 : VectorType result(number_of_independent_components *
126 : apply_matrices_detail::result_size(matrices, extents));
127 : apply_matrices(make_not_null(&result), matrices, u, extents);
128 : return result;
129 : }
130 :
131 : template <typename ResultType, typename MatrixType, typename VectorType,
132 : size_t Dim>
133 1 : ResultType apply_matrices(const std::array<MatrixType, Dim>& matrices,
134 : const VectorType& u, const Index<Dim>& extents) {
135 : const size_t number_of_independent_components = u.size() / extents.product();
136 : ResultType result(number_of_independent_components *
137 : apply_matrices_detail::result_size(matrices, extents));
138 : apply_matrices(make_not_null(&result), matrices, u, extents);
139 : return result;
140 : }
141 : /// @}
|