cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
als.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_ALS_H
4#define cMHN_TT_ALS_H
5
7#include <cmhn/tt/utility.hpp>
8
9#include <prc.hpp>
10
11namespace cMHN::TT
12{
13 template<pRC::Size R,
15 pRC::Size D, class T, class Tb, class X>
16 auto als(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance,
17 X const &x0)
18 {
19 using Ranks = decltype(getRanks<D, R>());
20 using BRanks = typename Tb::Ranks;
21
22 auto const toleranceLocal =
24
25 // get TT for 1-Q
26 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
27 [&](auto const... seq)
28 {
30 MHNop.template core<seq>()...);
31 }));
32
33 auto x = x0;
34
35 // create the objects holding psis/phis for A (without values yet)
36 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
37 [](auto const... seq)
38 {
39 return std::tuple(pRC::Tensor<T,
40 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
41 seq),
42 decltype(pRC::Sizes<1>(),
44 pRC::Sizes<1>())::size(seq),
45 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
46 seq)>()...);
47 });
48
49 // first and last psiphi is always 1
50 reshape<1, 1>(std::get<0>(psiphiA)) =
52 reshape<1, 1>(std::get<D>(psiphiA)) =
54
55 // create the objects holding psis/phis for b (without values yet)
56 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
57 [](auto const... seq)
58 {
59 return std::tuple(pRC::Tensor<T,
60 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
61 seq),
62 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
63 seq)>()...);
64 });
65
66 // first and last psiphi is always 1
67 reshape<1, 1>(std::get<0>(psiphib)) =
69 reshape<1, 1>(std::get<D>(psiphib)) =
71
72 // calculate the first psiphiA and psiphib and do initial QR
74 [&](auto const k)
75 {
76 auto const [l, q] =
77 orthogonalize<pRC::Position::Right>(x.template core<k>());
78
79 x.template core<k>() = q;
80 x.template core<k - 1>() =
81 contract<2, 0>(x.template core<k - 1>(), l);
82
83 // k-1 in Alg.
84 std::get<k>(psiphiA) =
85 contract<1, 2, 1, 3>(conj(x.template core<k>()),
86 contract<2, 3, 1, 3>(op.template core<k>(),
87 eval(contract<2, 2>(x.template core<k>(),
88 std::get<k + 1>(psiphiA)))));
89
90 std::get<k>(psiphib) =
91 contract<1, 2, 1, 2>(conj(x.template core<k>()),
92 eval(contract<2, 1>(b.template core<k>(),
93 std::get<k + 1>(psiphib))));
94 });
95
96 // for now, just use 2 iterations
97 for(pRC::Index i = 0; i < 2; ++i)
98 {
99 // forward sweep
101 [&](auto const k)
102 {
103 // local A and b
104 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
105 contract<1, 0>(std::get<k>(psiphiA),
106 eval(contract<3, 1>(op.template core<k>(),
107 std::get<k + 1>(psiphiA))))));
108 auto const bk =
109 linearize(contract<1, 0>(std::get<k>(psiphib),
110 eval(contract<2, 1>(b.template core<k>(),
111 std::get<k + 1>(psiphib)))));
112
113 // solve local system
114 pRC::RemoveConstReference<decltype(x.template core<k>())>
115 sol;
116 linearize(sol) = pRC::Solver::GMRES()(Ak, bk,
117 linearize(x.template core<k>()), toleranceLocal);
118
119 // left orthogonalization
120 auto const [q, r] = orthogonalize<pRC::Position::Left>(sol);
121 x.template core<k>() = q;
122 x.template core<k + 1>() =
123 contract<1, 0>(r, x.template core<k + 1>());
124
125 // calculate psiphiA(k+1)
126 std::get<k + 1>(psiphiA) =
127 contract<1, 0, 0, 3>(conj(x.template core<k>()),
128 contract<2, 0, 0, 3>(op.template core<k>(),
129 eval(contract<0, 2>(x.template core<k>(),
130 std::get<k>(psiphiA)))));
131
132 // calculate psiphib(k+1)
133 std::get<k + 1>(psiphib) =
134 contract<0, 1, 2, 0>(conj(x.template core<k>()),
135 eval(contract<0, 1>(b.template core<k>(),
136 std::get<k>(psiphib))));
137 });
138
139 // backward sweep
142 [&](auto const k)
143 {
144 // local A and b
145 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
146 contract<1, 0>(std::get<k>(psiphiA),
147 eval(contract<3, 1>(op.template core<k>(),
148 std::get<k + 1>(psiphiA))))));
149 auto const bk =
150 linearize(contract<1, 0>(std::get<k>(psiphib),
151 eval(contract<2, 1>(b.template core<k>(),
152 std::get<k + 1>(psiphib)))));
153
154 // solve local system
155 pRC::RemoveConstReference<decltype(x.template core<k>())>
156 sol;
157 linearize(sol) = pRC::Solver::GMRES()(Ak, bk,
158 linearize(x.template core<k>()), toleranceLocal);
159
160 // right orthogonalization
161 auto const [l, q] =
162 orthogonalize<pRC::Position::Right>(sol);
163 x.template core<k>() = q;
164 x.template core<k - 1>() =
165 contract<2, 0>(x.template core<k - 1>(), l);
166
167 // calculate psiphiA(k)
168 // k-1 in Alg. (shift due to dimensions of psiphi objects)
169 std::get<k>(psiphiA) =
170 contract<1, 2, 1, 3>(conj(x.template core<k>()),
171 contract<2, 3, 1, 3>(op.template core<k>(),
172 eval(contract<2, 2>(x.template core<k>(),
173 std::get<k + 1>(psiphiA)))));
174
175 // calculate psiphib(k)
176 std::get<k>(psiphib) =
177 contract<1, 2, 1, 2>(conj(x.template core<k>()),
178 eval(contract<2, 1>(b.template core<k>(),
179 std::get<k + 1>(psiphib))));
180 });
181 }
182 return x;
183 }
184
185 template<pRC::Size R,
187 pRC::Size D, class T, class Tb>
188 auto als(MHNOperator<T, D> const &MHNop, Tb const &b, T const &tolerance)
189 {
190 using ModeSizes = decltype(getModeSizes<D>());
191 using Ranks = decltype(getRanks<D, R>());
192
193 // start from a random TT with norm 1
194 pRC::SeedSequence seq(8, 16);
197 auto x = round<Ranks>(
199 dist));
200 x = x / norm(x);
201
202 return als<R, OT>(MHNop, b, tolerance, x);
203 }
204}
205
206#endif // cMHN_TT_ALS_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Definition value.hpp:12
Definition gaussian.hpp:14
Definition seq.hpp:13
Definition sequence.hpp:29
Definition gmres.hpp:22
Definition declarations.hpp:16
Definition tensor.hpp:25
Definition threefry.hpp:22
pRC::Float<> T
Definition externs_nonTT.hpp:1
int i
Definition gmock-matchers-comparisons_test.cc:603
int x
Definition gmock-matchers-containers_test.cc:376
Definition als.hpp:12
constexpr auto getRanks()
Definition utility.hpp:17
auto als(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition als.hpp:16
constexpr auto getModeSizes()
Definition utility.hpp:11
Transform
Definition transform.hpp:9
static constexpr auto fromCores(Xs &&...cores)
Definition from_cores.hpp:14
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
static constexpr auto makeSeries()
Definition sequence.hpp:390
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13