Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Defines base class for all tensor expressions
6 :
7 : #pragma once
8 :
9 : #include <limits>
10 :
11 : #include "Utilities/ForceInline.hpp"
12 : #include "Utilities/TMPL.hpp"
13 :
14 : /// \ingroup TensorExpressionsGroup
15 : /// \brief Marks a class as being a TensorExpression
16 : ///
17 : /// \details
18 : /// The empty base class provides a simple means for checking if a type is a
19 : /// TensorExpression.
20 1 : struct Expression {};
21 :
22 : /// @{
23 : /// \ingroup TensorExpressionsGroup
24 : /// \brief The base class all tensor expression implementations derive from
25 : ///
26 : /// \details
27 : /// ## Tensor equation construction
28 : /// Each derived `TensorExpression` class should be thought of as an expression
29 : /// tree that represents some operation done on or between tensor expressions.
30 : /// Arithmetic operators and other mathematical functions of interest
31 : /// (e.g. `sqrt`) have overloads defined that accept `TensorExpression`s and
32 : /// return a new `TensorExpression` representing the result tensor of such an
33 : /// operation. In this way, an equation written with `TensorExpression`s will
34 : /// generate an expression tree where the internal and leaf nodes are instances
35 : /// of the derived `TensorExpression` classes. For example, `tenex::AddSub`
36 : /// defines an internal node for handling the addition and subtraction
37 : /// operations between tensors expressions, while `tenex::TensorAsExpression`
38 : /// defines a leaf node that represents a single `Tensor` that appears in the
39 : /// equation.
40 : ///
41 : /// ## Tensor equation evaluation
42 : /// The overall tree for an equation and the order in which we traverse the tree
43 : /// define the order of operations done to compute the resulting LHS `Tensor`.
44 : /// The evaluation is done by `tenex::evaluate`, which traverses the whole tree
45 : /// once for each unique LHS component in order to evaluate the full LHS
46 : /// `Tensor`. There are two different traversals currently implemented that are
47 : /// chosen from, depending on the tensor equation being evaluated:
48 : /// 1. **Evaluate the whole tree as one expression** using in-order traversal.
49 : /// This is like generating and solving a one-liner of the whole equation.
50 : /// 2. **Split up the tree into subexpressions** that are each evaluated with
51 : /// in-order traversal to successively "accumulate" a LHS result component of
52 : /// the equation. This is like splitting the equation up and solving pieces of
53 : /// it at a time with multiple lines of assignments/updates (see details below).
54 : ///
55 : /// ## Equation splitting details
56 : /// Splitting up the tree and evaluating subexpressions is beneficial when we
57 : /// believe it to lead to a better runtime than if we were to compute the whole
58 : /// expression as a one-liner. One important use case is when the `Tensor`s in
59 : /// the equation hold components whose data type is `DataVector`. From
60 : /// benchmarking, it was found that the runtime of `DataVector` expressions
61 : /// scales poorly as we increase the number of operations. For example, for an
62 : /// inner product with 256 sums of products, instead of adding 256 `DataVector`
63 : /// products in one line (e.g. `result = A*B + C*D + E*F + ...;`), it's much
64 : /// faster to, say, set the result to be the sum of the first 8 products, then
65 : /// `+=` the next 8, and so forth. This is what is meant by "accumulating" the
66 : /// LHS result tensor, and what the `TensorExpression` splitting emulates. Note
67 : /// that while 8 is the number used in this example, the exact optimal number of
68 : /// operations will be hardware-dependent, but probably not something we need to
69 : /// really worry about fine-tuning. However, a ballpark estimate for a "good"
70 : /// number of operations may vary greatly depending on the data type of the
71 : /// components (e.g. `double` vs. `DataVector`), which is something important
72 : /// to at least coarsely tune.
73 : ///
74 : /// ### How the tree is split up
75 : /// Let's define the **primary path** to be the path in the tree going from the
76 : /// root node to the leftmost leaf. The overall tree contains subtrees
77 : /// represented by different `TensorExpression`s in the equation. Certain
78 : /// subtrees are marked as the starting and/or ending points of these "pieces"
79 : /// of the equation. Let's define a **leg** to be a "segment" along the primary
80 : /// path delineated by a starting and ending expression subtree. These
81 : /// delineations are made where we decide there are enough operations in a
82 : /// subtree that it would be wise to split at that point. What is considered to
83 : /// be "enough" operations is specialized based on the data type held by the
84 : /// `Tensor`s in the expression (see `tenex::max_num_ops_in_sub_expression`).
85 : ///
86 : /// ### How a split tree is traversed and evaluated
87 : /// We recurse down the primary path, visiting each expression subtree until we
88 : /// reach the start of the lowest leg, then initialize the LHS result component
89 : /// we're wanting to compute to be the result of this lowest expression. Then,
90 : /// we recurse back up to the expression subtree that is starting point of the
91 : /// leg "above" it and compute that subtree. This time, however, when
92 : /// recursively evaluating this higher subtree, we substitute in the current LHS
93 : /// result for that lower subtree that we have already computed. This is
94 : /// repeated as we "climb up" the primary path to successively accumulate the
95 : /// result component.
96 : ///
97 : /// **Note:** The primary path is currently implemented as the path specified
98 : /// above, but there's no reason it couldn't be reimplemented to be a different
99 : /// path. The idea with the current implementation is to select a path from root
100 : /// to leaf that is long so we have more flexibility in splitting, should we
101 : /// want to. When evaluating, we *could* implement the traversal to take a
102 : /// different path, but currently, derived `TensorExpression`s that represent
103 : /// commutative binary operations are instantiated with the larger subtree being
104 : /// the left child and the smaller subtree being the right child. By
105 : /// constructing it this way, we elongate the leftmost path, which will allow
106 : /// for increased splitting.
107 : ///
108 : /// ## Requirements for derived `TensorExpression` classes
109 : /// Each derived `TensorExpression` class must define the following aliases and
110 : /// members:
111 : /// - `private` variables that store its operands' derived `TensorExpression`s.
112 : /// We make these non-`const` to allow for move construction.
113 : /// - Constructor that initializes the above `private` operand members
114 : /// - alias `type`: The data type of the data being stored in the result of the
115 : /// expression, e.g. `double`, `DataVector`
116 : /// - alias `symmetry`: The ::Symmetry of the result of the expression
117 : /// - alias `index_list`: The list of \ref SpacetimeIndex "TensorIndexType"s of
118 : /// the result of the expression
119 : /// - alias `args_list`: The list of generic `TensorIndex`s of the result of the
120 : /// expression
121 : /// - variable `static constexpr size_t num_tensor_indices`: The number of
122 : /// tensor indices in the result of the expression
123 : /// - variable `static constexpr size_t num_ops_left_child`: The number of
124 : /// arithmetic tensor operations done in the subtree for the expression's left
125 : /// operand. If the expression represents a unary operation, their only child is
126 : /// considered the left child. If the expression is a leaf node, then this value
127 : /// should be set to 0 since retrieving a value at the leaf involves 0
128 : /// arithmetic tensor operations.
129 : /// - variable `static constexpr size_t num_ops_right_child`: The number of
130 : /// arithmetic tensor operations done in the expression's right operand. If the
131 : /// expression represents a unary operation or is leaf node, this should be set
132 : /// to 0 because there is no right child.
133 : /// - variable `static constexpr size_t num_ops_subtree`: The number of
134 : /// arithmetic tensor operations done in the subtree represented by the
135 : /// expression. For `AddSub`, for example, this is
136 : /// `num_ops_left_child + num_ops_right_child + 1`, the sum of the number of
137 : /// operations in each operand's subtrees plus one for the operation done for
138 : /// the expression, itself.
139 : /// - variable `static constexpr size_t
140 : /// height_relative_to_closest_tensor_leaf_in_subtree` : The height of an
141 : /// expression's node in the overall expression tree relative to the closest
142 : /// `TensorAsExpression` leaf in its subtree. This is stored so that we can
143 : /// traverse from the root along the shortest path to a `Tensor` when
144 : /// retrieving the size of a component from the RHS expression (see
145 : /// `get_rhs_tensor_component_size()` below). Non-`Tensor` leaves (e.g.
146 : /// `NumberAsExpression`) are defined to have maximum height
147 : /// `std::numeric_limits<size_t>::max()` to encode that they are maximally
148 : /// far away from their nearest `Tensor` descendant, since the expression's
149 : /// subtree (a leaf) can never have a `TensorAsExpression` descedant from it.
150 : /// This maximal height is leveraged by `get_rhs_tensor_component_size()` so
151 : /// that in traversing the expression tree to find a `Tensor`, it will never
152 : /// take the path that ends in a non-`Tensor` leaf because it is the worst path
153 : /// option.
154 : /// - function `decltype(auto) get(const std::array<size_t, num_tensor_indices>&
155 : /// result_multi_index) const`: Accepts a multi-index for the result tensor
156 : /// represented by the expression and returns the computed result of the
157 : /// expression at that multi-index. This should call the operands' `get`
158 : /// functions in order to recursively compute the result of the expression.
159 : /// - function template
160 : /// `template <typename LhsTensor> void assert_lhs_tensor_not_in_rhs_expression(
161 : /// const gsl::not_null<LhsTensor*> lhs_tensor) const`: Asserts that the LHS
162 : /// `Tensor` we're computing does not also appear in the RHS `TensorExpression`.
163 : /// We define this because if a tree is split up, then the LHS `Tensor` will
164 : /// generally not be computed correctly because the LHS components will be
165 : /// updated as we traverse the split tree.
166 : /// - function template
167 : /// \code
168 : /// template <typename LhsTensorIndices, typename LhsTensor>
169 : /// void assert_lhs_tensorindices_same_in_rhs(
170 : /// const gsl::not_null<LhsTensor*> lhs_tensor) const;
171 : /// \endcode
172 : /// Asserts that any instance of the LHS `Tensor` in the RHS `TensorExpression`
173 : /// uses the same generic index order that the LHS uses. We define this because
174 : /// if a tree is not split up, it's safe to use the LHS `Tensor` on the RHS if
175 : /// the generic index order is the same. In these cases, `tenex::update` should
176 : /// be used instead of `tenex::evaluate`. See the documentation for
177 : /// `tenex::update` for more details and `tenex::detail::evaluate_impl` for why
178 : /// this is safe to do.
179 : /// - function `size_t get_rhs_tensor_component_size() const`: Gets the size of
180 : /// a component from a `Tensor` in an expression's subtree of the RHS
181 : /// expression. This is used to size LHS components, if needed. Utilizes
182 : /// `height_relative_to_closest_tensor_leaf_in_subtree` to recursively find the
183 : /// nearest `TensorAsExpression` descendant leaf.
184 : ///
185 : /// Each derived `TensorExpression` class must also define the following
186 : /// members, which have real meaning for the expression *only* if it ends up
187 : /// belonging to the primary path of the tree that is traversed:
188 : /// - variable `static constexpr bool is_primary_start`: If on the primary path,
189 : /// whether or not the expression is a starting point of a leg. This is true
190 : /// when there are enough operations to warrant splitting (see
191 : /// `tenex::max_num_ops_in_sub_expression`).
192 : /// - variable `static constexpr bool is_primary_end`: If on the primary path,
193 : /// whether or not the expression is an ending point of a leg. This is true when
194 : /// the expression's child along the primary path is a starting point of a leg.
195 : /// - variable `static constexpr size_t num_ops_to_evaluate_primary_left_child`:
196 : /// If on the primary path, this is the remaining number of arithmetic tensor
197 : /// operations that need to be done in the subtree of the child along the
198 : /// primary path, given that we will have already computed the whole subtree at
199 : /// the next lowest leg's starting point.
200 : /// - variable
201 : /// `static constexpr size_t num_ops_to_evaluate_primary_right_child`:
202 : /// If on the primary path, this is the remaining number of arithmetic tensor
203 : /// operations that need to be done in the right operand's subtree. Because
204 : /// the branches off of the primary path currently are not split up in any way,
205 : /// this currently should simply be equal to `num_ops_right_child`. If logic is
206 : /// added to split up these branches, logic will need to be added to compute
207 : /// this remaining number of operations in the right subtree.
208 : /// - variable `static constexpr size_t num_ops_to_evaluate_primary_subtree`:
209 : /// If on the primary path, this is the remaining number of arithmetic tensor
210 : /// operations that need to be done for this expression's subtree, given that we
211 : /// will have already computed the subtree at the next lowest leg's starting
212 : /// point. For example, for `tenex::AddSub`, this is just
213 : /// `num_ops_to_evaluate_primary_left_child +
214 : /// num_ops_to_evaluate_primary_right_child + 1` (the extra 1 for the `+` or `-`
215 : /// operation itself).
216 : /// - variable
217 : /// `static constexpr bool primary_child_subtree_contains_primary_start`:
218 : /// If on the primary path, whether or not the expression's child along the
219 : /// primary path is a subtree that contains a starting point of a leg along the
220 : /// primary path. In other words, whether or not there is a split on the primary
221 : /// path lower than this expression. When evaluating a split tree, this is
222 : /// useful because it tells us we need to keep recursing down to a lower leg and
223 : /// evaluate that lower subtree first before evaluating the current subtree.
224 : /// - variable `static constexpr bool primary_subtree_contains_primary_start`:
225 : /// If on the primary path, whether or not this subtree contains a starting
226 : /// point of a leg along the primary path. In other words, whether or not there
227 : /// is a split on the primary path at this expression or beneath it.
228 : /// - function template
229 : /// \code
230 : /// template <typename ResultType>
231 : /// decltype(auto) get_primary(
232 : /// ResultType& result_component,
233 : /// const std::array<size_t, num_tensor_indices>& result_multi_index) const
234 : /// \endcode
235 : /// This is similar to the required `get` function described above, but this
236 : /// should be used when the tree is split up. The main difference with this
237 : /// function is that it takes the current result component (that we're
238 : /// computing) as an argument, and when we hit the starting point of the next
239 : /// lowest leg on the primary path when recursively evaluating the current leg,
240 : /// we substitute in the current LHS result for the subtree that we have already
241 : /// computed. This function should call `get_primary` on the child on the
242 : /// primary path and `get` on the other child, if one exists.
243 : /// - function template
244 : /// \code
245 : /// template <typename ResultType>
246 : /// void evaluate_primary_subtree(
247 : /// ResultType& result_component,
248 : /// const std::array<size_t, num_tensor_indices>& result_multi_index) const
249 : /// \endcode
250 : /// This should first recursively evaluate the legs beneath it on the primary
251 : /// path, then if the expression itself is the start of a leg, it should
252 : /// evaluate this leg by calling the expression's own `get_primary` to compute
253 : /// it and update the result component being accumulated. `tenex::evaluate`
254 : /// should call this function on the root node for the whole tree if there is
255 : /// determined to be any splits in the tree.
256 : ///
257 : /// ## Data type support
258 : /// Which types can be used, which operations with which types can be performed,
259 : /// and other type-specific support and configuration can be found in
260 : /// `DataStructures/Tensor/Expressions/DataTypeSupport.hpp`. To add support for
261 : /// equation terms with a certain type or to modify the configuration for a
262 : /// type that is already supported, see the contents of that file and modify
263 : /// settings as necessary.
264 : ///
265 : /// ## Current advice for improving and extending `TensorExpression`s
266 : /// - Derived `TensorExpression` classes (or the overloads that produce them)
267 : /// should include `static_assert`s for ensuring mathematical correctness
268 : /// wherever reasonable
269 : /// - Minimize breadth in the tree where possible because benchmarking inner
270 : /// products has shown that increased tree breadth can cause slower runtimes.
271 : /// In addition, more breadth means a decreased ability to split up the tree
272 : /// along the primary path.
273 : /// - Minimize the number of multi-index transformations that need to be done
274 : /// when evaluating the tree. For some operations like addition, the associated
275 : /// multi-indices of the two operands needs to be computed from the multi-index
276 : /// of the result, which may involve reordering and/or shifting the values of
277 : /// the result index. It's good to minimize the number of these kinds of
278 : /// transformations from result to operand multi-index where we can.
279 : /// - Unless the implementation of Tensor_detail::Structure changes, it's not
280 : /// advised for the derived `TensorExpression` classes to have anything that
281 : /// would instantiate the Tensor_detail::Structure of the tensor that would
282 : /// result from the expression. This is really only a problem when the result of
283 : /// the expression would be a tensor with many components, because the compile
284 : /// time of the mapping between storage indices and multi-indices within
285 : /// Tensor_detail::Structure scales very poorly with the number of components.
286 : /// It's important to keep in mind that while SpECTRE currently only supports
287 : /// creating `Tensor`s up to rank 4, there is nothing preventing the represented
288 : /// result tensor of a expression being higher rank, e.g.
289 : /// `R(ti_j, ti_b, ti_A) * (S(ti_d, ti_a, ti_B, ti_C) * T(ti_J, ti_k, ti_l))`
290 : /// contains an intermediate outer product expression
291 : /// `S(ti_d, ti_a, ti_B, ti_C) * T(ti_J, ti_k, ti_l)` that represents a rank 7
292 : /// tensor, even though a rank 7 `Tensor` is never instantiated. Having the
293 : /// outer product expression instantiate the Tensor_detail::Structure of this
294 : /// intermediate result currently leads to an unreasonable compile time.
295 : ///
296 : /// \tparam Derived the derived class needed for
297 : /// [CRTP](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern)
298 : /// \tparam DataType the type of the data being stored in the `Tensor`s
299 : /// \tparam Symm the ::Symmetry of the Derived class
300 : /// \tparam IndexList the list of \ref SpacetimeIndex "TensorIndexType"s
301 : /// \tparam Args typelist of the tensor indices, e.g. types of `ti::a` and
302 : /// `ti::b` in `F(ti::a, ti::b)`
303 : /// \cond HIDDEN_SYMBOLS
304 : template <typename Derived, typename DataType, typename Symm,
305 : typename IndexList, typename Args = tmpl::list<>,
306 : typename ReducedArgs = tmpl::list<>>
307 : struct TensorExpression;
308 : /// \endcond
309 :
310 : template <typename Derived, typename DataType, typename Symm,
311 : typename... Indices, template <typename...> class ArgsList,
312 : typename... Args>
313 1 : struct TensorExpression<Derived, DataType, Symm, tmpl::list<Indices...>,
314 : ArgsList<Args...>> : public Expression {
315 : static_assert(sizeof...(Args) == 0 or sizeof...(Args) == sizeof...(Indices),
316 : "the number of Tensor indices must match the number of "
317 : "components specified in an expression.");
318 : /// The type of the data being stored in the `Tensor`s
319 1 : using type = DataType;
320 : /// The ::Symmetry of the `Derived` class
321 1 : using symmetry = Symm;
322 : /// The list of \ref SpacetimeIndex "TensorIndexType"s
323 1 : using index_list = tmpl::list<Indices...>;
324 : /// Typelist of the tensor indices, e.g. types of `ti_a` and `ti_b`
325 : /// in `F(ti_a, ti_b)`
326 1 : using args_list = ArgsList<Args...>;
327 : /// The number of tensor indices of the `Derived` class
328 1 : static constexpr auto num_tensor_indices = tmpl::size<index_list>::value;
329 :
330 0 : virtual ~TensorExpression() = 0;
331 :
332 : /// @{
333 : /// Derived is casted down to the derived class. This is enabled by the
334 : /// [CRTP](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern)
335 : ///
336 : /// \returns const TensorExpression<Derived, DataType, Symm, IndexList,
337 : /// ArgsList<Args...>>&
338 1 : SPECTRE_ALWAYS_INLINE const auto& operator~() const {
339 : return static_cast<const Derived&>(*this);
340 : }
341 : /// @}
342 : };
343 :
344 : template <typename Derived, typename DataType, typename Symm,
345 : typename... Indices, template <typename...> class ArgsList,
346 : typename... Args>
347 : TensorExpression<Derived, DataType, Symm, tmpl::list<Indices...>,
348 : ArgsList<Args...>>::~TensorExpression() = default;
349 : /// @}
|