pRC
multi-purpose Tensor Train library for C++
Loading...
Searching...
No Matches
contract.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_CORE_TENSOR_FUNCTIONS_CONTRACT_H
4#define pRC_CORE_TENSOR_FUNCTIONS_CONTRACT_H
5
10
11namespace pRC
12{
13 template<Index... Is>
14 struct Contract;
15
29 template<Index... Is, class X, class R = RemoveReference<X>,
31 If<IsSatisfied<(isEven(sizeof...(Is)))>> = 0,
32 If<IsSatisfied<(sizeof...(Is) <= typename R::Dimension())>> = 0>
33 static inline constexpr auto contract(X &&a)
34 {
36
39
40 using S1 = decltype(cut<2, 0>(Sequence<Index, Is...>()));
41 using S2 = decltype(cut<2, 1>(Sequence<Index, Is...>()));
42
43 return expand(makeSeries<Index, sizeof...(Is) / 2>(),
44 [&a](auto const... seq)
45 {
46 static_assert(select<S1::value(seq)...>(typename R::Sizes()) ==
47 select<S2::value(seq)...>(typename R::Sizes()),
48 "Sizes of dimensions to be contracted differ.");
49
50 using Sizes =
51 decltype(chip<S1::value(seq)..., S2::value(seq)...>(
52 typename R::Sizes()));
53
55 view(forward<X>(a)));
56 });
57 }
58
75 template<Index... Is, class XA, class XB, class RA = RemoveReference<XA>,
79 If<IsSatisfied<(isEven(sizeof...(Is)))>> = 0,
80 If<IsSatisfied<(sizeof...(Is) / 2 <=
81 min(typename RA::Dimension(), typename RB::Dimension()))>> = 0>
82 static inline constexpr auto contract(XA &&a, XB &&b)
83 {
84 using SA = decltype(cut<2, 0>(Sequence<Index, Is...>()));
85 using SB = decltype(cut<2, 1>(Sequence<Index, Is...>()) +
86 Constant<Index, typename RA::Dimension{}>());
87
88 return expand((SA(), SB()),
89 [&a, &b](auto const... indices)
90 {
91 return contract<indices...>(
93 });
94 }
95
112 template<Index... Is, class X, class R = RemoveReference<X>,
114 If<IsInvocable<Contract<Is...>, X &>> = 0>
115 static inline constexpr auto contract(X &&a)
116 {
117 return eval(contract<Is...>(a));
118 }
119
139 template<Index... Is, class XA, class XB, class RA = RemoveReference<XA>,
143 If<IsInvocable<Contract<Is...>, XA &, XB &>> = 0>
144 static inline constexpr auto contract(XA &&a, XB &&b)
145 {
146 return eval(contract<Is...>(a, b));
147 }
148}
149#endif // pRC_CORE_TENSOR_FUNCTIONS_CONTRACT_H
Definition sequence.hpp:56
Definition sequence.hpp:34
Definition contract.hpp:14
Definition cholesky.hpp:18
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr auto select(Sequence< T, Is... > const)
Definition sequence.hpp:589
static constexpr X min(X &&a)
Definition min.hpp:13
static constexpr X view(X &&a)
Returns a TensorView obtained from a TensorView.
Definition view.hpp:22
std::invoke_result_t< F, Args... > ResultOf
Definition type_traits.hpp:140
static constexpr auto contract(X &&a)
Contracts given indices of a Tensor.
Definition contract.hpp:33
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
static constexpr auto isEven(T const a)
Definition is_even.hpp:11
std::is_invocable< F, Args... > IsInvocable
Definition type_traits.hpp:134
std::integral_constant< T, V > Constant
Definition type_traits.hpp:34
static constexpr auto makeSeries()
Definition sequence.hpp:361
Constant< Bool, B > IsSatisfied
Definition type_traits.hpp:71
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition type_traits.hpp:62
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:561
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
forwards the values in a pRC::Sequence to a function as parameters
Definition sequence.hpp:354
static constexpr auto tensorProduct(XA &&a, XB &&b)
Calculates the tensor product of two Tensors.
Definition tensor_product.hpp:32
Size Index
Definition type_traits.hpp:21