cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
amen.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_AMEN_H
4#define cMHN_TT_AMEN_H
5
7#include <cmhn/tt/utility.hpp>
8
9#include <prc.hpp>
10
11namespace cMHN::TT
12{
13 template<pRC::Size R,
15 pRC::Size D, class T, class Tb, class X>
16 auto amen(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance,
17 X const &x0)
18 {
19 using ModeSizes = decltype(getModeSizes<D>());
20 using Ranks = decltype(getRanks<D, R>());
21 using ERanks = decltype(getRanks<D, R + pRC::Size(4)>());
22 using ZRanks = decltype(getRanks<D, pRC::Size(4)>());
23 using BRanks = typename Tb::Ranks;
24
25 auto const toleranceLocal =
27
28 // get TT for 1-Q
29 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
30 [&](auto const... seq)
31 {
33 MHNop.template core<seq>()...);
34 }));
35
38 [&](auto const k)
39 {
41 decltype(x0.template core<k>())>::size(0),
42 2,
44 decltype(x0.template core<k>())>::size(2)>(
45 x.template core<k>(), 0, 0, 0) = x0.template core<k>();
46 });
47
48 // random setup
49 pRC::SeedSequence seq(8, 16);
52
53 // initialize approximation of residual with random TT with norm 1
54 auto z = round<ZRanks>(
56 dist));
57 z = z / norm(z);
58
59 // create the objects holding psis/phis for A (without values yet)
60 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
61 [](auto const... seq)
62 {
63 return std::tuple(pRC::Tensor<T,
64 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
65 seq),
66 decltype(pRC::Sizes<1>(),
68 pRC::Sizes<1>())::size(seq),
69 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
70 seq)>()...);
71 });
72
73 // first and last psiphi is always 1
74 reshape<1, 1>(std::get<0>(psiphiA)) =
76 reshape<1, 1>(std::get<D>(psiphiA)) =
78
79 // create the objects holding psis/phis for b (without values yet)
80 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
81 [](auto const... seq)
82 {
83 return std::tuple(pRC::Tensor<T,
84 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
85 seq),
86 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
87 seq)>()...);
88 });
89
90 // first and last psiphi is always 1
91 reshape<1, 1>(std::get<0>(psiphib)) =
93 reshape<1, 1>(std::get<D>(psiphib)) =
95
96 // create the objects holding psi_hat/phi_hat for A (without values yet)
97 auto psiphihatA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
98 [](auto const... seq)
99 {
100 return std::tuple(pRC::Tensor<T,
101 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
102 seq),
103 decltype(pRC::Sizes<1>(),
105 pRC::Sizes<1>())::size(seq),
106 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
107 seq)>()...);
108 });
109
110 // first and last psiphihatA is always 1
111 reshape<1, 1>(std::get<0>(psiphihatA)) =
113 reshape<1, 1>(std::get<D>(psiphihatA)) =
115
116 // create the objects holding psi_hat/phi_hat for b (without values yet)
117 auto psiphihatb = expand(pRC::makeSeries<pRC::Index, D + 1>(),
118 [](auto const... seq)
119 {
120 return std::tuple(pRC::Tensor<T,
121 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
122 seq),
123 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
124 seq)>()...);
125 });
126
127 // first and last psiphihatb is always 1
128 reshape<1, 1>(std::get<0>(psiphihatb)) =
130 reshape<1, 1>(std::get<D>(psiphihatb)) =
132
133 // calculate the first psiphiA and psiphib (also with hats) and do
134 // initial QR
136 [&](auto const k)
137 {
138 auto const [l, q] =
139 orthogonalize<pRC::Position::Right>(x.template core<k>());
140
141 x.template core<k>() = q;
142 x.template core<k - 1>() =
143 contract<2, 0>(x.template core<k - 1>(), l);
144 {
145 // also orthogonalize z
146 auto const [l, q] = orthogonalize<pRC::Position::Right>(
147 z.template core<k>());
148 z.template core<k>() = q;
149 z.template core<k - 1>() =
150 contract<2, 0>(z.template core<k - 1>(), l);
151 }
152
153 // k-1 in Alg.
154 std::get<k>(psiphiA) =
155 contract<1, 2, 1, 3>(conj(x.template core<k>()),
156 contract<2, 3, 1, 3>(op.template core<k>(),
157 eval(contract<2, 2>(x.template core<k>(),
158 std::get<k + 1>(psiphiA)))));
159
160 std::get<k>(psiphib) =
161 contract<1, 2, 1, 2>(conj(x.template core<k>()),
162 eval(contract<2, 1>(b.template core<k>(),
163 std::get<k + 1>(psiphib))));
164
165 std::get<k>(psiphihatA) =
166 contract<1, 2, 1, 3>(conj(z.template core<k>()),
167 contract<2, 3, 1, 3>(op.template core<k>(),
168 eval(contract<2, 2>(x.template core<k>(),
169 get<k + 1>(psiphihatA)))));
170
171 std::get<k>(psiphihatb) =
172 contract<1, 2, 1, 2>(conj(z.template core<k>()),
173 eval(contract<2, 1>(b.template core<k>(),
174 get<k + 1>(psiphihatb))));
175 });
176
177 // for now, just use 2 iterations
178 for(pRC::Index i = 0; i < 2; ++i)
179 {
180 // forward sweep
182 [&](auto const k)
183 {
184 // local A and b
185 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
186 contract<1, 0>(get<k>(psiphiA),
187 eval(contract<3, 1>(op.template core<k>(),
188 get<k + 1>(psiphiA))))));
189
190 auto const bk = linearize(contract<1, 0>(get<k>(psiphib),
191 eval(contract<2, 1>(b.template core<k>(),
192 get<k + 1>(psiphib)))));
193
194 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
195 contract<1, 0>(get<k>(psiphihatA),
196 eval(contract<3, 1>(op.template core<k>(),
197 get<k + 1>(psiphihatA))))));
198
199 auto const hatbk =
200 linearize(contract<1, 0>(get<k>(psiphihatb),
201 eval(contract<2, 1>(b.template core<k>(),
202 get<k + 1>(psiphihatb)))));
203
204 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
205 contract<1, 0>(get<k>(psiphiA),
206 eval(contract<3, 1>(op.template core<k>(),
207 get<k + 1>(psiphihatA))))));
208
209 auto const rbk = linearize(contract<1, 0>(get<k>(psiphib),
210 eval(contract<2, 1>(b.template core<k>(),
211 get<k + 1>(psiphihatb)))));
212
213 // solve local system
214 pRC::RemoveConstReference<decltype(x.template core<k>())>
215 sol;
216 linearize(sol) = pRC::Solver::GMRES<32, 0>()(Ak, bk,
217 linearize(x.template core<k>()), tolerance);
218
219 // orthogonalize
220 auto const [q, r] =
221 truncate<Ranks::size(k), pRC::Position::Right>(sol);
222
223 // enrich current core
226 decltype(x.template core<k>())>::size(0),
227 2, ERanks::size(k)>
228 ex;
230 decltype(x.template core<k>())>::size(0),
231 2, Ranks::size(k)>(ex, 0, 0, 0) = q;
232
233 if constexpr(ERanks::size(k) - Ranks::size(k) > 0)
234 {
236 decltype(x.template core<k>())>::size(0);
238 auto const tt = eval(rbk - rAk * linearize(sol));
239 linearize(eta) = tt;
240
241 auto const [zl, zr] =
242 truncate<ERanks::size(k) - Ranks::size(k),
244
246 decltype(x.template core<k>())>::size(0),
247 2, ERanks::size(k) - Ranks::size(k)>(ex, 0, 0,
248 Ranks::size(k)) = zl;
249 }
250
251 // update residual and orthogonalize
252 linearize(z.template core<k>()) =
253 hatbk - hatAk * linearize(sol);
254 auto const [rqq, rrr] = orthogonalize<pRC::Position::Left>(
255 z.template core<k>());
256 z.template core<k>() = rqq;
257
258 // enrich next core
259 auto tNext = contract<1, 0>(r, x.template core<k + 1>());
260 pRC::Tensor<T, ERanks::size(k), 2,
262 decltype(x.template core<k + 1>())>::size(2)>
263 enext;
264 slice<Ranks::size(k), 2,
266 decltype(x.template core<k + 1>())>::size(2)>(enext,
267 0, 0, 0) = tNext;
268 slice<ERanks::size(k) - Ranks::size(k), 2,
270 decltype(x.template core<k + 1>())>::size(2)>(enext,
271 Ranks::size(k), 0, 0) = pRC::zero();
272
273 // re-orthogonalize
274 auto const [qq, rr] =
275 orthogonalize<pRC::Position::Left>(ex);
276 x.template core<k>() = qq;
277 x.template core<k + 1>() = contract<1, 0>(rr, enext);
278
279 // calculate next psiphiA and psiphib (also with hats)
280 get<k + 1>(psiphiA) =
281 contract<1, 0, 0, 3>(conj(x.template core<k>()),
282 contract<2, 0, 0, 3>(op.template core<k>(),
283 eval(contract<0, 2>(x.template core<k>(),
284 get<k>(psiphiA)))));
285
286 get<k + 1>(psiphib) =
287 contract<0, 1, 2, 0>(conj(x.template core<k>()),
288 eval(contract<0, 1>(b.template core<k>(),
289 get<k>(psiphib))));
290
291 get<k + 1>(psiphihatA) =
292 contract<1, 0, 0, 3>(conj(z.template core<k>()),
293 contract<2, 0, 0, 3>(op.template core<k>(),
294 eval(contract<0, 2>(x.template core<k>(),
295 get<k>(psiphihatA)))));
296
297 get<k + 1>(psiphihatb) =
298 contract<0, 1, 2, 0>(conj(z.template core<k>()),
299 eval(contract<0, 1>(b.template core<k>(),
300 get<k>(psiphihatb))));
301 });
302
303 // backward sweep
306 [&](auto const k)
307 {
308 // local A and b
309 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
310 contract<1, 0>(get<k>(psiphiA),
311 eval(contract<3, 1>(op.template core<k>(),
312 get<k + 1>(psiphiA))))));
313
314 auto const bk = linearize(contract<1, 0>(get<k>(psiphib),
315 eval(contract<2, 1>(b.template core<k>(),
316 get<k + 1>(psiphib)))));
317
318 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
319 contract<1, 0>(get<k>(psiphihatA),
320 eval(contract<3, 1>(op.template core<k>(),
321 get<k + 1>(psiphihatA))))));
322
323 auto const hatbk =
324 linearize(contract<1, 0>(get<k>(psiphihatb),
325 eval(contract<2, 1>(b.template core<k>(),
326 get<k + 1>(psiphihatb)))));
327
328 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
329 contract<1, 0>(get<k>(psiphihatA),
330 eval(contract<3, 1>(op.template core<k>(),
331 get<k + 1>(psiphiA))))));
332
333 auto const rbk =
334 linearize(contract<1, 0>(get<k>(psiphihatb),
335 eval(contract<2, 1>(b.template core<k>(),
336 get<k + 1>(psiphib)))));
337
338 // solve local system
339 pRC::RemoveConstReference<decltype(x.template core<k>())>
340 sol;
341 linearize(sol) = pRC::Solver::GMRES<32, 0>()(Ak, bk,
342 linearize(x.template core<k>()), tolerance);
343
344 // orthogonalize
345 auto const [l, q] =
346 truncate<Ranks::size(k - 1), pRC::Position::Left>(sol);
347
348 // enrich current core
349 pRC::Tensor<T, ERanks::size(k - 1), 2,
351 decltype(x.template core<k>())>::size(2)>
352 ex;
353 slice<Ranks::size(k - 1), 2,
355 decltype(x.template core<k>())>::size(2)>(ex, 0, 0,
356 0) = q;
357
358 if constexpr(ERanks::size(k - 1) - Ranks::size(k - 1) > 0)
359 {
361 decltype(x.template core<k>())>::size(2);
363 auto const tt = eval(rbk - rAk * linearize(sol));
364 linearize(eta) = tt;
365
366 auto const [zl, zr] =
367 truncate<ERanks::size(k - 1) - Ranks::size(k - 1),
369
370 slice<ERanks::size(k - 1) - Ranks::size(k - 1), 2,
372 decltype(x.template core<k>())>::size(2)>(ex,
373 Ranks::size(k - 1), 0, 0) = zr;
374 }
375
376 // update residual
377 linearize(z.template core<k>()) =
378 hatbk - hatAk * linearize(sol);
379 auto const [rll, rqq] = orthogonalize<pRC::Position::Right>(
380 z.template core<k>());
381 z.template core<k>() = rqq;
382
383 // enrich next core
384 auto tNext = contract<2, 0>(x.template core<k - 1>(), l);
387 decltype(x.template core<k - 1>())>::size(0),
388 2, ERanks::size(k - 1)>
389 enext;
391 decltype(x.template core<k - 1>())>::size(0),
392 2, Ranks::size(k - 1)>(enext, 0, 0, 0) = tNext;
393 linearize(
395 decltype(x.template core<k - 1>())>::size(0),
396 2, ERanks::size(k - 1) - Ranks::size(k - 1)>(enext,
397 0, 0, Ranks::size(k - 1))) = pRC::zero();
398
399 // re-orthogonalize
400 auto const [ll, qq] =
401 orthogonalize<pRC::Position::Right>(ex);
402 x.template core<k>() = qq;
403 x.template core<k - 1>() = contract<2, 0>(enext, ll);
404
405 // calculate next psiphiA and psiphib (also with hats)
406 // k -1 in Alg.
407 get<k>(psiphiA) =
408 contract<1, 2, 1, 3>(conj(x.template core<k>()),
409 contract<2, 3, 1, 3>(op.template core<k>(),
410 eval(contract<2, 2>(x.template core<k>(),
411 get<k + 1>(psiphiA)))));
412
413 get<k>(psiphib) =
414 contract<1, 2, 1, 2>(conj(x.template core<k>()),
415 eval(contract<2, 1>(b.template core<k>(),
416 get<k + 1>(psiphib))));
417
418 get<k>(psiphihatA) =
419 contract<1, 2, 1, 3>(conj(z.template core<k>()),
420 contract<2, 3, 1, 3>(op.template core<k>(),
421 eval(contract<2, 2>(x.template core<k>(),
422 get<k + 1>(psiphihatA)))));
423
424 get<k>(psiphihatb) =
425 contract<1, 2, 1, 2>(conj(z.template core<k>()),
426 eval(contract<2, 1>(b.template core<k>(),
427 get<k + 1>(psiphihatb))));
428 });
429 }
430 return round<Ranks>(x);
431 }
432
433 template<pRC::Size R,
435 pRC::Size D, class T, class Tb>
436 auto amen(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance)
437 {
438 using ModeSizes = decltype(getModeSizes<D>());
439 using Ranks = decltype(getRanks<D, R>());
440
441 // start from a random TT with norm 1
442 pRC::SeedSequence seq(8, 16);
445 auto x = round<Ranks>(
447 dist));
448 x = x / norm(x);
449
450 return amen<R, OT>(MHNop, b, tolerance, x);
451 }
452}
453
454#endif // cMHN_TT_AMEN_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Definition value.hpp:12
Definition gaussian.hpp:14
Definition seq.hpp:13
Definition sequence.hpp:29
Definition gmres.hpp:22
Definition declarations.hpp:16
Definition tensor.hpp:25
Definition threefry.hpp:22
pRC::Float<> T
Definition externs_nonTT.hpp:1
int i
Definition gmock-matchers-comparisons_test.cc:603
Uncopyable z
Definition gmock-matchers-containers_test.cc:378
int x
Definition gmock-matchers-containers_test.cc:376
Definition als.hpp:12
constexpr auto getRanks()
Definition utility.hpp:17
auto amen(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition amen.hpp:16
constexpr auto getModeSizes()
Definition utility.hpp:11
Transform
Definition transform.hpp:9
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:14
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
static constexpr auto makeSeries()
Definition sequence.hpp:390
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13