14 pRC::Operator::Transform OT = pRC::Operator::Transform::None,
15 pRC::Size
D,
class T,
class Tb>
19 using ModeSizes =
decltype(getModeSizes<D>());
20 using Ranks =
decltype(getRanks<D, R>());
21 using ERanks =
decltype(getRanks<D, R + pRC::Size(4)>());
22 using ZRanks =
decltype(getRanks<D, pRC::Size(4)>());
23 using BRanks =
typename Tb::Ranks;
25 auto const toleranceLocal =
26 tolerance / pRC::sqrt(pRC::identity<pRC::Float<>>(
D));
29 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
30 [&](
auto const... seq)
32 return pRC::TensorTrain::fromCores(
33 MHNop.template core<seq>()...);
37 pRC::SeedSequence seq(8, 16);
38 pRC::RandomEngine rng(seq);
39 pRC::GaussianDistribution<pRC::Float<>> dist;
40 auto x = round<ERanks>(
41 pRC::random<pRC::TensorTrain::Tensor<T, ModeSizes, ERanks>>(rng,
46 auto z = round<ZRanks>(
47 pRC::random<pRC::TensorTrain::Tensor<T, ModeSizes, ZRanks>>(rng,
52 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
55 return std::tuple(pRC::Tensor<
T,
56 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
58 decltype(pRC::Sizes<1>(),
59 pRC::makeConstantSequence<pRC::Size, D - 1, D + 1>(),
60 pRC::Sizes<1>())::size(seq),
61 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
66 reshape<1, 1>(std::get<0>(psiphiA)) =
67 pRC::identity<pRC::Tensor<T, 1, 1>>();
68 reshape<1, 1>(std::get<D>(psiphiA)) =
69 pRC::identity<pRC::Tensor<T, 1, 1>>();
72 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
75 return std::tuple(pRC::Tensor<
T,
76 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
78 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
83 reshape<1, 1>(std::get<0>(psiphib)) =
84 pRC::identity<pRC::Tensor<T, 1, 1>>();
85 reshape<1, 1>(std::get<D>(psiphib)) =
86 pRC::identity<pRC::Tensor<T, 1, 1>>();
89 auto psiphihatA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
92 return std::tuple(pRC::Tensor<
T,
93 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
95 decltype(pRC::Sizes<1>(),
96 pRC::makeConstantSequence<pRC::Size, D - 1, D + 1>(),
97 pRC::Sizes<1>())::size(seq),
98 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
103 reshape<1, 1>(std::get<0>(psiphihatA)) =
104 pRC::identity<pRC::Tensor<T, 1, 1>>();
105 reshape<1, 1>(std::get<D>(psiphihatA)) =
106 pRC::identity<pRC::Tensor<T, 1, 1>>();
109 auto psiphihatb = expand(pRC::makeSeries<pRC::Index, D + 1>(),
110 [](
auto const... seq)
112 return std::tuple(pRC::Tensor<
T,
113 decltype(pRC::Sizes<1>(), ZRanks(), pRC::Sizes<1>())::size(
115 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
120 reshape<1, 1>(std::get<0>(psiphihatb)) =
121 pRC::identity<pRC::Tensor<T, 1, 1>>();
122 reshape<1, 1>(std::get<D>(psiphihatb)) =
123 pRC::identity<pRC::Tensor<T, 1, 1>>();
127 pRC::range<pRC::Context::CompileTime, 1, D, pRC::Direction::Backwards>(
130 auto const [lambda, l, q] =
131 orthogonalize<pRC::Position::Right>(x.template core<k>());
133 x.template core<k>() = q;
134 x.template core<k - 1>() =
135 lambda * contract<2, 0>(x.template core<k - 1>(), l);
138 auto const [lambda, l, q] =
139 orthogonalize<pRC::Position::Right>(
140 z.template core<k>());
141 z.template core<k>() = q;
142 z.template core<k - 1>() =
143 lambda * contract<2, 0>(z.template core<k - 1>(), l);
147 std::get<k>(psiphiA) =
148 contract<1, 2, 1, 3>(conj(x.template core<k>()),
149 contract<2, 3, 1, 3>(op.template core<k>(),
150 eval(contract<2, 2>(x.template core<k>(),
151 std::get<k + 1>(psiphiA)))));
153 std::get<k>(psiphib) =
154 contract<1, 2, 1, 2>(conj(x.template core<k>()),
155 eval(contract<2, 1>(b.template core<k>(),
156 std::get<k + 1>(psiphib))));
158 std::get<k>(psiphihatA) =
159 contract<1, 2, 1, 3>(conj(z.template core<k>()),
160 contract<2, 3, 1, 3>(op.template core<k>(),
161 eval(contract<2, 2>(x.template core<k>(),
162 get<k + 1>(psiphihatA)))));
164 std::get<k>(psiphihatb) =
165 contract<1, 2, 1, 2>(conj(z.template core<k>()),
166 eval(contract<2, 1>(b.template core<k>(),
167 get<k + 1>(psiphihatb))));
171 for(pRC::Index i = 0; i < 1; ++i)
174 pRC::range<pRC::Context::CompileTime,
D - 1>(
178 auto const sA = pRC::reshape<
179 pRC::RemoveConstReference<
180 decltype(op.template core<k>())>::size(0),
182 pRC::RemoveConstReference<
183 decltype(op.template core<k + 1>())>::size(3)>(
184 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
185 op.template core<k>(), op.template core<k + 1>())));
187 auto const sB = pRC::reshape<
188 pRC::RemoveConstReference<
189 decltype(b.template core<k>())>::size(0),
191 pRC::RemoveConstReference<
192 decltype(b.template core<k + 1>())>::size(2)>(
193 contract<2, 0>(b.template core<k>(),
194 b.template core<k + 1>()));
196 auto const sX = pRC::reshape<
197 pRC::RemoveConstReference<
198 decltype(x.template core<k>())>::size(0),
200 pRC::RemoveConstReference<
201 decltype(x.template core<k + 1>())>::size(2)>(
202 contract<2, 0>(x.template core<k>(),
203 x.template core<k + 1>()));
205 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
206 contract<1, 0>(std::get<k>(psiphiA),
207 eval(contract<3, 1>(sA,
208 std::get<k + 2>(psiphiA))))));
210 auto const bk = linearize(contract<1, 0>(
211 std::get<k>(psiphib),
212 eval(contract<2, 1>(sB, std::get<k + 2>(psiphib)))));
214 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
215 contract<1, 0>(std::get<k>(psiphihatA),
216 eval(contract<3, 1>(sA,
217 std::get<k + 2>(psiphihatA))))));
219 auto const hatbk = linearize(contract<1, 0>(
220 std::get<k>(psiphihatb),
221 eval(contract<2, 1>(sB, std::get<k + 2>(psiphihatb)))));
223 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
224 contract<1, 0>(std::get<k>(psiphiA),
225 eval(contract<3, 1>(sA,
226 std::get<k + 2>(psiphihatA))))));
228 auto const rbk = linearize(contract<1, 0>(
229 std::get<k>(psiphib),
230 eval(contract<2, 1>(sB, std::get<k + 2>(psiphihatb)))));
234 pRC::RemoveConstReference<
235 decltype(x.template core<k>())>::size(0),
237 pRC::RemoveConstReference<
238 decltype(x.template core<k + 1>())>::size(2)>
240 linearize(sol) = pRC::Solver::GMRES<256, 0>()(Ak, bk,
241 linearize(sX), toleranceLocal);
244 auto const [u, s, v] = svd<Ranks::size(k)>(matricize(sol));
247 pRC::RemoveConstReference<
248 decltype(x.template core<k>())>::size(0),
252 slice<pRC::RemoveConstReference<
253 decltype(x.template core<k>())>::size(0),
254 2, Ranks::size(k)>(ex, 0, 0, 0)) = linearize(u);
256 if constexpr(ERanks::size(k) - Ranks::size(k) > 0)
258 constexpr pRC::Size LL = pRC::RemoveConstReference<
259 decltype(x.template core<k>())>::size(0);
260 constexpr pRC::Size ZZ = pRC::RemoveConstReference<
261 decltype(z.template core<k + 1>())>::size(2);
262 pRC::Tensor<T, LL, 2, 2, ZZ> eta;
263 auto const tt = eval(rbk - rAk * linearize(sol));
266 auto const [zu, zs, zv] =
267 svd<ERanks::size(k) - Ranks::size(k)>(
271 slice<pRC::RemoveConstReference<
272 decltype(x.template core<k>())>::size(0),
273 2, ERanks::size(k) - Ranks::size(k)>(ex, 0, 0,
275 linearize(zu * fromDiagonal(zs));
279 pRC::RemoveConstReference<
280 decltype(z.template core<k>())>::size(0),
282 pRC::RemoveConstReference<
283 decltype(z.template core<k + 1>())>::size(2)>
285 linearize(zz) = hatbk - hatAk * linearize(sol);
287 auto const [zu, zs, zv] =
288 svd<ZRanks::size(k)>(matricize(zz));
290 linearize(z.template core<k>()) = linearize(zu);
292 pRC::Tensor<
T, Ranks::size(k), 2,
293 pRC::RemoveConstReference<
294 decltype(x.template core<k + 1>())>::size(2)>
296 folding<pRC::Position::Right>(tNext) =
297 fromDiagonal(s) * adjoint(v);
298 pRC::Tensor<
T, ERanks::size(k), 2,
299 pRC::RemoveConstReference<
300 decltype(x.template core<k + 1>())>::size(2)>
302 slice<Ranks::size(k), 2,
303 pRC::RemoveConstReference<
304 decltype(x.template core<k + 1>())>::size(2)>(eNext,
306 slice<ERanks::size(k) - Ranks::size(k), 2,
307 pRC::RemoveConstReference<
308 decltype(x.template core<k + 1>())>::size(2)>(eNext,
309 Ranks::size(k), 0, 0) = pRC::zero();
311 auto const [llambda, qq, rr] =
312 orthogonalize<pRC::Position::Left>(ex);
313 x.template core<k>() = qq;
314 x.template core<k + 1>() =
315 llambda * contract<1, 0>(rr, eNext);
318 std::get<k + 1>(psiphiA) =
319 contract<1, 0, 0, 3>(conj(x.template core<k>()),
320 contract<2, 0, 0, 3>(op.template core<k>(),
321 eval(contract<0, 2>(x.template core<k>(),
322 std::get<k>(psiphiA)))));
325 std::get<k + 1>(psiphib) =
326 contract<0, 1, 2, 0>(conj(x.template core<k>()),
327 eval(contract<0, 1>(b.template core<k>(),
328 std::get<k>(psiphib))));
331 std::get<k + 1>(psiphihatA) =
332 contract<1, 0, 0, 3>(conj(z.template core<k>()),
333 contract<2, 0, 0, 3>(op.template core<k>(),
334 eval(contract<0, 2>(x.template core<k>(),
335 std::get<k>(psiphihatA)))));
338 std::get<k + 1>(psiphihatb) =
339 contract<0, 1, 2, 0>(conj(z.template core<k>()),
340 eval(contract<0, 1>(b.template core<k>(),
341 std::get<k>(psiphihatb))));
345 pRC::range<pRC::Context::CompileTime, 1,
D,
346 pRC::Direction::Backwards>(
350 auto const sA = pRC::reshape<
351 pRC::RemoveConstReference<
352 decltype(op.template core<k - 1>())>::size(0),
354 pRC::RemoveConstReference<
355 decltype(op.template core<k>())>::size(3)>(
356 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
357 op.template core<k - 1>(), op.template core<k>())));
359 auto const sB = pRC::reshape<
360 pRC::RemoveConstReference<
361 decltype(b.template core<k - 1>())>::size(0),
363 pRC::RemoveConstReference<
364 decltype(b.template core<k>())>::size(2)>(
365 contract<2, 0>(b.template core<k - 1>(),
366 b.template core<k>()));
368 auto const sX = pRC::reshape<
369 pRC::RemoveConstReference<
370 decltype(x.template core<k - 1>())>::size(0),
372 pRC::RemoveConstReference<
373 decltype(x.template core<k>())>::size(2)>(
374 contract<2, 0>(x.template core<k - 1>(),
375 x.template core<k>()));
377 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
378 contract<1, 0>(std::get<k - 1>(psiphiA),
379 eval(contract<3, 1>(sA,
380 std::get<k + 1>(psiphiA))))));
382 auto const bk = linearize(contract<1, 0>(
383 std::get<k - 1>(psiphib),
384 eval(contract<2, 1>(sB, std::get<k + 1>(psiphib)))));
386 auto const hatAk = matricize(permute<0, 2, 4, 1, 3, 5>(
387 contract<1, 0>(std::get<k - 1>(psiphihatA),
388 eval(contract<3, 1>(sA,
389 std::get<k + 1>(psiphihatA))))));
391 auto const hatbk = linearize(contract<1, 0>(
392 std::get<k - 1>(psiphihatb),
393 eval(contract<2, 1>(sB, std::get<k + 1>(psiphihatb)))));
395 auto const rAk = matricize(permute<0, 2, 4, 1, 3, 5>(
396 contract<1, 0>(std::get<k - 1>(psiphihatA),
397 eval(contract<3, 1>(sA,
398 std::get<k + 1>(psiphiA))))));
400 auto const rbk = linearize(contract<1, 0>(
401 std::get<k - 1>(psiphihatb),
402 eval(contract<2, 1>(sB, std::get<k + 1>(psiphib)))));
406 pRC::RemoveConstReference<
407 decltype(x.template core<k - 1>())>::size(0),
409 pRC::RemoveConstReference<
410 decltype(x.template core<k>())>::size(2)>
412 linearize(sol) = pRC::Solver::GMRES<256, 0>()(Ak, bk,
413 linearize(sX), tolerance);
416 auto const [u, s, v] =
417 svd<Ranks::size(k - 1)>(matricize(sol));
419 pRC::Tensor<
T, ERanks::size(k - 1), 2,
420 pRC::RemoveConstReference<
421 decltype(x.template core<k>())>::size(2)>
423 linearize(slice<Ranks::size(k - 1), 2,
424 pRC::RemoveConstReference<
425 decltype(x.template core<k>())>::size(2)>(ex, 0, 0,
426 0)) = linearize(adjoint(v));
428 if constexpr(ERanks::size(k - 1) - Ranks::size(k - 1) > 0)
430 constexpr pRC::Size ZZ = pRC::RemoveConstReference<
431 decltype(z.template core<k - 1>())>::size(0);
432 constexpr pRC::Size RR = pRC::RemoveConstReference<
433 decltype(x.template core<k>())>::size(2);
434 pRC::Tensor<T, ZZ, 2, 2, RR> eta;
435 auto const tt = eval(rbk - rAk * linearize(sol));
438 auto const [zu, zs, zv] =
439 svd<ERanks::size(k - 1) - Ranks::size(k - 1)>(
443 slice<ERanks::size(k - 1) - Ranks::size(k - 1), 2,
444 pRC::RemoveConstReference<
445 decltype(x.template core<k>())>::size(2)>(
446 ex, Ranks::size(k - 1), 0, 0)) =
447 linearize(fromDiagonal(zs) * adjoint(zv));
451 pRC::RemoveConstReference<
452 decltype(z.template core<k - 1>())>::size(0),
454 pRC::RemoveConstReference<
455 decltype(z.template core<k>())>::size(2)>
457 linearize(zz) = hatbk - hatAk * linearize(sol);
459 auto const [zu, zs, zv] =
460 svd<ZRanks::size(k - 1)>(matricize(zz));
462 linearize(z.template core<k>()) = linearize(adjoint(zv));
465 pRC::RemoveConstReference<
466 decltype(x.template core<k - 1>())>::size(0),
467 2, Ranks::size(k - 1)>
469 folding<pRC::Position::Left>(tNext) = u * fromDiagonal(s);
471 pRC::RemoveConstReference<
472 decltype(x.template core<k - 1>())>::size(0),
473 2, ERanks::size(k - 1)>
475 slice<pRC::RemoveConstReference<
476 decltype(x.template core<k - 1>())>::size(0),
477 2, Ranks::size(k - 1)>(eNext, 0, 0, 0) = tNext;
479 slice<pRC::RemoveConstReference<
480 decltype(x.template core<k - 1>())>::size(0),
481 2, ERanks::size(k - 1) - Ranks::size(k - 1)>(eNext,
482 0, 0, Ranks::size(k - 1))) = pRC::zero();
484 auto const [llambda, ll, qq] =
485 orthogonalize<pRC::Position::Right>(ex);
486 x.template core<k>() = qq;
487 x.template core<k - 1>() =
488 llambda * contract<2, 0>(eNext, ll);
492 std::get<k>(psiphiA) =
493 contract<1, 2, 1, 3>(conj(x.template core<k>()),
494 contract<2, 3, 1, 3>(op.template core<k>(),
495 eval(contract<2, 2>(x.template core<k>(),
496 std::get<k + 1>(psiphiA)))));
499 std::get<k>(psiphib) =
500 contract<1, 2, 1, 2>(conj(x.template core<k>()),
501 eval(contract<2, 1>(b.template core<k>(),
502 std::get<k + 1>(psiphib))));
505 std::get<k>(psiphihatA) =
506 contract<1, 2, 1, 3>(conj(z.template core<k>()),
507 contract<2, 3, 1, 3>(op.template core<k>(),
508 eval(contract<2, 2>(x.template core<k>(),
509 std::get<k + 1>(psiphihatA)))));
512 std::get<k>(psiphihatb) =
513 contract<1, 2, 1, 2>(conj(z.template core<k>()),
514 eval(contract<2, 1>(b.template core<k>(),
515 std::get<k + 1>(psiphihatb))));
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition: mhn_operator.hpp:23
pRC::Float<> T
Definition: externs_nonTT.hpp:1
auto MAMEN(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance)
Definition: mamen.hpp:16