cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
rals.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_RALS_H
4#define cMHN_TT_RALS_H
5
7#include <cmhn/tt/utility.hpp>
8
9#include <prc.hpp>
10
11namespace cMHN::TT
12{
13 template<pRC::Size R,
15 pRC::Size D, class T, class Tb, class X>
16 auto rals(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance,
17 X const &x0)
18 {
19 using ModeSizes = decltype(getModeSizes<D>());
20
21 using Ranks = decltype(getRanks<D, R>());
22 using ERanks = decltype(getRanks<D, R + pRC::Size(4)>());
23 using BRanks = typename Tb::Ranks;
24
25 auto const toleranceLocal =
27
28 // get TT for 1-Q
29 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
30 [&](auto const... seq)
31 {
33 MHNop.template core<seq>()...);
34 }));
35
38 [&](auto const k)
39 {
41 decltype(x0.template core<k>())>::size(0),
42 2,
44 decltype(x0.template core<k>())>::size(2)>(
45 x.template core<k>(), 0, 0, 0) = x0.template core<k>();
46 });
47
48 // random setup
49 pRC::SeedSequence seq(8, 16);
52
53 // create the objects holding psis/phis for A (without values yet)
54 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
55 [](auto const... seq)
56 {
57 return std::tuple(pRC::Tensor<T,
58 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
59 seq),
60 decltype(pRC::Sizes<1>(),
62 pRC::Sizes<1>())::size(seq),
63 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
64 seq)>()...);
65 });
66
67 // first and last psiphi is always 1
68 reshape<1, 1>(std::get<0>(psiphiA)) =
70 reshape<1, 1>(std::get<D>(psiphiA)) =
72
73 // create the objects holding psis/phis for b (without values yet)
74 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
75 [](auto const... seq)
76 {
77 return std::tuple(pRC::Tensor<T,
78 decltype(pRC::Sizes<1>(), ERanks(), pRC::Sizes<1>())::size(
79 seq),
80 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
81 seq)>()...);
82 });
83
84 // first and last psiphi is always 1
85 reshape<1, 1>(std::get<0>(psiphib)) =
87 reshape<1, 1>(std::get<D>(psiphib)) =
89
90 // calculate the first psiphiA and psiphib and do initial QR
92 [&](auto const k)
93 {
94 auto const [l, q] =
95 orthogonalize<pRC::Position::Right>(x.template core<k>());
96
97 x.template core<k>() = q;
98 x.template core<k - 1>() =
99 contract<2, 0>(x.template core<k - 1>(), l);
100
101 // k-1 in Alg.
102 std::get<k>(psiphiA) =
103 contract<1, 2, 1, 3>(conj(x.template core<k>()),
104 contract<2, 3, 1, 3>(op.template core<k>(),
105 eval(contract<2, 2>(x.template core<k>(),
106 std::get<k + 1>(psiphiA)))));
107
108 std::get<k>(psiphib) =
109 contract<1, 2, 1, 2>(conj(x.template core<k>()),
110 eval(contract<2, 1>(b.template core<k>(),
111 std::get<k + 1>(psiphib))));
112 });
113
114 // for now, just use 2 iterations
115 for(pRC::Index i = 0; i < 2; ++i)
116 {
117 // forward sweep
119 [&](auto const k)
120 {
121 // local A and b
122 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
123 contract<1, 0>(std::get<k>(psiphiA),
124 eval(contract<3, 1>(op.template core<k>(),
125 std::get<k + 1>(psiphiA))))));
126 auto const bk =
127 linearize(contract<1, 0>(std::get<k>(psiphib),
128 eval(contract<2, 1>(b.template core<k>(),
129 std::get<k + 1>(psiphib)))));
130
131 // solve local system
132 pRC::RemoveConstReference<decltype(x.template core<k>())>
133 sol;
134 linearize(sol) = pRC::Solver::GMRES()(Ak, bk,
135 linearize(x.template core<k>()), toleranceLocal);
136
137 // left orthogonalization
138 auto const [q, r] =
139 truncate<Ranks::size(k), pRC::Position::Right>(sol);
140
141 // add enrichment to current core
144 decltype(x.template core<k>())>::size(0),
145 2, ERanks::size(k)>
146 ex;
148 decltype(x.template core<k>())>::size(0),
149 2, Ranks::size(k)>(ex, 0, 0, 0) = q;
151 decltype(x.template core<k>())>::size(0),
152 2, ERanks::size(k) - Ranks::size(k)>(ex, 0, 0,
153 Ranks::size(k)) = random<pRC::Tensor<T,
155 decltype(x.template core<k>())>::size(0),
156 2, ERanks::size(k) - Ranks::size(k)>>(rng, dist);
157
158 // add enrichment to next core
159 auto tNext = contract<1, 0>(r, x.template core<k + 1>());
160 pRC::Tensor<T, ERanks::size(k), 2,
162 decltype(x.template core<k + 1>())>::size(2)>
163 enext;
164 slice<Ranks::size(k), 2,
166 decltype(x.template core<k + 1>())>::size(2)>(enext,
167 0, 0, 0) = tNext;
168 slice<ERanks::size(k) - Ranks::size(k), 2,
170 decltype(x.template core<k + 1>())>::size(2)>(enext,
171 Ranks::size(k), 0, 0) = pRC::zero();
172
173 // re-orthogonalization
174 auto const [qq, rr] =
175 orthogonalize<pRC::Position::Left>(ex);
176 x.template core<k>() = qq;
177 x.template core<k + 1>() = contract<1, 0>(rr, enext);
178
179 // calculate psiphiA(k+1)
180 std::get<k + 1>(psiphiA) =
181 contract<1, 0, 0, 3>(conj(x.template core<k>()),
182 contract<2, 0, 0, 3>(op.template core<k>(),
183 eval(contract<0, 2>(x.template core<k>(),
184 std::get<k>(psiphiA)))));
185
186 // calculate psiphib(k+1)
187 std::get<k + 1>(psiphib) =
188 contract<0, 1, 2, 0>(conj(x.template core<k>()),
189 eval(contract<0, 1>(b.template core<k>(),
190 std::get<k>(psiphib))));
191 });
192
193 // backward sweep
196 [&](auto const k)
197 {
198 // local A and b
199 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
200 contract<1, 0>(std::get<k>(psiphiA),
201 eval(contract<3, 1>(op.template core<k>(),
202 std::get<k + 1>(psiphiA))))));
203 auto const bk =
204 linearize(contract<1, 0>(std::get<k>(psiphib),
205 eval(contract<2, 1>(b.template core<k>(),
206 std::get<k + 1>(psiphib)))));
207
208 // solve local system
209 pRC::RemoveConstReference<decltype(x.template core<k>())>
210 sol;
211 linearize(sol) = pRC::Solver::GMRES()(Ak, bk,
212 linearize(x.template core<k>()), toleranceLocal);
213
214 // right orthogonalization
215 auto const [l, q] =
216 truncate<Ranks::size(k - 1), pRC::Position::Left>(sol);
217
218 // add enrichment to current core
219 pRC::Tensor<T, ERanks::size(k - 1), 2,
221 decltype(x.template core<k>())>::size(2)>
222 ex;
223 slice<Ranks::size(k - 1), 2,
225 decltype(x.template core<k>())>::size(2)>(ex, 0, 0,
226 0) = q;
227 slice<ERanks::size(k - 1) - Ranks::size(k - 1), 2,
229 decltype(x.template core<k>())>::size(2)>(ex,
230 Ranks::size(k - 1), 0, 0) = random<pRC::Tensor<T,
231 ERanks::size(k - 1) - Ranks::size(k - 1), 2,
233 decltype(x.template core<k>())>::size(2)>>(rng,
234 dist);
235
236 // add enrichment to next core
237 auto tNext = contract<2, 0>(x.template core<k - 1>(), l);
240 decltype(x.template core<k - 1>())>::size(0),
241 2, ERanks::size(k - 1)>
242 enext;
244 decltype(x.template core<k - 1>())>::size(0),
245 2, Ranks::size(k - 1)>(enext, 0, 0, 0) = tNext;
246 linearize(
248 decltype(x.template core<k - 1>())>::size(0),
249 2, ERanks::size(k - 1) - Ranks::size(k - 1)>(enext,
250 0, 0, Ranks::size(k - 1))) = pRC::zero();
251
252 // re-orthogonalize
253 auto const [ll, qq] =
254 orthogonalize<pRC::Position::Right>(ex);
255 x.template core<k>() = qq;
256 x.template core<k - 1>() = contract<2, 0>(enext, ll);
257
258 // calculate psiphiA(k)
259 // k-1 in Alg. (shift due to dimensions of psiphi objects)
260 std::get<k>(psiphiA) =
261 contract<1, 2, 1, 3>(conj(x.template core<k>()),
262 contract<2, 3, 1, 3>(op.template core<k>(),
263 eval(contract<2, 2>(x.template core<k>(),
264 std::get<k + 1>(psiphiA)))));
265
266 // calculate psiphib(k)
267 std::get<k>(psiphib) =
268 contract<1, 2, 1, 2>(conj(x.template core<k>()),
269 eval(contract<2, 1>(b.template core<k>(),
270 std::get<k + 1>(psiphib))));
271 });
272 }
273 return round<Ranks>(x);
274 }
275
276 template<pRC::Size R,
278 pRC::Size D, class T, class Tb>
279 auto rals(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance)
280 {
281 using ModeSizes = decltype(getModeSizes<D>());
282 using Ranks = decltype(getRanks<D, R>());
283
284 // start from a random TT with norm 1
285 pRC::SeedSequence seq(8, 16);
288 auto x = round<Ranks>(
290 dist));
291 x = x / norm(x);
292
293 return rals<R, OT>(MHNop, b, tolerance, x);
294 }
295}
296
297#endif // cMHN_TT_RALS_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Definition value.hpp:12
Definition gaussian.hpp:14
Definition seq.hpp:13
Definition sequence.hpp:29
Definition gmres.hpp:22
Definition declarations.hpp:16
Definition tensor.hpp:25
Definition threefry.hpp:22
pRC::Float<> T
Definition externs_nonTT.hpp:1
int i
Definition gmock-matchers-comparisons_test.cc:603
int x
Definition gmock-matchers-containers_test.cc:376
Definition als.hpp:12
constexpr auto getRanks()
Definition utility.hpp:17
constexpr auto getModeSizes()
Definition utility.hpp:11
auto rals(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition rals.hpp:16
Transform
Definition transform.hpp:9
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:14
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
static constexpr auto makeSeries()
Definition sequence.hpp:390
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13