cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
qr.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_QR_H
4#define pRC_ALGORITHMS_QR_H
5
8
9namespace pRC
10{
11 template<Size B = 32, class X, IsTensorish R = RemoveReference<X>>
12 requires IsFloat<Value<R>> && (R::Dimension == 2)
13 static inline constexpr auto qr(X &&input)
14 {
15 using T = typename R::Type;
16 constexpr auto M = R::size(0);
17 constexpr auto N = R::size(1);
18 constexpr auto D = min(M, N);
19
20 decltype(auto) a = copy<!(!IsReference<X> && !IsConst<R> &&
22
24
26 [&](auto const b)
27 {
28 constexpr auto s = b * B;
30 [&](auto const k)
31 {
32 auto bA = slice<M - s, N - s>(a, s, s);
33 Tensor v = chip<1>(bA, k - s);
34 for(Index i = 0; i <= k - s; ++i)
35 {
36 v(i) = zero();
37 }
38
39 auto const tau = [&a, &v, &k]()
40 {
41 auto const c = a(k, k);
42 auto const sqNorm = norm<2, 1>(v)();
43 if(max(sqNorm, norm<2, 1>(imag(c))) <=
45 {
46 v = zero();
47 return zero<T>();
48 }
49 else
50 {
51 auto beta = sqrt(norm<2, 1>(c) + sqNorm);
52 if(real(c) < zero())
53 {
54 beta = -beta;
55 }
56 v /= c + beta;
57 return conj((beta + c) / beta);
58 }
59 }();
60
61 v(k - s) = unit();
62 Tensor const vDa = conj(v) * bA;
63 bA -= tau * tensorProduct(v, vDa);
64
65 auto bH = slice<M - s, M>(h, s, 0);
66 Tensor const vDh = conj(v) * bH;
67 bH -= tau * tensorProduct(v, vDh);
68 });
69 });
70
71 auto const result = [&a, &h]()
72 {
74 if constexpr(M > N)
75 {
76 return Tp(slice<M, N>(adjoint(h), 0, 0), slice<N, N>(a, 0, 0));
77 }
78 else
79 {
80 return Tp(adjoint(h), move(a));
81 }
82 }();
83
84 if constexpr(cDebugLevel >= DebugLevel::High)
85 {
86 auto const &q = get<0>(result);
87 auto const &r = get<1>(result);
88
89 if(!isApprox(input, q * r))
90 {
91 Logging::error("QR decomposition failed.");
92 }
93
94 if(!isUnitary(q))
95 {
96 Logging::error("QR decomposition failed: Q is not unitary.");
97 }
98 }
99
100 return result;
101 }
102}
103#endif // pRC_ALGORITHMS_QR_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition value.hpp:12
Definition tensor.hpp:25
Definition concepts.hpp:25
Definition concepts.hpp:19
int i
Definition gmock-matchers-comparisons_test.cc:603
static void error(Xs &&...args)
Definition log.hpp:14
Definition cholesky.hpp:10
static constexpr auto unit()
Definition unit.hpp:13
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
static constexpr Conditional< C, RemoveConstReference< X >, RemoveConst< X > > copy(X &&a)
Definition copy.hpp:13
Size Index
Definition basics.hpp:32
std::tuple< Ts... > Tuple
Definition basics.hpp:23
static constexpr auto qr(X &&input)
Definition qr.hpp:13
static constexpr auto slice(X &&a, Os const ... offsets)
Definition slice.hpp:17
static constexpr auto conj(T const &a)
Definition conj.hpp:11
static constexpr auto adjoint(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:312
static constexpr decltype(auto) min(X &&a)
Definition min.hpp:13
static constexpr auto tensorProduct(XA &&a, XB &&b)
Definition tensor_product.hpp:17
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:584
static constexpr auto isApprox(XE &&expected, XA &&approx, TT const &tolerance=NumericLimits< TT >::tolerance())
Definition is_approx.hpp:14
constexpr auto cDebugLevel
Definition config.hpp:48
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr decltype(auto) real(X &&a)
Definition real.hpp:12
static constexpr auto isUnitary(X &&a, TT const &tolerance=NumericLimits< TT >::tolerance())
Definition is_unitary.hpp:14
static constexpr decltype(auto) imag(X &&a)
Definition imag.hpp:12
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
static constexpr auto norm(T const &a)
Definition norm.hpp:12
static constexpr decltype(auto) max(X &&a)
Definition max.hpp:13
Definition limits.hpp:13