cMHN 1.0
C++ library for learning MHNs with pRC
mamen.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_MAMEN_H
4#define cMHN_TT_MAMEN_H
5
7#include <cmhn/tt/utility.hpp>
8
9#include <prc.hpp>
10
11namespace cMHN::TT
12{
13 template<pRC::Size R,
14 pRC::Operator::Transform OT = pRC::Operator::Transform::None,
15 pRC::Size D, class T, class Tb>
16 auto MAMEN(MHNOperator<T, D> const &MHNop, Tb const &b,
17 T const &tolerance)
18 {
19 using ModeSizes = decltype(getModeSizes<D>());
20 using Ranks = decltype(getRanks<D, R>());
21 using ERanks = decltype(getRanks<D, R + pRC::Size(4)>());
22 using ZRanks = decltype(getRanks<D, pRC::Size(4)>());
23 using BRanks = typename Tb::Ranks;
24
25 auto const toleranceLocal =
26 tolerance / pRC::sqrt(pRC::identity<pRC::Float<>>(D));
27
28 // get TT for 1-Q
29 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
30 [&](auto const... seq)
31 {
32 return pRC::TensorTrain::fromCores(
33 MHNop.template core<seq>()...);
34 }));
35
36 // start from a random TT with norm 1
37 pRC::SeedSequence seq(8, 16);
38 pRC::RandomEngine rng(seq);
39 pRC::GaussianDistribution<pRC::Float<>> dist;
40 auto x = round<ERanks>(
41 pRC::random<pRC::TensorTrain::Tensor<T, ModeSizes, ERanks>>(rng,
42 dist));
43 x = x / norm(x);
44
45 // initialize approximation of residual with random TT with norm 1
46 auto z = round<ZRanks>(
47 pRC::random<pRC::TensorTrain::Tensor<T, ModeSizes, ZRanks>>(rng,
48 dist));
49 z = z / norm(z);
50
51 // create the objects holding psis/phis for A (without values yet)
52 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
53 [](auto const... seq)
54 {
55 return std::tuple(pRC::Tensor<T,
56 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
57 seq),
58 decltype(pRC::Sizes<1>(),
59 pRC::makeConstantSequence<pRC::Size, D - 1, D + 1>(),
60 pRC::Sizes<1>())::size(seq),
61 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
62 seq)>()...);
63 });
64
65 // first and last psiphi is always 1
66 reshape<1, 1>(std::get<0>(psiphiA)) =
67 pRC::identity<pRC::Tensor<T, 1, 1>>();
68 reshape<1, 1>(std::get<D>(psiphiA)) =
69 pRC::identity<pRC::Tensor<T, 1, 1>>();
70
71 // create the objects holding psis/phis for b (without values yet)
72 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
73 [](auto const... seq)
74 {
75 return std::tuple(pRC::Tensor<T,
76 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
77 seq),
78 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
79 seq)>()...);
80 });
81
82 // first and last psiphi is always 1
83 reshape<1, 1>(std::get<0>(psiphib)) =
84 pRC::identity<pRC::Tensor<T, 1, 1>>();
85 reshape<1, 1>(std::get<D>(psiphib)) =
86 pRC::identity<pRC::Tensor<T, 1, 1>>();
87
88 // create the objects holding psi_hat/phi_hat for A (without values yet)
89 auto psiphihatA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
90 [](auto const... seq)
91 {
92 return std::tuple(pRC::Tensor<T,
93 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
94 seq),
95 decltype(pRC::Sizes<1>(),
96 pRC::makeConstantSequence<pRC::Size, D - 1, D + 1>(),
97 pRC::Sizes<1>())::size(seq),
98 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
99 seq)>()...);
100 });
101
102 // first and last psiphihatA is always 1
103 reshape<1, 1>(std::get<0>(psiphihatA)) =
104 pRC::identity<pRC::Tensor<T, 1, 1>>();
105 reshape<1, 1>(std::get<D>(psiphihatA)) =
106 pRC::identity<pRC::Tensor<T, 1, 1>>();
107
108 // create the objects holding psi_hat/phi_hat for b (without values yet)
109 auto psiphihatb = expand(pRC::makeSeries<pRC::Index, D + 1>(),
110 [](auto const... seq)
111 {
112 return std::tuple(pRC::Tensor<T,
113 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
114 seq),
115 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
116 seq)>()...);
117 });
118
119 // first and last psiphihatb is always 1
120 reshape<1, 1>(std::get<0>(psiphihatb)) =
121 pRC::identity<pRC::Tensor<T, 1, 1>>();
122 reshape<1, 1>(std::get<D>(psiphihatb)) =
123 pRC::identity<pRC::Tensor<T, 1, 1>>();
124
125 // calculate the first psiphiA and psiphib (also with hats) and do
126 // initial QR
127 pRC::range<pRC::Context::CompileTime, 1, D, pRC::Direction::Backwards>(
128 [&](auto const k)
129 {
130 auto const [lambda, l, q] =
131 orthogonalize<pRC::Position::Right>(x.template core<k>());
132
133 x.template core<k>() = q;
134 x.template core<k - 1>() =
135 lambda * contract<2, 0>(x.template core<k - 1>(), l);
136 {
137 // also orthogonalize z
138 auto const [lambda, l, q] =
139 orthogonalize<pRC::Position::Right>(
140 z.template core<k>());
141 z.template core<k>() = q;
142 z.template core<k - 1>() =
143 lambda * contract<2, 0>(z.template core<k - 1>(), l);
144 }
145
146 // k-1 in Alg.
147 std::get<k>(psiphiA) =
148 contract<1, 2, 1, 3>(conj(x.template core<k>()),
149 contract<2, 3, 1, 3>(op.template core<k>(),
150 eval(contract<2, 2>(x.template core<k>(),
151 std::get<k + 1>(psiphiA)))));
152
153 std::get<k>(psiphib) =
154 contract<1, 2, 1, 2>(conj(x.template core<k>()),
155 eval(contract<2, 1>(b.template core<k>(),
156 std::get<k + 1>(psiphib))));
157
158 std::get<k>(psiphihatA) =
159 contract<1, 2, 1, 3>(conj(z.template core<k>()),
160 contract<2, 3, 1, 3>(op.template core<k>(),
161 eval(contract<2, 2>(x.template core<k>(),
162 get<k + 1>(psiphihatA)))));
163
164 std::get<k>(psiphihatb) =
165 contract<1, 2, 1, 2>(conj(z.template core<k>()),
166 eval(contract<2, 1>(b.template core<k>(),
167 get<k + 1>(psiphihatb))));
168 });
169
170 // for now, just use 2 iterations
171 for(pRC::Index i = 0; i < 1; ++i)
172 {
173 // forward sweep
174 pRC::range<pRC::Context::CompileTime, D - 1>(
175 [&](auto const k)
176 {
177 // local A and b
178 auto const sA = pRC::reshape<
179 pRC::RemoveConstReference<
180 decltype(op.template core<k>())>::size(0),
181 4, 4,
182 pRC::RemoveConstReference<
183 decltype(op.template core<k + 1>())>::size(3)>(
184 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
185 op.template core<k>(), op.template core<k + 1>())));
186
187 auto const sB = pRC::reshape<
188 pRC::RemoveConstReference<
189 decltype(b.template core<k>())>::size(0),
190 4,
191 pRC::RemoveConstReference<
192 decltype(b.template core<k + 1>())>::size(2)>(
193 contract<2, 0>(b.template core<k>(),
194 b.template core<k + 1>()));
195
196 auto const sX = pRC::reshape<
197 pRC::RemoveConstReference<
198 decltype(x.template core<k>())>::size(0),
199 4,
200 pRC::RemoveConstReference<
201 decltype(x.template core<k + 1>())>::size(2)>(
202 contract<2, 0>(x.template core<k>(),
203 x.template core<k + 1>()));
204
205 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
206 contract<1, 0>(std::get<k>(psiphiA),
207 eval(contract<3, 1>(sA,
208 std::get<k + 2>(psiphiA))))));
209
210 auto const bk = linearize(contract<1, 0>(
211 std::get<k>(psiphib),
212 eval(contract<2, 1>(sB, std::get<k + 2>(psiphib)))));
213
214 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
215 contract<1, 0>(std::get<k>(psiphihatA),
216 eval(contract<3, 1>(sA,
217 std::get<k + 2>(psiphihatA))))));
218
219 auto const hatbk = linearize(contract<1, 0>(
220 std::get<k>(psiphihatb),
221 eval(contract<2, 1>(sB, std::get<k + 2>(psiphihatb)))));
222
223 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
224 contract<1, 0>(std::get<k>(psiphiA),
225 eval(contract<3, 1>(sA,
226 std::get<k + 2>(psiphihatA))))));
227
228 auto const rbk = linearize(contract<1, 0>(
229 std::get<k>(psiphib),
230 eval(contract<2, 1>(sB, std::get<k + 2>(psiphihatb)))));
231
232 // solve local system
233 pRC::Tensor<T,
234 pRC::RemoveConstReference<
235 decltype(x.template core<k>())>::size(0),
236 2, 2,
237 pRC::RemoveConstReference<
238 decltype(x.template core<k + 1>())>::size(2)>
239 sol;
240 linearize(sol) = pRC::Solver::GMRES<256, 0>()(Ak, bk,
241 linearize(sX), toleranceLocal);
242
243 // left orthogonalization
244 auto const [u, s, v] = svd<Ranks::size(k)>(matricize(sol));
245
246 pRC::Tensor<T,
247 pRC::RemoveConstReference<
248 decltype(x.template core<k>())>::size(0),
249 2, ERanks::size(k)>
250 ex;
251 linearize(
252 slice<pRC::RemoveConstReference<
253 decltype(x.template core<k>())>::size(0),
254 2, Ranks::size(k)>(ex, 0, 0, 0)) = linearize(u);
255
256 if constexpr(ERanks::size(k) - Ranks::size(k) > 0)
257 {
258 constexpr pRC::Size LL = pRC::RemoveConstReference<
259 decltype(x.template core<k>())>::size(0);
260 constexpr pRC::Size ZZ = pRC::RemoveConstReference<
261 decltype(z.template core<k + 1>())>::size(2);
262 pRC::Tensor<T, LL, 2, 2, ZZ> eta;
263 auto const tt = eval(rbk - rAk * linearize(sol));
264 linearize(eta) = tt;
265
266 auto const [zu, zs, zv] =
267 svd<ERanks::size(k) - Ranks::size(k)>(
268 matricize(eta));
269
270 linearize(
271 slice<pRC::RemoveConstReference<
272 decltype(x.template core<k>())>::size(0),
273 2, ERanks::size(k) - Ranks::size(k)>(ex, 0, 0,
274 Ranks::size(k))) =
275 linearize(zu * fromDiagonal(zs));
276 }
277
278 pRC::Tensor<T,
279 pRC::RemoveConstReference<
280 decltype(z.template core<k>())>::size(0),
281 2, 2,
282 pRC::RemoveConstReference<
283 decltype(z.template core<k + 1>())>::size(2)>
284 zz;
285 linearize(zz) = hatbk - hatAk * linearize(sol);
286
287 auto const [zu, zs, zv] =
288 svd<ZRanks::size(k)>(matricize(zz));
289
290 linearize(z.template core<k>()) = linearize(zu);
291
292 pRC::Tensor<T, Ranks::size(k), 2,
293 pRC::RemoveConstReference<
294 decltype(x.template core<k + 1>())>::size(2)>
295 tNext;
296 folding<pRC::Position::Right>(tNext) =
297 fromDiagonal(s) * adjoint(v);
298 pRC::Tensor<T, ERanks::size(k), 2,
299 pRC::RemoveConstReference<
300 decltype(x.template core<k + 1>())>::size(2)>
301 eNext;
302 slice<Ranks::size(k), 2,
303 pRC::RemoveConstReference<
304 decltype(x.template core<k + 1>())>::size(2)>(eNext,
305 0, 0, 0) = tNext;
306 slice<ERanks::size(k) - Ranks::size(k), 2,
307 pRC::RemoveConstReference<
308 decltype(x.template core<k + 1>())>::size(2)>(eNext,
309 Ranks::size(k), 0, 0) = pRC::zero();
310
311 auto const [llambda, qq, rr] =
312 orthogonalize<pRC::Position::Left>(ex);
313 x.template core<k>() = qq;
314 x.template core<k + 1>() =
315 llambda * contract<1, 0>(rr, eNext);
316
317 // calculate psiphiA(k+1)
318 std::get<k + 1>(psiphiA) =
319 contract<1, 0, 0, 3>(conj(x.template core<k>()),
320 contract<2, 0, 0, 3>(op.template core<k>(),
321 eval(contract<0, 2>(x.template core<k>(),
322 std::get<k>(psiphiA)))));
323
324 // calculate psiphib(k+1)
325 std::get<k + 1>(psiphib) =
326 contract<0, 1, 2, 0>(conj(x.template core<k>()),
327 eval(contract<0, 1>(b.template core<k>(),
328 std::get<k>(psiphib))));
329
330 // calculate psiphiAhat(k+1)
331 std::get<k + 1>(psiphihatA) =
332 contract<1, 0, 0, 3>(conj(z.template core<k>()),
333 contract<2, 0, 0, 3>(op.template core<k>(),
334 eval(contract<0, 2>(x.template core<k>(),
335 std::get<k>(psiphihatA)))));
336
337 // calculate psiphibhat(k+1)
338 std::get<k + 1>(psiphihatb) =
339 contract<0, 1, 2, 0>(conj(z.template core<k>()),
340 eval(contract<0, 1>(b.template core<k>(),
341 std::get<k>(psiphihatb))));
342 });
343
344 // backward sweep
345 pRC::range<pRC::Context::CompileTime, 1, D,
346 pRC::Direction::Backwards>(
347 [&](auto const k)
348 {
349 // local A and b
350 auto const sA = pRC::reshape<
351 pRC::RemoveConstReference<
352 decltype(op.template core<k - 1>())>::size(0),
353 4, 4,
354 pRC::RemoveConstReference<
355 decltype(op.template core<k>())>::size(3)>(
356 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
357 op.template core<k - 1>(), op.template core<k>())));
358
359 auto const sB = pRC::reshape<
360 pRC::RemoveConstReference<
361 decltype(b.template core<k - 1>())>::size(0),
362 4,
363 pRC::RemoveConstReference<
364 decltype(b.template core<k>())>::size(2)>(
365 contract<2, 0>(b.template core<k - 1>(),
366 b.template core<k>()));
367
368 auto const sX = pRC::reshape<
369 pRC::RemoveConstReference<
370 decltype(x.template core<k - 1>())>::size(0),
371 4,
372 pRC::RemoveConstReference<
373 decltype(x.template core<k>())>::size(2)>(
374 contract<2, 0>(x.template core<k - 1>(),
375 x.template core<k>()));
376
377 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
378 contract<1, 0>(std::get<k - 1>(psiphiA),
379 eval(contract<3, 1>(sA,
380 std::get<k + 1>(psiphiA))))));
381
382 auto const bk = linearize(contract<1, 0>(
383 std::get<k - 1>(psiphib),
384 eval(contract<2, 1>(sB, std::get<k + 1>(psiphib)))));
385
386 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
387 contract<1, 0>(std::get<k - 1>(psiphihatA),
388 eval(contract<3, 1>(sA,
389 std::get<k + 1>(psiphihatA))))));
390
391 auto const hatbk = linearize(contract<1, 0>(
392 std::get<k - 1>(psiphihatb),
393 eval(contract<2, 1>(sB, std::get<k + 1>(psiphihatb)))));
394
395 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
396 contract<1, 0>(std::get<k - 1>(psiphihatA),
397 eval(contract<3, 1>(sA,
398 std::get<k + 1>(psiphiA))))));
399
400 auto const rbk = linearize(contract<1, 0>(
401 std::get<k - 1>(psiphihatb),
402 eval(contract<2, 1>(sB, std::get<k + 1>(psiphib)))));
403
404 // solve local system
405 pRC::Tensor<T,
406 pRC::RemoveConstReference<
407 decltype(x.template core<k - 1>())>::size(0),
408 2, 2,
409 pRC::RemoveConstReference<
410 decltype(x.template core<k>())>::size(2)>
411 sol;
412 linearize(sol) = pRC::Solver::GMRES<256, 0>()(Ak, bk,
413 linearize(sX), tolerance);
414
415 // right orthogonalization
416 auto const [u, s, v] =
417 svd<Ranks::size(k - 1)>(matricize(sol));
418
419 pRC::Tensor<T, ERanks::size(k - 1), 2,
420 pRC::RemoveConstReference<
421 decltype(x.template core<k>())>::size(2)>
422 ex;
423 linearize(slice<Ranks::size(k - 1), 2,
424 pRC::RemoveConstReference<
425 decltype(x.template core<k>())>::size(2)>(ex, 0, 0,
426 0)) = linearize(adjoint(v));
427
428 if constexpr(ERanks::size(k - 1) - Ranks::size(k - 1) > 0)
429 {
430 constexpr pRC::Size ZZ = pRC::RemoveConstReference<
431 decltype(z.template core<k - 1>())>::size(0);
432 constexpr pRC::Size RR = pRC::RemoveConstReference<
433 decltype(x.template core<k>())>::size(2);
434 pRC::Tensor<T, ZZ, 2, 2, RR> eta;
435 auto const tt = eval(rbk - rAk * linearize(sol));
436 linearize(eta) = tt;
437
438 auto const [zu, zs, zv] =
439 svd<ERanks::size(k - 1) - Ranks::size(k - 1)>(
440 matricize(eta));
441
442 linearize(
443 slice<ERanks::size(k - 1) - Ranks::size(k - 1), 2,
444 pRC::RemoveConstReference<
445 decltype(x.template core<k>())>::size(2)>(
446 ex, Ranks::size(k - 1), 0, 0)) =
447 linearize(fromDiagonal(zs) * adjoint(zv));
448 }
449
450 pRC::Tensor<T,
451 pRC::RemoveConstReference<
452 decltype(z.template core<k - 1>())>::size(0),
453 2, 2,
454 pRC::RemoveConstReference<
455 decltype(z.template core<k>())>::size(2)>
456 zz;
457 linearize(zz) = hatbk - hatAk * linearize(sol);
458
459 auto const [zu, zs, zv] =
460 svd<ZRanks::size(k - 1)>(matricize(zz));
461
462 linearize(z.template core<k>()) = linearize(adjoint(zv));
463
464 pRC::Tensor<T,
465 pRC::RemoveConstReference<
466 decltype(x.template core<k - 1>())>::size(0),
467 2, Ranks::size(k - 1)>
468 tNext;
469 folding<pRC::Position::Left>(tNext) = u * fromDiagonal(s);
470 pRC::Tensor<T,
471 pRC::RemoveConstReference<
472 decltype(x.template core<k - 1>())>::size(0),
473 2, ERanks::size(k - 1)>
474 eNext;
475 slice<pRC::RemoveConstReference<
476 decltype(x.template core<k - 1>())>::size(0),
477 2, Ranks::size(k - 1)>(eNext, 0, 0, 0) = tNext;
478 linearize(
479 slice<pRC::RemoveConstReference<
480 decltype(x.template core<k - 1>())>::size(0),
481 2, ERanks::size(k - 1) - Ranks::size(k - 1)>(eNext,
482 0, 0, Ranks::size(k - 1))) = pRC::zero();
483
484 auto const [llambda, ll, qq] =
485 orthogonalize<pRC::Position::Right>(ex);
486 x.template core<k>() = qq;
487 x.template core<k - 1>() =
488 llambda * contract<2, 0>(eNext, ll);
489
490 // k-1 in Alg.
491 // calculate psiphiA(k)
492 std::get<k>(psiphiA) =
493 contract<1, 2, 1, 3>(conj(x.template core<k>()),
494 contract<2, 3, 1, 3>(op.template core<k>(),
495 eval(contract<2, 2>(x.template core<k>(),
496 std::get<k + 1>(psiphiA)))));
497
498 // calculate psiphib(k)
499 std::get<k>(psiphib) =
500 contract<1, 2, 1, 2>(conj(x.template core<k>()),
501 eval(contract<2, 1>(b.template core<k>(),
502 std::get<k + 1>(psiphib))));
503
504 // calculate psiphihatA(k)
505 std::get<k>(psiphihatA) =
506 contract<1, 2, 1, 3>(conj(z.template core<k>()),
507 contract<2, 3, 1, 3>(op.template core<k>(),
508 eval(contract<2, 2>(x.template core<k>(),
509 std::get<k + 1>(psiphihatA)))));
510
511 // calculate psiphihatb
512 std::get<k>(psiphihatb) =
513 contract<1, 2, 1, 2>(conj(z.template core<k>()),
514 eval(contract<2, 1>(b.template core<k>(),
515 std::get<k + 1>(psiphihatb))));
516 });
517 }
518 return x;
519 }
520}
521
522#endif // cMHN_TT_MAMEN_H
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition: mhn_operator.hpp:23
pRC::Float<> T
Definition: externs_nonTT.hpp:1
Definition: als.hpp:12
auto MAMEN(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance)
Definition: mamen.hpp:16