cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
direct_sum.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_CORE_TENSOR_VIEWS_DIRECT_SUM_H
4#define pRC_CORE_TENSOR_VIEWS_DIRECT_SUM_H
5
9
10namespace pRC::TensorViews
11{
12 template<class T, class N, class VA, class VB>
13 class DirectSum : public View<T, N, DirectSum<T, N, VA, VB>>
14 {
15 static_assert(IsTensorView<VA>());
16 static_assert(IsTensorView<VB>());
17
18 private:
20
21 public:
22 template<class XA, class XB, If<IsSame<VA, RemoveReference<XA>>> = 0,
23 If<IsSame<VB, RemoveReference<XB>>> = 0>
24 DirectSum(XA &&a, XB &&b)
25 : mA(forward<XA>(a))
26 , mB(forward<XB>(b))
27 {
28 }
29
30 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
31 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
32 constexpr decltype(auto) operator()(Is const... indices)
33 {
34 return expand(
35 makeSeries<Index, typename Base::Dimension{}>(),
36 [this](auto const &indices, auto const... seq) -> T
37 {
38 if(((indices[seq] < VA::size(seq)) && ...))
39 {
40 return mA(indices[seq]...);
41 }
42 else if(((indices[seq] >= VA::size(seq)) && ...))
43 {
44 return mB((indices[seq] - VA::size(seq))...);
45 }
46 else
47 {
48 return zero();
49 }
50 },
51 Indices<sizeof...(Is)>(indices...));
52 }
53
54 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
55 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
56 constexpr decltype(auto) operator()(Is const... indices) const
57 {
58 return expand(
59 makeSeries<Index, typename Base::Dimension{}>(),
60 [this](auto const &indices, auto const... seq) -> T
61 {
62 if(((indices[seq] < VA::size(seq)) && ...))
63 {
64 return mA(indices[seq]...);
65 }
66 else if(((indices[seq] >= VA::size(seq)) && ...))
67 {
68 return mB((indices[seq] - VA::size(seq))...);
69 }
70 else
71 {
72 return zero();
73 }
74 },
75 Indices<sizeof...(Is)>(indices...));
76 }
77
78 constexpr decltype(auto) operator()(
79 typename Base::Subscripts const &subscripts)
80 {
81 return this->call(subscripts);
82 }
83
84 constexpr decltype(auto) operator()(
85 typename Base::Subscripts const &subscripts) const
86 {
87 return this->call(subscripts);
88 }
89
90 private:
91 VA mA;
92 VB mB;
93 };
94}
95#endif // pRC_CORE_TENSOR_VIEWS_DIRECT_SUM_H
Definition indices.hpp:15
Definition direct_sum.hpp:14
DirectSum(XA &&a, XB &&b)
Definition direct_sum.hpp:24
constexpr decltype(auto) operator()(Is const ... indices) const
Definition direct_sum.hpp:56
constexpr decltype(auto) operator()(Is const ... indices)
Definition direct_sum.hpp:32
Definition type_traits.hpp:32
Definition diagonal.hpp:11
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr auto zero()
Definition zero.hpp:12
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
Constant< Bool, B > IsSatisfied
Definition type_traits.hpp:71
static constexpr auto makeSeries()
Definition sequence.hpp:351
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:344