cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
contract.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_CORE_TENSOR_VIEWS_CONTRACT_H
4#define pRC_CORE_TENSOR_VIEWS_CONTRACT_H
5
10
11namespace pRC::TensorViews
12{
13 template<class T, class N, class S1, class S2, class V>
15
16 template<class T, class N, Index... ILs, Index... IRs, class V>
18 V>
19 : public View<T, N,
20 ContractUnary<T, N, Sequence<Index, ILs...>,
21 Sequence<Index, IRs...>, V>>
22 {
23 static_assert(IsTensorView<V>());
24 static_assert(((V::size(ILs) == V::size(IRs)) && ...));
25
26 private:
28
29 public:
30 template<class X, If<IsSame<V, RemoveReference<X>>> = 0>
32 : mA(forward<X>(a))
33 {
34 }
35
36 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
37 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
38 constexpr decltype(auto) operator()(Is const... indices)
39 {
40 auto c = Add::template Identity<T>();
41
42 range<Sizes<V::size(ILs)...>>(
43 [this, &c, indices...](auto const... loop)
44 {
45 c += chip<ILs..., IRs...>(mA, loop..., loop...)(indices...);
46 });
47
48 return c;
49 }
50
51 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
52 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
53 constexpr decltype(auto) operator()(Is const... indices) const
54 {
55 auto c = Add::template Identity<T>();
56
57 range<Sizes<V::size(ILs)...>>(
58 [this, &c, indices...](auto const... loop)
59 {
60 c += chip<ILs..., IRs...>(mA, loop..., loop...)(indices...);
61 });
62
63 return c;
64 }
65
66 constexpr decltype(auto) operator()(
67 typename Base::Subscripts const &subscripts)
68 {
69 return this->call(subscripts);
70 }
71
72 constexpr decltype(auto) operator()(
73 typename Base::Subscripts const &subscripts) const
74 {
75 return this->call(subscripts);
76 }
77
78 private:
79 V mA;
80 };
81}
82#endif // pRC_CORE_TENSOR_VIEWS_CONTRACT_H
Definition sequence.hpp:56
Definition sequence.hpp:34
constexpr decltype(auto) operator()(Is const ... indices) const
Definition contract.hpp:53
constexpr decltype(auto) operator()(Is const ... indices)
Definition contract.hpp:38
Definition contract.hpp:14
Definition type_traits.hpp:32
pRC::Float<> T
Definition externs_nonTT.hpp:1
Definition diagonal.hpp:11
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
Constant< Bool, B > IsSatisfied
Definition type_traits.hpp:71
static constexpr auto loop(F &&f, Xs &&...args)
Definition loop.hpp:22
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:16
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:551
Definition type_traits.hpp:262