cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
contract.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_CORE_TENSOR_FUNCTIONS_CONTRACT_H
4#define pRC_CORE_TENSOR_FUNCTIONS_CONTRACT_H
5
10
11namespace pRC
12{
13 template<Index... Is>
14 struct Contract;
15
16 template<Index... Is, class X, class R = RemoveReference<X>,
18 If<IsSatisfied<(isEven(sizeof...(Is)))>> = 0,
19 If<IsSatisfied<(sizeof...(Is) <= typename R::Dimension())>> = 0>
20 static inline constexpr auto contract(X &&a)
21 {
23
26
27 using S1 = decltype(cut<2, 0>(Sequence<Index, Is...>()));
28 using S2 = decltype(cut<2, 1>(Sequence<Index, Is...>()));
29
30 return expand(makeSeries<Index, sizeof...(Is) / 2>(),
31 [&a](auto const... seq)
32 {
33 static_assert(select<S1::value(seq)...>(typename R::Sizes()) ==
34 select<S2::value(seq)...>(typename R::Sizes()),
35 "Sizes of dimensions to be contracted differ.");
36
37 using Sizes =
38 decltype(chip<S1::value(seq)..., S2::value(seq)...>(
39 typename R::Sizes()));
40
42 view(forward<X>(a)));
43 });
44 }
45
46 template<Index... Is, class XA, class XB, class RA = RemoveReference<XA>,
50 If<IsSatisfied<(isEven(sizeof...(Is)))>> = 0,
51 If<IsSatisfied<(sizeof...(Is) / 2 <=
52 min(typename RA::Dimension(), typename RB::Dimension()))>> = 0>
53 static inline constexpr auto contract(XA &&a, XB &&b)
54 {
55 using SA = decltype(cut<2, 0>(Sequence<Index, Is...>()));
56 using SB = decltype(cut<2, 1>(Sequence<Index, Is...>()) +
57 Constant<Index, typename RA::Dimension{}>());
58
59 return expand((SA(), SB()),
60 [&a, &b](auto const... indices)
61 {
62 return contract<indices...>(
63 tensorProduct(forward<XA>(a), forward<XB>(b)));
64 });
65 }
66
67 template<Index... Is, class X, class R = RemoveReference<X>,
69 If<IsInvocable<Contract<Is...>, X &>> = 0>
70 static inline constexpr auto contract(X &&a)
71 {
72 return eval(contract<Is...>(a));
73 }
74
75 template<Index... Is, class XA, class XB, class RA = RemoveReference<XA>,
79 If<IsInvocable<Contract<Is...>, XA &, XB &>> = 0>
80 static inline constexpr auto contract(XA &&a, XB &&b)
81 {
82 return eval(contract<Is...>(a, b));
83 }
84}
85#endif // pRC_CORE_TENSOR_FUNCTIONS_CONTRACT_H
Definition sequence.hpp:56
Definition sequence.hpp:34
Definition contract.hpp:14
Definition cholesky.hpp:18
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr auto select(Sequence< T, Is... > const)
Definition sequence.hpp:579
static constexpr X min(X &&a)
Definition min.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr X view(X &&a)
Definition view.hpp:12
std::invoke_result_t< F, Args... > ResultOf
Definition type_traits.hpp:140
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
Constant< Bool, B > IsSatisfied
Definition type_traits.hpp:71
static constexpr auto contract(X &&a)
Definition contract.hpp:20
static constexpr auto isEven(T const a)
Definition is_even.hpp:11
static constexpr auto makeSeries()
Definition sequence.hpp:351
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition type_traits.hpp:62
std::integral_constant< T, V > Constant
Definition type_traits.hpp:34
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:551
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:344
static constexpr auto tensorProduct(XA &&a, XB &&b)
Definition tensor_product.hpp:19
std::is_invocable< F, Args... > IsInvocable
Definition type_traits.hpp:134