cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
round.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ROUND_H
4#define pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ROUND_H
5
13
14namespace pRC
15{
16 template<class Ranks, class X, class R = RemoveReference<X>,
17 class V = typename R::Value, class VT = V,
18 If<All<IsFloat<V>, IsFloat<VT>>> = 0,
19 If<Any<TensorTrain::IsTensorish<R>, TensorTrain::IsOperatorish<R>>> = 0,
20 If<IsSizes<Ranks>> = 0,
21 If<IsSatisfied<(
22 typename Ranks::Dimension() == typename R::Dimension() - 1)>> = 0>
23 static inline constexpr auto round(X &&a,
24 VT const &tolerance = NumericLimits<VT>::epsilon())
25 {
26 return expand(makeSeries<Index, typename R::Dimension{}>(),
27 [&a, &tolerance](auto const... seq)
28 {
30
31 return RecursiveLambda([&tolerance, &
33 Index C = 0, class XA, class XB, class... Xs>(
34 auto const &self, XA &&a, XB &&b, Xs &&...cores) {
35 using RA = RemoveReference<XA>;
36 using RB = RemoveReference<XB>;
37
38 constexpr auto RAL = RA::size(0);
39 constexpr auto RAR = RA::size(typename RA::Dimension() - 1);
40
41 if constexpr(D == Direction::LeftToRight)
42 {
43 if constexpr(C == typename R::Dimension() - 1)
44 {
45 return expand(
47 typename R::Dimension{}>()),
48 [&self, &tolerance, &totalNorm](
49 auto const cores, auto const... seq)
50 {
51 return self.template
53 get<seq>(cores)...);
54 },
55 forwardAsTuple(forward<XB>(b),
56 forward<Xs>(cores)..., forward<XA>(a)));
57 }
58 else
59 {
60 auto const [q, r] = qr(
61 reshape<RA::size() / RAR, RAR>(forward<XA>(a)));
62
63 auto const newLeft = expand(
65 typename RA::Dimension() - 1>(),
66 [&q = q](auto const... seq)
67 {
68 return reshape<RAL, RA::size(seq)...,
69 RemoveReference<decltype(q)>::size(1)>(
70 q);
71 });
72
73 Tensor rNorm = norm(r);
74 using RN = decltype(rNorm);
75
76 auto const isCloseToZero = rNorm <=
79
81 ? eval(zero<decltype(r)>())
82 : eval(r / rNorm),
83 forward<XB>(b));
84
86 {
87 totalNorm = zero();
88 }
89 else
90 {
92 identity<RN>(typename R::Dimension()))();
93 }
94
95 return self.template operator()<D, C + 1>(newRight,
97 }
98 }
99 else
100 {
101 if constexpr(C == 0)
102 {
103 Tensor const firstNorm = norm(b);
104 using RF = decltype(firstNorm);
105
106 auto const isCloseToZero = firstNorm <=
109
110 auto const newFirst = isCloseToZero
112 : eval(forward<XB>(b) / firstNorm);
113
114 if(isCloseToZero)
115 {
116 totalNorm = zero();
117 }
118 else
119 {
121 identity<RF>(typename R::Dimension()))();
122 }
123
124 return expand(
126 typename R::Dimension{}>()),
127 [&totalNorm](auto const cores,
128 auto const... seq)
129 {
131 get<seq>(cores) * totalNorm...);
132 },
133 forwardAsTuple(newFirst, forward<Xs>(cores)...,
134 forward<XA>(a)));
135 }
136 else
137 {
138 constexpr auto cutoff =
139 min(Ranks::size(C - 1), RAL, RA::size() / RAL);
140
141 auto const [u, s, v] = svd<cutoff>(
142 reshape<RAL, RA::size() / RAL>(forward<XA>(a)),
143 tolerance);
144
145 auto const newRight =
147 typename RA::Dimension() - 1>(),
148 [&v = v](auto const... seq)
149 {
150 return reshape<cutoff, RA::size(seq)...,
151 RAR>(adjoint(v));
152 });
153
154 Tensor const newLeft =
155 contract<typename RB::Dimension() - 1, 0>(
156 forward<XB>(b), eval(u * fromDiagonal(s)));
157
158 return self.template operator()<D, C - 1>(newLeft,
160 }
161 }
162 })(forward<X>(a).template core<seq>()...);
163 });
164 }
165
166 template<class X, class R = RemoveReference<X>, class V = typename R::Value,
167 class VT = V, If<All<IsFloat<V>, IsFloat<VT>>> = 0,
168 If<Any<TensorTrain::IsTensorish<R>, TensorTrain::IsOperatorish<R>>> = 0>
169 static inline constexpr auto round(X &&a,
170 VT const &tolerance = NumericLimits<VT>::epsilon())
171 {
172 return round<typename R::Ranks>(forward<X>(a), tolerance);
173 }
174
175 template<Size C, class X, class R = RemoveReference<X>,
176 class V = typename R::Value, class VT = V,
177 If<All<IsFloat<V>, IsFloat<VT>>> = 0,
178 If<Any<TensorTrain::IsTensorish<R>, TensorTrain::IsOperatorish<R>>> = 0>
179 static inline constexpr auto round(X &&a,
180 VT const &tolerance = NumericLimits<VT>::epsilon())
181 {
182 return round<decltype(Sizes(
183 makeConstantSequence<Index, typename R::Dimension() - 1, C>()))>(
184 forward<X>(a), tolerance);
185 }
186
187 template<Size... Cs, class X, class R = RemoveReference<X>,
188 class V = typename R::Value, class VT = V,
190 If<Any<TensorTrain::IsTensorish<R>, TensorTrain::IsOperatorish<R>>> = 0,
191 If<IsSatisfied<(sizeof...(Cs) > 1)>> = 0,
192 If<IsSatisfied<(sizeof...(Cs) == typename R::Dimension() - 1)>> = 0>
193 static inline constexpr auto round(X &&a,
194 VT const &tolerance = NumericLimits<VT>::epsilon())
195 {
196 return round<Sizes<Cs...>>(forward<X>(a), tolerance);
197 }
198}
199#endif // pRC_TENSOR_TRAIN_COMMON_FUNCTIONS_ROUND_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition tensor.hpp:28
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:13
Definition cholesky.hpp:18
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr X min(X &&a)
Definition min.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr auto exp(Complex< T > const &a)
Definition exp.hpp:12
std::size_t Size
Definition type_traits.hpp:20
std::remove_reference_t< T > RemoveReference
Definition type_traits.hpp:56
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto reverse(Direction const D)
Definition direction.hpp:24
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
Constant< Bool, B > IsSatisfied
Definition type_traits.hpp:71
static constexpr auto contract(X &&a)
Definition contract.hpp:20
RecursiveLambda(X &&) -> RecursiveLambda< RemoveReference< X > >
static constexpr auto qr(X &&input)
Definition qr.hpp:24
static constexpr auto forwardAsTuple(Xs &&...args)
Definition type_traits.hpp:202
static constexpr auto makeRange()
Definition sequence.hpp:379
static constexpr auto adjoint(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:325
Sequence< Size, Ns... > Sizes
Definition type_traits.hpp:238
static constexpr auto makeSeries()
Definition sequence.hpp:351
static constexpr auto fromDiagonal(X &&a)
Definition from_diagonal.hpp:21
static constexpr auto log(Complex< T > const &a)
Definition log.hpp:11
static constexpr auto reshape(X &&a)
Definition reshape.hpp:17
static constexpr auto norm(Complex< T > const &a)
Definition norm.hpp:11
static constexpr auto round(Complex< T > const &a)
Definition round.hpp:12
Direction
Definition direction.hpp:9
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:344
Definition limits.hpp:13