Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Defines Expression Templates for contracting tensor indices on a single
6 : /// Tensor
7 :
8 : #pragma once
9 :
10 : #include <array>
11 : #include <cstddef>
12 : #include <limits>
13 : #include <type_traits>
14 : #include <utility>
15 :
16 : #include "DataStructures/Tensor/Expressions/NumberAsExpression.hpp"
17 : #include "DataStructures/Tensor/Expressions/TensorExpression.hpp"
18 : #include "DataStructures/Tensor/Expressions/TensorIndex.hpp"
19 : #include "DataStructures/Tensor/Expressions/TimeIndex.hpp"
20 : #include "DataStructures/Tensor/IndexType.hpp"
21 : #include "DataStructures/Tensor/Symmetry.hpp"
22 : #include "Utilities/ForceInline.hpp"
23 : #include "Utilities/Gsl.hpp"
24 : #include "Utilities/MakeArray.hpp"
25 : #include "Utilities/TMPL.hpp"
26 :
27 : /*!
28 : * \ingroup TensorExpressionsGroup
29 : * Holds all possible TensorExpressions currently implemented
30 : */
31 : namespace tenex {
32 : namespace detail {
33 : template <typename I1, typename I2>
34 : using indices_contractible = std::bool_constant<
35 : I1::ul != I2::ul and
36 : std::is_same_v<typename I1::Frame, typename I2::Frame> and
37 : ((I1::index_type == I2::index_type and I1::dim == I2::dim) or
38 : // If one index is spacetime and the other is spatial, the indices can
39 : // be contracted if they have the same number of spatial dimensions
40 : (I1::index_type == IndexType::Spacetime and I1::dim == I2::dim + 1) or
41 : (I2::index_type == IndexType::Spacetime and I1::dim + 1 == I2::dim))>;
42 :
43 : template <size_t NumUncontractedIndices>
44 : constexpr size_t get_num_contracted_index_pairs(
45 : const std::array<size_t, NumUncontractedIndices>&
46 : uncontracted_tensor_index_values) {
47 : size_t count = 0;
48 : for (size_t i = 0; i < NumUncontractedIndices; i++) {
49 : const size_t current_value = gsl::at(uncontracted_tensor_index_values, i);
50 : // Concrete time indices are not contracted
51 : if (not detail::is_time_index_value(current_value)) {
52 : const size_t opposite_value_to_find =
53 : get_tensorindex_value_with_opposite_valence(current_value);
54 : for (size_t j = i + 1; j < NumUncontractedIndices; j++) {
55 : if (opposite_value_to_find ==
56 : gsl::at(uncontracted_tensor_index_values, j)) {
57 : // We found both the lower and upper version of a generic index in the
58 : // list of generic indices, so we return this pair's positions
59 : count++;
60 : }
61 : }
62 : }
63 : }
64 :
65 : return count;
66 : }
67 :
68 : /// \brief Computes the mapping from the positions of indices in the resultant
69 : /// contracted tensor to their positions in the operand uncontracted tensor, as
70 : /// well as the positions of the index pairs in the operand uncontracted tensor
71 : /// that we wish to contract
72 : ///
73 : /// \details
74 : /// Both quantities returned are computed in and returned from the same
75 : /// function so as to not repeat overlapping necessary work that would be in two
76 : /// separate functions
77 : ///
78 : /// \tparam NumContractedIndexPairs the number of pairs of indices that will be
79 : /// contracted
80 : /// \tparam NumUncontractedIndices the number of indices in the operand
81 : /// expression that we wish to contract
82 : /// \param uncontracted_tensor_index_values the values of the `TensorIndex`s
83 : /// used to generically represent the uncontracted operand expression
84 : /// \return the mapping from the positions of indices in the resultant
85 : /// contracted tensor to their positions in the operand uncontracted tensor, as
86 : /// well as the positions of the index pairs in the operand uncontracted tensor
87 : /// that we wish to contract
88 : template <size_t NumContractedIndexPairs, size_t NumUncontractedIndices>
89 : constexpr std::pair<
90 : std::array<size_t, NumUncontractedIndices - NumContractedIndexPairs * 2>,
91 : std::array<std::pair<size_t, size_t>, NumContractedIndexPairs>>
92 : get_index_transformation_and_contracted_pair_positions(
93 : const std::array<size_t, NumUncontractedIndices>&
94 : uncontracted_tensor_index_values) {
95 : static_assert(NumUncontractedIndices >= 2,
96 : "There should be at least 2 indices");
97 : // Positions of indices in the result tensor (ones that are not contracted)
98 : // mapped to their locations in the uncontracted operand expression
99 : std::array<size_t, NumUncontractedIndices - NumContractedIndexPairs * 2>
100 : index_transformation{};
101 : // Positions of contracted index pairs in the uncontracted operand expression
102 : std::array<std::pair<size_t, size_t>, NumContractedIndexPairs>
103 : contracted_index_pair_positions{};
104 :
105 : // Marks whether or not we have already paired an index in the uncontracted
106 : // operand expression with an index to contract it with
107 : std::array<bool, NumUncontractedIndices> index_mapping_set =
108 : make_array<NumUncontractedIndices, bool>(false);
109 :
110 : // Index of `contracted_index_pair_positions` that we're currently assigning
111 : size_t contracted_map_index_to_assign = 0;
112 : // Index of `index_transformation` that we're currently assigning
113 : size_t not_contracted_map_index_to_assign =
114 : NumUncontractedIndices - NumContractedIndexPairs * 2 - 1;
115 : // Iteration is performed backwards here for the reasons below, but note that
116 : // no benchmarking has been done to confirm the backwards iteration order
117 : // makes a meaningful improvement in runtime vs. iterating forward:
118 : //
119 : // Here, we iterate backwards to find the "rightmost" contracted indices and
120 : // proceed leftwards to find the other contracted index pairs so that the
121 : // "rightmost" pairs in the expression to contract appear first in the list of
122 : // contracted index pairs (`contracted_index_pair_positions`). Then, when we
123 : // later iterate over all of the multi-indices to sum, each time we go to grab
124 : // the next multi-index to sum and we need to compute what that next
125 : // multi-index is, we can choose to increment the concrete values of the
126 : // rightmost index pair.
127 : //
128 : // This has not been benchmarked to confirm, but the thought with making this
129 : // choice to order the contracted pairs from right to left is that this may
130 : // help with spatial locality for caching when we are contracting a single
131 : // tensor (`TensorAsExpression`) that is non-symmetric. For example, let's say
132 : // we are contracting `R(ti::A, ti::B, ti::a, ti::b)` and the current
133 : // multi-index we just accessed (one of the components to sum) is
134 : // `{0, 0, 0, 0}`, representing \f$R^{00}{}_{00}\f$. To find a next
135 : // multi-index to sum, we can simply increase one of the pairs' concrete
136 : // values by 1, e.g. the next multi-index to access could be `{0, 1, 0, 1}`
137 : // for \f$R^{01}{}_{01}\f$ or `{1, 0, 1, 0}` \f$R^{10}{}_{10}\f$. In this
138 : // case, the idea is that choosing to increment the concrete values of the
139 : // rightmost pair could provide better spatial locality, as `{0, 1, 0, 1}` is
140 : // closer in memory to `{0, 0, 0, 0}` than `{1, 0, 1, 0}` is. It's important
141 : // to note that this, of course, depends on the implementation of
142 : // `Tensor_detail::Structure` - specifically, the order in which the
143 : // components are laid out in memory.
144 : //
145 : // Note: the loop terminates when underflow causes `i` to wrap back around to
146 : // the maximum `size_t` value. If we use the condition `i > 0`, we miss the
147 : // final iteration, and if we use `i >= 0`, we never terminate because `i`
148 : // is always positive.
149 : for (size_t i = NumUncontractedIndices - 1; i < NumUncontractedIndices; i--) {
150 : if (not gsl::at(index_mapping_set, i)) {
151 : const size_t current_value = gsl::at(uncontracted_tensor_index_values, i);
152 : // Concrete time indices are not contracted
153 : if (not detail::is_time_index_value(current_value)) {
154 : const size_t opposite_value_to_find =
155 : get_tensorindex_value_with_opposite_valence(current_value);
156 : for (size_t j = i - 1; j < NumUncontractedIndices; j--) {
157 : if (opposite_value_to_find ==
158 : gsl::at(uncontracted_tensor_index_values, j)) {
159 : // We found both the lower and upper version of a generic index in
160 : // the list of generic indices, pair them up
161 : gsl::at(contracted_index_pair_positions,
162 : contracted_map_index_to_assign)
163 : .first = i;
164 : gsl::at(contracted_index_pair_positions,
165 : contracted_map_index_to_assign)
166 : .second = j;
167 : contracted_map_index_to_assign++;
168 : // Mark that we've found contraction partners for these two indices
169 : gsl::at(index_mapping_set, i) = true;
170 : gsl::at(index_mapping_set, j) = true;
171 : break;
172 : }
173 : }
174 : }
175 : if (not gsl::at(index_mapping_set, i)) {
176 : // If we haven't assigned this index to a partner, it is not an index
177 : // that is contracted, so record its position mapping from contracted to
178 : // uncontracted tensor indices
179 : gsl::at(index_transformation, not_contracted_map_index_to_assign) = i;
180 : not_contracted_map_index_to_assign--;
181 : }
182 : }
183 : }
184 :
185 : return std::pair{index_transformation, contracted_index_pair_positions};
186 : }
187 :
188 : /// \brief Computes type information for the tensor expression that results from
189 : /// a contraction, as well as information internally useful for carrying out the
190 : /// contraction
191 : ///
192 : /// \tparam UncontractedTensorExpression the operand uncontracted
193 : /// `TensorExpression` being contracted
194 : /// \tparam DataType the data type of the `Tensor` components
195 : /// \tparam UncontractedSymm the ::Symmetry of the operand uncontracted
196 : /// `TensorExpression`
197 : /// \tparam UncontractedIndexList the list of
198 : /// \ref SpacetimeIndex "TensorIndexType"s of the operand uncontracted
199 : /// `TensorExpression`
200 : /// \tparam UncontractedTensorIndexList the list of generic `TensorIndex`s used
201 : /// for the the operand uncontracted `TensorExpression`
202 : /// \tparam NumContractedIndices the number of indices in the resultant tensor
203 : /// after contracting
204 : /// \tparam NumIndexPairsToContract the number of pairs of indices that will be
205 : /// contracted
206 : template <typename UncontractedTensorExpression, typename DataType,
207 : typename UncontractedSymm, typename UncontractedIndexList,
208 : typename UncontractedTensorIndexList, size_t NumContractedIndices,
209 : size_t NumIndexPairsToContract,
210 : typename ContractedIndexSequence =
211 : std::make_index_sequence<NumContractedIndices>,
212 : typename IndexPairsToContractSequence =
213 : std::make_index_sequence<NumIndexPairsToContract>>
214 : struct ContractedType;
215 :
216 : template <typename UncontractedTensorExpression, typename DataType,
217 : template <typename...> class UncontractedSymmList,
218 : typename... UncontractedSymm,
219 : template <typename...> class UncontractedIndexList,
220 : typename... UncontractedIndices,
221 : template <typename...> class UncontractedTensorIndexList,
222 : typename... UncontractedTensorIndices, size_t NumContractedIndices,
223 : size_t NumIndexPairsToContract, size_t... ContractedInts,
224 : size_t... IndexPairsToContractInts>
225 : struct ContractedType<UncontractedTensorExpression, DataType,
226 : UncontractedSymmList<UncontractedSymm...>,
227 : UncontractedIndexList<UncontractedIndices...>,
228 : UncontractedTensorIndexList<UncontractedTensorIndices...>,
229 : NumContractedIndices, NumIndexPairsToContract,
230 : std::index_sequence<ContractedInts...>,
231 : std::index_sequence<IndexPairsToContractInts...>> {
232 : static constexpr size_t num_uncontracted_tensor_indices =
233 : sizeof...(UncontractedTensorIndices);
234 : static constexpr std::array<size_t, num_uncontracted_tensor_indices>
235 : uncontracted_tensorindex_values = {{UncontractedTensorIndices::value...}};
236 : static constexpr size_t num_indices_to_contract =
237 : num_uncontracted_tensor_indices - NumContractedIndices;
238 : static constexpr size_t num_contracted_index_pairs =
239 : num_indices_to_contract / 2;
240 : // First item in pair:
241 : // - index transformation: mapping from the positions of indices in the
242 : // resultant contracted tensor to their positions in the operand
243 : // uncontracted tensor
244 : // Second item in pair:
245 : // contracted index pair positions: positions of the index pairs in the
246 : // operand uncontracted tensor that we wish to contract
247 : static constexpr inline std::pair<
248 : std::array<size_t, NumContractedIndices>,
249 : std::array<std::pair<size_t, size_t>, num_contracted_index_pairs>>
250 : index_transformation_and_contracted_pair_positions =
251 : get_index_transformation_and_contracted_pair_positions<
252 : num_contracted_index_pairs, num_uncontracted_tensor_indices>(
253 : uncontracted_tensorindex_values);
254 :
255 : // Make sure it's mathematically legal to perform the requested contraction
256 : static_assert(((... and
257 : (indices_contractible<
258 : typename tmpl::at_c<
259 : tmpl::list<UncontractedIndices...>,
260 : index_transformation_and_contracted_pair_positions
261 : .second[IndexPairsToContractInts]
262 : .first>,
263 : typename tmpl::at_c<
264 : tmpl::list<UncontractedIndices...>,
265 : index_transformation_and_contracted_pair_positions
266 : .second[IndexPairsToContractInts]
267 : .second>>::value))),
268 : "Cannot contract the requested indices.");
269 :
270 : static constexpr inline std::array<IndexType, num_uncontracted_tensor_indices>
271 : uncontracted_index_types = {{UncontractedIndices::index_type...}};
272 :
273 : // First concrete values of contracted indices to sum. This is to handle
274 : // cases when we have generic spatial `TensorIndex`s used for spacetime
275 : // indices, as the first concrete index value to contract will be 1 (first
276 : // spatial index) instead of 0 (the time index). Contracted index pairs will
277 : // have different "starting" concrete indices when one index in the pair is a
278 : // spatial spacetime index and the other is not.
279 : static constexpr inline std::array<std::pair<size_t, size_t>,
280 : num_contracted_index_pairs>
281 : contracted_index_first_values = []() {
282 : std::array<std::pair<size_t, size_t>, num_contracted_index_pairs>
283 : first_values{};
284 : for (size_t i = 0; i < num_contracted_index_pairs; i++) {
285 : // Assign the value for first index in a pair to be the smallest value
286 : // used in the terms being summed: assign to 1 if we have a spacetime
287 : // index where a generic spatial index has been used, otherwise assign
288 : // to 0.
289 : gsl::at(first_values, i).first = static_cast<size_t>(
290 : gsl::at(
291 : uncontracted_index_types,
292 : gsl::at(
293 : index_transformation_and_contracted_pair_positions.second,
294 : i)
295 : .first) == IndexType::Spacetime and
296 : gsl::at(
297 : uncontracted_tensorindex_values,
298 : gsl::at(
299 : index_transformation_and_contracted_pair_positions.second,
300 : i)
301 : .first) >= TensorIndex_detail::spatial_sentinel);
302 : // Assign the value for second index in a pair to be the smallest
303 : // value used in the terms being summed (assigned with same logic
304 : // described above for the first index in the pair)
305 : gsl::at(first_values, i).second = static_cast<size_t>(
306 : gsl::at(
307 : uncontracted_index_types,
308 : gsl::at(
309 : index_transformation_and_contracted_pair_positions.second,
310 : i)
311 : .second) == IndexType::Spacetime and
312 : gsl::at(
313 : uncontracted_tensorindex_values,
314 : gsl::at(
315 : index_transformation_and_contracted_pair_positions.second,
316 : i)
317 : .second) >= TensorIndex_detail::spatial_sentinel);
318 : }
319 : return first_values;
320 : }();
321 :
322 : static constexpr inline std::array<size_t, num_uncontracted_tensor_indices>
323 : uncontracted_index_dims = {{UncontractedIndices::dim...}};
324 :
325 : // The number of terms to sum for this expression's contraction
326 : static constexpr size_t num_terms_summed = []() {
327 : size_t num_terms =
328 : gsl::at(
329 : uncontracted_index_dims,
330 : gsl::at(index_transformation_and_contracted_pair_positions.second,
331 : 0)
332 : .first) -
333 : gsl::at(contracted_index_first_values, 0).first;
334 : for (size_t i = 1; i < num_contracted_index_pairs; i++) {
335 : num_terms *=
336 : gsl::at(
337 : uncontracted_index_dims,
338 : gsl::at(index_transformation_and_contracted_pair_positions.second,
339 : i)
340 : .first) -
341 : gsl::at(contracted_index_first_values, i).first;
342 : }
343 : return num_terms;
344 : }();
345 : static_assert(num_terms_summed > 0,
346 : "There should be a non-zero number of components to sum in the "
347 : "contraction.");
348 : // The ::Symmetry of the result of the contraction
349 : using symmetry =
350 : Symmetry<tmpl::at_c<UncontractedSymmList<UncontractedSymm...>,
351 : index_transformation_and_contracted_pair_positions
352 : .first[ContractedInts]>::value...>;
353 : // The list of \ref SpacetimeIndex "TensorIndexType"s of the result of the
354 : // contraction
355 : using index_list =
356 : tmpl::list<tmpl::at_c<UncontractedIndexList<UncontractedIndices...>,
357 : index_transformation_and_contracted_pair_positions
358 : .first[ContractedInts]>...>;
359 : // The list of generic `TensorIndex`s of the result of the contraction
360 : using tensorindex_list = tmpl::list<
361 : tmpl::at_c<UncontractedTensorIndexList<UncontractedTensorIndices...>,
362 : index_transformation_and_contracted_pair_positions
363 : .first[ContractedInts]>...>;
364 : // The `TensorExpression` type that results from performing the contraction
365 : using type = TensorExpression<UncontractedTensorExpression, DataType,
366 : symmetry, index_list, tensorindex_list>;
367 : };
368 : } // namespace detail
369 :
370 : /*!
371 : * \ingroup TensorExpressionsGroup
372 : */
373 : template <typename T, typename X, typename Symm, typename IndexList,
374 : typename ArgsList, size_t NumContractedIndices>
375 0 : struct TensorContract
376 : : public TensorExpression<
377 : TensorContract<T, X, Symm, IndexList, ArgsList, NumContractedIndices>,
378 : X,
379 : typename detail::ContractedType<
380 : T, X, Symm, IndexList, ArgsList, NumContractedIndices,
381 : (tmpl::size<Symm>::value - NumContractedIndices) /
382 : 2>::type::symmetry,
383 : typename detail::ContractedType<
384 : T, X, Symm, IndexList, ArgsList, NumContractedIndices,
385 : (tmpl::size<Symm>::value - NumContractedIndices) /
386 : 2>::type::index_list,
387 : typename detail::ContractedType<
388 : T, X, Symm, IndexList, ArgsList, NumContractedIndices,
389 : (tmpl::size<Symm>::value - NumContractedIndices) /
390 : 2>::type::args_list> {
391 : /// Stores internally useful information regarding the contraction. See
392 : /// `detail::ContractedType` for more details
393 1 : using contracted_type = typename detail::ContractedType<
394 : T, X, Symm, IndexList, ArgsList, NumContractedIndices,
395 : (tmpl::size<Symm>::value - NumContractedIndices) / 2>;
396 : /// The `TensorExpression` type that results from performing the contraction
397 1 : using new_type = typename contracted_type::type;
398 :
399 : // === Index properties ===
400 : /// The type of the data being stored in the result of the expression
401 1 : using type = X;
402 : /// The ::Symmetry of the result of the expression
403 1 : using symmetry = typename new_type::symmetry;
404 : /// The list of \ref SpacetimeIndex "TensorIndexType"s of the result of the
405 : /// expression
406 1 : using index_list = typename new_type::index_list;
407 : /// The list of generic `TensorIndex`s of the result of the expression
408 1 : using args_list = typename new_type::args_list;
409 : /// The number of tensor indices in the result of the expression
410 1 : static constexpr size_t num_tensor_indices = NumContractedIndices;
411 : /// The number of tensor indices in the operand expression being contracted
412 1 : static constexpr size_t num_uncontracted_tensor_indices =
413 : tmpl::size<Symm>::value;
414 : /// The number of tensor indices in the operand expression that will be
415 : /// contracted
416 1 : static constexpr size_t num_indices_to_contract =
417 : contracted_type::num_indices_to_contract;
418 : static_assert(num_indices_to_contract > 0,
419 : "There are no indices to contract that were found.");
420 : static_assert(num_indices_to_contract % 2 == 0,
421 : "Cannot contract an odd number of indices.");
422 : /// The number of tensor index pairs in the operand expression that will be
423 : /// contracted
424 1 : static constexpr size_t num_contracted_index_pairs =
425 : contracted_type::num_contracted_index_pairs;
426 : /// Mapping from the positions of indices in the resultant contracted tensor
427 : /// to their positions in the operand uncontracted tensor
428 : static constexpr inline std::array<size_t, NumContractedIndices>
429 1 : index_transformation =
430 : contracted_type::index_transformation_and_contracted_pair_positions
431 : .first;
432 : /// Positions of the index pairs in the operand uncontracted tensor that we
433 : /// wish to contract
434 : static constexpr inline std::array<std::pair<size_t, size_t>,
435 : num_contracted_index_pairs>
436 1 : contracted_index_pair_positions =
437 : contracted_type::index_transformation_and_contracted_pair_positions
438 : .second;
439 : /// First concrete values of contracted indices to sum. This is to handle
440 : /// cases when we have generic spatial `TensorIndex`s used for spacetime
441 : /// indices, as the first concrete index value to contract will be 1 (first
442 : /// spatial index) instead of 0 (the time index). Contracted index pairs will
443 : /// have different "starting" concrete indices when one index in the pair is a
444 : /// spatial spacetime index and the other is not.
445 : static constexpr inline std::array<std::pair<size_t, size_t>,
446 : num_contracted_index_pairs>
447 1 : contracted_index_first_values =
448 : contracted_type::contracted_index_first_values;
449 : /// The dimensions of the indices in the uncontracted operand expression
450 : static constexpr inline std::array<size_t, num_uncontracted_tensor_indices>
451 1 : uncontracted_index_dims = contracted_type::uncontracted_index_dims;
452 : /// The number of terms to sum for this expression's contraction
453 1 : static constexpr size_t num_terms_summed = contracted_type::num_terms_summed;
454 :
455 : // === Expression subtree properties ===
456 : /// The number of arithmetic tensor operations done in the subtree for the
457 : /// left operand
458 1 : static constexpr size_t num_ops_left_child =
459 : T::num_ops_subtree * num_terms_summed + num_terms_summed - 1;
460 : /// The number of arithmetic tensor operations done in the subtree for the
461 : /// right operand. This is 0 because this expression represents a unary
462 : /// operation.
463 1 : static constexpr size_t num_ops_right_child = 0;
464 : /// The total number of arithmetic tensor operations done in this expression's
465 : /// whole subtree
466 1 : static constexpr size_t num_ops_subtree = num_ops_left_child;
467 : /// The height of this expression's node in the expression tree relative to
468 : /// the closest `TensorAsExpression` leaf in its subtree
469 1 : static constexpr size_t height_relative_to_closest_tensor_leaf_in_subtree =
470 : T::height_relative_to_closest_tensor_leaf_in_subtree !=
471 : std::numeric_limits<size_t>::max()
472 : ? T::height_relative_to_closest_tensor_leaf_in_subtree + 1
473 : : T::height_relative_to_closest_tensor_leaf_in_subtree;
474 :
475 : // === Properties for splitting up subexpressions along the primary path ===
476 : // These definitions only have meaning if this expression actually ends up
477 : // being along the primary path that is taken when evaluating the whole tree.
478 : // See documentation for `TensorExpression` for more details.
479 : /// If on the primary path, whether or not the expression is an ending point
480 : /// of a leg
481 1 : static constexpr bool is_primary_end = T::is_primary_start;
482 : /// If on the primary path, this is the remaining number of arithmetic tensor
483 : /// operations that need to be done in the subtree of the child along the
484 : /// primary path, given that we will have already computed the whole subtree
485 : /// at the next lowest leg's starting point.
486 1 : static constexpr size_t num_ops_to_evaluate_primary_left_child =
487 : is_primary_end
488 : ? num_ops_subtree - T::num_ops_subtree
489 : : T::num_ops_subtree * (num_terms_summed - 1) +
490 : T::num_ops_to_evaluate_primary_subtree + num_terms_summed - 1;
491 : /// If on the primary path, this is the remaining number of arithmetic tensor
492 : /// operations that need to be done in the right operand's subtree. No
493 : /// splitting is currently done, so this is just `num_ops_right_child`.
494 1 : static constexpr size_t num_ops_to_evaluate_primary_right_child =
495 : num_ops_right_child;
496 : /// If on the primary path, this is the remaining number of arithmetic tensor
497 : /// operations that need to be done for this expression's subtree, given that
498 : /// we will have already computed the subtree at the next lowest leg's
499 : /// starting point
500 1 : static constexpr size_t num_ops_to_evaluate_primary_subtree =
501 : num_ops_to_evaluate_primary_left_child +
502 : num_ops_to_evaluate_primary_right_child;
503 : /// If on the primary path, whether or not the expression is a starting point
504 : /// of a leg
505 1 : static constexpr bool is_primary_start =
506 : num_ops_to_evaluate_primary_subtree >
507 : // Multiply by 2 because each term has a + and * operation, while other
508 : // arithmetic expression types do one operation
509 : 2 * detail::max_num_ops_in_sub_expression<type>;
510 : /// If on the primary path, whether or not the expression's child along the
511 : /// primary path is a subtree that contains a starting point of a leg along
512 : /// the primary path
513 1 : static constexpr bool primary_child_subtree_contains_primary_start =
514 : T::primary_subtree_contains_primary_start;
515 : /// If on the primary path, whether or not this subtree contains a starting
516 : /// point of a leg along the primary path
517 1 : static constexpr bool primary_subtree_contains_primary_start =
518 : is_primary_start or primary_child_subtree_contains_primary_start;
519 : /// Number of arithmetic tensor operations done in the subtree of the operand
520 : /// expression being contracted
521 1 : static constexpr size_t num_ops_subexpression = T::num_ops_subtree;
522 : /// In the subtree for this contraction, how many terms we sum together for
523 : /// each leg of the contraction
524 1 : static constexpr size_t leg_length = []() {
525 : if constexpr (not is_primary_start) {
526 : // If we're not even stopping at the beginning of the contraction, it's
527 : // because there weren't enough terms to justify any splitting, so the
528 : // leg_length is just the total number of terms to sum
529 : return num_terms_summed;
530 : } else if constexpr (num_ops_subexpression >=
531 : detail::max_num_ops_in_sub_expression<type>) {
532 : // If the subexpression itself has more than the max # of ops
533 : return 0;
534 : } else {
535 : // Otherwise, find how many terms to sum in each leg
536 : size_t length = 1;
537 : while (2 * (length * (num_ops_subexpression + 1) - 1) <=
538 : detail::max_num_ops_in_sub_expression<type>) {
539 : length *= 2;
540 : }
541 : return length;
542 : }
543 : }();
544 : /// After dividing up the contraction subtree into legs, the number of legs
545 : /// whose length is equal to `leg_length`
546 1 : static constexpr size_t num_full_legs =
547 : leg_length == 0 ? num_terms_summed : num_terms_summed / leg_length;
548 : /// After dividing up the contraction subtree into legs of even length, the
549 : /// number of terms we still have left to sum
550 1 : static constexpr size_t last_leg_length =
551 : leg_length == 0 ? 0 : num_terms_summed % leg_length;
552 : /// When evaluating along a primary path, whether each term's subtrees should
553 : /// be evaluated separately. Since `DataVector` expression runtime scales
554 : /// poorly with increased number of operations, evaluating individual terms'
555 : /// subtrees separately like this is beneficial when each term, itself,
556 : /// involves many tensor operations.
557 1 : static constexpr bool evaluate_terms_separately = leg_length == 0;
558 :
559 0 : explicit TensorContract(
560 : const TensorExpression<T, X, Symm, IndexList, ArgsList>& t)
561 : : t_(~t) {}
562 0 : ~TensorContract() override = default;
563 :
564 : /// \brief Assert that the LHS tensor of the equation does not also appear in
565 : /// this expression's subtree
566 : template <typename LhsTensor>
567 1 : SPECTRE_ALWAYS_INLINE void assert_lhs_tensor_not_in_rhs_expression(
568 : const gsl::not_null<LhsTensor*> lhs_tensor) const {
569 : if constexpr (not std::is_base_of_v<MarkAsNumberAsExpression, T>) {
570 : t_.assert_lhs_tensor_not_in_rhs_expression(lhs_tensor);
571 : }
572 : }
573 :
574 : /// \brief Assert that each instance of the LHS tensor in the RHS tensor
575 : /// expression uses the same generic index order that the LHS uses
576 : ///
577 : /// \tparam LhsTensorIndices the list of generic `TensorIndex`s of the LHS
578 : /// result `Tensor` being computed
579 : /// \param lhs_tensor the LHS result `Tensor` being computed
580 : template <typename LhsTensorIndices, typename LhsTensor>
581 1 : SPECTRE_ALWAYS_INLINE void assert_lhs_tensorindices_same_in_rhs(
582 : const gsl::not_null<LhsTensor*> lhs_tensor) const {
583 : if constexpr (not std::is_base_of_v<MarkAsNumberAsExpression, T>) {
584 : t_.template assert_lhs_tensorindices_same_in_rhs<LhsTensorIndices>(
585 : lhs_tensor);
586 : }
587 : }
588 :
589 : /// \brief Get the size of a component from a `Tensor` in this expression's
590 : /// subtree of the RHS `TensorExpression`
591 : ///
592 : /// \return the size of a component from a `Tensor` in this expression's
593 : /// subtree of the RHS `TensorExpression`
594 1 : SPECTRE_ALWAYS_INLINE size_t get_rhs_tensor_component_size() const {
595 : return t_.get_rhs_tensor_component_size();
596 : }
597 :
598 : /// \brief Return the highest multi-index between the components being summed
599 : /// in the contraction
600 : ///
601 : /// \details
602 : /// Example:
603 : /// We have expression `R(ti::A, ti::b, ti::a)` to represent the contraction
604 : /// \f$L_b = R^{a}{}_{ba}\f$. If the `contracted_multi_index` is `{1}`, which
605 : /// represents \f$L_1 = R^{a}{}_{1a}\f$, and the dimension of \f$a\f$ is 3,
606 : /// then we will need to sum the following terms: \f$R^{0}{}_{10}\f$,
607 : /// \f$R^{1}{}_{11}\f$, and \f$R^{2}{}_{12}\f$. Between the terms being
608 : /// summed, the multi-index whose values are the largest is
609 : /// \f$R^{2}{}_{12}\f$, so this function would return `{2, 1, 2}`.
610 : ///
611 : /// \param contracted_multi_index the multi-index of a component of the
612 : /// contracted expression
613 : /// \return the highest multi-index between the components being summed in
614 : /// the contraction
615 : SPECTRE_ALWAYS_INLINE static constexpr std::array<
616 : size_t, num_uncontracted_tensor_indices>
617 1 : get_highest_multi_index_to_sum(
618 : const std::array<size_t, num_tensor_indices>& contracted_multi_index) {
619 : // Initialize with placeholders for debugging
620 : auto highest_multi_index = make_array<num_uncontracted_tensor_indices>(
621 : std::numeric_limits<size_t>::max());
622 :
623 : // Fill uncontracted indices
624 : for (size_t i = 0; i < num_tensor_indices; i++) {
625 : gsl::at(highest_multi_index, gsl::at(index_transformation, i)) =
626 : gsl::at(contracted_multi_index, i);
627 : }
628 :
629 : // Fill contracted indices
630 : for (size_t i = 0; i < num_contracted_index_pairs; i++) {
631 : const size_t first_index_position_in_pair =
632 : gsl::at(contracted_index_pair_positions, i).first;
633 : const size_t second_index_position_in_pair =
634 : gsl::at(contracted_index_pair_positions, i).second;
635 : gsl::at(highest_multi_index, first_index_position_in_pair) =
636 : gsl::at(uncontracted_index_dims, first_index_position_in_pair) - 1;
637 : gsl::at(highest_multi_index, second_index_position_in_pair) =
638 : gsl::at(uncontracted_index_dims, second_index_position_in_pair) - 1;
639 : }
640 :
641 : return highest_multi_index;
642 : }
643 :
644 : /// \brief Return the lowest multi-index between the components being summed
645 : /// in the contraction
646 : ///
647 : /// \details
648 : /// Example:
649 : /// We have expression `R(ti::A, ti::b, ti::a)` to represent the contraction
650 : /// \f$L_b = R^{a}{}_{ba}\f$. If the `contracted_multi_index` is `{1}`, which
651 : /// represents \f$L_1 = R^{a}{}_{1a}\f$, and the dimension of \f$a\f$ is 3,
652 : /// then we will need to sum the following terms: \f$R^{0}{}_{10}\f$,
653 : /// \f$R^{1}{}_{11}\f$, and \f$R^{2}{}_{12}\f$. Between the terms being
654 : /// summed, the multi-index whose values are the smallest is
655 : /// \f$R^{0}{}_{10}\f$, so this function would return `{0, 1, 0}`.
656 : ///
657 : /// \param contracted_multi_index the multi-index of a component of the
658 : /// contracted expression
659 : /// \return the lowest multi-index between the components being summed in
660 : /// the contraction
661 : SPECTRE_ALWAYS_INLINE static constexpr std::array<
662 : size_t, num_uncontracted_tensor_indices>
663 1 : get_lowest_multi_index_to_sum(
664 : const std::array<size_t, num_tensor_indices>& contracted_multi_index) {
665 : // Initialize with placeholders for debugging
666 : auto lowest_multi_index = make_array<num_uncontracted_tensor_indices>(
667 : std::numeric_limits<size_t>::max());
668 :
669 : // Fill uncontracted indices
670 : for (size_t i = 0; i < num_tensor_indices; i++) {
671 : gsl::at(lowest_multi_index, gsl::at(index_transformation, i)) =
672 : gsl::at(contracted_multi_index, i);
673 : }
674 :
675 : // Fill contracted indices
676 : for (size_t i = 0; i < num_contracted_index_pairs; i++) {
677 : const size_t first_index_position_in_pair =
678 : gsl::at(contracted_index_pair_positions, i).first;
679 : const size_t second_index_position_in_pair =
680 : gsl::at(contracted_index_pair_positions, i).second;
681 : gsl::at(lowest_multi_index, first_index_position_in_pair) =
682 : gsl::at(contracted_index_first_values, i).first;
683 : gsl::at(lowest_multi_index, second_index_position_in_pair) =
684 : gsl::at(contracted_index_first_values, i).second;
685 : }
686 :
687 : return lowest_multi_index;
688 : }
689 :
690 : /// \brief Given the multi-index of one term being summed in the contraction,
691 : /// return the next highest multi-index of a component being summed
692 : ///
693 : /// \details
694 : /// What is meant by "next highest" is implementation defined, but generally
695 : /// means, of the components being summed, return the multi-index that results
696 : /// from lowering one of the contracted index pairs' values by one.
697 : ///
698 : /// Example:
699 : /// We have expression `R(ti::A, ti::b, ti::a)` to represent the contraction
700 : /// \f$L_b = R^{a}{}_{ba}\f$. If we are evaluating \f$L_1 = R^{a}{}_{1a}\f$
701 : /// and the dimension of \f$a\f$ is 3, then we will need to sum the following
702 : /// terms: \f$R^{0}{}_{10}\f$, \f$R^{1}{}_{11}\f$, and \f$R^{2}{}_{12}\f$.
703 : /// If `uncontracted_multi_index` is `{1, 1, 1}`, then the "next highest"
704 : /// multi-index is the result of lowering the values of the \f$a\f$ indices by
705 : /// 1. The component with that resulting multi-index is \f$R^{0}{}_{10}\f$, so
706 : /// this function would return `{0, 1, 0}`.
707 : ///
708 : /// Note: this function should perform the inverse functionality of
709 : /// `get_next_highest_multi_index_to_sum`. If the implementation of this
710 : /// function or the other changes what is meant by "next highest" or "next
711 : /// lowest," the other function should be updated in accordance.
712 : ///
713 : /// \param uncontracted_multi_index the multi-index of one of the components
714 : /// of the uncontracted operand expression to sum
715 : /// \return the next highest multi-index between the components being summed
716 : /// in the contraction
717 : SPECTRE_ALWAYS_INLINE static std::array<size_t,
718 : num_uncontracted_tensor_indices>
719 1 : get_next_highest_multi_index_to_sum(
720 : const std::array<size_t, num_uncontracted_tensor_indices>&
721 : uncontracted_multi_index) {
722 : std::array<size_t, num_uncontracted_tensor_indices>
723 : next_highest_uncontracted_multi_index = uncontracted_multi_index;
724 :
725 : size_t i = 0;
726 : while (i < num_contracted_index_pairs) {
727 : // the position of the first index in a pair being contracted
728 : const size_t current_index_first_position =
729 : gsl::at(contracted_index_pair_positions, i).first;
730 : // the position of the second index in a pair being contracted
731 : const size_t current_index_second_position =
732 : gsl::at(contracted_index_pair_positions, i).second;
733 : // of the values being summed over, the lowest concrete value of the first
734 : // index in the contracted pair
735 : const size_t current_index_first_first_value =
736 : gsl::at(contracted_index_first_values, i).first;
737 :
738 : // decrement the current index pair's values
739 : gsl::at(next_highest_uncontracted_multi_index,
740 : current_index_first_position)--;
741 : gsl::at(next_highest_uncontracted_multi_index,
742 : current_index_second_position)--;
743 :
744 : // If the index values of the index pair being contracted aren't lower
745 : // than the minimum values included in the summation, then we're done
746 : // computing this next multi-index
747 : if (not(gsl::at(next_highest_uncontracted_multi_index,
748 : current_index_first_position) <
749 : current_index_first_first_value or
750 : gsl::at(next_highest_uncontracted_multi_index,
751 : current_index_first_position) >
752 : gsl::at(uncontracted_index_dims,
753 : current_index_first_position))) {
754 : break;
755 : }
756 : // Otherwise, we've wrapped around the lowest value being summed over for
757 : // this index, so we need to set it back to the maximum values being
758 : // summed and "carry" the decrementing over to the next contracted pair's
759 : // values
760 : gsl::at(next_highest_uncontracted_multi_index,
761 : current_index_first_position) =
762 : gsl::at(uncontracted_index_dims, current_index_first_position) - 1;
763 : gsl::at(next_highest_uncontracted_multi_index,
764 : current_index_second_position) =
765 : gsl::at(uncontracted_index_dims, current_index_second_position) - 1;
766 :
767 : i++;
768 : }
769 :
770 : return next_highest_uncontracted_multi_index;
771 : }
772 :
773 : /// \brief Given the multi-index of one term being summed in the contraction,
774 : /// return the next lowest multi-index of a component being summed
775 : ///
776 : /// \details
777 : /// What is meant by "next lowest" is implementation defined, but generally
778 : /// means, of the components being summed, return the multi-index that results
779 : /// from raising one of the contracted index pairs' values by one.
780 : ///
781 : /// Example:
782 : /// We have expression `R(ti::A, ti::b, ti::a)` to represent the contraction
783 : /// \f$L_b = R^{a}{}_{ba}\f$. If we are evaluating \f$L_1 = R^{a}{}_{1a}\f$
784 : /// and the dimension of \f$a\f$ is 3, then we will need to sum the following
785 : /// terms: \f$R^{0}{}_{10}\f$, \f$R^{1}{}_{11}\f$, and \f$R^{2}{}_{12}\f$.
786 : /// If `uncontracted_multi_index` is `{1, 1, 1}`, then the "next lowest"
787 : /// multi-index is the result of raising the values of the \f$a\f$ indices by
788 : /// 1. The component with that resulting multi-index is \f$R^{2}{}_{12}\f$, so
789 : /// this function would return `{2, 1, 2}`.
790 : ///
791 : /// Note: this function should perform the inverse functionality of
792 : /// `get_next_lowest_multi_index_to_sum`. If the implementation of this
793 : /// function or the other changes what is meant by "next highest" or "next
794 : /// lowest," the other function should be updated in accordance.
795 : ///
796 : /// \param uncontracted_multi_index the multi-index of one of the components
797 : /// of the uncontracted operand expression to sum
798 : /// \return the next lowest multi-index between the components being summed in
799 : /// the contraction
800 : SPECTRE_ALWAYS_INLINE static std::array<size_t,
801 : num_uncontracted_tensor_indices>
802 1 : get_next_lowest_multi_index_to_sum(
803 : const std::array<size_t, num_uncontracted_tensor_indices>&
804 : uncontracted_multi_index) {
805 : std::array<size_t, num_uncontracted_tensor_indices>
806 : next_lowest_uncontracted_multi_index = uncontracted_multi_index;
807 :
808 : size_t i = 0;
809 : while (i < num_contracted_index_pairs) {
810 : // the position of the first index in a pair being contracted
811 : const size_t current_index_first_position =
812 : gsl::at(contracted_index_pair_positions, i).first;
813 : // the position of the second index in a pair being contracted
814 : const size_t current_index_second_position =
815 : gsl::at(contracted_index_pair_positions, i).second;
816 :
817 : // increment the current index pair's values
818 : gsl::at(next_lowest_uncontracted_multi_index,
819 : current_index_first_position)++;
820 : gsl::at(next_lowest_uncontracted_multi_index,
821 : current_index_second_position)++;
822 :
823 : // if the previous index value is > dim, then we've wrapped around
824 : // and we need to go again
825 : // If the index values of the index pair being contracted aren't higher
826 : // than the maximum values included in the summation, then we're done
827 : // computing this next multi-index
828 : if (not(gsl::at(next_lowest_uncontracted_multi_index,
829 : current_index_first_position) >
830 : gsl::at(uncontracted_index_dims, current_index_first_position) -
831 : 1)) {
832 : break;
833 : }
834 : // Otherwise, we've wrapped around the highest value being summed over for
835 : // this index, so we need to set it back to the minimum values being
836 : // summed and "carry" the incrementing over to the next contracted pair's
837 : // values
838 : gsl::at(next_lowest_uncontracted_multi_index,
839 : current_index_first_position) =
840 : gsl::at(contracted_index_first_values, i).first;
841 : gsl::at(next_lowest_uncontracted_multi_index,
842 : current_index_second_position) =
843 : gsl::at(contracted_index_first_values, i).second;
844 :
845 : i++;
846 : }
847 :
848 : return next_lowest_uncontracted_multi_index;
849 : }
850 :
851 : /// \brief Computes the value of a component in the resultant contracted
852 : /// tensor
853 : ///
854 : /// \details
855 : /// The contraction is computed by recursively adding up each component in the
856 : /// summation, across all index pairs being contracted in the operand
857 : /// expression. This function is called `Iteration = num_terms_summed` times,
858 : /// once for each uncontracted tensor component being summed. It should
859 : /// externally be called for the first time with `Iteration == 0` and
860 : /// `current_multi_index == <highest multi index to sum>` (see
861 : /// `get_next_highest_multi_index_to_sum` for details).
862 : ///
863 : /// In performing the recursive summation, the recursion is
864 : /// specifically done "to the left," in that this function returns
865 : /// `compute_contraction(next index) + get(this_index)` as opposed to
866 : /// `get(this_index) + compute_contraction`. Benchmarking has shown that
867 : /// increased breadth in an equation's expression tree can slow down runtime.
868 : /// By "recursing left" here, we minimize breadth in the overall tree for an
869 : /// equation, as both `AddSub` addition and `OuterProduct` (other expressions
870 : /// with two children) make efforts to make their operands with larger
871 : /// subtrees be their left operand.
872 : ///
873 : /// \tparam Iteration the nth term to sum, where n is between
874 : /// [0, num_terms_summed)
875 : /// \param t the expression contained within this contraction expression
876 : /// \param current_multi_index the multi-index of the uncontracted tensor
877 : /// component to retrieve
878 : /// \return the value of a component of the resulant contracted tensor
879 : template <size_t Iteration>
880 1 : SPECTRE_ALWAYS_INLINE static decltype(auto) compute_contraction(
881 : const T& t, const std::array<size_t, num_uncontracted_tensor_indices>&
882 : current_multi_index) {
883 : if constexpr (Iteration < num_terms_summed - 1) {
884 : // We have more than one component left to sum
885 : return compute_contraction<Iteration + 1>(
886 : t, get_next_highest_multi_index_to_sum(current_multi_index)) +
887 : t.get(current_multi_index);
888 : } else {
889 : // We only have one final component to sum
890 : return t.get(current_multi_index);
891 : }
892 : }
893 :
894 : /// \brief Return the value of the component of the resultant contracted
895 : /// tensor at a given multi-index
896 : ///
897 : /// \param contracted_multi_index the multi-index of the resultant contracted
898 : /// tensor component to retrieve
899 : /// \return the value of the component at `contracted_multi_index` in the
900 : /// resultant contracted tensor
901 1 : decltype(auto) get(const std::array<size_t, num_tensor_indices>&
902 : contracted_multi_index) const {
903 : return compute_contraction<0>(
904 : t_, get_highest_multi_index_to_sum(contracted_multi_index));
905 : }
906 :
907 : /// \brief Computes the result of an internal leg of the contraction
908 : ///
909 : /// \details
910 : /// This function differs from `compute_contraction` and
911 : /// `compute_contraction_primary` in that it only computes one leg of the
912 : /// whole contraction, as opposed to the whole contraction.
913 : ///
914 : /// The leg being summed is defined by the `current_multi_index` and
915 : /// `Iteration` passed in from the inital external call: consecutive terms
916 : /// will be summed until the base case `Iteration == 0` is reached.
917 : ///
918 : /// \tparam Iteration the nth term in the leg to sum, where n is between
919 : /// [0, leg_length)
920 : /// \param t the expression contained within this contraction expression
921 : /// \param current_multi_index the multi-index of the uncontracted tensor
922 : /// component to retrieve as part of this leg's summation
923 : /// \param next_leg_starting_multi_index in the final iteration, the
924 : /// multi-index to update to be the next leg's starting multi-index
925 : /// \return the result of summing up the terms in the given leg
926 : template <size_t Iteration>
927 1 : SPECTRE_ALWAYS_INLINE static decltype(auto) compute_contraction_leg(
928 : const T& t,
929 : const std::array<size_t, num_uncontracted_tensor_indices>&
930 : current_multi_index,
931 : std::array<size_t, num_uncontracted_tensor_indices>&
932 : next_leg_starting_multi_index) {
933 : if constexpr (Iteration != 0) {
934 : // We have more than one component left to sum
935 : (void)next_leg_starting_multi_index;
936 : return compute_contraction_leg<Iteration - 1>(
937 : t, get_next_highest_multi_index_to_sum(current_multi_index),
938 : next_leg_starting_multi_index) +
939 : t.get(current_multi_index);
940 : } else {
941 : // We only have one final component to sum
942 : next_leg_starting_multi_index =
943 : get_next_highest_multi_index_to_sum(current_multi_index);
944 : return t.get(current_multi_index);
945 : }
946 : }
947 :
948 : /// \brief Computes the value of a component in the resultant contracted
949 : /// tensor
950 : ///
951 : /// \details
952 : /// First see `compute_contraction` for details on basic functionality.
953 : ///
954 : /// This function differs from `compute_contraction` in that it takes into
955 : /// account whether we have already computed part of the result component at a
956 : /// lower subtree. In recursively computing this contraction, the current
957 : /// result component will be substituted in for the most recent (highest)
958 : /// subtree below it that has already been evaluated.
959 : ///
960 : /// \tparam Iteration the nth term to sum, where n is between
961 : /// [0, num_terms_summed)
962 : /// \param t the expression contained within this contraction expression
963 : /// \param result_component the LHS tensor component to evaluate
964 : /// \param current_multi_index the multi-index of the uncontracted tensor
965 : /// component to retrieve
966 : /// \return the value of a component of the resulant contracted tensor
967 : template <size_t Iteration>
968 1 : SPECTRE_ALWAYS_INLINE static decltype(auto) compute_contraction_primary(
969 : const T& t, const type& result_component,
970 : const std::array<size_t, num_uncontracted_tensor_indices>&
971 : current_multi_index) {
972 : if constexpr (is_primary_end) {
973 : // We've already computed the whole subtree of the term being summed that
974 : // is at the lowest depth in the tree
975 : if constexpr (Iteration < num_terms_summed - 1) {
976 : // We have more than one component left to sum
977 : return compute_contraction_primary<Iteration + 1>(
978 : t, result_component,
979 : get_next_highest_multi_index_to_sum(current_multi_index)) +
980 : t.get(current_multi_index);
981 : } else {
982 : // The deepest term in the contraction subtree that is being summed is
983 : // just our current result, so return it
984 : return result_component;
985 : }
986 : } else {
987 : // We've haven't yet computed the whole subtree of the term being summed
988 : // that is at the lowest depth in the tree
989 : if constexpr (Iteration < num_terms_summed - 1) {
990 : // We have more than one component left to sum
991 : return compute_contraction_primary<Iteration + 1>(
992 : t, result_component,
993 : get_next_highest_multi_index_to_sum(current_multi_index)) +
994 : t.get(current_multi_index);
995 : } else {
996 : // We only have one final component to sum
997 : return t.get_primary(result_component, current_multi_index);
998 : }
999 : }
1000 : }
1001 :
1002 : /// \brief Return the value of the component of the resultant contracted
1003 : /// tensor at a given multi-index
1004 : ///
1005 : /// \details
1006 : /// This function differs from `get` in that it takes into account whether we
1007 : /// have already computed part of the result component at a lower subtree.
1008 : /// In recursively computing this contraction, the current result component
1009 : /// will be substituted in for the most recent (highest) subtree below it that
1010 : /// has already been evaluated.
1011 : ///
1012 : /// \param result_component the LHS tensor component to evaluate
1013 : /// \param contracted_multi_index the multi-index of the resultant contracted
1014 : /// tensor component to retrieve
1015 : /// \return the value of the component at `contracted_multi_index` in the
1016 : /// resultant contracted tensor
1017 : template <typename ResultType>
1018 1 : SPECTRE_ALWAYS_INLINE decltype(auto) get_primary(
1019 : const ResultType& result_component,
1020 : const std::array<size_t, num_tensor_indices>& contracted_multi_index)
1021 : const {
1022 : return compute_contraction_primary<0>(
1023 : t_, result_component,
1024 : get_highest_multi_index_to_sum(contracted_multi_index));
1025 : }
1026 :
1027 : /// \brief Successively evaluate the LHS Tensor's result component at each
1028 : /// leg of summations within the contraction expression
1029 : ///
1030 : /// \details
1031 : /// This function takes into account whether we have already computed part of
1032 : /// the result component at a lower subtree. In recursively computing this
1033 : /// contraction, the current result component will be substituted in for the
1034 : /// most recent (highest) subtree below it that has already been evaluated.
1035 : ///
1036 : /// \param result_component the LHS tensor component to evaluate
1037 : /// \param contracted_multi_index the multi-index of the component of the
1038 : /// contracted result tensor to evaluate
1039 : /// \param lowest_multi_index the lowest multi-index between the components
1040 : /// being summed in the contraction (see `get_lowest_multi_index_to_sum`)
1041 1 : SPECTRE_ALWAYS_INLINE void evaluate_primary_contraction(
1042 : type& result_component,
1043 : const std::array<size_t, num_tensor_indices>& contracted_multi_index,
1044 : const std::array<size_t, num_uncontracted_tensor_indices>&
1045 : lowest_multi_index) const {
1046 : if constexpr (not is_primary_end) {
1047 : // We need to first evaluate the subtree of the term being summed that
1048 : // is deepest in the tree
1049 : result_component =
1050 : t_.template get_primary(result_component, lowest_multi_index);
1051 : }
1052 :
1053 : if constexpr (evaluate_terms_separately) {
1054 : // Case 1: Evaluate all of the remaining terms, one TERM at a time
1055 : (void)contracted_multi_index;
1056 : std::array<size_t, num_uncontracted_tensor_indices> current_multi_index =
1057 : lowest_multi_index;
1058 : for (size_t i = 1; i < num_terms_summed; i++) {
1059 : const std::array<size_t, num_uncontracted_tensor_indices>
1060 : next_lowest_multi_index_to_sum =
1061 : get_next_lowest_multi_index_to_sum(current_multi_index);
1062 : result_component += t_.get(next_lowest_multi_index_to_sum);
1063 : current_multi_index = next_lowest_multi_index_to_sum;
1064 : }
1065 : } else {
1066 : // Case 2: Evaluate all of the remaining terms, one LEG at a time
1067 : (void)lowest_multi_index;
1068 : std::array<size_t, num_uncontracted_tensor_indices>
1069 : next_leg_starting_multi_index =
1070 : get_highest_multi_index_to_sum(contracted_multi_index);
1071 : if constexpr (last_leg_length > 0) {
1072 : // Case 2a: We have a remainder of terms that don't make up a full leg
1073 : // length
1074 :
1075 : // Evaluate all the full-length legs
1076 : for (size_t i = 0; i < num_full_legs; i++) {
1077 : const std::array<size_t, num_uncontracted_tensor_indices>
1078 : current_multi_index = next_leg_starting_multi_index;
1079 : result_component += compute_contraction_leg<leg_length - 1>(
1080 : t_, current_multi_index, next_leg_starting_multi_index);
1081 : }
1082 : if constexpr (last_leg_length > 1) {
1083 : // Get rest of the deepest (partial-length) leg if there are more
1084 : // terms in it than just the one deepest term we already computed
1085 : const std::array<size_t, num_uncontracted_tensor_indices>
1086 : current_multi_index = next_leg_starting_multi_index;
1087 : result_component +=
1088 : // start at last_leg_length - 2 because we already computed one of
1089 : // the terms in this deepest leg (the deepest term)
1090 : compute_contraction_leg<last_leg_length - 2>(
1091 : t_, current_multi_index, next_leg_starting_multi_index);
1092 : }
1093 : } else {
1094 : // Case 2b: We don't have remaining terms that only make up a
1095 : // partial leg length (i.e. we only have full-length legs)
1096 :
1097 : // Evaluate all but the deepest leg
1098 : for (size_t i = 1; i < num_full_legs; i++) {
1099 : const std::array<size_t, num_uncontracted_tensor_indices>
1100 : current_multi_index = next_leg_starting_multi_index;
1101 :
1102 : result_component += compute_contraction_leg<leg_length - 1>(
1103 : t_, current_multi_index, next_leg_starting_multi_index);
1104 : }
1105 :
1106 : if constexpr (leg_length > 1) {
1107 : // Get rest of the deepest leg if there are more terms in it than
1108 : // just the one deepest term we already computed
1109 : const std::array<size_t, num_uncontracted_tensor_indices>
1110 : current_multi_index = next_leg_starting_multi_index;
1111 : result_component +=
1112 : // start at leg_length - 2 because we already computed one of the
1113 : // terms in this deepest leg (the deepest term)
1114 : compute_contraction_leg<leg_length - 2>(
1115 : t_, current_multi_index, next_leg_starting_multi_index);
1116 : }
1117 : }
1118 : }
1119 : }
1120 :
1121 : /// \brief Successively evaluate the LHS Tensor's result component at each
1122 : /// leg in this expression's subtree
1123 : ///
1124 : /// \details
1125 : /// This function takes into account whether we have already computed part of
1126 : /// the result component at a lower subtree. In recursively computing this
1127 : /// contraction, the current result component will be substituted in for the
1128 : /// most recent (highest) subtree below it that has already been evaluated.
1129 : ///
1130 : /// If this contraction expression is the beginning of a leg,
1131 : /// `evaluate_primary_contraction` is called to evaluate each individual
1132 : /// leg of summations within the contraction.
1133 : ///
1134 : /// \param result_component the LHS tensor component to evaluate
1135 : /// \param contracted_multi_index the multi-index of the component of the
1136 : /// contracted result tensor to evaluate
1137 : template <typename ResultType>
1138 1 : SPECTRE_ALWAYS_INLINE void evaluate_primary_subtree(
1139 : ResultType& result_component,
1140 : const std::array<size_t, num_tensor_indices>& contracted_multi_index)
1141 : const {
1142 : const auto lowest_multi_index_to_sum =
1143 : get_lowest_multi_index_to_sum(contracted_multi_index);
1144 : if constexpr (primary_child_subtree_contains_primary_start) {
1145 : // The primary child's subtree contains at least one leg, so recurse down
1146 : // and evaluate that first. Here, we evaluate the lowest multi-index
1147 : // because, according to `compute_contraction`, the lowest multi-index is
1148 : // the one in the last/leaf/final call to `compute_contraction` (i.e. the
1149 : // multi-index of the final term to sum)
1150 : t_.template evaluate_primary_subtree(result_component,
1151 : lowest_multi_index_to_sum);
1152 : }
1153 : if constexpr (is_primary_start) {
1154 : // We want to evaluate the subtree for this expression, one leg of
1155 : // summations at a time
1156 : evaluate_primary_contraction(result_component, contracted_multi_index,
1157 : lowest_multi_index_to_sum);
1158 : }
1159 : }
1160 :
1161 : private:
1162 : /// Operand expression being contracted
1163 1 : T t_;
1164 : };
1165 :
1166 : template <typename T, typename X, typename Symm, typename IndexList,
1167 : typename... TensorIndices>
1168 0 : SPECTRE_ALWAYS_INLINE static constexpr auto contract(
1169 : const TensorExpression<T, X, Symm, IndexList, tmpl::list<TensorIndices...>>&
1170 : t) {
1171 : // Number of indices in the tensor expression we're wanting to contract
1172 : constexpr size_t num_uncontracted_indices = sizeof...(TensorIndices);
1173 : // The number of index pairs to contract
1174 : constexpr size_t num_contracted_index_pairs =
1175 : detail::get_num_contracted_index_pairs<num_uncontracted_indices>(
1176 : {{TensorIndices::value...}});
1177 :
1178 : if constexpr (num_contracted_index_pairs == 0) {
1179 : // There aren't any indices to contract, so we just return the input
1180 : return ~t;
1181 : } else {
1182 : // We have at least one pair of indices to contract
1183 : return TensorContract<T, X, Symm, IndexList, tmpl::list<TensorIndices...>,
1184 : num_uncontracted_indices -
1185 : (num_contracted_index_pairs * 2)>{t};
1186 : }
1187 : }
1188 : } // namespace tenex
|