Product.hpp
Go to the documentation of this file.
1 // Distributed under the MIT License.
2 // See LICENSE.txt for details.
3 
4 /// \file
5 /// Defines ET for tensor products
6 
7 #pragma once
8 
9 #include "Data/Tensor/Expressions/TensorExpression.hpp"
10 
11 namespace TensorExpressions {
12 
13 /*!
14  * \ingroup TensorExpressionsGroup
15  *
16  * @tparam T1
17  * @tparam T2
18  * @tparam ArgsList1
19  * @tparam ArgsList2
20  */
21 template <typename T1, typename T2, typename ArgsList1, typename ArgsList2>
22 struct Product;
23 
24 template <typename T1, typename T2, template <typename...> class ArgsList1,
25  template <typename...> class ArgsList2, typename... Args1,
26  typename... Args2>
27 struct Product<T1, T2, ArgsList1<Args1...>, ArgsList2<Args2...>>
28  : public TensorExpression<
29  Product<T1, T2, ArgsList1<Args1...>, ArgsList2<Args2...>>,
30  typename T1::type, double,
31  tmpl::append<typename T1::index_list, typename T2::index_list>,
32  tmpl::sort<
33  tmpl::append<typename T1::args_list, typename T2::args_list>>>,
34  public Expression {
36  "Cannot product Tensors holding different data types.");
37  using max_symm2 = tmpl::fold<typename T2::symmetry, tmpl::uint32_t<0>,
38  tmpl::max<tmpl::_state, tmpl::_element>>;
39 
40  using type = typename T1::type;
41  using symmetry = tmpl::append<
42  tmpl::transform<typename T1::symmetry, tmpl::plus<tmpl::_1, max_symm2>>,
43  typename T2::symmetry>;
44  using index_list =
45  tmpl::append<typename T1::index_list, typename T2::index_list>;
46  static constexpr auto num_tensor_indices =
47  tmpl::size<index_list>::value == 0 ? 1 : tmpl::size<index_list>::value;
48  using args_list =
49  tmpl::sort<tmpl::append<typename T1::args_list, typename T2::args_list>>;
50 
51  Product(const T1& t1, const T2& t2) : t1_(t1), t2_(t2) {}
52 
53  // TODO: The args will need to be reduced in a careful manner, which means
54  // they need to be reduced together, then split at the correct length so that
55  // the indexing is correct.
56  template <typename... LhsIndices, typename U>
58  get(const std::array<U, num_tensor_indices>& tensor_index) const {
59  return t1_.template get<LhsIndices...>(tensor_index) *
60  t2_.template get<LhsIndices...>(tensor_index);
61  }
62 
63  private:
64  const T1 t1_;
65  const T2 t2_;
66 };
67 
68 } // namespace TensorExpressions
69 
70 /*!
71  * @ingroup TensorExpressionsGroup
72  *
73  * @tparam T1
74  * @tparam T2
75  * @tparam X
76  * @tparam Symm1
77  * @tparam Symm2
78  * @tparam IndexList1
79  * @tparam IndexList2
80  * @tparam Args1
81  * @tparam Args2
82  * @param t1
83  * @param t2
84  * @return
85  */
86 template <typename T1, typename T2, typename X, typename Symm1, typename Symm2,
87  typename IndexList1, typename IndexList2, typename Args1,
88  typename Args2>
90  const TensorExpression<T1, X, Symm1, IndexList1, Args1>& t1,
91  const TensorExpression<T2, X, Symm2, IndexList2, Args2>& t2) {
92  // static_assert(tmpl::size<Args1>::value == tmpl::size<Args2>::value,
93  // "Tensor addition is only possible with the same rank
94  // tensors");
95  // static_assert(tmpl::equal_members<Args1, Args2>::value,
96  // "The indices when adding two tensors must be equal. This
97  // error "
98  // "occurs from expressions like A(_a, _b) + B(_c, _a)");
100  typename std::conditional<
102  TensorExpression<T1, X, Symm1, IndexList1, Args1>>::type,
103  typename std::conditional<
105  TensorExpression<T2, X, Symm2, IndexList2, Args2>>::type,
106  Args1, Args2>(~t1, ~t2);
107 }
auto operator*(const TensorExpression< T1, X, Symm1, IndexList1, Args1 > &t1, const TensorExpression< T2, X, Symm2, IndexList2, Args2 > &t2)
Definition: Product.hpp:89
Marks a class as being a TensorExpression.
Definition: TensorExpression.hpp:271
#define SPECTRE_ALWAYS_INLINE
Always inline a function. Only use this if you benchmarked the code.
Definition: ForceInline.hpp:20
Definition: Product.hpp:22
Definition: AddSubtract.hpp:28