cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
mals.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_MALS_H
4#define cMHN_TT_MALS_H
5
7#include <cmhn/tt/utility.hpp>
8
9#include <prc.hpp>
10
11namespace cMHN::TT
12{
13 template<pRC::Size R,
15 pRC::Size D, class T, class Tb, class X>
16 auto mals(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance,
17 X const &x0)
18 {
19 using Ranks = decltype(getRanks<D, R>());
20 using BRanks = typename Tb::Ranks;
21
22 auto const toleranceLocal =
24
25 // get TT for 1-Q
26 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
27 [&](auto const... seq)
28 {
30 MHNop.template core<seq>()...);
31 }));
32
33 auto x = x0;
34
35 // create the objects holding psis/phis for A (without values yet)
36 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
37 [](auto const... seq)
38 {
39 return std::tuple(pRC::Tensor<T,
40 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
41 seq),
42 decltype(pRC::Sizes<1>(),
44 pRC::Sizes<1>())::size(seq),
45 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
46 seq)>()...);
47 });
48
49 // first and last psiphi is always 1
50 reshape<1, 1>(std::get<0>(psiphiA)) =
52 reshape<1, 1>(std::get<D>(psiphiA)) =
54
55 // create the objects holding psis/phis for b (without values yet)
56 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
57 [](auto const... seq)
58 {
59 return std::tuple(pRC::Tensor<T,
60 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
61 seq),
62 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
63 seq)>()...);
64 });
65
66 // first and last psiphi is always 1
67 reshape<1, 1>(std::get<0>(psiphib)) =
69 reshape<1, 1>(std::get<D>(psiphib)) =
71
72 // calculate the first psiphiA and psiphib and do initial QR
74 [&](auto const k)
75 {
76 auto const [l, q] =
77 orthogonalize<pRC::Position::Right>(x.template core<k>());
78
79 x.template core<k>() = q;
80 x.template core<k - 1>() =
81 contract<2, 0>(x.template core<k - 1>(), l);
82
83 // k-1 in Alg.
84 std::get<k>(psiphiA) =
85 contract<1, 2, 1, 3>(conj(x.template core<k>()),
86 contract<2, 3, 1, 3>(op.template core<k>(),
87 eval(contract<2, 2>(x.template core<k>(),
88 std::get<k + 1>(psiphiA)))));
89
90 std::get<k>(psiphib) =
91 contract<1, 2, 1, 2>(conj(x.template core<k>()),
92 eval(contract<2, 1>(b.template core<k>(),
93 std::get<k + 1>(psiphib))));
94 });
95
96 // for now, just use 2 iterations
97 for(pRC::Index i = 0; i < 2; ++i)
98 {
99 // forward sweep
101 [&](auto const k)
102 {
103 // local two-site A, b and X
104 auto const sA =
106 decltype(op.template core<k>())>::size(0),
107 4, 4,
109 decltype(op.template core<k + 1>())>::size(3)>(
110 permute<0, 1, 3, 2, 4, 5>(
111 contract<3, 0>(op.template core<k>(),
112 op.template core<k + 1>())));
113 auto const sB =
115 decltype(b.template core<k>())>::size(0),
116 4,
118 decltype(b.template core<k + 1>())>::size(2)>(
119 contract<2, 0>(b.template core<k>(),
120 b.template core<k + 1>()));
121 auto const sX =
123 decltype(x.template core<k>())>::size(0),
124 4,
126 decltype(x.template core<k + 1>())>::size(2)>(
127 contract<2, 0>(x.template core<k>(),
128 x.template core<k + 1>()));
129
130 // local A and b
131 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
132 contract<1, 0>(get<k>(psiphiA),
133 eval(contract<3, 1>(sA, get<k + 2>(psiphiA))))));
134 auto const bk = linearize(contract<1, 0>(get<k>(psiphib),
135 eval(contract<2, 1>(sB, get<k + 2>(psiphib)))));
136
137 // solve local system
140 decltype(x.template core<k>())>::size(0),
141 2, 2,
143 decltype(x.template core<k + 1>())>::size(2)>
144 sol;
145 linearize(sol) = pRC::Solver::GMRES<32, 0>()(Ak, bk,
146 linearize(sX), tolerance);
147
148 // split two-site solution into two cores
149 auto const [u, s, v] = svd<Ranks::size(k)>(matricize(sol));
150 folding<pRC::Position::Left>(x.template core<k>()) = u;
151 folding<pRC::Position::Right>(x.template core<k + 1>()) =
152 fromDiagonal(s) * adjoint(v);
153
154 // calculate psiphiA(k+1)
155 get<k + 1>(psiphiA) =
156 contract<1, 0, 0, 3>(conj(x.template core<k>()),
157 contract<2, 0, 0, 3>(op.template core<k>(),
158 eval(contract<0, 2>(x.template core<k>(),
159 get<k>(psiphiA)))));
160
161 // calculate psiphib(k+1)
162 get<k + 1>(psiphib) =
163 contract<0, 1, 2, 0>(conj(x.template core<k>()),
164 eval(contract<0, 1>(b.template core<k>(),
165 get<k>(psiphib))));
166 });
167
168 // backward sweep
171 [&](auto const k)
172 {
173 // local two-site A, b and X
174 auto const sA = reshape<
176 decltype(op.template core<k - 1>())>::size(0),
177 4, 4,
179 decltype(op.template core<k>())>::size(3)>(
180 permute<0, 1, 3, 2, 4, 5>(contract<3, 0>(
181 op.template core<k - 1>(), op.template core<k>())));
182 auto const sB = reshape<
184 decltype(b.template core<k - 1>())>::size(0),
185 4,
187 decltype(b.template core<k>())>::size(2)>(
188 contract<2, 0>(b.template core<k - 1>(),
189 b.template core<k>()));
190 auto const sX = reshape<
192 decltype(x.template core<k - 1>())>::size(0),
193 4,
195 decltype(x.template core<k>())>::size(2)>(
196 contract<2, 0>(x.template core<k - 1>(),
197 x.template core<k>()));
198
199 // local A and b
200 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
201 contract<1, 0>(get<k - 1>(psiphiA),
202 eval(contract<3, 1>(sA, get<k + 1>(psiphiA))))));
203 auto const bk =
204 linearize(contract<1, 0>(get<k - 1>(psiphib),
205 eval(contract<2, 1>(sB, get<k + 1>(psiphib)))));
206
207 // solve local system
210 decltype(x.template core<k - 1>())>::size(0),
211 2, 2,
213 decltype(x.template core<k>())>::size(2)>
214 sol;
215 linearize(sol) = pRC::Solver::GMRES<32, 0>()(Ak, bk,
216 linearize(sX), tolerance);
217
218 // split two-site solution into two cores
219 auto const [u, s, v] =
220 svd<Ranks::size(k - 1)>(matricize(sol));
221 folding<pRC::Position::Right>(x.template core<k>()) =
222 adjoint(v);
223 folding<pRC::Position::Left>(x.template core<k - 1>()) =
224 u * fromDiagonal(s);
225
226 // calculate psiphiA(k)
227 // k-1 in Alg. (shift due to dimensions of psiphi objects)
228 get<k>(psiphiA) =
229 contract<1, 2, 1, 3>(conj(x.template core<k>()),
230 contract<2, 3, 1, 3>(op.template core<k>(),
231 eval(contract<2, 2>(x.template core<k>(),
232 get<k + 1>(psiphiA)))));
233
234 // calculate psiphib(k)
235 get<k>(psiphib) =
236 contract<1, 2, 1, 2>(conj(x.template core<k>()),
237 eval(contract<2, 1>(b.template core<k>(),
238 get<k + 1>(psiphib))));
239 });
240 }
241 return x;
242 }
243
244 template<pRC::Size R,
246 pRC::Size D, class T, class Tb>
247 auto mals(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance)
248 {
249 using ModeSizes = decltype(getModeSizes<D>());
250 using Ranks = decltype(getRanks<D, R>());
251
252 // start from a random TT with norm 1
253 pRC::SeedSequence seq(8, 16);
256 auto x = round<Ranks>(
258 dist));
259 x = x / norm(x);
260
261 return mals<R, OT>(MHNop, b, tolerance, x);
262 }
263}
264
265#endif // cMHN_TT_MALS_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Definition value.hpp:12
Definition gaussian.hpp:14
Definition seq.hpp:13
Definition sequence.hpp:29
Definition gmres.hpp:22
Definition declarations.hpp:16
Definition tensor.hpp:25
Definition threefry.hpp:22
pRC::Float<> T
Definition externs_nonTT.hpp:1
int i
Definition gmock-matchers-comparisons_test.cc:603
int x
Definition gmock-matchers-containers_test.cc:376
Definition als.hpp:12
auto mals(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition mals.hpp:16
constexpr auto getRanks()
Definition utility.hpp:17
constexpr auto getModeSizes()
Definition utility.hpp:11
Transform
Definition transform.hpp:9
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:14
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
static constexpr auto makeSeries()
Definition sequence.hpp:390
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13