cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
gmres.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_SOLVER_GMRES_H
4#define pRC_ALGORITHMS_SOLVER_GMRES_H
5
16
17namespace pRC::Solver
18{
19 template<Size M = 32, Bool MGS = true, Size K = 0>
20 requires(M > K)
21 class GMRES
22 {
23 private:
24 static constexpr Size defaultMaxIterations()
25 {
26 return 1000;
27 }
28
29 public:
30 constexpr GMRES(Size const maxIterations = defaultMaxIterations())
31 : mMaxIterations(maxIterations)
32 {
33 }
34
35 constexpr auto maxIterations() const
36 {
37 return mMaxIterations;
38 }
39
43 class RA = RemoveReference<XA>, class XB,
44 class RB = RemoveReference<XB>, class XX = decltype(zero<RB>()),
45 class RX = RemoveReference<XX>,
47 requires IsInvocable<Apply<OT, OR, OH>, XA, XX> &&
49 inline constexpr auto operator()(XA &&A, XB &&b, XX &&x0 = zero<RX>(),
50 VT const &tolerance = NumericLimits<VT>::tolerance()) const
51 {
52 decltype(auto) x =
54
56 decltype(round(apply<OT, OR, OH>(forward<XA>(A), x)))>>;
57 using T = typename V::Type;
58
59 auto const scaledTolerance = tolerance * norm(b)();
60
62 Bool converged = false;
63 Index iteration = 0;
64 do
65 {
70
71 v[0] = apply<OT, OR, OH>(A, x);
72 v[0] = b - v[0];
73
74 auto vNorm = norm(v[0])();
75 g(0) = vNorm;
76
77 Index m = 0;
78 while(m < M)
79 {
80 if(norm(g(m)) <= scaledTolerance)
81 {
82 converged = true;
83 break;
84 }
85
86 ++iteration;
87 v[m] /= vNorm;
88
89 if constexpr(K > 0)
90 {
91 if(m >= M - z.size())
92 {
93 v[m + 1] = apply<OT, OR, OH>(A,
94 z.front(m - (M - z.size())));
95 }
96 else
97 {
98 v[m + 1] = apply<OT, OR, OH>(A, v[m]);
99 }
100 }
101 else
102 {
103 v[m + 1] = apply<OT, OR, OH>(A, v[m]);
104 }
105
106 for(Index i = 0; i < m + 1; ++i)
107 {
108 R(i, m) = innerProduct(v[m + 1], v[i])();
109 if constexpr(MGS)
110 {
111 v[m + 1] -= R(i, m) * v[i];
112 }
113 }
114
115 if constexpr(!MGS)
116 {
117 for(Index i = 0; i < m + 1; ++i)
118 {
119 v[m + 1] -= R(i, m) * v[i];
120 }
121 }
122
123 vNorm = norm(v[m + 1])();
124 R(m + 1, m) = vNorm;
125
126 for(Index i = 0; i < m; ++i)
127 {
128 apply(G[i], chip<1>(R, m), i, i + 1);
129 }
130
131 G[m] = JacobiRotation<T>(chip<1>(R, m), m, m + 1);
132 apply(G[m], chip<1>(R, m), m, m + 1);
133 apply(G[m], g, m, m + 1);
134
135 ++m;
136
137 Logging::debug("GMRES", "Iteration:", iteration,
138 "Inner Iteration:", m, "Residual:", norm(g(m)),
139 "Target:", scaledTolerance);
140 }
141
142 auto const y = solve<BackwardSubstitution>(slice<M, M>(R, 0, 0),
143 slice<M>(g, 0));
144
145 if constexpr(K > 0)
146 {
147 auto dx = eval(v[0] * y(0));
148 for(Index i = 1; i < m && i < M - z.size(); ++i)
149 {
150 dx += v[i] * y(i);
151 }
152 for(Index i = 0; i < z.size(); ++i)
153 {
154 dx += z.front(i) * y((M - z.size()) + i);
155 }
156 x += dx;
157
158 auto const dxn = eval(norm(dx));
159 z.pushFront(dx / dxn);
160 }
161 else
162 {
163 for(Index i = 0; i < m; ++i)
164 {
165 x += v[i] * y(i);
166 }
167 }
168 }
169 while(iteration < maxIterations() && !converged);
170
171 if(!converged)
172 {
173 if constexpr(M <= 512)
174 {
176 "GMRES failed to converge within allowed max "
177 "iterations for M =",
178 M, "and K =", K, "- doubling M and K and continuing");
179 return GMRES<2 * M, true, 2 * K>()(A, b, x, tolerance);
180 }
181 else
182 {
184 "GMRES failed to converge within allowed max "
185 "iterations.");
186 }
187 }
188
189 if constexpr(cDebugLevel >= DebugLevel::High)
190 {
191 auto const y = eval(apply<OT, OR, OH>(A, x));
192 if(!isApprox(b, y, tolerance))
193 {
194 Logging::error("GMRES failed.");
195 }
196 }
197
198 Logging::debug("GMRES converged after", iteration, "iterations.");
199
200 if constexpr(IsReference<decltype(x)>)
201 {
202 return forward<XX>(x0);
203 }
204 else
205 {
206 return x;
207 }
208 }
209
210 private:
211 Size const mMaxIterations;
212 };
213}
214#endif // pRC_ALGORITHMS_SOLVER_GMRES_H
Definition deque.hpp:15
Definition value.hpp:12
Definition gmres.hpp:22
constexpr auto maxIterations() const
Definition gmres.hpp:35
constexpr GMRES(Size const maxIterations=defaultMaxIterations())
Definition gmres.hpp:30
Definition tensor.hpp:25
Definition concepts.hpp:25
Definition value.hpp:24
Definition concepts.hpp:31
Definition concepts.hpp:19
int i
Definition gmock-matchers-comparisons_test.cc:603
Uncopyable z
Definition gmock-matchers-containers_test.cc:378
const double y
Definition gmock-matchers-containers_test.cc:377
int x
Definition gmock-matchers-containers_test.cc:376
static void info(Xs &&...args)
Definition log.hpp:27
static void debug(Xs &&...args)
Definition log.hpp:33
static void error(Xs &&...args)
Definition log.hpp:14
Restrict
Definition restrict.hpp:9
Hint
Definition hint.hpp:9
Transform
Definition transform.hpp:9
Definition backward_substitution.hpp:13
static constexpr decltype(auto) solve(Solver &&solver, XA &&A, XB &&b)
Definition solve.hpp:18
static constexpr Conditional< C, RemoveConstReference< X >, RemoveConst< X > > copy(X &&a)
Definition copy.hpp:13
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
std::invoke_result_t< F, Args... > ResultOf
Definition basics.hpp:59
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr decltype(auto) apply(JacobiRotation< T > const &r, X &&m, Index const p, Index const q)
Definition jacobi_rotation.hpp:321
JacobiRotation(Tensor< T, M, N > const &a, Index const p, Index const q) -> JacobiRotation< T >
static constexpr auto innerProduct(TA const &a, TB const &b)
Definition inner_product.hpp:11
static constexpr auto slice(X &&a, Os const ... offsets)
Definition slice.hpp:17
std::common_type_t< Ts... > Common
Definition basics.hpp:53
typename ValueType< T >::Type Value
Definition value.hpp:72
Conditional<((Ns *... *1) *sizeof(T) > cHugepageSizeByte), HeapArray< T, Ns... >, StackArray< T, Ns... > > Array
Definition declarations.hpp:21
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:584
static constexpr auto isApprox(XE &&expected, XA &&approx, TT const &tolerance=NumericLimits< TT >::tolerance())
Definition is_approx.hpp:14
constexpr auto cDebugLevel
Definition config.hpp:48
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto round(T const &a)
Definition round.hpp:11
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
static constexpr auto norm(T const &a)
Definition norm.hpp:12
Definition gtest_pred_impl_unittest.cc:54
Definition apply.hpp:17
Definition eval.hpp:11
Definition limits.hpp:13
Definition sub.hpp:11