cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
round.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ROUND_H
4#define pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ROUND_H
5
10
11namespace pRC
12{
13 template<IsSizes Ranks, class X, class R = RemoveReference<X>,
14 IsFloat VT = Value<R>>
15 requires(TensorTrain::IsTensorish<R> ||
16 TensorTrain::IsOperatorish<R>) &&
17 IsFloat<Value<R>> && (Ranks::Dimension == R::Dimension - 1)
18 static inline constexpr auto round(X &&a,
19 VT const &tolerance = NumericLimits<VT>::epsilon())
20 {
22 [&a, &tolerance](auto const... seq)
23 {
24 auto totalNorm = identity<typename R::Type>();
25
26 return RecursiveLambda([&tolerance, &
28 Index C = 0, class XA, class XB, class... Xs>(
29 auto const &self, XA &&a, XB &&b, Xs &&...cores)
30
31 {
32 using RA = RemoveReference<XA>;
33 using RB = RemoveReference<XB>;
34
35 constexpr auto RAL = RA::size(0);
36 constexpr auto RAR = RA::size(RA::Dimension - 1);
37
38 if constexpr(D == Direction::LeftToRight)
39 {
40 if constexpr(C == R::Dimension - 1)
41 {
42 return expand(
44 [&self, &tolerance, &totalNorm](
45 auto const cores, auto const... seq)
46 {
47 return self.template
48 operator()<Direction::RightToLeft, C>(
49 get<seq>(cores)...);
50 },
51 forwardAsTuple(forward<XB>(b),
52 forward<Xs>(cores)..., forward<XA>(a)));
53 }
54 else
55 {
56 auto const [q, r] =
57 qr(reshape<RA::size() / RAR, RAR>(
58 forward<XA>(a)));
59
60 auto const newLeft = expand(
62 [&q = q](auto const... seq)
63 {
64 return reshape<RAL, RA::size(seq)...,
65 RemoveReference<decltype(q)>::size(
66 1)>(q);
67 });
68
69 Tensor rNorm = norm(r);
70 using RN = decltype(rNorm);
71
72 auto const isCloseToZero = rNorm <=
75
76 auto const newRight = contract<1, 0>(
77 isCloseToZero ? eval(zero<decltype(r)>())
78 : eval(r / rNorm),
79 forward<XB>(b));
80
81 if(isCloseToZero)
82 {
83 totalNorm = zero();
84 }
85 else
86 {
87 totalNorm *= exp(log(rNorm) /
88 identity<RN>(R::Dimension))();
89 }
90
91 return self.template operator()<D, C + 1>(
92 newRight, forward<Xs>(cores)..., newLeft);
93 }
94 }
95 else
96 {
97 if constexpr(C == 0)
98 {
99 Tensor const firstNorm = norm(b);
100 using RF = decltype(firstNorm);
101
102 auto const isCloseToZero = firstNorm <=
105
106 auto const newFirst = isCloseToZero
108 : eval(forward<XB>(b) / firstNorm);
109
110 if(isCloseToZero)
111 {
112 totalNorm = zero();
113 }
114 else
115 {
116 totalNorm *= exp(log(firstNorm) /
117 identity<RF>(R::Dimension))();
118 }
119
120 return expand(
122 [&totalNorm](auto const cores,
123 auto const... seq)
124 {
126 get<seq>(cores) * totalNorm...);
127 },
128 forwardAsTuple(newFirst,
129 forward<Xs>(cores)..., forward<XA>(a)));
130 }
131 else
132 {
133 constexpr auto cutoff = min(Ranks::size(C - 1),
134 RAL, RA::size() / RAL);
135
136 auto const [u, s, v] =
137 svd<cutoff>(reshape<RAL, RA::size() / RAL>(
138 forward<XA>(a)),
139 tolerance);
140
141 auto const newRight = expand(
143 [&v = v](auto const... seq)
144 {
145 return reshape<cutoff, RA::size(seq)...,
146 RAR>(adjoint(v));
147 });
148
149 Tensor const newLeft =
150 contract<RB::Dimension - 1, 0>(
151 forward<XB>(b),
152 eval(u * fromDiagonal(s)));
153
154 return self.template operator()<D, C - 1>(
155 newLeft, forward<Xs>(cores)..., newRight);
156 }
157 }
158 })(forward<X>(a).template core<seq>()...);
159 });
160 }
161
162 template<class X, class R = RemoveReference<X>, IsFloat VT = Value<R>>
163 requires(TensorTrain::IsTensorish<R> ||
164 TensorTrain::IsOperatorish<R>) &&
165 IsFloat<Value<R>>
166 static inline constexpr auto round(X &&a,
167 VT const &tolerance = NumericLimits<VT>::epsilon())
168 {
169 return round<typename R::Ranks>(forward<X>(a), tolerance);
170 }
171
172 template<Size C, class X, class R = RemoveReference<X>,
173 IsFloat VT = Value<R>>
174 requires(TensorTrain::IsTensorish<R> ||
175 TensorTrain::IsOperatorish<R>) &&
176 IsFloat<Value<R>>
177 static inline constexpr auto round(X &&a,
178 VT const &tolerance = NumericLimits<VT>::epsilon())
179 {
180 return round<
181 decltype(makeConstantSequence<Size, R::Dimension - 1, C>())>(
182 forward<X>(a), tolerance);
183 }
184
185 template<Size... Cs, class X, class R = RemoveReference<X>,
186 IsFloat VT = Value<R>>
187 requires(TensorTrain::IsTensorish<R> ||
188 TensorTrain::IsOperatorish<R>) &&
189 IsFloat<Value<R>> && (sizeof...(Cs) > 1) &&
190 (sizeof...(Cs) == R::Dimension - 1)
191 static inline constexpr auto round(X &&a,
192 VT const &tolerance = NumericLimits<VT>::epsilon())
193 {
194 return round<Sizes<Cs...>>(forward<X>(a), tolerance);
195 }
196}
197#endif // pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ROUND_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition sequence.hpp:29
Definition tensor.hpp:25
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:14
Definition cholesky.hpp:10
static constexpr auto contract(X &&a)
Definition contract.hpp:20
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
static constexpr auto svd(X &&input)
Definition svd.hpp:15
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr auto reverse(Direction const D)
Definition direction.hpp:24
static constexpr auto qr(X &&input)
Definition qr.hpp:13
RecursiveLambda(X &&) -> RecursiveLambda< RemoveReference< X > >
static constexpr auto fromDiagonal(X &&a)
Definition from_diagonal.hpp:17
static constexpr auto reshape(X &&a)
Definition reshape.hpp:14
static constexpr auto forwardAsTuple(Xs &&...args)
Definition basics.hpp:65
typename ValueType< T >::Type Value
Definition value.hpp:72
static constexpr auto adjoint(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:312
static constexpr decltype(auto) min(X &&a)
Definition min.hpp:13
static constexpr auto makeSeries()
Definition sequence.hpp:390
static constexpr auto makeRange()
Definition sequence.hpp:421
Direction
Definition direction.hpp:9
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:383
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto log(T const &a)
Definition log.hpp:11
static constexpr auto round(T const &a)
Definition round.hpp:11
static constexpr auto exp(T const &a)
Definition exp.hpp:11
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
static constexpr auto norm(T const &a)
Definition norm.hpp:12
Definition limits.hpp:13