cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
svd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_SVD_H
4#define pRC_ALGORITHMS_SVD_H
5
10
11namespace pRC
12{
13 template<class X, IsTensorish R = RemoveReference<X>>
14 requires IsFloat<Value<R>> && (R::Dimension == 2)
15 static inline constexpr auto svd(X &&input)
16 {
17 using T = typename R::Type;
18 constexpr auto M = R::size(0);
19 constexpr auto N = R::size(1);
20 constexpr auto D = min(M, N);
21
22 Tensor scale = reduce<Max>(abs(input));
23 if(abs(scale)() <= NumericLimits<Value<T>>::min())
24 {
25 scale = identity();
26 }
27
29 result;
30 auto &u = get<0>(result);
31 auto &s = get<1>(result);
32 auto &v = get<2>(result);
34
35 if constexpr(M == N)
36 {
37 a = input / scale;
38 u = identity();
39 v = identity();
40 }
41 else if constexpr(M > N)
42 {
43 tie(u, a) = qr(input / scale);
44 v = identity();
45 }
46 else if constexpr(M < N)
47 {
48 tie(v, a) = qr(adjoint(input / scale));
49 a = eval(adjoint(a));
50 u = identity();
51 }
52
53 auto maxDiagEntry = reduce<Max>(extractDiagonal(abs(a)))();
54
55 auto const updateMaxDiagEntry =
56 [](auto const &a, Index const p, Index const q, auto const &old)
57 {
58 return max(abs(a(p, p)), abs(a(q, q)), old);
59 };
60
61 auto const threshold = [](auto const &diag)
62 {
64 NumericLimits<Value<T>>::epsilon() * identity<Value<T>>(2) *
65 diag);
66 };
67
68 Bool finished = false;
69 while(!finished)
70 {
71 finished = true;
73 [&](auto const p)
74 {
75 for(Index q = 0; q < p; ++q)
76 {
77 if(max(abs(a(p, q)), abs(a(q, p))) >
78 threshold(maxDiagEntry))
79 {
80 finished = false;
81
82 if constexpr(IsComplexified<T>)
83 {
84 if(auto const n = sqrt(norm<2, 1>(a(p, p)) +
85 norm<2, 1>(a(q, p)));
86 n == zero())
87 {
88 a(p, p) = zero();
89 a(q, p) = zero();
90
91 if(abs(imag(a(p, q))) >
93 {
94 auto const z = abs(a(p, q)) / a(p, q);
95 chip<0>(a, p) *= z;
96 chip<1>(u, p) *= conj(z);
97 }
98 }
99 else
100 {
101 JacobiRotation rot(conj(a(p, p)) / n,
102 a(q, p) / n);
103 apply(rot, a, p, q);
104 apply(u, adjoint(rot), p, q);
105
106 if(abs(imag(a(p, q))) >
108 {
109 auto const z = abs(a(p, q)) / a(p, q);
110 chip<1>(a, q) *= z;
111 chip<1>(v, q) *= z;
112 }
113 }
114 if(abs(imag(a(q, q))) >
116 {
117 auto const z = abs(a(q, q)) / a(q, q);
118 chip<0>(a, q) *= z;
119 chip<1>(u, q) *= conj(z);
120 }
121
122 maxDiagEntry =
123 updateMaxDiagEntry(a, p, q, maxDiagEntry);
124
125 if(max(abs(a(p, q)), abs(a(q, p))) <=
126 threshold(maxDiagEntry))
127 {
128 continue;
129 }
130 }
131
132 Tensor<NonComplex<T>, 2, 2> m;
133 m(0, 0) = real(a(p, p));
134 m(1, 0) = real(a(q, p));
135 m(0, 1) = real(a(p, q));
136 m(1, 1) = real(a(q, q));
137
138 auto const t = m(0, 0) + m(1, 1);
139 auto const d = m(1, 0) - m(0, 1);
140 auto const rot =
142
143 apply(rot, m, 0, 1);
144 auto const jRight = JacobiRotation(m, 0, 1);
145 auto const jLeft = rot * transpose(jRight);
146
147 apply(jLeft, a, p, q);
148 apply(a, jRight, p, q);
149 apply(u, transpose(jLeft), p, q);
150 apply(v, jRight, p, q);
151
152 maxDiagEntry =
153 updateMaxDiagEntry(a, p, q, maxDiagEntry);
154 }
155 }
156 });
157 }
158
159 range<D>(
160 [&a, &u, &s](auto const i)
161 {
162 if(abs(imag(a(i, i))) > NumericLimits<Value<T>>::min())
163 {
164 s(i) = abs(a(i, i));
165 chip<1>(u, i) *= a(i, i) / s(i);
166 }
167 else
168 {
169 auto r = real(a(i, i));
170 s(i) = abs(r);
171 if(r < zero())
172 {
173 chip<1>(u, i) = -chip<1>(u, i);
174 }
175 }
176 });
177
178 s *= scale;
179
180 for(Index i = 0; i < D; ++i)
181 {
182 Index iMax = i;
183 for(Index p = i + 1; p < D; ++p)
184 {
185 if(s(p) > s(iMax))
186 {
187 iMax = p;
188 }
189 }
190
191 if(s(iMax) == zero())
192 {
193 break;
194 }
195
196 if(iMax != i)
197 {
198 swap(s(i), s(iMax));
199 swap(chip<1>(u, i), chip<1>(u, iMax));
200 swap(chip<1>(v, i), chip<1>(v, iMax));
201 }
202 }
203
204 if constexpr(cDebugLevel >= DebugLevel::High)
205 {
206 if(!isApprox(input, u * fromDiagonal(s) * adjoint(v)))
207 {
208 Logging::error("SVD decomposition failed.");
209 }
210 if(!isUnitary(u))
211 {
212 Logging::error("SVD decomposition failed: U is not unitary.");
213 }
214 if(!isUnitary(v))
215 {
216 Logging::error("SVD decomposition failed: V is not unitary.");
217 }
218 if(reduce<Min>(s)() < zero())
219 {
221 "SVD decomposition failed: Found negative singular "
222 "values.");
223 }
224 }
225
226 return result;
227 }
228
229 template<Size C, class X, IsTensorish R = RemoveReference<X>,
230 IsFloat V = Value<R>, IsFloat VT = V>
231 requires(R::Dimension == 2) && (C < reduce<Min>(typename R::Sizes()))
232 static inline constexpr auto svd(X &&input,
233 VT const &tolerance = NumericLimits<VT>::tolerance())
234 {
235 auto const [fU, fS, fV] = svd(forward<X>(input));
236
237 constexpr auto M = decltype(fU)::size(0);
238 constexpr auto D = decltype(fS)::size();
239 constexpr auto N = decltype(fV)::size(0);
240
241 auto const sNorm = norm<2, 1>(fS)();
242 auto sError = norm<2, 1>(slice<D - C>(fS, C))();
243
244 auto result = Tuple(eval(slice<M, C>(fU, 0, 0)), eval(slice<C>(fS, 0)),
245 eval(slice<N, C>(fV, 0, 0)));
246
247 auto &u = get<0>(result);
248 auto &s = get<1>(result);
249 auto &v = get<2>(result);
250
251 if(sError > square(tolerance) * sNorm)
252 {
253 /*
254 Logging::info(
255 "Truncated SVD: Relative error is larger than requested "
256 "tolerance due to specified cutoff:",
257 sqrt(sError / sNorm));
258 */
259
260 return result;
261 }
262
263 for(Index k = C; k != 0;)
264 {
265 if(sError + square(s(--k)) > square(tolerance) * sNorm)
266 {
267 break;
268 }
269
270 sError += square(s(k));
271
272 s(k) = zero();
273 chip<1>(u, k) = zero();
274 chip<1>(v, k) = zero();
275 }
276
277 return result;
278 }
279
280 template<Size C, class X, IsTensorish R = RemoveReference<X>,
281 IsFloat V = Value<R>, IsFloat VT = V>
282 requires(R::Dimension == 2) && (C == reduce<Min>(typename R::Sizes()))
283 static inline constexpr auto svd(X &&input,
284 VT const &tolerance = NumericLimits<VT>::tolerance())
285 {
286 auto result = svd(forward<X>(input));
287 auto &u = get<0>(result);
288 auto &s = get<1>(result);
289 auto &v = get<2>(result);
290
291 auto const sNorm = norm<2, 1>(s)();
292 auto sError = zero<decltype(sNorm)>();
293
294 for(Index k = C; k != 0;)
295 {
296 if(sError + square(s(--k)) > square(tolerance) * sNorm)
297 {
298 break;
299 }
300
301 sError += square(s(k));
302
303 s(k) = zero();
304 chip<1>(u, k) = zero();
305 chip<1>(v, k) = zero();
306 }
307
308 return result;
309 }
310}
311#endif // pRC_ALGORITHMS_SVD_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition value.hpp:12
Definition jacobi_rotation.hpp:13
static constexpr auto MakeGivens(R1 const &x, R2 const &y)
Definition jacobi_rotation.hpp:24
Definition tensor.hpp:25
Definition complex.hpp:197
int i
Definition gmock-matchers-comparisons_test.cc:603
Uncopyable z
Definition gmock-matchers-containers_test.cc:378
const char * p
Definition gmock-matchers-containers_test.cc:379
static void error(Xs &&...args)
Definition log.hpp:14
Definition cholesky.hpp:10
static constexpr auto svd(X &&input)
Definition svd.hpp:15
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
Size Index
Definition basics.hpp:32
std::tuple< Ts... > Tuple
Definition basics.hpp:23
static constexpr decltype(auto) apply(JacobiRotation< T > const &r, X &&m, Index const p, Index const q)
Definition jacobi_rotation.hpp:321
JacobiRotation(Tensor< T, M, N > const &a, Index const p, Index const q) -> JacobiRotation< T >
static constexpr auto qr(X &&input)
Definition qr.hpp:13
static constexpr auto fromDiagonal(X &&a)
Definition from_diagonal.hpp:17
static constexpr auto slice(X &&a, Os const ... offsets)
Definition slice.hpp:17
static constexpr auto conj(T const &a)
Definition conj.hpp:11
static constexpr auto transpose(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:306
static constexpr auto abs(T const &a)
Definition abs.hpp:11
typename ValueType< T >::Type Value
Definition value.hpp:72
static constexpr auto adjoint(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:312
static constexpr decltype(auto) min(X &&a)
Definition min.hpp:13
static constexpr auto reduce(Sequence< T, I1, I2, Is... > const)
Definition sequence.hpp:458
static constexpr auto swap(XA &&a, XB &&b)
Definition swap.hpp:20
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:584
static constexpr auto extractDiagonal(X &&a)
Definition extract_diagonal.hpp:16
static constexpr auto isApprox(XE &&expected, XA &&approx, TT const &tolerance=NumericLimits< TT >::tolerance())
Definition is_approx.hpp:14
constexpr auto cDebugLevel
Definition config.hpp:48
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr decltype(auto) real(X &&a)
Definition real.hpp:12
static constexpr auto isUnitary(X &&a, TT const &tolerance=NumericLimits< TT >::tolerance())
Definition is_unitary.hpp:14
static constexpr auto square(T const &a)
Definition square.hpp:11
static constexpr decltype(auto) imag(X &&a)
Definition imag.hpp:12
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
static constexpr auto norm(T const &a)
Definition norm.hpp:12
static constexpr decltype(auto) max(X &&a)
Definition max.hpp:13
Definition gtest_pred_impl_unittest.cc:54
Definition limits.hpp:13