pRC
multi-purpose Tensor Train library for C++
Loading...
Searching...
No Matches
reduce.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_CORE_TENSOR_VIEWS_REDUCE_H
4#define pRC_CORE_TENSOR_VIEWS_REDUCE_H
5
10
11namespace pRC::TensorViews
12{
13 template<class T, class N, class F, class R, class V>
14 class Reduce;
15
16 template<class T, class N, class F, Index... Rs, class V>
17 class Reduce<T, N, F, Sequence<Index, Rs...>, V>
18 : public View<T, N, Reduce<T, N, F, Sequence<Index, Rs...>, V>>
19 {
20 static_assert(IsTensorView<V>());
21
22 private:
24
25 public:
26 template<class X, If<IsSame<V, RemoveReference<X>>> = 0>
27 Reduce(X &&a)
28 : mA(forward<X>(a))
29 {
30 }
31
32 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
33 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
34 constexpr decltype(auto) operator()(Is const... indices)
35 {
36 auto c = F::template Identity<T>();
37
38 if constexpr(typename Base::Dimension() == 0)
39 {
40 if constexpr(IsSubscriptable<V>())
41 {
43 [this, &c](auto const i)
44 {
45 c = F()(c, mA[i]);
46 });
47 }
48 else
49 {
51 [this, &c](auto const... loop)
52 {
53 c = F()(c, mA(loop...));
54 });
55 }
56 }
57 else
58 {
59 range<Sizes<V::size(Rs)...>>(
60 [this, &c, indices...](auto const... loop)
61 {
62 c = F()(c, chip<Rs...>(mA, loop...)(indices...));
63 });
64 }
65
66 return c;
67 }
68
69 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
70 If<IsSatisfied<(sizeof...(Is) == typename Base::Dimension())>> = 0>
71 constexpr decltype(auto) operator()(Is const... indices) const
72 {
73 auto c = F::template Identity<T>();
74
75 if constexpr(typename Base::Dimension() == 0)
76 {
77 if constexpr(IsSubscriptable<V>())
78 {
80 [this, &c](auto const i)
81 {
82 c = F()(c, mA[i]);
83 });
84 }
85 else
86 {
88 [this, &c](auto const... loop)
89 {
90 c = F()(c, mA(loop...));
91 });
92 }
93 }
94 else
95 {
96 range<Sizes<V::size(Rs)...>>(
97 [this, &c, indices...](auto const... loop)
98 {
99 c = F()(c, chip<Rs...>(mA, loop...)(indices...));
100 });
101 }
102
103 return c;
104 }
105
106 constexpr decltype(auto) operator()(
107 typename Base::Subscripts const &subscripts)
108 {
109 return this->call(subscripts);
110 }
111
112 constexpr decltype(auto) operator()(
113 typename Base::Subscripts const &subscripts) const
114 {
115 return this->call(subscripts);
116 }
117
118 private:
119 V mA;
120 };
121}
122#endif // pRC_CORE_TENSOR_VIEWS_REDUCE_H
Definition sequence.hpp:56
Definition sequence.hpp:34
constexpr decltype(auto) operator()(Is const ... indices) const
Definition reduce.hpp:71
constexpr decltype(auto) operator()(Is const ... indices)
Definition reduce.hpp:34
Definition reduce.hpp:14
Definition type_traits.hpp:32
Definition diagonal.hpp:11
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
static constexpr auto loop(F &&f, Xs &&...args)
Applies a function element-wise to Tensors.
Definition loop.hpp:31
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:16
Constant< Bool, B > IsSatisfied
Definition type_traits.hpp:71
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:561
Size Index
Definition type_traits.hpp:21
Definition type_traits.hpp:262