cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
direct_sum.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_CORE_TENSOR_VIEWS_DIRECT_SUM_H
4#define pRC_CORE_TENSOR_VIEWS_DIRECT_SUM_H
5
9
10namespace pRC::TensorViews
11{
12 template<class T, class N, class VA, class VB>
13 requires IsTensorView<VA> && IsTensorView<VB>
14 class DirectSum : public View<T, N, DirectSum<T, N, VA, VB>>
15 {
16 private:
18
19 public:
20 template<class XA, class XB>
23 DirectSum(XA &&a, XB &&b)
24 : mA(forward<XA>(a))
25 , mB(forward<XB>(b))
26 {
27 }
28
29 template<IsConvertible<Index>... Is>
30 requires(sizeof...(Is) == Base::Dimension)
31 constexpr decltype(auto) operator()(Is const... indices)
32 {
33 return this->call(indices...);
34 }
35
36 template<IsConvertible<Index>... Is>
37 requires(sizeof...(Is) == Base::Dimension)
38 constexpr decltype(auto) operator()(Is const... indices) const
39 {
40 return this->call(indices...);
41 }
42
43 constexpr decltype(auto) operator()(
44 typename Base::Subscripts const &subscripts)
45 {
47 [this, &subscripts](auto const... seq) -> T
48 {
49 if(((subscripts[seq] < VA::size(seq)) && ...))
50 {
51 return mA(subscripts[seq]...);
52 }
53 else if(((subscripts[seq] >= VA::size(seq)) && ...))
54 {
55 return mB((subscripts[seq] - VA::size(seq))...);
56 }
57 else
58 {
59 return zero();
60 }
61 });
62 }
63
64 constexpr decltype(auto) operator()(
65 typename Base::Subscripts const &subscripts) const
66 {
68 [this, &subscripts](auto const... seq) -> T
69 {
70 if(((subscripts[seq] < VA::size(seq)) && ...))
71 {
72 return mA(subscripts[seq]...);
73 }
74 else if(((subscripts[seq] >= VA::size(seq)) && ...))
75 {
76 return mB((subscripts[seq] - VA::size(seq))...);
77 }
78 else
79 {
80 return zero();
81 }
82 });
83 }
84
85 constexpr decltype(auto) operator[](Index const index) = delete;
86 constexpr decltype(auto) operator[](Index const index) const = delete;
87
88 private:
89 VA mA;
90 VB mB;
91 };
92}
93#endif // pRC_CORE_TENSOR_VIEWS_DIRECT_SUM_H
Definition value.hpp:12
Definition direct_sum.hpp:15
DirectSum(XA &&a, XB &&b)
Definition direct_sum.hpp:23
Definition declarations.hpp:20
Definition concepts.hpp:28
Definition declarations.hpp:18
Size Index
Definition basics.hpp:32
static constexpr auto makeSeries()
Definition sequence.hpp:390
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:383
static constexpr auto zero()
Definition zero.hpp:12