cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
qr.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_QR_H
4#define pRC_ALGORITHMS_QR_H
5
6#include <prc/config.hpp>
18
19namespace pRC
20{
21 template<Size B = 32, class X, class R = RemoveReference<X>,
22 If<IsTensorish<R>> = 0, If<IsFloat<typename R::Value>> = 0,
23 If<IsSatisfied<(typename R::Dimension() == 2)>> = 0>
24 static inline constexpr auto qr(X &&input)
25 {
26 using T = typename R::Type;
27 constexpr auto M = R::size(0);
28 constexpr auto N = R::size(1);
29 constexpr auto D = min(M, N);
30
31 decltype(auto) a = copy<!(!IsReference<X>() && !IsConst<R>() &&
33
35
37 [&](auto const b)
38 {
39 constexpr auto s = b * B;
41 [&](auto const k)
42 {
43 auto bA = slice<M - s, N - s>(a, s, s);
44 Tensor v = chip<1>(bA, k - s);
45 for(Index i = 0; i <= k - s; ++i)
46 {
47 v(i) = zero();
48 }
49
50 auto const tau = [&a, &v, &k]()
51 {
52 auto const c = a(k, k);
53 auto const sqNorm = norm<2, 1>(v)();
54 if(max(sqNorm, norm<2, 1>(imag(c))) <=
56 {
57 v = zero();
58 return zero<T>();
59 }
60 else
61 {
62 auto beta = sqrt(norm<2, 1>(c) + sqNorm);
64 {
65 beta = -beta;
66 }
67 v /= c + beta;
68 return conj((beta + c) / beta);
69 }
70 }();
71
72 v(k - s) = unit();
73 Tensor const vDa = conj(v) * bA;
74 bA -= tau * tensorProduct(v, vDa);
75
76 auto bH = slice<M - s, M>(h, s, 0);
77 Tensor const vDh = conj(v) * bH;
78 bH -= tau * tensorProduct(v, vDh);
79 });
80 });
81
82 auto const result = [&a, &h]()
83 {
84 using Tp = tuple<Tensor<T, M, D>, Tensor<T, D, N>>;
85 if constexpr(M > N)
86 {
87 return Tp(slice<M, N>(adjoint(h), 0, 0), slice<N, N>(a, 0, 0));
88 }
89 else
90 {
91 return Tp(adjoint(h), move(a));
92 }
93 }();
94
95 if constexpr(cDebugLevel >= DebugLevel::High)
96 {
97 auto const &q = get<0>(result);
98 auto const &r = get<1>(result);
99
100 if(!isApprox(input, q * r))
101 {
102 Logging::error("QR decomposition failed.");
103 }
104
105 if(!isUnitary(q))
106 {
107 Logging::error("QR decomposition failed: Q is not unitary.");
108 }
109 }
110
111 return result;
112 }
113}
114#endif // pRC_ALGORITHMS_QR_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition tensor.hpp:28
static void error(Xs &&...args)
Definition log.hpp:14
Definition cholesky.hpp:18
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr X min(X &&a)
Definition min.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr decltype(auto) imag(X &&a)
Definition imag.hpp:11
static constexpr decltype(auto) real(X &&a)
Definition real.hpp:11
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto qr(X &&input)
Definition qr.hpp:24
static constexpr auto slice(X &&a, Os const ... offsets)
Definition slice.hpp:20
static constexpr auto adjoint(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:325
std::is_reference< T > IsReference
Definition type_traits.hpp:47
static constexpr auto isApprox(XA &&a, XB &&b, TT const &tolerance=NumericLimits< TT >::tolerance())
Definition is_approx.hpp:24
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
static constexpr auto unit()
Definition unit.hpp:12
constexpr auto cDebugLevel
Definition config.hpp:46
static constexpr auto sqrt(Complex< T > const &a)
Definition sqrt.hpp:12
static constexpr auto tensorProduct(XA &&a, XB &&b)
Definition tensor_product.hpp:19
static constexpr auto conj(Complex< T > const &a)
Definition conj.hpp:11
static constexpr auto identity()
Definition identity.hpp:12
static constexpr auto isUnitary(X &&a, TT const &tolerance=NumericLimits< TT >::tolerance())
Definition is_unitary.hpp:19
static constexpr X max(X &&a)
Definition max.hpp:13
Definition limits.hpp:13