cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
rmals.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_RMALS_H
4#define cMHN_TT_RMALS_H
5
7#include <cmhn/tt/utility.hpp>
8
9#include <prc.hpp>
10
11namespace cMHN::TT
12{
13 template<pRC::Size R,
15 pRC::Size D, class T, class Tb, class X>
16 auto rmals(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance,
17 X const &x0)
18 {
19 using ModeSizes = decltype(getModeSizes<D>());
20
21 using Ranks = decltype(getRanks<D, R>());
22 using ERanks = decltype(getRanks<D, R + pRC::Size(4)>());
23 using BRanks = typename Tb::Ranks;
24
25 auto const toleranceLocal =
27
28 // get TT for 1-Q
29 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
30 [&](auto const... seq)
31 {
33 MHNop.template core<seq>()...);
34 }));
35
38 [&](auto const k)
39 {
41 decltype(x0.template core<k>())>::size(0),
42 2,
44 decltype(x0.template core<k>())>::size(2)>(
45 x.template core<k>(), 0, 0, 0) = x0.template core<k>();
46 });
47
48 // random setup
49 pRC::SeedSequence seq(8, 16);
52
53 // create the objects holding psis/phis for A (without values yet)
54 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
55 [](auto const... seq)
56 {
57 return std::tuple(pRC::Tensor<T,
58 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
59 seq),
60 decltype(pRC::Sizes<1>(),
62 pRC::Sizes<1>())::size(seq),
63 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
64 seq)>()...);
65 });
66
67 // first and last psiphi is always 1
68 reshape<1, 1>(std::get<0>(psiphiA)) =
70 reshape<1, 1>(std::get<D>(psiphiA)) =
72
73 // create the objects holding psis/phis for b (without values yet)
74 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
75 [](auto const... seq)
76 {
77 return std::tuple(pRC::Tensor<T,
78 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
79 seq),
80 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
81 seq)>()...);
82 });
83
84 // first and last psiphi is always 1
85 reshape<1, 1>(std::get<0>(psiphib)) =
87 reshape<1, 1>(std::get<D>(psiphib)) =
89
90 // calculate the first psiphiA and psiphib and do initial QR
92 [&](auto const k)
93 {
94 auto const [l, q] =
95 orthogonalize<pRC::Position::Right>(x.template core<k>());
96
97 x.template core<k>() = q;
98 x.template core<k - 1>() =
99 contract<2, 0>(x.template core<k - 1>(), l);
100
101 // k-1 in Alg.
102 std::get<k>(psiphiA) =
103 contract<1, 2, 1, 3>(conj(x.template core<k>()),
104 contract<2, 3, 1, 3>(op.template core<k>(),
105 eval(contract<2, 2>(x.template core<k>(),
106 std::get<k + 1>(psiphiA)))));
107
108 std::get<k>(psiphib) =
109 contract<1, 2, 1, 2>(conj(x.template core<k>()),
110 eval(contract<2, 1>(b.template core<k>(),
111 std::get<k + 1>(psiphib))));
112 });
113
114 // for now, just use 2 iterations
115 for(pRC::Index i = 0; i < 2; ++i)
116 {
117 // forward sweep
119 [&](auto const k)
120 {
121 // local two-site A, b and x
122 auto const sA =
124 decltype(op.template core<k>())>::size(0),
125 4, 4,
127 decltype(op.template core<k + 1>())>::size(3)>(
128 permute<0, 1, 3, 2, 4, 5>(
129 contract<3, 0>(op.template core<k>(),
130 op.template core<k + 1>())));
131 auto const sB =
133 decltype(b.template core<k>())>::size(0),
134 4,
136 decltype(b.template core<k + 1>())>::size(2)>(
137 contract<2, 0>(b.template core<k>(),
138 b.template core<k + 1>()));
139 auto const sX =
141 decltype(x.template core<k>())>::size(0),
142 4,
144 decltype(x.template core<k + 1>())>::size(2)>(
145 contract<2, 0>(x.template core<k>(),
146 x.template core<k + 1>()));
147
148 // local A and b
149 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
150 contract<1, 0>(get<k>(psiphiA),
151 eval(contract<3, 1>(sA, get<k + 2>(psiphiA))))));
152 auto const bk = linearize(contract<1, 0>(get<k>(psiphib),
153 eval(contract<2, 1>(sB, get<k + 2>(psiphib)))));
154
155 // solve local system
158 decltype(x.template core<k>())>::size(0),
159 2, 2,
161 decltype(x.template core<k + 1>())>::size(2)>
162 sol;
163 linearize(sol) = pRC::Solver::GMRES<32, 0>()(Ak, bk,
164 linearize(sX), tolerance);
165
166 // split two-site solution into two cores
167 auto const [u, s, v] = svd<Ranks::size(k)>(matricize(sol));
168
169 // add enrichment to current core
172 decltype(x.template core<k>())>::size(0),
173 2, ERanks::size(k)>
174 ex;
175 linearize(
177 decltype(x.template core<k>())>::size(0),
178 2, Ranks::size(k)>(ex, 0, 0, 0)) = linearize(u);
180 decltype(x.template core<k>())>::size(0),
181 2, ERanks::size(k) - Ranks::size(k)>(ex, 0, 0,
182 Ranks::size(k)) = random<pRC::Tensor<T,
184 decltype(x.template core<k>())>::size(0),
185 2, ERanks::size(k) - Ranks::size(k)>>(rng, dist);
186
187 // add enrichment to next core
188 pRC::Tensor<T, Ranks::size(k), 2,
190 decltype(x.template core<k + 1>())>::size(2)>
191 tNext;
192 folding<pRC::Position::Right>(tNext) =
193 fromDiagonal(s) * adjoint(v);
194 pRC::Tensor<T, ERanks::size(k), 2,
196 decltype(x.template core<k + 1>())>::size(2)>
197 enext;
198 slice<Ranks::size(k), 2,
200 decltype(x.template core<k + 1>())>::size(2)>(enext,
201 0, 0, 0) = tNext;
202 slice<ERanks::size(k) - Ranks::size(k), 2,
204 decltype(x.template core<k + 1>())>::size(2)>(enext,
205 Ranks::size(k), 0, 0) = pRC::zero();
206
207 // orthogonalize
208 auto const [qq, rr] =
209 orthogonalize<pRC::Position::Left>(ex);
210 x.template core<k>() = qq;
211 x.template core<k + 1>() = contract<1, 0>(rr, enext);
212
213 // calculate psiphiA(k+1)
214 get<k + 1>(psiphiA) =
215 contract<1, 0, 0, 3>(conj(x.template core<k>()),
216 contract<2, 0, 0, 3>(op.template core<k>(),
217 eval(contract<0, 2>(x.template core<k>(),
218 get<k>(psiphiA)))));
219
220 // calculate psiphib(k+1)
221 get<k + 1>(psiphib) =
222 contract<0, 1, 2, 0>(conj(x.template core<k>()),
223 eval(contract<0, 1>(b.template core<k>(),
224 get<k>(psiphib))));
225 });
226
227 // backward sweep
230 [&](auto const k)
231 {
232 // local two-site A, b and x
233 auto const sA = reshape<
235 decltype(op.template core<k - 1>())>::size(0),
236 4, 4,
238 decltype(op.template core<k>())>::size(3)>(
239 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
240 op.template core<k - 1>(), op.template core<k>())));
241 auto const sB = reshape<
243 decltype(b.template core<k - 1>())>::size(0),
244 4,
246 decltype(b.template core<k>())>::size(2)>(
247 contract<2, 0>(b.template core<k - 1>(),
248 b.template core<k>()));
249 auto const sX = reshape<
251 decltype(x.template core<k - 1>())>::size(0),
252 4,
254 decltype(x.template core<k>())>::size(2)>(
255 contract<2, 0>(x.template core<k - 1>(),
256 x.template core<k>()));
257
258 // local A and b
259 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
260 contract<1, 0>(get<k - 1>(psiphiA),
261 eval(contract<3, 1>(sA, get<k + 1>(psiphiA))))));
262 auto const bk =
263 linearize(contract<1, 0>(get<k - 1>(psiphib),
264 eval(contract<2, 1>(sB, get<k + 1>(psiphib)))));
265
266 // solve local system
269 decltype(x.template core<k - 1>())>::size(0),
270 2, 2,
272 decltype(x.template core<k>())>::size(2)>
273 sol;
274 linearize(sol) = pRC::Solver::GMRES<32, 0>()(Ak, bk,
275 linearize(sX), tolerance);
276
277 // split two-site solution into two cores
278 auto const [u, s, v] =
279 svd<Ranks::size(k - 1)>(matricize(sol));
280
281 // add enrichment to current core
282 pRC::Tensor<T, ERanks::size(k - 1), 2,
284 decltype(x.template core<k>())>::size(2)>
285 ex;
286 linearize(slice<Ranks::size(k - 1), 2,
288 decltype(x.template core<k>())>::size(2)>(ex, 0, 0,
289 0)) = linearize(adjoint(v));
290 slice<ERanks::size(k - 1) - Ranks::size(k - 1), 2,
292 decltype(x.template core<k>())>::size(2)>(ex,
293 Ranks::size(k - 1), 0, 0) = random<pRC::Tensor<T,
294 ERanks::size(k - 1) - Ranks::size(k - 1), 2,
296 decltype(x.template core<k>())>::size(2)>>(rng,
297 dist);
298
299 // add enrichment to next core
302 decltype(x.template core<k - 1>())>::size(0),
303 2, Ranks::size(k - 1)>
304 tNext;
305 folding<pRC::Position::Left>(tNext) = u * fromDiagonal(s);
308 decltype(x.template core<k - 1>())>::size(0),
309 2, ERanks::size(k - 1)>
310 enext;
312 decltype(x.template core<k - 1>())>::size(0),
313 2, Ranks::size(k - 1)>(enext, 0, 0, 0) = tNext;
314 linearize(
316 decltype(x.template core<k - 1>())>::size(0),
317 2, ERanks::size(k - 1) - Ranks::size(k - 1)>(enext,
318 0, 0, Ranks::size(k - 1))) = pRC::zero();
319
320 // orthogonalize
321 auto const [ll, qq] =
322 orthogonalize<pRC::Position::Right>(ex);
323 x.template core<k>() = qq;
324 x.template core<k - 1>() = contract<2, 0>(enext, ll);
325
326 // calculate psiphiA(k)
327 // k -1 in Alg.
328 get<k>(psiphiA) =
329 contract<1, 2, 1, 3>(conj(x.template core<k>()),
330 contract<2, 3, 1, 3>(op.template core<k>(),
331 eval(contract<2, 2>(x.template core<k>(),
332 get<k + 1>(psiphiA)))));
333
334 // calculate psiphib(k)
335 get<k>(psiphib) =
336 contract<1, 2, 1, 2>(conj(x.template core<k>()),
337 eval(contract<2, 1>(b.template core<k>(),
338 get<k + 1>(psiphib))));
339 });
340 }
341 return round<Ranks>(x);
342 }
343
344 template<pRC::Size R,
346 pRC::Size D, class T, class Tb>
347 auto rmals(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance)
348 {
349 using ModeSizes = decltype(getModeSizes<D>());
350 using Ranks = decltype(getRanks<D, R>());
351
352 // start from a random TT with norm 1
353 pRC::SeedSequence seq(8, 16);
356 auto x = round<Ranks>(
358 dist));
359 x = x / norm(x);
360
361 return rmals<R, OT>(MHNop, b, tolerance, x);
362 }
363}
364
365#endif // cMHN_TT_RMALS_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Definition value.hpp:12
Definition gaussian.hpp:14
Definition seq.hpp:13
Definition sequence.hpp:29
Definition gmres.hpp:22
Definition declarations.hpp:16
Definition tensor.hpp:25
Definition threefry.hpp:22
pRC::Float<> T
Definition externs_nonTT.hpp:1
int i
Definition gmock-matchers-comparisons_test.cc:603
int x
Definition gmock-matchers-containers_test.cc:376
Definition als.hpp:12
constexpr auto getRanks()
Definition utility.hpp:17
auto rmals(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition rmals.hpp:16
constexpr auto getModeSizes()
Definition utility.hpp:11
Transform
Definition transform.hpp:9
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:14
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
static constexpr auto makeSeries()
Definition sequence.hpp:390
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13