pRC
multi-purpose Tensor Train library for C++
Loading...
Searching...
No Matches
tensor.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_TENSOR_TENSOR_H
4#define pRC_TENSOR_TRAIN_TENSOR_TENSOR_H
5
16
17namespace pRC::TensorTrain
18{
19 template<class T, class N, class Ranks, class F>
21
22 template<class T, Size... Ns, Size... Rs>
23 class Tensor<T, Sizes<Ns...>, Sizes<Rs...>,
24 If<All<IsSatisfied<(sizeof...(Ns) - 1 == sizeof...(Rs))>,
26 {
27 public:
28 using N = pRC::Sizes<Ns...>;
29 using L = N;
30 using Sizes = N;
31
35
36 using Dimension = typename N::Dimension;
37
38 using Ranks = pRC::Sizes<Rs...>;
39 template<class S,
42
43 template<Index C>
45 N::size(C), pRC::Sizes<1, Rs..., 1>::size(C + 1)>;
46
47 using Type = T;
48 template<class C>
50
51 using Value = typename T::Value;
52 template<class V, If<IsValue<V>> = 0>
55
56 using Signed = typename T::Signed;
57 template<Bool R>
60
61 using Width = typename T::Width;
62 template<Size Q>
65
66 using IsComplexified = typename T::IsComplexified;
69
70 template<class E = typename N::IsLinearizable, If<E> = 0>
71 static constexpr auto n()
72 {
73 return N::size();
74 }
75
76 static constexpr auto n(Index const dimension)
77 {
78 return N::size(dimension);
79 }
80
81 template<class E = typename Sizes::IsLinearizable, If<E> = 0>
82 static constexpr auto size()
83 {
84 return Sizes::size();
85 }
86
87 static constexpr auto size(Index const dimension)
88 {
89 return Sizes::size(dimension);
90 }
91
92 template<class X, class... Is, If<IsConstructible<T, X>> = 0,
93 If<IsSatisfied<(sizeof...(Is) == Dimension())>> = 0>
94 static inline constexpr auto Single(X &&value, Is const... indices)
95 {
96 return Single(forward<X>(value), Subscripts(indices...));
97 }
98
99 template<class X, If<IsConstructible<T, X>> = 0>
100 static inline constexpr auto Single(X &&value,
101 Subscripts const &subscripts)
102 {
103 auto const f = [subscripts, value = T(forward<X>(value))]<Index C>()
104 {
105 if constexpr(C == 0)
106 {
107 return Cores<C>::Single(value, 0, subscripts[C], 0);
108 }
109 else
110 {
111 return Cores<C>::Single(identity<T>(), 0, subscripts[C], 0);
112 }
113 };
114 using F = RemoveConstReference<decltype(f)>;
116 }
117
118 public:
119 ~Tensor() = default;
120 constexpr Tensor(Tensor const &) = default;
121 constexpr Tensor(Tensor &&) = default;
122 constexpr Tensor &operator=(Tensor const &) & = default;
123 constexpr Tensor &operator=(Tensor &&) & = default;
124 constexpr Tensor() = default;
125
126 template<class X,
128 constexpr Tensor(X &&other)
129 {
130 *this = forward<X>(other);
131 }
132
133 template<class X,
135 constexpr auto &operator=(X &&rhs) &
136 {
137 view(*this) = forward<X>(rhs);
138 return *this;
139 }
140
141 template<Index C>
142 constexpr decltype(auto) core() &&
143 {
144 return get<C>(move(mCores));
145 }
146
147 template<Index C>
148 constexpr decltype(auto) core() const &&
149 {
150 return get<C>(move(mCores));
151 }
152
153 template<Index C>
154 constexpr decltype(auto) core() &
155 {
156 return get<C>(mCores);
157 }
158
159 template<Index C>
160 constexpr decltype(auto) core() const &
161 {
162 return get<C>(mCores);
163 }
164
165 template<class... Is, If<All<IsConvertible<Is, Index>...>> = 0,
166 If<IsSatisfied<(sizeof...(Is) == Dimension())>> = 0>
167 constexpr decltype(auto) operator()(Is const... indices) const
168 {
169 return view(*this)(indices...);
170 }
171
172 constexpr decltype(auto) operator()(Subscripts const &subscripts) const
173 {
174 return view(*this)(subscripts);
175 }
176
177 template<class X, If<IsInvocable<Add, Tensor &, X>> = 0>
178 constexpr auto &operator+=(X &&rhs) &
179 {
180 return *this = *this + forward<X>(rhs);
181 }
182
183 template<class X, If<IsInvocable<Sub, Tensor &, X>> = 0>
184 constexpr auto &operator-=(X &&rhs) &
185 {
186 return *this = *this - forward<X>(rhs);
187 }
188
189 template<class X, If<IsInvocable<Mul, X, Tensor &>> = 0>
190 constexpr auto &applyOnTheLeft(X &&lhs) &
191 {
192 view(*this).applyOnTheLeft(forward<X>(lhs));
193 return *this;
194 }
195
196 template<class X, If<IsInvocable<Mul, Tensor &, X>> = 0>
197 constexpr auto &applyOnTheRight(X &&rhs) &
198 {
199 view(*this).applyOnTheRight(forward<X>(rhs));
200 return *this;
201 }
202
203 template<class X, If<IsInvocable<Mul, Tensor &, X>> = 0>
204 constexpr auto &operator*=(X &&rhs) &
205 {
206 view(*this) *= forward<X>(rhs);
207 return *this;
208 }
209
210 template<class X, If<IsInvocable<Div, Tensor &, X>> = 0>
211 constexpr auto &operator/=(X &&rhs) &
212 {
213 return *this = *this / forward<X>(rhs);
214 }
215
216 template<class E =
217 IsSatisfied<(Ns * ... * 1) <= NumericLimits<Size>::max() &&
218 typename Sizes::IsLinearizable()>,
219 If<E> = 0>
220 explicit constexpr operator pRC::Tensor<T, Ns...>() const
221 {
222 return pRC::Tensor(view(*this));
223 }
224
225 private:
226 template<Index... seq>
227 static constexpr auto coreTypes(Sequence<Index, seq...>)
228 {
229 return tuple<Cores<seq>...>{};
230 }
231
232 using CoreTypes = decltype(coreTypes(makeSeries<Index, Dimension{}>()));
233
234 private:
236 };
237}
238
239namespace pRC
240{
241 template<class T, Size... Ns, class R>
243}
244#endif // pRC_TENSOR_TRAIN_TENSOR_TENSOR_H
Definition sequence.hpp:56
static constexpr auto size()
Definition sequence.hpp:88
Constant< Size, sizeof...(Ns)> Dimension
Definition sequence.hpp:74
Constant< Bool, linearizable()> IsLinearizable
Definition sequence.hpp:75
Definition sequence.hpp:34
Definition subscripts.hpp:20
Definition enumerate.hpp:19
Definition type_traits.hpp:37
Definition type_traits.hpp:17
Class storing tensors.
Definition tensor.hpp:44
static constexpr auto size()
Returns the number of entries of the Tensor class.
Definition tensor.hpp:84
Definition from_cores.hpp:11
Definition cholesky.hpp:18
std::conjunction< Bs... > All
Definition type_traits.hpp:77
static constexpr X view(X &&a)
Returns a TensorView obtained from a TensorView.
Definition view.hpp:22
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
std::size_t Size
Definition type_traits.hpp:20
static constexpr auto makeSeries()
Definition sequence.hpp:361
Constant< Bool, B > IsSatisfied
Definition type_traits.hpp:71
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
std::disjunction< Bs... > Any
Definition type_traits.hpp:80
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition type_traits.hpp:62
Size Index
Definition type_traits.hpp:21
Definition type_traits.hpp:15
Definition limits.hpp:13