SpECTRE Documentation Coverage Report
Current view: top level - DataStructures/Tensor/Expressions - Contract.hpp Hit Total Coverage
Commit: 5f37f3d7c5afe86be8ec8102ab4a768be82d2177 Lines: 47 51 92.2 %
Date: 2024-04-26 23:32:03
Legend: Lines: hit not hit

          Line data    Source code
       1           1 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : /// \file
       5             : /// Defines Expression Templates for contracting tensor indices on a single
       6             : /// Tensor
       7             : 
       8             : #pragma once
       9             : 
      10             : #include <array>
      11             : #include <cstddef>
      12             : #include <limits>
      13             : #include <type_traits>
      14             : #include <utility>
      15             : 
      16             : #include "DataStructures/Tensor/Expressions/NumberAsExpression.hpp"
      17             : #include "DataStructures/Tensor/Expressions/TensorExpression.hpp"
      18             : #include "DataStructures/Tensor/Expressions/TensorIndex.hpp"
      19             : #include "DataStructures/Tensor/Expressions/TimeIndex.hpp"
      20             : #include "DataStructures/Tensor/IndexType.hpp"
      21             : #include "DataStructures/Tensor/Symmetry.hpp"
      22             : #include "Utilities/ForceInline.hpp"
      23             : #include "Utilities/Gsl.hpp"
      24             : #include "Utilities/MakeArray.hpp"
      25             : #include "Utilities/TMPL.hpp"
      26             : 
      27             : /*!
      28             :  * \ingroup TensorExpressionsGroup
      29             :  * Holds all possible TensorExpressions currently implemented
      30             :  */
      31             : namespace tenex {
      32             : namespace detail {
      33             : template <typename I1, typename I2>
      34             : using indices_contractible = std::bool_constant<
      35             :     I1::ul != I2::ul and
      36             :     std::is_same_v<typename I1::Frame, typename I2::Frame> and
      37             :     ((I1::index_type == I2::index_type and I1::dim == I2::dim) or
      38             :      // If one index is spacetime and the other is spatial, the indices can
      39             :      // be contracted if they have the same number of spatial dimensions
      40             :      (I1::index_type == IndexType::Spacetime and I1::dim == I2::dim + 1) or
      41             :      (I2::index_type == IndexType::Spacetime and I1::dim + 1 == I2::dim))>;
      42             : 
      43             : template <size_t NumUncontractedIndices>
      44             : constexpr size_t get_num_contracted_index_pairs(
      45             :     const std::array<size_t, NumUncontractedIndices>&
      46             :         uncontracted_tensor_index_values) {
      47             :   size_t count = 0;
      48             :   for (size_t i = 0; i < NumUncontractedIndices; i++) {
      49             :     const size_t current_value = gsl::at(uncontracted_tensor_index_values, i);
      50             :     // Concrete time indices are not contracted
      51             :     if (not detail::is_time_index_value(current_value)) {
      52             :       const size_t opposite_value_to_find =
      53             :           get_tensorindex_value_with_opposite_valence(current_value);
      54             :       for (size_t j = i + 1; j < NumUncontractedIndices; j++) {
      55             :         if (opposite_value_to_find ==
      56             :             gsl::at(uncontracted_tensor_index_values, j)) {
      57             :           // We found both the lower and upper version of a generic index in the
      58             :           // list of generic indices, so we return this pair's positions
      59             :           count++;
      60             :         }
      61             :       }
      62             :     }
      63             :   }
      64             : 
      65             :   return count;
      66             : }
      67             : 
      68             : /// \brief Computes the mapping from the positions of indices in the resultant
      69             : /// contracted tensor to their positions in the operand uncontracted tensor, as
      70             : /// well as the positions of the index pairs in the operand uncontracted tensor
      71             : /// that we wish to contract
      72             : ///
      73             : /// \details
      74             : /// Both quantities returned are computed in and returned from the same
      75             : /// function so as to not repeat overlapping necessary work that would be in two
      76             : /// separate functions
      77             : ///
      78             : /// \tparam NumContractedIndexPairs the number of pairs of indices that will be
      79             : /// contracted
      80             : /// \tparam NumUncontractedIndices the number of indices in the operand
      81             : /// expression that we wish to contract
      82             : /// \param uncontracted_tensor_index_values the values of the `TensorIndex`s
      83             : /// used to generically represent the uncontracted operand expression
      84             : /// \return the mapping from the positions of indices in the resultant
      85             : /// contracted tensor to their positions in the operand uncontracted tensor, as
      86             : /// well as the positions of the index pairs in the operand uncontracted tensor
      87             : /// that we wish to contract
      88             : template <size_t NumContractedIndexPairs, size_t NumUncontractedIndices>
      89             : constexpr std::pair<
      90             :     std::array<size_t, NumUncontractedIndices - NumContractedIndexPairs * 2>,
      91             :     std::array<std::pair<size_t, size_t>, NumContractedIndexPairs>>
      92             : get_index_transformation_and_contracted_pair_positions(
      93             :     const std::array<size_t, NumUncontractedIndices>&
      94             :         uncontracted_tensor_index_values) {
      95             :   static_assert(NumUncontractedIndices >= 2,
      96             :                 "There should be at least 2 indices");
      97             :   // Positions of indices in the result tensor (ones that are not contracted)
      98             :   // mapped to their locations in the uncontracted operand expression
      99             :   std::array<size_t, NumUncontractedIndices - NumContractedIndexPairs * 2>
     100             :       index_transformation{};
     101             :   // Positions of contracted index pairs in the uncontracted operand expression
     102             :   std::array<std::pair<size_t, size_t>, NumContractedIndexPairs>
     103             :       contracted_index_pair_positions{};
     104             : 
     105             :   // Marks whether or not we have already paired an index in the uncontracted
     106             :   // operand expression with an index to contract it with
     107             :   std::array<bool, NumUncontractedIndices> index_mapping_set =
     108             :       make_array<NumUncontractedIndices, bool>(false);
     109             : 
     110             :   // Index of `contracted_index_pair_positions` that we're currently assigning
     111             :   size_t contracted_map_index_to_assign = 0;
     112             :   // Index of `index_transformation` that we're currently assigning
     113             :   size_t not_contracted_map_index_to_assign =
     114             :       NumUncontractedIndices - NumContractedIndexPairs * 2 - 1;
     115             :   // Iteration is performed backwards here for the reasons below, but note that
     116             :   // no benchmarking has been done to confirm the backwards iteration order
     117             :   // makes a meaningful improvement in runtime vs. iterating forward:
     118             :   //
     119             :   // Here, we iterate backwards to find the "rightmost" contracted indices and
     120             :   // proceed leftwards to find the other contracted index pairs so that the
     121             :   // "rightmost" pairs in the expression to contract appear first in the list of
     122             :   // contracted index pairs (`contracted_index_pair_positions`). Then, when we
     123             :   // later iterate over all of the multi-indices to sum, each time we go to grab
     124             :   // the next multi-index to sum and we need to compute what that next
     125             :   // multi-index is, we can choose to increment the concrete values of the
     126             :   // rightmost index pair.
     127             :   //
     128             :   // This has not been benchmarked to confirm, but the thought with making this
     129             :   // choice to order the contracted pairs from right to left is that this may
     130             :   // help with spatial locality for caching when we are contracting a single
     131             :   // tensor (`TensorAsExpression`) that is non-symmetric. For example, let's say
     132             :   // we are contracting `R(ti::A, ti::B, ti::a, ti::b)` and the current
     133             :   // multi-index we just accessed (one of the components to sum) is
     134             :   // `{0, 0, 0, 0}`, representing \f$R^{00}{}_{00}\f$. To find a next
     135             :   // multi-index to sum, we can simply increase one of the pairs' concrete
     136             :   // values by 1, e.g. the next multi-index to access could be `{0, 1, 0, 1}`
     137             :   // for \f$R^{01}{}_{01}\f$ or `{1, 0, 1, 0}` \f$R^{10}{}_{10}\f$. In this
     138             :   // case, the idea is that choosing to increment the concrete values of the
     139             :   // rightmost pair could provide better spatial locality, as `{0, 1, 0, 1}` is
     140             :   // closer in memory to `{0, 0, 0, 0}` than `{1, 0, 1, 0}` is. It's important
     141             :   // to note that this, of course, depends on the implementation of
     142             :   // `Tensor_detail::Structure` - specifically, the order in which the
     143             :   // components are laid out in memory.
     144             :   //
     145             :   // Note: the loop terminates when underflow causes `i` to wrap back around to
     146             :   // the maximum `size_t` value. If we use the condition `i > 0`, we miss the
     147             :   // final iteration, and if we use `i >= 0`, we never terminate because `i`
     148             :   // is always positive.
     149             :   for (size_t i = NumUncontractedIndices - 1; i < NumUncontractedIndices; i--) {
     150             :     if (not gsl::at(index_mapping_set, i)) {
     151             :       const size_t current_value = gsl::at(uncontracted_tensor_index_values, i);
     152             :       // Concrete time indices are not contracted
     153             :       if (not detail::is_time_index_value(current_value)) {
     154             :         const size_t opposite_value_to_find =
     155             :             get_tensorindex_value_with_opposite_valence(current_value);
     156             :         for (size_t j = i - 1; j < NumUncontractedIndices; j--) {
     157             :           if (opposite_value_to_find ==
     158             :               gsl::at(uncontracted_tensor_index_values, j)) {
     159             :             // We found both the lower and upper version of a generic index in
     160             :             // the list of generic indices, pair them up
     161             :             gsl::at(contracted_index_pair_positions,
     162             :                     contracted_map_index_to_assign)
     163             :                 .first = i;
     164             :             gsl::at(contracted_index_pair_positions,
     165             :                     contracted_map_index_to_assign)
     166             :                 .second = j;
     167             :             contracted_map_index_to_assign++;
     168             :             // Mark that we've found contraction partners for these two indices
     169             :             gsl::at(index_mapping_set, i) = true;
     170             :             gsl::at(index_mapping_set, j) = true;
     171             :             break;
     172             :           }
     173             :         }
     174             :       }
     175             :       if (not gsl::at(index_mapping_set, i)) {
     176             :         // If we haven't assigned this index to a partner, it is not an index
     177             :         // that is contracted, so record its position mapping from contracted to
     178             :         // uncontracted tensor indices
     179             :         gsl::at(index_transformation, not_contracted_map_index_to_assign) = i;
     180             :         not_contracted_map_index_to_assign--;
     181             :       }
     182             :     }
     183             :   }
     184             : 
     185             :   return std::pair{index_transformation, contracted_index_pair_positions};
     186             : }
     187             : 
     188             : /// \brief Computes type information for the tensor expression that results from
     189             : /// a contraction, as well as information internally useful for carrying out the
     190             : /// contraction
     191             : ///
     192             : /// \tparam UncontractedTensorExpression the operand uncontracted
     193             : /// `TensorExpression` being contracted
     194             : /// \tparam DataType the data type of the `Tensor` components
     195             : /// \tparam UncontractedSymm the ::Symmetry of the operand uncontracted
     196             : /// `TensorExpression`
     197             : /// \tparam UncontractedIndexList the list of
     198             : /// \ref SpacetimeIndex "TensorIndexType"s of the operand uncontracted
     199             : /// `TensorExpression`
     200             : /// \tparam UncontractedTensorIndexList the list of generic `TensorIndex`s used
     201             : /// for the the operand uncontracted `TensorExpression`
     202             : /// \tparam NumContractedIndices the number of indices in the resultant tensor
     203             : /// after contracting
     204             : /// \tparam NumIndexPairsToContract the number of pairs of indices that will be
     205             : /// contracted
     206             : template <typename UncontractedTensorExpression, typename DataType,
     207             :           typename UncontractedSymm, typename UncontractedIndexList,
     208             :           typename UncontractedTensorIndexList, size_t NumContractedIndices,
     209             :           size_t NumIndexPairsToContract,
     210             :           typename ContractedIndexSequence =
     211             :               std::make_index_sequence<NumContractedIndices>,
     212             :           typename IndexPairsToContractSequence =
     213             :               std::make_index_sequence<NumIndexPairsToContract>>
     214             : struct ContractedType;
     215             : 
     216             : template <typename UncontractedTensorExpression, typename DataType,
     217             :           template <typename...> class UncontractedSymmList,
     218             :           typename... UncontractedSymm,
     219             :           template <typename...> class UncontractedIndexList,
     220             :           typename... UncontractedIndices,
     221             :           template <typename...> class UncontractedTensorIndexList,
     222             :           typename... UncontractedTensorIndices, size_t NumContractedIndices,
     223             :           size_t NumIndexPairsToContract, size_t... ContractedInts,
     224             :           size_t... IndexPairsToContractInts>
     225             : struct ContractedType<UncontractedTensorExpression, DataType,
     226             :                       UncontractedSymmList<UncontractedSymm...>,
     227             :                       UncontractedIndexList<UncontractedIndices...>,
     228             :                       UncontractedTensorIndexList<UncontractedTensorIndices...>,
     229             :                       NumContractedIndices, NumIndexPairsToContract,
     230             :                       std::index_sequence<ContractedInts...>,
     231             :                       std::index_sequence<IndexPairsToContractInts...>> {
     232             :   static constexpr size_t num_uncontracted_tensor_indices =
     233             :       sizeof...(UncontractedTensorIndices);
     234             :   static constexpr std::array<size_t, num_uncontracted_tensor_indices>
     235             :       uncontracted_tensorindex_values = {{UncontractedTensorIndices::value...}};
     236             :   static constexpr size_t num_indices_to_contract =
     237             :       num_uncontracted_tensor_indices - NumContractedIndices;
     238             :   static constexpr size_t num_contracted_index_pairs =
     239             :       num_indices_to_contract / 2;
     240             :   // First item in pair:
     241             :   // - index transformation: mapping from the positions of indices in the
     242             :   // resultant contracted tensor to their positions in the operand
     243             :   // uncontracted tensor
     244             :   // Second item in pair:
     245             :   // contracted index pair positions: positions of the index pairs in the
     246             :   // operand uncontracted tensor that we wish to contract
     247             :   static constexpr inline std::pair<
     248             :       std::array<size_t, NumContractedIndices>,
     249             :       std::array<std::pair<size_t, size_t>, num_contracted_index_pairs>>
     250             :       index_transformation_and_contracted_pair_positions =
     251             :           get_index_transformation_and_contracted_pair_positions<
     252             :               num_contracted_index_pairs, num_uncontracted_tensor_indices>(
     253             :               uncontracted_tensorindex_values);
     254             : 
     255             :   // Make sure it's mathematically legal to perform the requested contraction
     256             :   static_assert(((... and
     257             :                   (indices_contractible<
     258             :                       typename tmpl::at_c<
     259             :                           tmpl::list<UncontractedIndices...>,
     260             :                           index_transformation_and_contracted_pair_positions
     261             :                               .second[IndexPairsToContractInts]
     262             :                               .first>,
     263             :                       typename tmpl::at_c<
     264             :                           tmpl::list<UncontractedIndices...>,
     265             :                           index_transformation_and_contracted_pair_positions
     266             :                               .second[IndexPairsToContractInts]
     267             :                               .second>>::value))),
     268             :                 "Cannot contract the requested indices.");
     269             : 
     270             :   static constexpr inline std::array<IndexType, num_uncontracted_tensor_indices>
     271             :       uncontracted_index_types = {{UncontractedIndices::index_type...}};
     272             : 
     273             :   // First concrete values of contracted indices to sum. This is to handle
     274             :   // cases when we have generic spatial `TensorIndex`s used for spacetime
     275             :   // indices, as the first concrete index value to contract will be 1 (first
     276             :   // spatial index) instead of 0 (the time index). Contracted index pairs will
     277             :   // have different "starting" concrete indices when one index in the pair is a
     278             :   // spatial spacetime index and the other is not.
     279             :   static constexpr inline std::array<std::pair<size_t, size_t>,
     280             :                                      num_contracted_index_pairs>
     281             :       contracted_index_first_values = []() {
     282             :         std::array<std::pair<size_t, size_t>, num_contracted_index_pairs>
     283             :             first_values{};
     284             :         for (size_t i = 0; i < num_contracted_index_pairs; i++) {
     285             :           // Assign the value for first index in a pair to be the smallest value
     286             :           // used in the terms being summed: assign to 1 if we have a spacetime
     287             :           // index where a generic spatial index has been used, otherwise assign
     288             :           // to 0.
     289             :           gsl::at(first_values, i).first = static_cast<size_t>(
     290             :               gsl::at(
     291             :                   uncontracted_index_types,
     292             :                   gsl::at(
     293             :                       index_transformation_and_contracted_pair_positions.second,
     294             :                       i)
     295             :                       .first) == IndexType::Spacetime and
     296             :               gsl::at(
     297             :                   uncontracted_tensorindex_values,
     298             :                   gsl::at(
     299             :                       index_transformation_and_contracted_pair_positions.second,
     300             :                       i)
     301             :                       .first) >= TensorIndex_detail::spatial_sentinel);
     302             :           // Assign the value for second index in a pair to be the smallest
     303             :           // value used in the terms being summed (assigned with same logic
     304             :           // described above for the first index in the pair)
     305             :           gsl::at(first_values, i).second = static_cast<size_t>(
     306             :               gsl::at(
     307             :                   uncontracted_index_types,
     308             :                   gsl::at(
     309             :                       index_transformation_and_contracted_pair_positions.second,
     310             :                       i)
     311             :                       .second) == IndexType::Spacetime and
     312             :               gsl::at(
     313             :                   uncontracted_tensorindex_values,
     314             :                   gsl::at(
     315             :                       index_transformation_and_contracted_pair_positions.second,
     316             :                       i)
     317             :                       .second) >= TensorIndex_detail::spatial_sentinel);
     318             :         }
     319             :         return first_values;
     320             :       }();
     321             : 
     322             :   static constexpr inline std::array<size_t, num_uncontracted_tensor_indices>
     323             :       uncontracted_index_dims = {{UncontractedIndices::dim...}};
     324             : 
     325             :   // The number of terms to sum for this expression's contraction
     326             :   static constexpr size_t num_terms_summed = []() {
     327             :     size_t num_terms =
     328             :         gsl::at(
     329             :             uncontracted_index_dims,
     330             :             gsl::at(index_transformation_and_contracted_pair_positions.second,
     331             :                     0)
     332             :                 .first) -
     333             :         gsl::at(contracted_index_first_values, 0).first;
     334             :     for (size_t i = 1; i < num_contracted_index_pairs; i++) {
     335             :       num_terms *=
     336             :           gsl::at(
     337             :               uncontracted_index_dims,
     338             :               gsl::at(index_transformation_and_contracted_pair_positions.second,
     339             :                       i)
     340             :                   .first) -
     341             :           gsl::at(contracted_index_first_values, i).first;
     342             :     }
     343             :     return num_terms;
     344             :   }();
     345             :   static_assert(num_terms_summed > 0,
     346             :                 "There should be a non-zero number of components to sum in the "
     347             :                 "contraction.");
     348             :   // The ::Symmetry of the result of the contraction
     349             :   using symmetry =
     350             :       Symmetry<tmpl::at_c<UncontractedSymmList<UncontractedSymm...>,
     351             :                           index_transformation_and_contracted_pair_positions
     352             :                               .first[ContractedInts]>::value...>;
     353             :   // The list of \ref SpacetimeIndex "TensorIndexType"s of the result of the
     354             :   // contraction
     355             :   using index_list =
     356             :       tmpl::list<tmpl::at_c<UncontractedIndexList<UncontractedIndices...>,
     357             :                             index_transformation_and_contracted_pair_positions
     358             :                                 .first[ContractedInts]>...>;
     359             :   // The list of generic `TensorIndex`s of the result of the contraction
     360             :   using tensorindex_list = tmpl::list<
     361             :       tmpl::at_c<UncontractedTensorIndexList<UncontractedTensorIndices...>,
     362             :                  index_transformation_and_contracted_pair_positions
     363             :                      .first[ContractedInts]>...>;
     364             :   // The `TensorExpression` type that results from performing the contraction
     365             :   using type = TensorExpression<UncontractedTensorExpression, DataType,
     366             :                                 symmetry, index_list, tensorindex_list>;
     367             : };
     368             : }  // namespace detail
     369             : 
     370             : /*!
     371             :  * \ingroup TensorExpressionsGroup
     372             :  */
     373             : template <typename T, typename X, typename Symm, typename IndexList,
     374             :           typename ArgsList, size_t NumContractedIndices>
     375           0 : struct TensorContract
     376             :     : public TensorExpression<
     377             :           TensorContract<T, X, Symm, IndexList, ArgsList, NumContractedIndices>,
     378             :           X,
     379             :           typename detail::ContractedType<
     380             :               T, X, Symm, IndexList, ArgsList, NumContractedIndices,
     381             :               (tmpl::size<Symm>::value - NumContractedIndices) /
     382             :                   2>::type::symmetry,
     383             :           typename detail::ContractedType<
     384             :               T, X, Symm, IndexList, ArgsList, NumContractedIndices,
     385             :               (tmpl::size<Symm>::value - NumContractedIndices) /
     386             :                   2>::type::index_list,
     387             :           typename detail::ContractedType<
     388             :               T, X, Symm, IndexList, ArgsList, NumContractedIndices,
     389             :               (tmpl::size<Symm>::value - NumContractedIndices) /
     390             :                   2>::type::args_list> {
     391             :   /// Stores internally useful information regarding the contraction. See
     392             :   /// `detail::ContractedType` for more details
     393           1 :   using contracted_type = typename detail::ContractedType<
     394             :       T, X, Symm, IndexList, ArgsList, NumContractedIndices,
     395             :       (tmpl::size<Symm>::value - NumContractedIndices) / 2>;
     396             :   /// The `TensorExpression` type that results from performing the contraction
     397           1 :   using new_type = typename contracted_type::type;
     398             : 
     399             :   // === Index properties ===
     400             :   /// The type of the data being stored in the result of the expression
     401           1 :   using type = X;
     402             :   /// The ::Symmetry of the result of the expression
     403           1 :   using symmetry = typename new_type::symmetry;
     404             :   /// The list of \ref SpacetimeIndex "TensorIndexType"s of the result of the
     405             :   /// expression
     406           1 :   using index_list = typename new_type::index_list;
     407             :   /// The list of generic `TensorIndex`s of the result of the expression
     408           1 :   using args_list = typename new_type::args_list;
     409             :   /// The number of tensor indices in the result of the expression
     410           1 :   static constexpr size_t num_tensor_indices = NumContractedIndices;
     411             :   /// The number of tensor indices in the operand expression being contracted
     412           1 :   static constexpr size_t num_uncontracted_tensor_indices =
     413             :       tmpl::size<Symm>::value;
     414             :   /// The number of tensor indices in the operand expression that will be
     415             :   /// contracted
     416           1 :   static constexpr size_t num_indices_to_contract =
     417             :       contracted_type::num_indices_to_contract;
     418             :   static_assert(num_indices_to_contract > 0,
     419             :                 "There are no indices to contract that were found.");
     420             :   static_assert(num_indices_to_contract % 2 == 0,
     421             :                 "Cannot contract an odd number of indices.");
     422             :   /// The number of tensor index pairs in the operand expression that will be
     423             :   /// contracted
     424           1 :   static constexpr size_t num_contracted_index_pairs =
     425             :       contracted_type::num_contracted_index_pairs;
     426             :   /// Mapping from the positions of indices in the resultant contracted tensor
     427             :   /// to their positions in the operand uncontracted tensor
     428             :   static constexpr inline std::array<size_t, NumContractedIndices>
     429           1 :       index_transformation =
     430             :           contracted_type::index_transformation_and_contracted_pair_positions
     431             :               .first;
     432             :   /// Positions of the index pairs in the operand uncontracted tensor that we
     433             :   /// wish to contract
     434             :   static constexpr inline std::array<std::pair<size_t, size_t>,
     435             :                                      num_contracted_index_pairs>
     436           1 :       contracted_index_pair_positions =
     437             :           contracted_type::index_transformation_and_contracted_pair_positions
     438             :               .second;
     439             :   /// First concrete values of contracted indices to sum. This is to handle
     440             :   /// cases when we have generic spatial `TensorIndex`s used for spacetime
     441             :   /// indices, as the first concrete index value to contract will be 1 (first
     442             :   /// spatial index) instead of 0 (the time index). Contracted index pairs will
     443             :   /// have different "starting" concrete indices when one index in the pair is a
     444             :   /// spatial spacetime index and the other is not.
     445             :   static constexpr inline std::array<std::pair<size_t, size_t>,
     446             :                                      num_contracted_index_pairs>
     447           1 :       contracted_index_first_values =
     448             :           contracted_type::contracted_index_first_values;
     449             :   /// The dimensions of the indices in the uncontracted operand expression
     450             :   static constexpr inline std::array<size_t, num_uncontracted_tensor_indices>
     451           1 :       uncontracted_index_dims = contracted_type::uncontracted_index_dims;
     452             :   /// The number of terms to sum for this expression's contraction
     453           1 :   static constexpr size_t num_terms_summed = contracted_type::num_terms_summed;
     454             : 
     455             :   // === Expression subtree properties ===
     456             :   /// The number of arithmetic tensor operations done in the subtree for the
     457             :   /// left operand
     458           1 :   static constexpr size_t num_ops_left_child =
     459             :       T::num_ops_subtree * num_terms_summed + num_terms_summed - 1;
     460             :   /// The number of arithmetic tensor operations done in the subtree for the
     461             :   /// right operand. This is 0 because this expression represents a unary
     462             :   /// operation.
     463           1 :   static constexpr size_t num_ops_right_child = 0;
     464             :   /// The total number of arithmetic tensor operations done in this expression's
     465             :   /// whole subtree
     466           1 :   static constexpr size_t num_ops_subtree = num_ops_left_child;
     467             :   /// The height of this expression's node in the expression tree relative to
     468             :   /// the closest `TensorAsExpression` leaf in its subtree
     469           1 :   static constexpr size_t height_relative_to_closest_tensor_leaf_in_subtree =
     470             :       T::height_relative_to_closest_tensor_leaf_in_subtree !=
     471             :               std::numeric_limits<size_t>::max()
     472             :           ? T::height_relative_to_closest_tensor_leaf_in_subtree + 1
     473             :           : T::height_relative_to_closest_tensor_leaf_in_subtree;
     474             : 
     475             :   // === Properties for splitting up subexpressions along the primary path ===
     476             :   // These definitions only have meaning if this expression actually ends up
     477             :   // being along the primary path that is taken when evaluating the whole tree.
     478             :   // See documentation for `TensorExpression` for more details.
     479             :   /// If on the primary path, whether or not the expression is an ending point
     480             :   /// of a leg
     481           1 :   static constexpr bool is_primary_end = T::is_primary_start;
     482             :   /// If on the primary path, this is the remaining number of arithmetic tensor
     483             :   /// operations that need to be done in the subtree of the child along the
     484             :   /// primary path, given that we will have already computed the whole subtree
     485             :   /// at the next lowest leg's starting point.
     486           1 :   static constexpr size_t num_ops_to_evaluate_primary_left_child =
     487             :       is_primary_end
     488             :           ? num_ops_subtree - T::num_ops_subtree
     489             :           : T::num_ops_subtree * (num_terms_summed - 1) +
     490             :                 T::num_ops_to_evaluate_primary_subtree + num_terms_summed - 1;
     491             :   /// If on the primary path, this is the remaining number of arithmetic tensor
     492             :   /// operations that need to be done in the right operand's subtree. No
     493             :   /// splitting is currently done, so this is just `num_ops_right_child`.
     494           1 :   static constexpr size_t num_ops_to_evaluate_primary_right_child =
     495             :       num_ops_right_child;
     496             :   /// If on the primary path, this is the remaining number of arithmetic tensor
     497             :   /// operations that need to be done for this expression's subtree, given that
     498             :   /// we will have already computed the subtree at the next lowest leg's
     499             :   /// starting point
     500           1 :   static constexpr size_t num_ops_to_evaluate_primary_subtree =
     501             :       num_ops_to_evaluate_primary_left_child +
     502             :       num_ops_to_evaluate_primary_right_child;
     503             :   /// If on the primary path, whether or not the expression is a starting point
     504             :   /// of a leg
     505           1 :   static constexpr bool is_primary_start =
     506             :       num_ops_to_evaluate_primary_subtree >
     507             :       // Multiply by 2 because each term has a + and * operation, while other
     508             :       // arithmetic expression types do one operation
     509             :       2 * detail::max_num_ops_in_sub_expression<type>;
     510             :   /// If on the primary path, whether or not the expression's child along the
     511             :   /// primary path is a subtree that contains a starting point of a leg along
     512             :   /// the primary path
     513           1 :   static constexpr bool primary_child_subtree_contains_primary_start =
     514             :       T::primary_subtree_contains_primary_start;
     515             :   /// If on the primary path, whether or not this subtree contains a starting
     516             :   /// point of a leg along the primary path
     517           1 :   static constexpr bool primary_subtree_contains_primary_start =
     518             :       is_primary_start or primary_child_subtree_contains_primary_start;
     519             :   /// Number of arithmetic tensor operations done in the subtree of the operand
     520             :   /// expression being contracted
     521           1 :   static constexpr size_t num_ops_subexpression = T::num_ops_subtree;
     522             :   /// In the subtree for this contraction, how many terms we sum together for
     523             :   /// each leg of the contraction
     524           1 :   static constexpr size_t leg_length = []() {
     525             :     if constexpr (not is_primary_start) {
     526             :       // If we're not even stopping at the beginning of the contraction, it's
     527             :       // because there weren't enough terms to justify any splitting, so the
     528             :       // leg_length is just the total number of terms to sum
     529             :       return num_terms_summed;
     530             :     } else if constexpr (num_ops_subexpression >=
     531             :                          detail::max_num_ops_in_sub_expression<type>) {
     532             :       // If the subexpression itself has more than the max # of ops
     533             :       return 0;
     534             :     } else {
     535             :       // Otherwise, find how many terms to sum in each leg
     536             :       size_t length = 1;
     537             :       while (2 * (length * (num_ops_subexpression + 1) - 1) <=
     538             :              detail::max_num_ops_in_sub_expression<type>) {
     539             :         length *= 2;
     540             :       }
     541             :       return length;
     542             :     }
     543             :   }();
     544             :   /// After dividing up the contraction subtree into legs, the number of legs
     545             :   /// whose length is equal to `leg_length`
     546           1 :   static constexpr size_t num_full_legs =
     547             :       leg_length == 0 ? num_terms_summed : num_terms_summed / leg_length;
     548             :   /// After dividing up the contraction subtree into legs of even length, the
     549             :   /// number of terms we still have left to sum
     550           1 :   static constexpr size_t last_leg_length =
     551             :       leg_length == 0 ? 0 : num_terms_summed % leg_length;
     552             :   /// When evaluating along a primary path, whether each term's subtrees should
     553             :   /// be evaluated separately. Since `DataVector` expression runtime scales
     554             :   /// poorly with increased number of operations, evaluating individual terms'
     555             :   /// subtrees separately like this is beneficial when each term, itself,
     556             :   /// involves many tensor operations.
     557           1 :   static constexpr bool evaluate_terms_separately = leg_length == 0;
     558             : 
     559           0 :   explicit TensorContract(
     560             :       const TensorExpression<T, X, Symm, IndexList, ArgsList>& t)
     561             :       : t_(~t) {}
     562           0 :   ~TensorContract() override = default;
     563             : 
     564             :   /// \brief Assert that the LHS tensor of the equation does not also appear in
     565             :   /// this expression's subtree
     566             :   template <typename LhsTensor>
     567           1 :   SPECTRE_ALWAYS_INLINE void assert_lhs_tensor_not_in_rhs_expression(
     568             :       const gsl::not_null<LhsTensor*> lhs_tensor) const {
     569             :     if constexpr (not std::is_base_of_v<MarkAsNumberAsExpression, T>) {
     570             :       t_.assert_lhs_tensor_not_in_rhs_expression(lhs_tensor);
     571             :     }
     572             :   }
     573             : 
     574             :   /// \brief Assert that each instance of the LHS tensor in the RHS tensor
     575             :   /// expression uses the same generic index order that the LHS uses
     576             :   ///
     577             :   /// \tparam LhsTensorIndices the list of generic `TensorIndex`s of the LHS
     578             :   /// result `Tensor` being computed
     579             :   /// \param lhs_tensor the LHS result `Tensor` being computed
     580             :   template <typename LhsTensorIndices, typename LhsTensor>
     581           1 :   SPECTRE_ALWAYS_INLINE void assert_lhs_tensorindices_same_in_rhs(
     582             :       const gsl::not_null<LhsTensor*> lhs_tensor) const {
     583             :     if constexpr (not std::is_base_of_v<MarkAsNumberAsExpression, T>) {
     584             :       t_.template assert_lhs_tensorindices_same_in_rhs<LhsTensorIndices>(
     585             :           lhs_tensor);
     586             :     }
     587             :   }
     588             : 
     589             :   /// \brief Get the size of a component from a `Tensor` in this expression's
     590             :   /// subtree of the RHS `TensorExpression`
     591             :   ///
     592             :   /// \return the size of a component from a `Tensor` in this expression's
     593             :   /// subtree of the RHS `TensorExpression`
     594           1 :   SPECTRE_ALWAYS_INLINE size_t get_rhs_tensor_component_size() const {
     595             :     return t_.get_rhs_tensor_component_size();
     596             :   }
     597             : 
     598             :   /// \brief Return the highest multi-index between the components being summed
     599             :   /// in the contraction
     600             :   ///
     601             :   /// \details
     602             :   /// Example:
     603             :   /// We have expression `R(ti::A, ti::b, ti::a)` to represent the contraction
     604             :   /// \f$L_b = R^{a}{}_{ba}\f$. If the `contracted_multi_index` is `{1}`, which
     605             :   /// represents \f$L_1 = R^{a}{}_{1a}\f$, and the dimension of \f$a\f$ is 3,
     606             :   /// then we will need to sum the following terms: \f$R^{0}{}_{10}\f$,
     607             :   /// \f$R^{1}{}_{11}\f$, and \f$R^{2}{}_{12}\f$. Between the terms being
     608             :   /// summed, the multi-index whose values are the largest is
     609             :   /// \f$R^{2}{}_{12}\f$, so this function would return `{2, 1, 2}`.
     610             :   ///
     611             :   /// \param contracted_multi_index the multi-index of a component of the
     612             :   /// contracted expression
     613             :   /// \return the highest multi-index between the components being summed in
     614             :   /// the contraction
     615             :   SPECTRE_ALWAYS_INLINE static constexpr std::array<
     616             :       size_t, num_uncontracted_tensor_indices>
     617           1 :   get_highest_multi_index_to_sum(
     618             :       const std::array<size_t, num_tensor_indices>& contracted_multi_index) {
     619             :     // Initialize with placeholders for debugging
     620             :     auto highest_multi_index = make_array<num_uncontracted_tensor_indices>(
     621             :         std::numeric_limits<size_t>::max());
     622             : 
     623             :     // Fill uncontracted indices
     624             :     for (size_t i = 0; i < num_tensor_indices; i++) {
     625             :       gsl::at(highest_multi_index, gsl::at(index_transformation, i)) =
     626             :           gsl::at(contracted_multi_index, i);
     627             :     }
     628             : 
     629             :     // Fill contracted indices
     630             :     for (size_t i = 0; i < num_contracted_index_pairs; i++) {
     631             :       const size_t first_index_position_in_pair =
     632             :           gsl::at(contracted_index_pair_positions, i).first;
     633             :       const size_t second_index_position_in_pair =
     634             :           gsl::at(contracted_index_pair_positions, i).second;
     635             :       gsl::at(highest_multi_index, first_index_position_in_pair) =
     636             :           gsl::at(uncontracted_index_dims, first_index_position_in_pair) - 1;
     637             :       gsl::at(highest_multi_index, second_index_position_in_pair) =
     638             :           gsl::at(uncontracted_index_dims, second_index_position_in_pair) - 1;
     639             :     }
     640             : 
     641             :     return highest_multi_index;
     642             :   }
     643             : 
     644             :   /// \brief Return the lowest multi-index between the components being summed
     645             :   /// in the contraction
     646             :   ///
     647             :   /// \details
     648             :   /// Example:
     649             :   /// We have expression `R(ti::A, ti::b, ti::a)` to represent the contraction
     650             :   /// \f$L_b = R^{a}{}_{ba}\f$. If the `contracted_multi_index` is `{1}`, which
     651             :   /// represents \f$L_1 = R^{a}{}_{1a}\f$, and the dimension of \f$a\f$ is 3,
     652             :   /// then we will need to sum the following terms: \f$R^{0}{}_{10}\f$,
     653             :   /// \f$R^{1}{}_{11}\f$, and \f$R^{2}{}_{12}\f$. Between the terms being
     654             :   /// summed, the multi-index whose values are the smallest is
     655             :   /// \f$R^{0}{}_{10}\f$, so this function would return `{0, 1, 0}`.
     656             :   ///
     657             :   /// \param contracted_multi_index the multi-index of a component of the
     658             :   /// contracted expression
     659             :   /// \return the lowest multi-index between the components being summed in
     660             :   /// the contraction
     661             :   SPECTRE_ALWAYS_INLINE static constexpr std::array<
     662             :       size_t, num_uncontracted_tensor_indices>
     663           1 :   get_lowest_multi_index_to_sum(
     664             :       const std::array<size_t, num_tensor_indices>& contracted_multi_index) {
     665             :     // Initialize with placeholders for debugging
     666             :     auto lowest_multi_index = make_array<num_uncontracted_tensor_indices>(
     667             :         std::numeric_limits<size_t>::max());
     668             : 
     669             :     // Fill uncontracted indices
     670             :     for (size_t i = 0; i < num_tensor_indices; i++) {
     671             :       gsl::at(lowest_multi_index, gsl::at(index_transformation, i)) =
     672             :           gsl::at(contracted_multi_index, i);
     673             :     }
     674             : 
     675             :     // Fill contracted indices
     676             :     for (size_t i = 0; i < num_contracted_index_pairs; i++) {
     677             :       const size_t first_index_position_in_pair =
     678             :           gsl::at(contracted_index_pair_positions, i).first;
     679             :       const size_t second_index_position_in_pair =
     680             :           gsl::at(contracted_index_pair_positions, i).second;
     681             :       gsl::at(lowest_multi_index, first_index_position_in_pair) =
     682             :           gsl::at(contracted_index_first_values, i).first;
     683             :       gsl::at(lowest_multi_index, second_index_position_in_pair) =
     684             :           gsl::at(contracted_index_first_values, i).second;
     685             :     }
     686             : 
     687             :     return lowest_multi_index;
     688             :   }
     689             : 
     690             :   /// \brief Given the multi-index of one term being summed in the contraction,
     691             :   /// return the next highest multi-index of a component being summed
     692             :   ///
     693             :   /// \details
     694             :   /// What is meant by "next highest" is implementation defined, but generally
     695             :   /// means, of the components being summed, return the multi-index that results
     696             :   /// from lowering one of the contracted index pairs' values by one.
     697             :   ///
     698             :   /// Example:
     699             :   /// We have expression `R(ti::A, ti::b, ti::a)` to represent the contraction
     700             :   /// \f$L_b = R^{a}{}_{ba}\f$. If we are evaluating \f$L_1 = R^{a}{}_{1a}\f$
     701             :   ///  and the dimension of \f$a\f$ is 3, then we will need to sum the following
     702             :   /// terms: \f$R^{0}{}_{10}\f$, \f$R^{1}{}_{11}\f$, and \f$R^{2}{}_{12}\f$.
     703             :   /// If `uncontracted_multi_index` is `{1, 1, 1}`, then the "next highest"
     704             :   /// multi-index is the result of lowering the values of the \f$a\f$ indices by
     705             :   /// 1. The component with that resulting multi-index is \f$R^{0}{}_{10}\f$, so
     706             :   /// this function would return `{0, 1, 0}`.
     707             :   ///
     708             :   /// Note: this function should perform the inverse functionality of
     709             :   /// `get_next_highest_multi_index_to_sum`. If the implementation of this
     710             :   /// function or the other changes what is meant by "next highest" or "next
     711             :   /// lowest," the other function should be updated in accordance.
     712             :   ///
     713             :   /// \param uncontracted_multi_index the multi-index of one of the components
     714             :   /// of the uncontracted operand expression to sum
     715             :   /// \return the next highest multi-index between the components being summed
     716             :   /// in the contraction
     717             :   SPECTRE_ALWAYS_INLINE static std::array<size_t,
     718             :                                           num_uncontracted_tensor_indices>
     719           1 :   get_next_highest_multi_index_to_sum(
     720             :       const std::array<size_t, num_uncontracted_tensor_indices>&
     721             :           uncontracted_multi_index) {
     722             :     std::array<size_t, num_uncontracted_tensor_indices>
     723             :         next_highest_uncontracted_multi_index = uncontracted_multi_index;
     724             : 
     725             :     size_t i = 0;
     726             :     while (i < num_contracted_index_pairs) {
     727             :       // the position of the first index in a pair being contracted
     728             :       const size_t current_index_first_position =
     729             :           gsl::at(contracted_index_pair_positions, i).first;
     730             :       // the position of the second index in a pair being contracted
     731             :       const size_t current_index_second_position =
     732             :           gsl::at(contracted_index_pair_positions, i).second;
     733             :       // of the values being summed over, the lowest concrete value of the first
     734             :       // index in the contracted pair
     735             :       const size_t current_index_first_first_value =
     736             :           gsl::at(contracted_index_first_values, i).first;
     737             : 
     738             :       // decrement the current index pair's values
     739             :       gsl::at(next_highest_uncontracted_multi_index,
     740             :               current_index_first_position)--;
     741             :       gsl::at(next_highest_uncontracted_multi_index,
     742             :               current_index_second_position)--;
     743             : 
     744             :       // If the index values of the index pair being contracted aren't lower
     745             :       // than the minimum values included in the summation, then we're done
     746             :       // computing this next multi-index
     747             :       if (not(gsl::at(next_highest_uncontracted_multi_index,
     748             :                       current_index_first_position) <
     749             :                   current_index_first_first_value or
     750             :               gsl::at(next_highest_uncontracted_multi_index,
     751             :                       current_index_first_position) >
     752             :                   gsl::at(uncontracted_index_dims,
     753             :                           current_index_first_position))) {
     754             :         break;
     755             :       }
     756             :       // Otherwise, we've wrapped around the lowest value being summed over for
     757             :       // this index, so we need to set it back to the maximum values being
     758             :       // summed and "carry" the decrementing over to the next contracted pair's
     759             :       // values
     760             :       gsl::at(next_highest_uncontracted_multi_index,
     761             :               current_index_first_position) =
     762             :           gsl::at(uncontracted_index_dims, current_index_first_position) - 1;
     763             :       gsl::at(next_highest_uncontracted_multi_index,
     764             :               current_index_second_position) =
     765             :           gsl::at(uncontracted_index_dims, current_index_second_position) - 1;
     766             : 
     767             :       i++;
     768             :     }
     769             : 
     770             :     return next_highest_uncontracted_multi_index;
     771             :   }
     772             : 
     773             :   /// \brief Given the multi-index of one term being summed in the contraction,
     774             :   /// return the next lowest multi-index of a component being summed
     775             :   ///
     776             :   /// \details
     777             :   /// What is meant by "next lowest" is implementation defined, but generally
     778             :   /// means, of the components being summed, return the multi-index that results
     779             :   /// from raising one of the contracted index pairs' values by one.
     780             :   ///
     781             :   /// Example:
     782             :   /// We have expression `R(ti::A, ti::b, ti::a)` to represent the contraction
     783             :   /// \f$L_b = R^{a}{}_{ba}\f$. If we are evaluating \f$L_1 = R^{a}{}_{1a}\f$
     784             :   ///  and the dimension of \f$a\f$ is 3, then we will need to sum the following
     785             :   /// terms: \f$R^{0}{}_{10}\f$, \f$R^{1}{}_{11}\f$, and \f$R^{2}{}_{12}\f$.
     786             :   /// If `uncontracted_multi_index` is `{1, 1, 1}`, then the "next lowest"
     787             :   /// multi-index is the result of raising the values of the \f$a\f$ indices by
     788             :   /// 1. The component with that resulting multi-index is \f$R^{2}{}_{12}\f$, so
     789             :   /// this function would return `{2, 1, 2}`.
     790             :   ///
     791             :   /// Note: this function should perform the inverse functionality of
     792             :   /// `get_next_lowest_multi_index_to_sum`. If the implementation of this
     793             :   /// function or the other changes what is meant by "next highest" or "next
     794             :   /// lowest," the other function should be updated in accordance.
     795             :   ///
     796             :   /// \param uncontracted_multi_index the multi-index of one of the components
     797             :   /// of the uncontracted operand expression to sum
     798             :   /// \return the next lowest multi-index between the components being summed in
     799             :   /// the contraction
     800             :   SPECTRE_ALWAYS_INLINE static std::array<size_t,
     801             :                                           num_uncontracted_tensor_indices>
     802           1 :   get_next_lowest_multi_index_to_sum(
     803             :       const std::array<size_t, num_uncontracted_tensor_indices>&
     804             :           uncontracted_multi_index) {
     805             :     std::array<size_t, num_uncontracted_tensor_indices>
     806             :         next_lowest_uncontracted_multi_index = uncontracted_multi_index;
     807             : 
     808             :     size_t i = 0;
     809             :     while (i < num_contracted_index_pairs) {
     810             :       // the position of the first index in a pair being contracted
     811             :       const size_t current_index_first_position =
     812             :           gsl::at(contracted_index_pair_positions, i).first;
     813             :       // the position of the second index in a pair being contracted
     814             :       const size_t current_index_second_position =
     815             :           gsl::at(contracted_index_pair_positions, i).second;
     816             : 
     817             :       // increment the current index pair's values
     818             :       gsl::at(next_lowest_uncontracted_multi_index,
     819             :               current_index_first_position)++;
     820             :       gsl::at(next_lowest_uncontracted_multi_index,
     821             :               current_index_second_position)++;
     822             : 
     823             :       // if the previous index value is > dim, then we've wrapped around
     824             :       // and we need to go again
     825             :       // If the index values of the index pair being contracted aren't higher
     826             :       // than the maximum values included in the summation, then we're done
     827             :       // computing this next multi-index
     828             :       if (not(gsl::at(next_lowest_uncontracted_multi_index,
     829             :                       current_index_first_position) >
     830             :               gsl::at(uncontracted_index_dims, current_index_first_position) -
     831             :                   1)) {
     832             :         break;
     833             :       }
     834             :       // Otherwise, we've wrapped around the highest value being summed over for
     835             :       // this index, so we need to set it back to the minimum values being
     836             :       // summed and "carry" the incrementing over to the next contracted pair's
     837             :       // values
     838             :       gsl::at(next_lowest_uncontracted_multi_index,
     839             :               current_index_first_position) =
     840             :           gsl::at(contracted_index_first_values, i).first;
     841             :       gsl::at(next_lowest_uncontracted_multi_index,
     842             :               current_index_second_position) =
     843             :           gsl::at(contracted_index_first_values, i).second;
     844             : 
     845             :       i++;
     846             :     }
     847             : 
     848             :     return next_lowest_uncontracted_multi_index;
     849             :   }
     850             : 
     851             :   /// \brief Computes the value of a component in the resultant contracted
     852             :   /// tensor
     853             :   ///
     854             :   /// \details
     855             :   /// The contraction is computed by recursively adding up each component in the
     856             :   /// summation, across all index pairs being contracted in the operand
     857             :   /// expression. This function is called `Iteration = num_terms_summed` times,
     858             :   /// once for each uncontracted tensor component being summed. It should
     859             :   /// externally be called for the first time with `Iteration == 0` and
     860             :   /// `current_multi_index == <highest multi index to sum>` (see
     861             :   /// `get_next_highest_multi_index_to_sum` for details).
     862             :   ///
     863             :   /// In performing the recursive summation, the recursion is
     864             :   /// specifically done "to the left," in that this function returns
     865             :   /// `compute_contraction(next index) + get(this_index)` as opposed to
     866             :   /// `get(this_index) + compute_contraction`. Benchmarking has shown that
     867             :   /// increased breadth in an equation's expression tree can slow down runtime.
     868             :   /// By "recursing left" here, we  minimize breadth in the overall tree for an
     869             :   /// equation, as both `AddSub` addition and `OuterProduct` (other expressions
     870             :   /// with two children) make efforts to make their operands with larger
     871             :   /// subtrees be their left operand.
     872             :   ///
     873             :   /// \tparam Iteration the nth term to sum, where n is between
     874             :   /// [0, num_terms_summed)
     875             :   /// \param t the expression contained within this contraction expression
     876             :   /// \param current_multi_index the multi-index of the uncontracted tensor
     877             :   /// component to retrieve
     878             :   /// \return the value of a component of the resulant contracted tensor
     879             :   template <size_t Iteration>
     880           1 :   SPECTRE_ALWAYS_INLINE static decltype(auto) compute_contraction(
     881             :       const T& t, const std::array<size_t, num_uncontracted_tensor_indices>&
     882             :                       current_multi_index) {
     883             :     if constexpr (Iteration < num_terms_summed - 1) {
     884             :       // We have more than one component left to sum
     885             :       return compute_contraction<Iteration + 1>(
     886             :                  t, get_next_highest_multi_index_to_sum(current_multi_index)) +
     887             :              t.get(current_multi_index);
     888             :     } else {
     889             :       // We only have one final component to sum
     890             :       return t.get(current_multi_index);
     891             :     }
     892             :   }
     893             : 
     894             :   /// \brief Return the value of the component of the resultant contracted
     895             :   /// tensor at a given multi-index
     896             :   ///
     897             :   /// \param contracted_multi_index the multi-index of the resultant contracted
     898             :   /// tensor component to retrieve
     899             :   /// \return the value of the component at `contracted_multi_index` in the
     900             :   /// resultant contracted tensor
     901           1 :   decltype(auto) get(const std::array<size_t, num_tensor_indices>&
     902             :                          contracted_multi_index) const {
     903             :     return compute_contraction<0>(
     904             :         t_, get_highest_multi_index_to_sum(contracted_multi_index));
     905             :   }
     906             : 
     907             :   /// \brief Computes the result of an internal leg of the contraction
     908             :   ///
     909             :   /// \details
     910             :   /// This function differs from `compute_contraction` and
     911             :   /// `compute_contraction_primary` in that it only computes one leg of the
     912             :   /// whole contraction, as opposed to the whole contraction.
     913             :   ///
     914             :   /// The leg being summed is defined by the `current_multi_index` and
     915             :   /// `Iteration` passed in from the inital external call: consecutive terms
     916             :   /// will be summed until the base case `Iteration == 0` is reached.
     917             :   ///
     918             :   /// \tparam Iteration the nth term in the leg to sum, where n is between
     919             :   /// [0, leg_length)
     920             :   /// \param t the expression contained within this contraction expression
     921             :   /// \param current_multi_index the multi-index of the uncontracted tensor
     922             :   /// component to retrieve as part of this leg's summation
     923             :   /// \param next_leg_starting_multi_index in the final iteration, the
     924             :   /// multi-index to update to be the next leg's starting multi-index
     925             :   /// \return the result of summing up the terms in the given leg
     926             :   template <size_t Iteration>
     927           1 :   SPECTRE_ALWAYS_INLINE static decltype(auto) compute_contraction_leg(
     928             :       const T& t,
     929             :       const std::array<size_t, num_uncontracted_tensor_indices>&
     930             :           current_multi_index,
     931             :       std::array<size_t, num_uncontracted_tensor_indices>&
     932             :           next_leg_starting_multi_index) {
     933             :     if constexpr (Iteration != 0) {
     934             :       // We have more than one component left to sum
     935             :       (void)next_leg_starting_multi_index;
     936             :       return compute_contraction_leg<Iteration - 1>(
     937             :                  t, get_next_highest_multi_index_to_sum(current_multi_index),
     938             :                  next_leg_starting_multi_index) +
     939             :              t.get(current_multi_index);
     940             :     } else {
     941             :       // We only have one final component to sum
     942             :       next_leg_starting_multi_index =
     943             :           get_next_highest_multi_index_to_sum(current_multi_index);
     944             :       return t.get(current_multi_index);
     945             :     }
     946             :   }
     947             : 
     948             :   /// \brief Computes the value of a component in the resultant contracted
     949             :   /// tensor
     950             :   ///
     951             :   /// \details
     952             :   /// First see `compute_contraction` for details on basic functionality.
     953             :   ///
     954             :   /// This function differs from `compute_contraction` in that it takes into
     955             :   /// account whether we have already computed part of the result component at a
     956             :   /// lower subtree. In recursively computing this contraction, the current
     957             :   /// result component will be substituted in for the most recent (highest)
     958             :   /// subtree below it that has already been evaluated.
     959             :   ///
     960             :   /// \tparam Iteration the nth term to sum, where n is between
     961             :   /// [0, num_terms_summed)
     962             :   /// \param t the expression contained within this contraction expression
     963             :   /// \param result_component the LHS tensor component to evaluate
     964             :   /// \param current_multi_index the multi-index of the uncontracted tensor
     965             :   /// component to retrieve
     966             :   /// \return the value of a component of the resulant contracted tensor
     967             :   template <size_t Iteration>
     968           1 :   SPECTRE_ALWAYS_INLINE static decltype(auto) compute_contraction_primary(
     969             :       const T& t, const type& result_component,
     970             :       const std::array<size_t, num_uncontracted_tensor_indices>&
     971             :           current_multi_index) {
     972             :     if constexpr (is_primary_end) {
     973             :       // We've already computed the whole subtree of the term being summed that
     974             :       // is at the lowest depth in the tree
     975             :       if constexpr (Iteration < num_terms_summed - 1) {
     976             :         // We have more than one component left to sum
     977             :         return compute_contraction_primary<Iteration + 1>(
     978             :                    t, result_component,
     979             :                    get_next_highest_multi_index_to_sum(current_multi_index)) +
     980             :                t.get(current_multi_index);
     981             :       } else {
     982             :         // The deepest term in the contraction subtree that is being summed is
     983             :         // just our current result, so return it
     984             :         return result_component;
     985             :       }
     986             :     } else {
     987             :       // We've haven't yet computed the whole subtree of the term being summed
     988             :       // that is at the lowest depth in the tree
     989             :       if constexpr (Iteration < num_terms_summed - 1) {
     990             :         // We have more than one component left to sum
     991             :         return compute_contraction_primary<Iteration + 1>(
     992             :                    t, result_component,
     993             :                    get_next_highest_multi_index_to_sum(current_multi_index)) +
     994             :                t.get(current_multi_index);
     995             :       } else {
     996             :         // We only have one final component to sum
     997             :         return t.get_primary(result_component, current_multi_index);
     998             :       }
     999             :     }
    1000             :   }
    1001             : 
    1002             :   /// \brief Return the value of the component of the resultant contracted
    1003             :   /// tensor at a given multi-index
    1004             :   ///
    1005             :   /// \details
    1006             :   /// This function differs from `get` in that it takes into account whether we
    1007             :   /// have already computed part of the result component at a lower subtree.
    1008             :   /// In recursively computing this contraction, the current result component
    1009             :   /// will be substituted in for the most recent (highest) subtree below it that
    1010             :   /// has already been evaluated.
    1011             :   ///
    1012             :   /// \param result_component the LHS tensor component to evaluate
    1013             :   /// \param contracted_multi_index the multi-index of the resultant contracted
    1014             :   /// tensor component to retrieve
    1015             :   /// \return the value of the component at `contracted_multi_index` in the
    1016             :   /// resultant contracted tensor
    1017             :   template <typename ResultType>
    1018           1 :   SPECTRE_ALWAYS_INLINE decltype(auto) get_primary(
    1019             :       const ResultType& result_component,
    1020             :       const std::array<size_t, num_tensor_indices>& contracted_multi_index)
    1021             :       const {
    1022             :     return compute_contraction_primary<0>(
    1023             :         t_, result_component,
    1024             :         get_highest_multi_index_to_sum(contracted_multi_index));
    1025             :   }
    1026             : 
    1027             :   /// \brief Successively evaluate the LHS Tensor's result component at each
    1028             :   /// leg of summations within the contraction expression
    1029             :   ///
    1030             :   /// \details
    1031             :   /// This function takes into account whether we have already computed part of
    1032             :   /// the result component at a lower subtree. In recursively computing this
    1033             :   /// contraction, the current result component will be substituted in for the
    1034             :   /// most recent (highest) subtree below it that has already been evaluated.
    1035             :   ///
    1036             :   /// \param result_component the LHS tensor component to evaluate
    1037             :   /// \param contracted_multi_index the multi-index of the component of the
    1038             :   /// contracted result tensor to evaluate
    1039             :   /// \param lowest_multi_index the lowest multi-index between the components
    1040             :   /// being summed in the contraction (see `get_lowest_multi_index_to_sum`)
    1041           1 :   SPECTRE_ALWAYS_INLINE void evaluate_primary_contraction(
    1042             :       type& result_component,
    1043             :       const std::array<size_t, num_tensor_indices>& contracted_multi_index,
    1044             :       const std::array<size_t, num_uncontracted_tensor_indices>&
    1045             :           lowest_multi_index) const {
    1046             :     if constexpr (not is_primary_end) {
    1047             :       // We need to first evaluate the subtree of the term being summed that
    1048             :       // is deepest in the tree
    1049             :       result_component =
    1050             :           t_.template get_primary(result_component, lowest_multi_index);
    1051             :     }
    1052             : 
    1053             :     if constexpr (evaluate_terms_separately) {
    1054             :       // Case 1: Evaluate all of the remaining terms, one TERM at a time
    1055             :       (void)contracted_multi_index;
    1056             :       std::array<size_t, num_uncontracted_tensor_indices> current_multi_index =
    1057             :           lowest_multi_index;
    1058             :       for (size_t i = 1; i < num_terms_summed; i++) {
    1059             :         const std::array<size_t, num_uncontracted_tensor_indices>
    1060             :             next_lowest_multi_index_to_sum =
    1061             :                 get_next_lowest_multi_index_to_sum(current_multi_index);
    1062             :         result_component += t_.get(next_lowest_multi_index_to_sum);
    1063             :         current_multi_index = next_lowest_multi_index_to_sum;
    1064             :       }
    1065             :     } else {
    1066             :       // Case 2: Evaluate all of the remaining terms, one LEG at a time
    1067             :       (void)lowest_multi_index;
    1068             :       std::array<size_t, num_uncontracted_tensor_indices>
    1069             :           next_leg_starting_multi_index =
    1070             :               get_highest_multi_index_to_sum(contracted_multi_index);
    1071             :       if constexpr (last_leg_length > 0) {
    1072             :         // Case 2a: We have a remainder of terms that don't make up a full leg
    1073             :         // length
    1074             : 
    1075             :         // Evaluate all the full-length legs
    1076             :         for (size_t i = 0; i < num_full_legs; i++) {
    1077             :           const std::array<size_t, num_uncontracted_tensor_indices>
    1078             :               current_multi_index = next_leg_starting_multi_index;
    1079             :           result_component += compute_contraction_leg<leg_length - 1>(
    1080             :               t_, current_multi_index, next_leg_starting_multi_index);
    1081             :         }
    1082             :         if constexpr (last_leg_length > 1) {
    1083             :           // Get rest of the deepest (partial-length) leg if there are more
    1084             :           // terms in it than just the one deepest term we already computed
    1085             :           const std::array<size_t, num_uncontracted_tensor_indices>
    1086             :               current_multi_index = next_leg_starting_multi_index;
    1087             :           result_component +=
    1088             :               // start at last_leg_length - 2 because we already computed one of
    1089             :               // the terms in this deepest leg (the deepest term)
    1090             :               compute_contraction_leg<last_leg_length - 2>(
    1091             :                   t_, current_multi_index, next_leg_starting_multi_index);
    1092             :         }
    1093             :       } else {
    1094             :         // Case 2b: We don't have remaining terms that only make up a
    1095             :         // partial leg length (i.e. we only have full-length legs)
    1096             : 
    1097             :         // Evaluate all but the deepest leg
    1098             :         for (size_t i = 1; i < num_full_legs; i++) {
    1099             :           const std::array<size_t, num_uncontracted_tensor_indices>
    1100             :               current_multi_index = next_leg_starting_multi_index;
    1101             : 
    1102             :           result_component += compute_contraction_leg<leg_length - 1>(
    1103             :               t_, current_multi_index, next_leg_starting_multi_index);
    1104             :         }
    1105             : 
    1106             :         if constexpr (leg_length > 1) {
    1107             :           // Get rest of the deepest leg if there are more terms in it than
    1108             :           // just the one deepest term we already computed
    1109             :           const std::array<size_t, num_uncontracted_tensor_indices>
    1110             :               current_multi_index = next_leg_starting_multi_index;
    1111             :           result_component +=
    1112             :               // start at leg_length - 2 because we already computed one of the
    1113             :               // terms in this deepest leg (the deepest term)
    1114             :               compute_contraction_leg<leg_length - 2>(
    1115             :                   t_, current_multi_index, next_leg_starting_multi_index);
    1116             :         }
    1117             :       }
    1118             :     }
    1119             :   }
    1120             : 
    1121             :   /// \brief Successively evaluate the LHS Tensor's result component at each
    1122             :   /// leg in this expression's subtree
    1123             :   ///
    1124             :   /// \details
    1125             :   /// This function takes into account whether we have already computed part of
    1126             :   /// the result component at a lower subtree. In recursively computing this
    1127             :   /// contraction, the current result component will be substituted in for the
    1128             :   /// most recent (highest) subtree below it that has already been evaluated.
    1129             :   ///
    1130             :   /// If this contraction expression is the beginning of a leg,
    1131             :   /// `evaluate_primary_contraction` is called to evaluate each individual
    1132             :   /// leg of summations within the contraction.
    1133             :   ///
    1134             :   /// \param result_component the LHS tensor component to evaluate
    1135             :   /// \param contracted_multi_index the multi-index of the component of the
    1136             :   /// contracted result tensor to evaluate
    1137             :   template <typename ResultType>
    1138           1 :   SPECTRE_ALWAYS_INLINE void evaluate_primary_subtree(
    1139             :       ResultType& result_component,
    1140             :       const std::array<size_t, num_tensor_indices>& contracted_multi_index)
    1141             :       const {
    1142             :     const auto lowest_multi_index_to_sum =
    1143             :         get_lowest_multi_index_to_sum(contracted_multi_index);
    1144             :     if constexpr (primary_child_subtree_contains_primary_start) {
    1145             :       // The primary child's subtree contains at least one leg, so recurse down
    1146             :       // and evaluate that first. Here, we evaluate the lowest multi-index
    1147             :       // because, according to `compute_contraction`, the lowest multi-index is
    1148             :       // the one in the last/leaf/final call to `compute_contraction` (i.e. the
    1149             :       // multi-index of the final term to sum)
    1150             :       t_.template evaluate_primary_subtree(result_component,
    1151             :                                            lowest_multi_index_to_sum);
    1152             :     }
    1153             :     if constexpr (is_primary_start) {
    1154             :       // We want to evaluate the subtree for this expression, one leg of
    1155             :       // summations at a time
    1156             :       evaluate_primary_contraction(result_component, contracted_multi_index,
    1157             :                                    lowest_multi_index_to_sum);
    1158             :     }
    1159             :   }
    1160             : 
    1161             :  private:
    1162             :   /// Operand expression being contracted
    1163           1 :   T t_;
    1164             : };
    1165             : 
    1166             : template <typename T, typename X, typename Symm, typename IndexList,
    1167             :           typename... TensorIndices>
    1168           0 : SPECTRE_ALWAYS_INLINE static constexpr auto contract(
    1169             :     const TensorExpression<T, X, Symm, IndexList, tmpl::list<TensorIndices...>>&
    1170             :         t) {
    1171             :   // Number of indices in the tensor expression we're wanting to contract
    1172             :   constexpr size_t num_uncontracted_indices = sizeof...(TensorIndices);
    1173             :   // The number of index pairs to contract
    1174             :   constexpr size_t num_contracted_index_pairs =
    1175             :       detail::get_num_contracted_index_pairs<num_uncontracted_indices>(
    1176             :           {{TensorIndices::value...}});
    1177             : 
    1178             :   if constexpr (num_contracted_index_pairs == 0) {
    1179             :     // There aren't any indices to contract, so we just return the input
    1180             :     return ~t;
    1181             :   } else {
    1182             :     // We have at least one pair of indices to contract
    1183             :     return TensorContract<T, X, Symm, IndexList, tmpl::list<TensorIndices...>,
    1184             :                           num_uncontracted_indices -
    1185             :                               (num_contracted_index_pairs * 2)>{t};
    1186             :   }
    1187             : }
    1188             : }  // namespace tenex

Generated by: LCOV version 1.14