cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
contract.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_CORE_TENSOR_FUNCTIONS_CONTRACT_H
4#define pRC_CORE_TENSOR_FUNCTIONS_CONTRACT_H
5
12
13namespace pRC
14{
15 template<Index... Is, class X, IsTensorish R = RemoveReference<X>>
16 requires(isEven(sizeof...(Is)) && sizeof...(Is) <= R::Dimension) &&
17 (max(Is...) < R::Dimension) &&
18 (cut<2, 0>(select<Is...>(typename R::Sizes())) ==
19 cut<2, 1>(select<Is...>(typename R::Sizes())))
20 static inline constexpr auto contract(X &&a)
21 {
22 if constexpr(IsInvocable<View, X>)
23 {
25
28
29 using S1 = decltype(cut<2, 0>(Sequence<Index, Is...>()));
30 using S2 = decltype(cut<2, 1>(Sequence<Index, Is...>()));
31
32 return expand(makeSeries<Index, sizeof...(Is) / 2>(),
33 [&a](auto const... seq)
34 {
35 using Sizes =
36 decltype(chip<S1::value(seq)..., S2::value(seq)...>(
37 typename R::Sizes()));
38
40 view(forward<X>(a)));
41 });
42 }
43 else
44 {
45 return eval(contract<Is...>(a));
46 }
47 }
48
49 template<Index... Is, class XA, class XB,
50 IsTensorish RA = RemoveReference<XA>,
51 IsTensorish RB = RemoveReference<XB>>
52 requires(isEven(sizeof...(Is)) &&
53 sizeof...(Is) / 2 <= min(RA::Dimension, RB::Dimension)) &&
54 (reduce<Max>(cut<2, 0>(Sizes<Is...>())) < RA::Dimension &&
55 reduce<Max>(cut<2, 1>(Sizes<Is...>())) < RB::Dimension) &&
57 [](auto const... indices)
58 {
59 return select<indices...>(typename RA::Sizes());
60 }) ==
62 [](auto const... indices)
63 {
64 return select<indices...>(typename RB::Sizes());
65 }))
66 static inline constexpr auto contract(XA &&a, XB &&b)
67 {
69 {
70 using SA = decltype(cut<2, 0>(Sequence<Index, Is...>()));
71 using SB = decltype(cut<2, 1>(Sequence<Index, Is...>()) +
73
74 return expand((SA(), SB()),
75 [&a, &b](auto const... indices)
76 {
77 return contract<indices...>(
78 tensorProduct(forward<XA>(a), forward<XB>(b)));
79 });
80 }
81 else
82 {
83 return eval(contract<Is...>(a, b));
84 }
85 }
86}
87#endif // pRC_CORE_TENSOR_FUNCTIONS_CONTRACT_H
Definition value.hpp:12
Definition sequence.hpp:29
Definition contract.hpp:13
Definition concepts.hpp:31
Definition cholesky.hpp:10
static constexpr auto contract(X &&a)
Definition contract.hpp:20
static constexpr auto select(Sequence< T, Is... > const)
Definition sequence.hpp:610
static constexpr auto isEven(T const a)
Definition is_even.hpp:11
Size Index
Definition basics.hpp:32
std::invoke_result_t< F, Args... > ResultOf
Definition basics.hpp:59
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr decltype(auto) view(X &&a)
Definition view.hpp:13
static constexpr decltype(auto) min(X &&a)
Definition min.hpp:13
Sequence< Size, Ns... > Sizes
Definition sequence.hpp:100
static constexpr auto makeSeries()
Definition sequence.hpp:390
static constexpr auto reduce(Sequence< T, I1, I2, Is... > const)
Definition sequence.hpp:458
static constexpr auto tensorProduct(XA &&a, XB &&b)
Definition tensor_product.hpp:17
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:584
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
std::integral_constant< T, V > Constant
Definition basics.hpp:38
static constexpr auto cut(Sequence< T, Is... > const)
Definition sequence.hpp:631
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:383
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
static constexpr decltype(auto) max(X &&a)
Definition max.hpp:13