cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
mamen.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_MAMEN_H
4#define cMHN_TT_MAMEN_H
5
7#include <cmhn/tt/utility.hpp>
8
9#include <prc.hpp>
10
11namespace cMHN::TT
12{
13 template<pRC::Size R,
15 pRC::Size D, class T, class Tb, class X>
16 auto mamen(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance,
17 X const &x0)
18 {
19 using ModeSizes = decltype(getModeSizes<D>());
20 using Ranks = decltype(getRanks<D, R>());
21 using ERanks = decltype(getRanks<D, R + pRC::Size(4)>());
22 using ZRanks = decltype(getRanks<D, pRC::Size(4)>());
23 using BRanks = typename Tb::Ranks;
24
25 auto const toleranceLocal =
27
28 // get TT for 1-Q
29 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
30 [&](auto const... seq)
31 {
33 MHNop.template core<seq>()...);
34 }));
35
38 [&](auto const k)
39 {
41 decltype(x0.template core<k>())>::size(0),
42 2,
44 decltype(x0.template core<k>())>::size(2)>(
45 x.template core<k>(), 0, 0, 0) = x0.template core<k>();
46 });
47
48 // random setup
49 pRC::SeedSequence seq(8, 16);
52
53 // initialize approximation of residual with random TT with norm 1
54 auto z = round<ZRanks>(
56 dist));
57 z = z / norm(z);
58
59 // create the objects holding psis/phis for A (without values yet)
60 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
61 [](auto const... seq)
62 {
63 return std::tuple(pRC::Tensor<T,
64 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
65 seq),
66 decltype(pRC::Sizes<1>(),
68 pRC::Sizes<1>())::size(seq),
69 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
70 seq)>()...);
71 });
72
73 // first and last psiphi is always 1
74 reshape<1, 1>(std::get<0>(psiphiA)) =
76 reshape<1, 1>(std::get<D>(psiphiA)) =
78
79 // create the objects holding psis/phis for b (without values yet)
80 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
81 [](auto const... seq)
82 {
83 return std::tuple(pRC::Tensor<T,
84 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
85 seq),
86 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
87 seq)>()...);
88 });
89
90 // first and last psiphi is always 1
91 reshape<1, 1>(std::get<0>(psiphib)) =
93 reshape<1, 1>(std::get<D>(psiphib)) =
95
96 // create the objects holding psi_hat/phi_hat for A (without values yet)
97 auto psiphihatA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
98 [](auto const... seq)
99 {
100 return std::tuple(pRC::Tensor<T,
101 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
102 seq),
103 decltype(pRC::Sizes<1>(),
105 pRC::Sizes<1>())::size(seq),
106 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
107 seq)>()...);
108 });
109
110 // first and last psiphihatA is always 1
111 reshape<1, 1>(std::get<0>(psiphihatA)) =
113 reshape<1, 1>(std::get<D>(psiphihatA)) =
115
116 // create the objects holding psi_hat/phi_hat for b (without values yet)
117 auto psiphihatb = expand(pRC::makeSeries<pRC::Index, D + 1>(),
118 [](auto const... seq)
119 {
120 return std::tuple(pRC::Tensor<T,
121 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
122 seq),
123 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
124 seq)>()...);
125 });
126
127 // first and last psiphihatb is always 1
128 reshape<1, 1>(std::get<0>(psiphihatb)) =
130 reshape<1, 1>(std::get<D>(psiphihatb)) =
132
133 // calculate the first psiphiA and psiphib (also with hats) and do
134 // initial QR
136 [&](auto const k)
137 {
138 auto const [l, q] =
139 orthogonalize<pRC::Position::Right>(x.template core<k>());
140
141 x.template core<k>() = q;
142 x.template core<k - 1>() =
143 contract<2, 0>(x.template core<k - 1>(), l);
144 {
145 // also orthogonalize z
146 auto const [l, q] = orthogonalize<pRC::Position::Right>(
147 z.template core<k>());
148 z.template core<k>() = q;
149 z.template core<k - 1>() =
150 contract<2, 0>(z.template core<k - 1>(), l);
151 }
152
153 // k-1 in Alg.
154 std::get<k>(psiphiA) =
155 contract<1, 2, 1, 3>(conj(x.template core<k>()),
156 contract<2, 3, 1, 3>(op.template core<k>(),
157 eval(contract<2, 2>(x.template core<k>(),
158 std::get<k + 1>(psiphiA)))));
159
160 std::get<k>(psiphib) =
161 contract<1, 2, 1, 2>(conj(x.template core<k>()),
162 eval(contract<2, 1>(b.template core<k>(),
163 std::get<k + 1>(psiphib))));
164
165 std::get<k>(psiphihatA) =
166 contract<1, 2, 1, 3>(conj(z.template core<k>()),
167 contract<2, 3, 1, 3>(op.template core<k>(),
168 eval(contract<2, 2>(x.template core<k>(),
169 get<k + 1>(psiphihatA)))));
170
171 std::get<k>(psiphihatb) =
172 contract<1, 2, 1, 2>(conj(z.template core<k>()),
173 eval(contract<2, 1>(b.template core<k>(),
174 get<k + 1>(psiphihatb))));
175 });
176
177 // for now, just use 2 iterations
178 for(pRC::Index i = 0; i < 2; ++i)
179 {
180 // forward sweep
182 [&](auto const k)
183 {
184 // local A and b
185 auto const sA = pRC::reshape<
187 decltype(op.template core<k>())>::size(0),
188 4, 4,
190 decltype(op.template core<k + 1>())>::size(3)>(
191 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
192 op.template core<k>(), op.template core<k + 1>())));
193
194 auto const sB = pRC::reshape<
196 decltype(b.template core<k>())>::size(0),
197 4,
199 decltype(b.template core<k + 1>())>::size(2)>(
200 contract<2, 0>(b.template core<k>(),
201 b.template core<k + 1>()));
202
203 auto const sX = pRC::reshape<
205 decltype(x.template core<k>())>::size(0),
206 4,
208 decltype(x.template core<k + 1>())>::size(2)>(
209 contract<2, 0>(x.template core<k>(),
210 x.template core<k + 1>()));
211
212 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
213 contract<1, 0>(std::get<k>(psiphiA),
214 eval(contract<3, 1>(sA,
215 std::get<k + 2>(psiphiA))))));
216
217 auto const bk = linearize(contract<1, 0>(
218 std::get<k>(psiphib),
219 eval(contract<2, 1>(sB, std::get<k + 2>(psiphib)))));
220
221 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
222 contract<1, 0>(std::get<k>(psiphihatA),
223 eval(contract<3, 1>(sA,
224 std::get<k + 2>(psiphihatA))))));
225
226 auto const hatbk = linearize(contract<1, 0>(
227 std::get<k>(psiphihatb),
228 eval(contract<2, 1>(sB, std::get<k + 2>(psiphihatb)))));
229
230 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
231 contract<1, 0>(std::get<k>(psiphiA),
232 eval(contract<3, 1>(sA,
233 std::get<k + 2>(psiphihatA))))));
234
235 auto const rbk = linearize(contract<1, 0>(
236 std::get<k>(psiphib),
237 eval(contract<2, 1>(sB, std::get<k + 2>(psiphihatb)))));
238
239 // solve local system
242 decltype(x.template core<k>())>::size(0),
243 2, 2,
245 decltype(x.template core<k + 1>())>::size(2)>
246 sol;
247 linearize(sol) = pRC::Solver::GMRES<256, 0>()(Ak, bk,
248 linearize(sX), toleranceLocal);
249
250 // left orthogonalization
251 auto const [u, s, v] = svd<Ranks::size(k)>(matricize(sol));
252
255 decltype(x.template core<k>())>::size(0),
256 2, ERanks::size(k)>
257 ex;
258 linearize(
260 decltype(x.template core<k>())>::size(0),
261 2, Ranks::size(k)>(ex, 0, 0, 0)) = linearize(u);
262
263 if constexpr(ERanks::size(k) - Ranks::size(k) > 0)
264 {
266 decltype(x.template core<k>())>::size(0);
268 decltype(z.template core<k + 1>())>::size(2);
270 auto const tt = eval(rbk - rAk * linearize(sol));
271 linearize(eta) = tt;
272
273 auto const [zu, zs, zv] =
274 svd<ERanks::size(k) - Ranks::size(k)>(
275 matricize(eta));
276
277 linearize(
279 decltype(x.template core<k>())>::size(0),
280 2, ERanks::size(k) - Ranks::size(k)>(ex, 0, 0,
281 Ranks::size(k))) =
282 linearize(zu * fromDiagonal(zs));
283 }
284
287 decltype(z.template core<k>())>::size(0),
288 2, 2,
290 decltype(z.template core<k + 1>())>::size(2)>
291 zz;
292 linearize(zz) = hatbk - hatAk * linearize(sol);
293
294 auto const [zu, zs, zv] =
295 svd<ZRanks::size(k)>(matricize(zz));
296
297 linearize(z.template core<k>()) = linearize(zu);
298
299 pRC::Tensor<T, Ranks::size(k), 2,
301 decltype(x.template core<k + 1>())>::size(2)>
302 tNext;
303 folding<pRC::Position::Right>(tNext) =
304 fromDiagonal(s) * adjoint(v);
305 pRC::Tensor<T, ERanks::size(k), 2,
307 decltype(x.template core<k + 1>())>::size(2)>
308 eNext;
309 slice<Ranks::size(k), 2,
311 decltype(x.template core<k + 1>())>::size(2)>(eNext,
312 0, 0, 0) = tNext;
313 slice<ERanks::size(k) - Ranks::size(k), 2,
315 decltype(x.template core<k + 1>())>::size(2)>(eNext,
316 Ranks::size(k), 0, 0) = pRC::zero();
317
318 auto const [qq, rr] =
319 orthogonalize<pRC::Position::Left>(ex);
320 x.template core<k>() = qq;
321 x.template core<k + 1>() = contract<1, 0>(rr, eNext);
322
323 // calculate psiphiA(k+1)
324 std::get<k + 1>(psiphiA) =
325 contract<1, 0, 0, 3>(conj(x.template core<k>()),
326 contract<2, 0, 0, 3>(op.template core<k>(),
327 eval(contract<0, 2>(x.template core<k>(),
328 std::get<k>(psiphiA)))));
329
330 // calculate psiphib(k+1)
331 std::get<k + 1>(psiphib) =
332 contract<0, 1, 2, 0>(conj(x.template core<k>()),
333 eval(contract<0, 1>(b.template core<k>(),
334 std::get<k>(psiphib))));
335
336 // calculate psiphiAhat(k+1)
337 std::get<k + 1>(psiphihatA) =
338 contract<1, 0, 0, 3>(conj(z.template core<k>()),
339 contract<2, 0, 0, 3>(op.template core<k>(),
340 eval(contract<0, 2>(x.template core<k>(),
341 std::get<k>(psiphihatA)))));
342
343 // calculate psiphibhat(k+1)
344 std::get<k + 1>(psiphihatb) =
345 contract<0, 1, 2, 0>(conj(z.template core<k>()),
346 eval(contract<0, 1>(b.template core<k>(),
347 std::get<k>(psiphihatb))));
348 });
349
350 // backward sweep
353 [&](auto const k)
354 {
355 // local A and b
356 auto const sA = pRC::reshape<
358 decltype(op.template core<k - 1>())>::size(0),
359 4, 4,
361 decltype(op.template core<k>())>::size(3)>(
362 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
363 op.template core<k - 1>(), op.template core<k>())));
364
365 auto const sB = pRC::reshape<
367 decltype(b.template core<k - 1>())>::size(0),
368 4,
370 decltype(b.template core<k>())>::size(2)>(
371 contract<2, 0>(b.template core<k - 1>(),
372 b.template core<k>()));
373
374 auto const sX = pRC::reshape<
376 decltype(x.template core<k - 1>())>::size(0),
377 4,
379 decltype(x.template core<k>())>::size(2)>(
380 contract<2, 0>(x.template core<k - 1>(),
381 x.template core<k>()));
382
383 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
384 contract<1, 0>(std::get<k - 1>(psiphiA),
385 eval(contract<3, 1>(sA,
386 std::get<k + 1>(psiphiA))))));
387
388 auto const bk = linearize(contract<1, 0>(
389 std::get<k - 1>(psiphib),
390 eval(contract<2, 1>(sB, std::get<k + 1>(psiphib)))));
391
392 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
393 contract<1, 0>(std::get<k - 1>(psiphihatA),
394 eval(contract<3, 1>(sA,
395 std::get<k + 1>(psiphihatA))))));
396
397 auto const hatbk = linearize(contract<1, 0>(
398 std::get<k - 1>(psiphihatb),
399 eval(contract<2, 1>(sB, std::get<k + 1>(psiphihatb)))));
400
401 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
402 contract<1, 0>(std::get<k - 1>(psiphihatA),
403 eval(contract<3, 1>(sA,
404 std::get<k + 1>(psiphiA))))));
405
406 auto const rbk = linearize(contract<1, 0>(
407 std::get<k - 1>(psiphihatb),
408 eval(contract<2, 1>(sB, std::get<k + 1>(psiphib)))));
409
410 // solve local system
413 decltype(x.template core<k - 1>())>::size(0),
414 2, 2,
416 decltype(x.template core<k>())>::size(2)>
417 sol;
418 linearize(sol) = pRC::Solver::GMRES<256, 0>()(Ak, bk,
419 linearize(sX), tolerance);
420
421 // right orthogonalization
422 auto const [u, s, v] =
423 svd<Ranks::size(k - 1)>(matricize(sol));
424
425 pRC::Tensor<T, ERanks::size(k - 1), 2,
427 decltype(x.template core<k>())>::size(2)>
428 ex;
429 linearize(slice<Ranks::size(k - 1), 2,
431 decltype(x.template core<k>())>::size(2)>(ex, 0, 0,
432 0)) = linearize(adjoint(v));
433
434 if constexpr(ERanks::size(k - 1) - Ranks::size(k - 1) > 0)
435 {
437 decltype(z.template core<k - 1>())>::size(0);
439 decltype(x.template core<k>())>::size(2);
441 auto const tt = eval(rbk - rAk * linearize(sol));
442 linearize(eta) = tt;
443
444 auto const [zu, zs, zv] =
445 svd<ERanks::size(k - 1) - Ranks::size(k - 1)>(
446 matricize(eta));
447
448 linearize(
449 slice<ERanks::size(k - 1) - Ranks::size(k - 1), 2,
451 decltype(x.template core<k>())>::size(2)>(
452 ex, Ranks::size(k - 1), 0, 0)) =
453 linearize(fromDiagonal(zs) * adjoint(zv));
454 }
455
458 decltype(z.template core<k - 1>())>::size(0),
459 2, 2,
461 decltype(z.template core<k>())>::size(2)>
462 zz;
463 linearize(zz) = hatbk - hatAk * linearize(sol);
464
465 auto const [zu, zs, zv] =
466 svd<ZRanks::size(k - 1)>(matricize(zz));
467
468 linearize(z.template core<k>()) = linearize(adjoint(zv));
469
472 decltype(x.template core<k - 1>())>::size(0),
473 2, Ranks::size(k - 1)>
474 tNext;
475 folding<pRC::Position::Left>(tNext) = u * fromDiagonal(s);
478 decltype(x.template core<k - 1>())>::size(0),
479 2, ERanks::size(k - 1)>
480 eNext;
482 decltype(x.template core<k - 1>())>::size(0),
483 2, Ranks::size(k - 1)>(eNext, 0, 0, 0) = tNext;
484 linearize(
486 decltype(x.template core<k - 1>())>::size(0),
487 2, ERanks::size(k - 1) - Ranks::size(k - 1)>(eNext,
488 0, 0, Ranks::size(k - 1))) = pRC::zero();
489
490 auto const [ll, qq] =
491 orthogonalize<pRC::Position::Right>(ex);
492 x.template core<k>() = qq;
493 x.template core<k - 1>() = contract<2, 0>(eNext, ll);
494
495 // k-1 in Alg.
496 // calculate psiphiA(k)
497 std::get<k>(psiphiA) =
498 contract<1, 2, 1, 3>(conj(x.template core<k>()),
499 contract<2, 3, 1, 3>(op.template core<k>(),
500 eval(contract<2, 2>(x.template core<k>(),
501 std::get<k + 1>(psiphiA)))));
502
503 // calculate psiphib(k)
504 std::get<k>(psiphib) =
505 contract<1, 2, 1, 2>(conj(x.template core<k>()),
506 eval(contract<2, 1>(b.template core<k>(),
507 std::get<k + 1>(psiphib))));
508
509 // calculate psiphihatA(k)
510 std::get<k>(psiphihatA) =
511 contract<1, 2, 1, 3>(conj(z.template core<k>()),
512 contract<2, 3, 1, 3>(op.template core<k>(),
513 eval(contract<2, 2>(x.template core<k>(),
514 std::get<k + 1>(psiphihatA)))));
515
516 // calculate psiphihatb
517 std::get<k>(psiphihatb) =
518 contract<1, 2, 1, 2>(conj(z.template core<k>()),
519 eval(contract<2, 1>(b.template core<k>(),
520 std::get<k + 1>(psiphihatb))));
521 });
522 }
523 return round<Ranks>(x);
524 }
525
526 template<pRC::Size R,
528 pRC::Size D, class T, class Tb>
529 auto mamen(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance)
530 {
531 using ModeSizes = decltype(getModeSizes<D>());
532 using Ranks = decltype(getRanks<D, R>());
533
534 // start from a random TT with norm 1
535 pRC::SeedSequence seq(8, 16);
538 auto x = round<Ranks>(
540 dist));
541 x = x / norm(x);
542
543 return mamen<R, OT>(MHNop, b, tolerance, x);
544 }
545}
546
547#endif // cMHN_TT_MAMEN_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Definition value.hpp:12
Definition gaussian.hpp:14
Definition seq.hpp:13
Definition sequence.hpp:29
Definition gmres.hpp:22
Definition declarations.hpp:16
Definition tensor.hpp:25
Definition threefry.hpp:22
pRC::Float<> T
Definition externs_nonTT.hpp:1
int i
Definition gmock-matchers-comparisons_test.cc:603
Uncopyable z
Definition gmock-matchers-containers_test.cc:378
int x
Definition gmock-matchers-containers_test.cc:376
Definition als.hpp:12
constexpr auto getRanks()
Definition utility.hpp:17
auto mamen(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition mamen.hpp:16
constexpr auto getModeSizes()
Definition utility.hpp:11
Transform
Definition transform.hpp:9
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:14
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
static constexpr auto reshape(X &&a)
Definition reshape.hpp:14
static constexpr auto makeSeries()
Definition sequence.hpp:390
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13