cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
svd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_SVD_H
4#define pRC_ALGORITHMS_SVD_H
5
6#include <prc/config.hpp>
20
21namespace pRC
22{
23 template<class X, class R = RemoveReference<X>, If<IsTensorish<R>> = 0,
24 If<IsFloat<typename R::Value>> = 0,
25 If<IsSatisfied<(typename R::Dimension() == 2)>> = 0>
26 static inline constexpr auto svd(X &&input)
27 {
28 using T = typename R::Type;
29 constexpr auto M = R::size(0);
30 constexpr auto N = R::size(1);
31 constexpr auto D = min(M, N);
32
35 {
36 scale = identity();
37 }
38
39 tuple<Tensor<T, M, D>, Tensor<typename T::NonComplex, D>,
41 result;
42 auto &u = get<0>(result);
43 auto &s = get<1>(result);
44 auto &v = get<2>(result);
46
47 if constexpr(M == N)
48 {
49 a = input / scale;
50 u = identity();
51 v = identity();
52 }
53 else if constexpr(M > N)
54 {
55 tie(u, a) = qr(input / scale);
56 v = identity();
57 }
58 else if constexpr(M < N)
59 {
60 tie(v, a) = qr(adjoint(input / scale));
61 a = eval(adjoint(a));
62 u = identity();
63 }
64
66
67 auto const updateMaxDiagEntry =
68 [](auto const &a, Index const p, Index const q, auto const &old)
69 {
70 return max(abs(a(p, p)), abs(a(q, q)), old);
71 };
72
73 auto const threshold = [](auto const &diag)
74 {
78 };
79
80 Bool finished = false;
81 while(!finished)
82 {
83 finished = true;
85 [&](auto const p)
86 {
87 for(Index q = 0; q < p; ++q)
88 {
89 if(max(abs(a(p, q)), abs(a(q, p))) >
90 threshold(maxDiagEntry))
91 {
92 finished = false;
93
94 if constexpr(typename T::IsComplexified{})
95 {
96 if(auto const n = sqrt(norm<2, 1>(a(p, p)) +
97 norm<2, 1>(a(q, p)));
98 n == zero())
99 {
100 a(p, p) = zero();
101 a(q, p) = zero();
102
103 if(abs(imag(a(p, q))) >
105 {
106 auto const z = abs(a(p, q)) / a(p, q);
107 chip<0>(a, p) *= z;
108 chip<1>(u, p) *= conj(z);
109 }
110 }
111 else
112 {
113 JacobiRotation rot(conj(a(p, p)) / n,
114 a(q, p) / n);
115 apply(rot, a, p, q);
116 apply(u, adjoint(rot), p, q);
117
118 if(abs(imag(a(p, q))) >
120 {
121 auto const z = abs(a(p, q)) / a(p, q);
122 chip<1>(a, q) *= z;
123 chip<1>(v, q) *= z;
124 }
125 }
126 if(abs(imag(a(q, q))) >
128 {
129 auto const z = abs(a(q, q)) / a(q, q);
130 chip<0>(a, q) *= z;
131 chip<1>(u, q) *= conj(z);
132 }
133
136
137 if(max(abs(a(p, q)), abs(a(q, p))) <=
138 threshold(maxDiagEntry))
139 {
140 continue;
141 }
142 }
143
145 m(0, 0) = real(a(p, p));
146 m(1, 0) = real(a(q, p));
147 m(0, 1) = real(a(p, q));
148 m(1, 1) = real(a(q, q));
149
150 auto const t = m(0, 0) + m(1, 1);
151 auto const d = m(1, 0) - m(0, 1);
152 auto const rot =
154
155 apply(rot, m, 0, 1);
156 auto const jRight = JacobiRotation(m, 0, 1);
157 auto const jLeft = rot * transpose(jRight);
158
159 apply(jLeft, a, p, q);
160 apply(a, jRight, p, q);
161 apply(u, transpose(jLeft), p, q);
162 apply(v, jRight, p, q);
163
166 }
167 }
168 });
169 }
170
171 range<D>(
172 [&a, &u, &s](auto const i)
173 {
175 {
176 s(i) = abs(a(i, i));
177 chip<1>(u, i) *= a(i, i) / s(i);
178 }
179 else
180 {
181 auto r = real(a(i, i));
182 s(i) = abs(r);
183 if(r < zero())
184 {
185 chip<1>(u, i) = -chip<1>(u, i);
186 }
187 }
188 });
189
190 s *= scale;
191
192 for(Index i = 0; i < D; ++i)
193 {
194 Index iMax = i;
195 for(Index p = i + 1; p < D; ++p)
196 {
197 if(s(p) > s(iMax))
198 {
199 iMax = p;
200 }
201 }
202
203 if(s(iMax) == zero())
204 {
205 break;
206 }
207
208 if(iMax != i)
209 {
210 swap(s(i), s(iMax));
211 swap(chip<1>(u, i), chip<1>(u, iMax));
212 swap(chip<1>(v, i), chip<1>(v, iMax));
213 }
214 }
215
216 if constexpr(cDebugLevel >= DebugLevel::High)
217 {
218 if(!isApprox(input, u * fromDiagonal(s) * adjoint(v)))
219 {
220 Logging::error("SVD decomposition failed.");
221 }
222 if(!isUnitary(u))
223 {
224 Logging::error("SVD decomposition failed: U is not unitary.");
225 }
226 if(!isUnitary(v))
227 {
228 Logging::error("SVD decomposition failed: V is not unitary.");
229 }
230 if(reduce<Min>(s)() < zero())
231 {
233 "SVD decomposition failed: Found negative singular "
234 "values.");
235 }
236 }
237
238 return result;
239 }
240
241 template<Size C, class X, class R = RemoveReference<X>,
242 If<IsTensorish<R>> = 0, class V = typename R::Value, class VT = V,
243 If<All<IsFloat<V>, IsFloat<VT>>> = 0,
244 If<IsSatisfied<(typename R::Dimension() == 2)>> = 0,
245 If<IsSatisfied<(C < min(R::size(0), R::size(1)))>> = 0>
246 static inline constexpr auto svd(X &&input,
247 VT const &tolerance = NumericLimits<VT>::tolerance())
248 {
249 auto const [fU, fS, fV] = svd(forward<X>(input));
250
251 constexpr auto M = decltype(fU)::size(0);
252 constexpr auto D = decltype(fS)::size();
253 constexpr auto N = decltype(fV)::size(0);
254
255 auto const sNorm = norm<2, 1>(fS)();
256 auto sError = norm<2, 1>(slice<D - C>(fS, C))();
257
258 auto result = tuple(eval(slice<M, C>(fU, 0, 0)), eval(slice<C>(fS, 0)),
259 eval(slice<N, C>(fV, 0, 0)));
260
261 auto &u = get<0>(result);
262 auto &s = get<1>(result);
263 auto &v = get<2>(result);
264
265 if(sError > square(tolerance) * sNorm)
266 {
267 /*
268 Logging::info(
269 "Truncated SVD: Relative error is larger than requested "
270 "tolerance due to specified cutoff:",
271 sqrt(sError / sNorm));
272 */
273
274 return result;
275 }
276
277 for(Index k = C; k != 0;)
278 {
279 if(sError + square(s(--k)) > square(tolerance) * sNorm)
280 {
281 break;
282 }
283
284 sError += square(s(k));
285
286 s(k) = zero();
287 chip<1>(u, k) = zero();
288 chip<1>(v, k) = zero();
289 }
290
291 return result;
292 }
293
294 template<Size C, class X, class R = RemoveReference<X>,
295 If<IsTensorish<R>> = 0, class V = typename R::Value, class VT = V,
296 If<All<IsFloat<V>, IsFloat<VT>>> = 0,
297 If<IsSatisfied<(typename R::Dimension() == 2)>> = 0,
298 If<IsSatisfied<(C == min(R::size(0), R::size(1)))>> = 0>
299 static inline constexpr auto svd(X &&input,
300 VT const &tolerance = NumericLimits<VT>::tolerance())
301 {
302 auto result = svd(forward<X>(input));
303 auto &u = get<0>(result);
304 auto &s = get<1>(result);
305 auto &v = get<2>(result);
306
307 auto const sNorm = norm<2, 1>(s)();
308 auto sError = zero<decltype(sNorm)>();
309
310 for(Index k = C; k != 0;)
311 {
312 if(sError + square(s(--k)) > square(tolerance) * sNorm)
313 {
314 break;
315 }
316
317 sError += square(s(k));
318
319 s(k) = zero();
320 chip<1>(u, k) = zero();
321 chip<1>(v, k) = zero();
322 }
323
324 return result;
325 }
326}
327#endif // pRC_ALGORITHMS_SVD_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition jacobi_rotation.hpp:17
static constexpr auto MakeGivens(R1 const &x, R2 const &y)
Definition jacobi_rotation.hpp:33
Definition tensor.hpp:28
static void error(Xs &&...args)
Definition log.hpp:14
Definition cholesky.hpp:18
static constexpr auto extractDiagonal(X &&a)
Definition extract_diagonal.hpp:17
static constexpr X eval(X &&a)
Definition eval.hpp:11
bool Bool
Definition type_traits.hpp:18
static constexpr decltype(auto) apply(JacobiRotation< T > const &r, X &&m, Index const p, Index const q)
Definition jacobi_rotation.hpp:334
static constexpr X min(X &&a)
Definition min.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr decltype(auto) imag(X &&a)
Definition imag.hpp:11
static constexpr decltype(auto) real(X &&a)
Definition real.hpp:11
static constexpr auto zero()
Definition zero.hpp:12
JacobiRotation(Tensor< T, M, N > const &a, Index const p, Index const q) -> JacobiRotation< T >
static constexpr auto qr(X &&input)
Definition qr.hpp:24
static constexpr auto transpose(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:319
static constexpr auto adjoint(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:325
static constexpr auto square(Complex< T > const &a)
Definition square.hpp:14
static constexpr auto isApprox(XA &&a, XB &&b, TT const &tolerance=NumericLimits< TT >::tolerance())
Definition is_approx.hpp:24
static constexpr auto fromDiagonal(X &&a)
Definition from_diagonal.hpp:21
static constexpr auto abs(Complex< T > const &a)
Definition abs.hpp:12
constexpr auto cDebugLevel
Definition config.hpp:46
static constexpr auto sqrt(Complex< T > const &a)
Definition sqrt.hpp:12
static constexpr auto swap(XA &&a, XB &&b)
Definition swap.hpp:21
static constexpr auto svd(X &&input)
Definition svd.hpp:26
static constexpr auto conj(Complex< T > const &a)
Definition conj.hpp:11
static constexpr auto identity()
Definition identity.hpp:12
static constexpr auto isUnitary(X &&a, TT const &tolerance=NumericLimits< TT >::tolerance())
Definition is_unitary.hpp:19
static constexpr X max(X &&a)
Definition max.hpp:13
Definition limits.hpp:13