cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
mamen.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_MAMEN_H
4#define cMHN_TT_MAMEN_H
5
7#include <cmhn/tt/utility.hpp>
8
9#include <prc.hpp>
10
11namespace cMHN::TT
12{
13 template<pRC::Size R,
15 pRC::Size D, class T, class Tb, class X>
16 auto MAMEN(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance,
17 X const &x0)
18 {
19 using ModeSizes = decltype(getModeSizes<D>());
20 using Ranks = decltype(getRanks<D, R>());
21 using ERanks = decltype(getRanks<D, R + pRC::Size(4)>());
22 using ZRanks = decltype(getRanks<D, pRC::Size(4)>());
23 using BRanks = typename Tb::Ranks;
24
25 auto const toleranceLocal =
27
28 // get TT for 1-Q
29 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
30 [&](auto const... seq)
31 {
33 MHNop.template core<seq>()...);
34 }));
35
36 auto x = x0;
37
38 // random setup
39 pRC::SeedSequence seq(8, 16);
40 pRC::RandomEngine rng(seq);
42
43 // initialize approximation of residual with random TT with norm 1
44 auto z = round<ZRanks>(
46 dist));
47 z = z / norm(z);
48
49 // create the objects holding psis/phis for A (without values yet)
50 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
51 [](auto const... seq)
52 {
53 return std::tuple(pRC::Tensor<T,
54 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
55 seq),
56 decltype(pRC::Sizes<1>(),
58 pRC::Sizes<1>())::size(seq),
59 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
60 seq)>()...);
61 });
62
63 // first and last psiphi is always 1
64 reshape<1, 1>(std::get<0>(psiphiA)) =
66 reshape<1, 1>(std::get<D>(psiphiA)) =
68
69 // create the objects holding psis/phis for b (without values yet)
70 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
71 [](auto const... seq)
72 {
73 return std::tuple(pRC::Tensor<T,
74 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
75 seq),
76 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
77 seq)>()...);
78 });
79
80 // first and last psiphi is always 1
81 reshape<1, 1>(std::get<0>(psiphib)) =
83 reshape<1, 1>(std::get<D>(psiphib)) =
85
86 // create the objects holding psi_hat/phi_hat for A (without values yet)
87 auto psiphihatA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
88 [](auto const... seq)
89 {
90 return std::tuple(pRC::Tensor<T,
91 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
92 seq),
93 decltype(pRC::Sizes<1>(),
95 pRC::Sizes<1>())::size(seq),
96 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
97 seq)>()...);
98 });
99
100 // first and last psiphihatA is always 1
101 reshape<1, 1>(std::get<0>(psiphihatA)) =
103 reshape<1, 1>(std::get<D>(psiphihatA)) =
105
106 // create the objects holding psi_hat/phi_hat for b (without values yet)
107 auto psiphihatb = expand(pRC::makeSeries<pRC::Index, D + 1>(),
108 [](auto const... seq)
109 {
110 return std::tuple(pRC::Tensor<T,
111 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
112 seq),
113 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
114 seq)>()...);
115 });
116
117 // first and last psiphihatb is always 1
118 reshape<1, 1>(std::get<0>(psiphihatb)) =
120 reshape<1, 1>(std::get<D>(psiphihatb)) =
122
123 // calculate the first psiphiA and psiphib (also with hats) and do
124 // initial QR
126 [&](auto const k)
127 {
128 auto const [lambda, l, q] =
129 orthogonalize<pRC::Position::Right>(x.template core<k>());
130
131 x.template core<k>() = q;
132 x.template core<k - 1>() =
133 lambda * contract<2, 0>(x.template core<k - 1>(), l);
134 {
135 // also orthogonalize z
136 auto const [lambda, l, q] =
137 orthogonalize<pRC::Position::Right>(
138 z.template core<k>());
139 z.template core<k>() = q;
140 z.template core<k - 1>() =
141 lambda * contract<2, 0>(z.template core<k - 1>(), l);
142 }
143
144 // k-1 in Alg.
145 std::get<k>(psiphiA) =
146 contract<1, 2, 1, 3>(conj(x.template core<k>()),
147 contract<2, 3, 1, 3>(op.template core<k>(),
148 eval(contract<2, 2>(x.template core<k>(),
149 std::get<k + 1>(psiphiA)))));
150
151 std::get<k>(psiphib) =
152 contract<1, 2, 1, 2>(conj(x.template core<k>()),
153 eval(contract<2, 1>(b.template core<k>(),
154 std::get<k + 1>(psiphib))));
155
156 std::get<k>(psiphihatA) =
157 contract<1, 2, 1, 3>(conj(z.template core<k>()),
158 contract<2, 3, 1, 3>(op.template core<k>(),
159 eval(contract<2, 2>(x.template core<k>(),
160 get<k + 1>(psiphihatA)))));
161
162 std::get<k>(psiphihatb) =
163 contract<1, 2, 1, 2>(conj(z.template core<k>()),
164 eval(contract<2, 1>(b.template core<k>(),
165 get<k + 1>(psiphihatb))));
166 });
167
168 // for now, just use 2 iterations
169 for(pRC::Index i = 0; i < 1; ++i)
170 {
171 // forward sweep
173 [&](auto const k)
174 {
175 // local A and b
176 auto const sA = pRC::reshape<
178 decltype(op.template core<k>())>::size(0),
179 4, 4,
181 decltype(op.template core<k + 1>())>::size(3)>(
182 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
183 op.template core<k>(), op.template core<k + 1>())));
184
185 auto const sB = pRC::reshape<
187 decltype(b.template core<k>())>::size(0),
188 4,
190 decltype(b.template core<k + 1>())>::size(2)>(
191 contract<2, 0>(b.template core<k>(),
192 b.template core<k + 1>()));
193
194 auto const sX = pRC::reshape<
196 decltype(x.template core<k>())>::size(0),
197 4,
199 decltype(x.template core<k + 1>())>::size(2)>(
200 contract<2, 0>(x.template core<k>(),
201 x.template core<k + 1>()));
202
203 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
204 contract<1, 0>(std::get<k>(psiphiA),
205 eval(contract<3, 1>(sA,
206 std::get<k + 2>(psiphiA))))));
207
208 auto const bk = linearize(contract<1, 0>(
209 std::get<k>(psiphib),
210 eval(contract<2, 1>(sB, std::get<k + 2>(psiphib)))));
211
212 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
213 contract<1, 0>(std::get<k>(psiphihatA),
214 eval(contract<3, 1>(sA,
215 std::get<k + 2>(psiphihatA))))));
216
217 auto const hatbk = linearize(contract<1, 0>(
218 std::get<k>(psiphihatb),
219 eval(contract<2, 1>(sB, std::get<k + 2>(psiphihatb)))));
220
221 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
222 contract<1, 0>(std::get<k>(psiphiA),
223 eval(contract<3, 1>(sA,
224 std::get<k + 2>(psiphihatA))))));
225
226 auto const rbk = linearize(contract<1, 0>(
227 std::get<k>(psiphib),
228 eval(contract<2, 1>(sB, std::get<k + 2>(psiphihatb)))));
229
230 // solve local system
233 decltype(x.template core<k>())>::size(0),
234 2, 2,
236 decltype(x.template core<k + 1>())>::size(2)>
237 sol;
238 linearize(sol) = pRC::Solver::GMRES<256, 0>()(Ak, bk,
239 linearize(sX), toleranceLocal);
240
241 // left orthogonalization
242 auto const [u, s, v] = svd<Ranks::size(k)>(matricize(sol));
243
246 decltype(x.template core<k>())>::size(0),
247 2, ERanks::size(k)>
248 ex;
249 linearize(
251 decltype(x.template core<k>())>::size(0),
252 2, Ranks::size(k)>(ex, 0, 0, 0)) = linearize(u);
253
254 if constexpr(ERanks::size(k) - Ranks::size(k) > 0)
255 {
257 decltype(x.template core<k>())>::size(0);
259 decltype(z.template core<k + 1>())>::size(2);
261 auto const tt = eval(rbk - rAk * linearize(sol));
262 linearize(eta) = tt;
263
264 auto const [zu, zs, zv] =
265 svd<ERanks::size(k) - Ranks::size(k)>(
266 matricize(eta));
267
268 linearize(
270 decltype(x.template core<k>())>::size(0),
271 2, ERanks::size(k) - Ranks::size(k)>(ex, 0, 0,
272 Ranks::size(k))) =
273 linearize(zu * fromDiagonal(zs));
274 }
275
278 decltype(z.template core<k>())>::size(0),
279 2, 2,
281 decltype(z.template core<k + 1>())>::size(2)>
282 zz;
283 linearize(zz) = hatbk - hatAk * linearize(sol);
284
285 auto const [zu, zs, zv] =
286 svd<ZRanks::size(k)>(matricize(zz));
287
288 linearize(z.template core<k>()) = linearize(zu);
289
290 pRC::Tensor<T, Ranks::size(k), 2,
292 decltype(x.template core<k + 1>())>::size(2)>
293 tNext;
294 folding<pRC::Position::Right>(tNext) =
295 fromDiagonal(s) * adjoint(v);
296 pRC::Tensor<T, ERanks::size(k), 2,
298 decltype(x.template core<k + 1>())>::size(2)>
299 eNext;
300 slice<Ranks::size(k), 2,
302 decltype(x.template core<k + 1>())>::size(2)>(eNext,
303 0, 0, 0) = tNext;
304 slice<ERanks::size(k) - Ranks::size(k), 2,
306 decltype(x.template core<k + 1>())>::size(2)>(eNext,
307 Ranks::size(k), 0, 0) = pRC::zero();
308
309 auto const [llambda, qq, rr] =
310 orthogonalize<pRC::Position::Left>(ex);
311 x.template core<k>() = qq;
312 x.template core<k + 1>() =
313 llambda * contract<1, 0>(rr, eNext);
314
315 // calculate psiphiA(k+1)
316 std::get<k + 1>(psiphiA) =
317 contract<1, 0, 0, 3>(conj(x.template core<k>()),
318 contract<2, 0, 0, 3>(op.template core<k>(),
319 eval(contract<0, 2>(x.template core<k>(),
320 std::get<k>(psiphiA)))));
321
322 // calculate psiphib(k+1)
323 std::get<k + 1>(psiphib) =
324 contract<0, 1, 2, 0>(conj(x.template core<k>()),
325 eval(contract<0, 1>(b.template core<k>(),
326 std::get<k>(psiphib))));
327
328 // calculate psiphiAhat(k+1)
329 std::get<k + 1>(psiphihatA) =
330 contract<1, 0, 0, 3>(conj(z.template core<k>()),
331 contract<2, 0, 0, 3>(op.template core<k>(),
332 eval(contract<0, 2>(x.template core<k>(),
333 std::get<k>(psiphihatA)))));
334
335 // calculate psiphibhat(k+1)
336 std::get<k + 1>(psiphihatb) =
337 contract<0, 1, 2, 0>(conj(z.template core<k>()),
338 eval(contract<0, 1>(b.template core<k>(),
339 std::get<k>(psiphihatb))));
340 });
341
342 // backward sweep
345 [&](auto const k)
346 {
347 // local A and b
348 auto const sA = pRC::reshape<
350 decltype(op.template core<k - 1>())>::size(0),
351 4, 4,
353 decltype(op.template core<k>())>::size(3)>(
354 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
355 op.template core<k - 1>(), op.template core<k>())));
356
357 auto const sB = pRC::reshape<
359 decltype(b.template core<k - 1>())>::size(0),
360 4,
362 decltype(b.template core<k>())>::size(2)>(
363 contract<2, 0>(b.template core<k - 1>(),
364 b.template core<k>()));
365
366 auto const sX = pRC::reshape<
368 decltype(x.template core<k - 1>())>::size(0),
369 4,
371 decltype(x.template core<k>())>::size(2)>(
372 contract<2, 0>(x.template core<k - 1>(),
373 x.template core<k>()));
374
375 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
376 contract<1, 0>(std::get<k - 1>(psiphiA),
377 eval(contract<3, 1>(sA,
378 std::get<k + 1>(psiphiA))))));
379
380 auto const bk = linearize(contract<1, 0>(
381 std::get<k - 1>(psiphib),
382 eval(contract<2, 1>(sB, std::get<k + 1>(psiphib)))));
383
384 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
385 contract<1, 0>(std::get<k - 1>(psiphihatA),
386 eval(contract<3, 1>(sA,
387 std::get<k + 1>(psiphihatA))))));
388
389 auto const hatbk = linearize(contract<1, 0>(
390 std::get<k - 1>(psiphihatb),
391 eval(contract<2, 1>(sB, std::get<k + 1>(psiphihatb)))));
392
393 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
394 contract<1, 0>(std::get<k - 1>(psiphihatA),
395 eval(contract<3, 1>(sA,
396 std::get<k + 1>(psiphiA))))));
397
398 auto const rbk = linearize(contract<1, 0>(
399 std::get<k - 1>(psiphihatb),
400 eval(contract<2, 1>(sB, std::get<k + 1>(psiphib)))));
401
402 // solve local system
405 decltype(x.template core<k - 1>())>::size(0),
406 2, 2,
408 decltype(x.template core<k>())>::size(2)>
409 sol;
410 linearize(sol) = pRC::Solver::GMRES<256, 0>()(Ak, bk,
411 linearize(sX), tolerance);
412
413 // right orthogonalization
414 auto const [u, s, v] =
415 svd<Ranks::size(k - 1)>(matricize(sol));
416
417 pRC::Tensor<T, ERanks::size(k - 1), 2,
419 decltype(x.template core<k>())>::size(2)>
420 ex;
421 linearize(slice<Ranks::size(k - 1), 2,
423 decltype(x.template core<k>())>::size(2)>(ex, 0, 0,
424 0)) = linearize(adjoint(v));
425
426 if constexpr(ERanks::size(k - 1) - Ranks::size(k - 1) > 0)
427 {
429 decltype(z.template core<k - 1>())>::size(0);
431 decltype(x.template core<k>())>::size(2);
433 auto const tt = eval(rbk - rAk * linearize(sol));
434 linearize(eta) = tt;
435
436 auto const [zu, zs, zv] =
437 svd<ERanks::size(k - 1) - Ranks::size(k - 1)>(
438 matricize(eta));
439
440 linearize(
441 slice<ERanks::size(k - 1) - Ranks::size(k - 1), 2,
443 decltype(x.template core<k>())>::size(2)>(
444 ex, Ranks::size(k - 1), 0, 0)) =
445 linearize(fromDiagonal(zs) * adjoint(zv));
446 }
447
450 decltype(z.template core<k - 1>())>::size(0),
451 2, 2,
453 decltype(z.template core<k>())>::size(2)>
454 zz;
455 linearize(zz) = hatbk - hatAk * linearize(sol);
456
457 auto const [zu, zs, zv] =
458 svd<ZRanks::size(k - 1)>(matricize(zz));
459
460 linearize(z.template core<k>()) = linearize(adjoint(zv));
461
464 decltype(x.template core<k - 1>())>::size(0),
465 2, Ranks::size(k - 1)>
466 tNext;
467 folding<pRC::Position::Left>(tNext) = u * fromDiagonal(s);
470 decltype(x.template core<k - 1>())>::size(0),
471 2, ERanks::size(k - 1)>
472 eNext;
474 decltype(x.template core<k - 1>())>::size(0),
475 2, Ranks::size(k - 1)>(eNext, 0, 0, 0) = tNext;
476 linearize(
478 decltype(x.template core<k - 1>())>::size(0),
479 2, ERanks::size(k - 1) - Ranks::size(k - 1)>(eNext,
480 0, 0, Ranks::size(k - 1))) = pRC::zero();
481
482 auto const [llambda, ll, qq] =
483 orthogonalize<pRC::Position::Right>(ex);
484 x.template core<k>() = qq;
485 x.template core<k - 1>() =
486 llambda * contract<2, 0>(eNext, ll);
487
488 // k-1 in Alg.
489 // calculate psiphiA(k)
490 std::get<k>(psiphiA) =
491 contract<1, 2, 1, 3>(conj(x.template core<k>()),
492 contract<2, 3, 1, 3>(op.template core<k>(),
493 eval(contract<2, 2>(x.template core<k>(),
494 std::get<k + 1>(psiphiA)))));
495
496 // calculate psiphib(k)
497 std::get<k>(psiphib) =
498 contract<1, 2, 1, 2>(conj(x.template core<k>()),
499 eval(contract<2, 1>(b.template core<k>(),
500 std::get<k + 1>(psiphib))));
501
502 // calculate psiphihatA(k)
503 std::get<k>(psiphihatA) =
504 contract<1, 2, 1, 3>(conj(z.template core<k>()),
505 contract<2, 3, 1, 3>(op.template core<k>(),
506 eval(contract<2, 2>(x.template core<k>(),
507 std::get<k + 1>(psiphihatA)))));
508
509 // calculate psiphihatb
510 std::get<k>(psiphihatb) =
511 contract<1, 2, 1, 2>(conj(z.template core<k>()),
512 eval(contract<2, 1>(b.template core<k>(),
513 std::get<k + 1>(psiphihatb))));
514 });
515 }
516 return x;
517 }
518
519 template<pRC::Size R,
521 pRC::Size D, class T, class Tb>
522 auto MAMEN(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance)
523 {
524 using ModeSizes = decltype(getModeSizes<D>());
525 using Ranks = decltype(getRanks<D, R>());
526
527 // start from a random TT with norm 1
528 pRC::SeedSequence seq(8, 16);
529 pRC::RandomEngine rng(seq);
531 auto x = round<Ranks>(
533 dist));
534 x = x / norm(x);
535
536 return MAMEN<R, OT>(MHNop, b, tolerance, x);
537 }
538}
539
540#endif // cMHN_TT_MAMEN_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Definition type_traits.hpp:57
Definition seq.hpp:13
Definition sequence.hpp:56
Definition gmres.hpp:23
Definition type_traits.hpp:17
Definition tensor.hpp:28
Definition threefry.hpp:24
pRC::Float<> T
Definition externs_nonTT.hpp:1
Definition als.hpp:12
auto MAMEN(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition mamen.hpp:16
Transform
Definition transform.hpp:11
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
static constexpr auto random(RandomEngine &rng, D &distribution)
Definition random.hpp:12
Size Index
Definition type_traits.hpp:21
std::size_t Size
Definition type_traits.hpp:20
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:16
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition type_traits.hpp:62
static constexpr auto reshape(X &&a)
Definition reshape.hpp:17
static constexpr auto sqrt(Complex< T > const &a)
Definition sqrt.hpp:12
static constexpr auto identity()
Definition identity.hpp:12