pRC
multi-purpose Tensor Train library for C++
Loading...
Searching...
No Matches
gmres.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_SOLVER_GMRES_H
4#define pRC_ALGORITHMS_SOLVER_GMRES_H
5
18
19namespace pRC::Solver
20{
21 template<Size M = 32, Size K = 2, If<IsSatisfied<(M > K)>> = 0>
22 class GMRES
23 {
24 private:
25 static constexpr Size defaultMaxIterations()
26 {
27 return 10000;
28 }
29
30 public:
31 constexpr GMRES(Size const maxIterations = defaultMaxIterations())
32 : mMaxIterations(maxIterations)
33 {
34 }
35
36 constexpr auto maxIterations() const
37 {
38 return mMaxIterations;
39 }
40
44 class RA = RemoveConstReference<XA>, class TA = typename RA::Type,
45 class VA = typename TA::Value, class XB,
46 class RB = RemoveConstReference<XB>, class TB = typename RB::Type,
47 class VB = typename TB::Value, class XX = decltype(zero<RB>()),
48 class RX = RemoveConstReference<XX>, class TX = typename RX::Type,
49 class VX = typename TX::Value, class VT = Common<VA, VB, VX>,
51 inline constexpr auto operator()(XA &&A, XB &&b, XX &&x0 = zero<RX>(),
52 VT const &tolerance = NumericLimits<VT>::tolerance()) const
53 {
54 decltype(auto) x =
55 copy<!(!IsReference<XX>() && !IsConst<RX>())>(eval(x0));
56
58 decltype(round(apply<OT, OR, OH>(forward<XA>(A), x)))>>;
59 using T = typename V::Type;
60
61 auto const scaledTolerance = tolerance * norm(b)();
62
64 Bool converged = false;
65 Index iteration = 0;
66 do
67 {
72
73 v[0] = apply<OT, OR, OH>(A, x);
74 v[0] = b - v[0];
75
76 auto vNorm = eval(norm(v[0]));
77 auto error = vNorm();
78
79 Index m = 0;
80 while(m < M)
81 {
82 if(norm(error) <= scaledTolerance)
83 {
84 converged = true;
85 break;
86 }
87
88 ++iteration;
89 v[m] /= vNorm;
90
91 if constexpr(K > 0)
92 {
93 if(m >= M - z.size())
94 {
95 v[m + 1] = apply<OT, OR, OH>(A,
96 z.front(m - (M - z.size())));
97 }
98 else
99 {
100 v[m + 1] = apply<OT, OR, OH>(A, v[m]);
101 }
102 }
103 else
104 {
105 v[m + 1] = apply<OT, OR, OH>(A, v[m]);
106 }
107
108 for(Index i = 0; i < m + 1; ++i)
109 {
110 R(i, m) = innerProduct(v[m + 1], v[i])();
111 v[m + 1] -= R(i, m) * v[i];
112 }
113
114 for(Index i = 0; i < m; ++i)
115 {
116 apply(G[i], chip<1>(R, m), i, i + 1);
117 }
118
119 vNorm = norm(v[m + 1]);
120
121 G[m] = JacobiRotation<T>::MakeGivens(R(m, m), vNorm());
122 R(m, m) = G[m].c() * R(m, m) + conj(G[m].s()) * vNorm();
123
124 g(m) = G[m].c() * error;
125 error *= -G[m].s();
126
127 ++m;
128
129 Logging::debug("GMRES", "Iteration:", iteration,
130 "Inner Iteration:", m, "Residual:", norm(error),
131 "Target:", scaledTolerance);
132 }
133
134 auto const y = solve<BackwardSubstitution>(R, g);
135
136 if constexpr(K > 0)
137 {
138 auto dx = eval(v[0] * y(0));
139 for(Index i = 1; i < m && i < M - z.size(); ++i)
140 {
141 dx += v[i] * y(i);
142 }
143 for(Index i = 0; i < z.size(); ++i)
144 {
145 dx += z.front(i) * y((M - z.size()) + i);
146 }
147 x += dx;
148
149 auto const dxn = eval(norm(dx));
150 z.pushFront(dx / dxn);
151 }
152 else
153 {
154 for(Index i = 0; i < m; ++i)
155 {
156 x += v[i] * y(i);
157 }
158 }
159 }
160 while(iteration < maxIterations() && !converged);
161
162 if(!converged)
163 {
164 if constexpr (M <= 512)
165 {
167 "GMRES failed to converge within allowed max iterations for M =",
168 M, "and K =", K, "- doubling M and K and continuing");
169 x = GMRES<2*M, 2*K>()(A, b, x, tolerance);
170 return x;
171 }
172 else
173 {
175 "GMRES failed to converge within allowed max iterations.");
176 }
177 }
178
179 if constexpr(cDebugLevel >= DebugLevel::High)
180 {
181 auto const y = eval(apply<OT, OR, OH>(A, x));
182 if(!isApprox(y, b, tolerance))
183 {
184 Logging::error("GMRES failed.");
185 }
186 }
187
188 Logging::debug("GMRES converged after", iteration, "iterations.");
189
190 if constexpr(IsReference<decltype(x)>())
191 {
192 return forward<XX>(x0);
193 }
194 else
195 {
196 return x;
197 }
198 }
199
200 private:
201 Size const mMaxIterations;
202 };
203}
204#endif // pRC_ALGORITHMS_SOLVER_GMRES_H
Definition deque.hpp:15
static constexpr auto MakeGivens(R1 const &x, R2 const &y)
Definition jacobi_rotation.hpp:33
Definition gmres.hpp:23
constexpr auto maxIterations() const
Definition gmres.hpp:36
constexpr GMRES(Size const maxIterations=defaultMaxIterations())
Definition gmres.hpp:31
Class storing tensors.
Definition tensor.hpp:44
static void info(Xs &&...args)
Definition log.hpp:27
static void debug(Xs &&...args)
Definition log.hpp:33
static void error(Xs &&...args)
Definition log.hpp:14
Restrict
This enum's elements denote a restriction of an operator.
Definition restrict.hpp:27
Hint
This enum's elements denote a hint regarding an operator.
Definition hint.hpp:30
Transform
This enum's elements denote a transformation done to an operator.
Definition transform.hpp:23
Definition backward_substitution.hpp:19
static constexpr X eval(X &&a)
Definition eval.hpp:11
std::conjunction< Bs... > All
Definition type_traits.hpp:77
bool Bool
Definition type_traits.hpp:18
static constexpr decltype(auto) apply(JacobiRotation< T > const &r, X &&m, Index const p, Index const q)
Definition jacobi_rotation.hpp:334
std::invoke_result_t< F, Args... > ResultOf
Definition type_traits.hpp:140
static constexpr auto zero()
Definition zero.hpp:12
std::enable_if_t< B{}, int > If
Definition type_traits.hpp:68
std::size_t Size
Definition type_traits.hpp:20
Conditional< IsSatisfied<((Ns *... *1) *sizeof(T) > cHugepageSizeByte)>, HeapArray< T, Ns... >, StackArray< T, Ns... > > Array
Definition type_traits.hpp:60
typename CommonTypes< Ts... >::Type Common
Definition common.hpp:55
static constexpr auto isApprox(XA &&a, XB &&b, TT const &tolerance=NumericLimits< TT >::tolerance())
Checks if two pRC objects agree up to a given tolerance.
Definition is_approx.hpp:44
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
constexpr auto cDebugLevel
Definition config.hpp:46
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition type_traits.hpp:62
static constexpr auto norm(Complex< T > const &a)
Definition norm.hpp:11
static constexpr auto innerProduct(Complex< TA > const &a, Complex< TB > const &b)
Definition inner_product.hpp:16
std::is_reference< T > IsReference
Definition type_traits.hpp:47
static constexpr auto round(Complex< T > const &a)
Definition round.hpp:12
static constexpr auto conj(Complex< T > const &a)
Definition conj.hpp:11
Size Index
Definition type_traits.hpp:21
static constexpr auto identity()
Definition identity.hpp:12
Definition eval.hpp:11
Definition type_traits.hpp:16
Definition limits.hpp:13