cMHN 1.0
C++ library for learning MHNs with pRC
als.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_ALS_H
4#define cMHN_TT_ALS_H
5
7#include <cmhn/tt/utility.hpp>
8
9#include <prc.hpp>
10
11namespace cMHN::TT
12{
13 template<pRC::Size R,
14 pRC::Operator::Transform OT = pRC::Operator::Transform::None,
15 pRC::Size D, class T, class Tb, class X>
16 auto ALS(MHNOperator<T, D> const &MHNop, Tb const &b,
17 T const &tolerance, X const &x0)
18 {
19 using Ranks = decltype(getRanks<D, R>());
20 using BRanks = typename Tb::Ranks;
21
22 auto const toleranceLocal =
23 tolerance / pRC::sqrt(pRC::identity<pRC::Float<>>(D));
24
25 // get TT for 1-Q
26 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
27 [&](auto const... seq)
28 {
29 return pRC::TensorTrain::fromCores(
30 MHNop.template core<seq>()...);
31 }));
32
33 auto x = x0;
34
35 // create the objects holding psis/phis for A (without values yet)
36 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
37 [](auto const... seq)
38 {
39 return std::tuple(pRC::Tensor<T,
40 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
41 seq),
42 decltype(pRC::Sizes<1>(),
43 pRC::Sizes(pRC::makeConstantSequence<pRC::Size, D - 1,
44 D + 1>()),
45 pRC::Sizes<1>())::size(seq),
46 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
47 seq)>()...);
48 });
49
50 // first and last psiphi is always 1
51 reshape<1, 1>(std::get<0>(psiphiA)) =
52 pRC::identity<pRC::Tensor<T, 1, 1>>();
53 reshape<1, 1>(std::get<D>(psiphiA)) =
54 pRC::identity<pRC::Tensor<T, 1, 1>>();
55
56 // create the objects holding psis/phis for b (without values yet)
57 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
58 [](auto const... seq)
59 {
60 return std::tuple(pRC::Tensor<T,
61 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
62 seq),
63 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
64 seq)>()...);
65 });
66
67 // first and last psiphi is always 1
68 reshape<1, 1>(std::get<0>(psiphib)) =
69 pRC::identity<pRC::Tensor<T, 1, 1>>();
70 reshape<1, 1>(std::get<D>(psiphib)) =
71 pRC::identity<pRC::Tensor<T, 1, 1>>();
72
73 // calculate the first psiphiA and psiphib and do initial QR
74 pRC::range<pRC::Context::CompileTime, 1, D, pRC::Direction::Backwards>(
75 [&](auto const k)
76 {
77 auto const [lambda, l, q] =
78 orthogonalize<pRC::Position::Right>(x.template core<k>());
79
80 x.template core<k>() = q;
81 x.template core<k - 1>() =
82 lambda * contract<2, 0>(x.template core<k - 1>(), l);
83
84 // k-1 in Alg.
85 std::get<k>(psiphiA) =
86 contract<1, 2, 1, 3>(conj(x.template core<k>()),
87 contract<2, 3, 1, 3>(op.template core<k>(),
88 eval(contract<2, 2>(x.template core<k>(),
89 std::get<k + 1>(psiphiA)))));
90
91 std::get<k>(psiphib) =
92 contract<1, 2, 1, 2>(conj(x.template core<k>()),
93 eval(contract<2, 1>(b.template core<k>(),
94 std::get<k + 1>(psiphib))));
95 });
96
97 // for now, just use 2 iterations
98 for(pRC::Index i = 0; i < 2; ++i)
99 {
100 // forward sweep
101 pRC::range<pRC::Context::CompileTime, D - 1>(
102 [&](auto const k)
103 {
104 // local A and b
105 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
106 contract<1, 0>(std::get<k>(psiphiA),
107 eval(contract<3, 1>(op.template core<k>(),
108 std::get<k + 1>(psiphiA))))));
109 auto const bk =
110 linearize(contract<1, 0>(std::get<k>(psiphib),
111 eval(contract<2, 1>(b.template core<k>(),
112 std::get<k + 1>(psiphib)))));
113
114 // solve local system
115 pRC::RemoveConstReference<decltype(x.template core<k>())>
116 sol;
117 linearize(sol) = pRC::Solver::GMRES()(Ak, bk,
118 linearize(x.template core<k>()), toleranceLocal);
119
120 // left orthogonalization
121 auto const [lambda, q, r] =
122 orthogonalize<pRC::Position::Left>(sol);
123 x.template core<k>() = q;
124 x.template core<k + 1>() =
125 lambda * contract<1, 0>(r, x.template core<k + 1>());
126
127 // calculate psiphiA(k+1)
128 std::get<k + 1>(psiphiA) =
129 contract<1, 0, 0, 3>(conj(x.template core<k>()),
130 contract<2, 0, 0, 3>(op.template core<k>(),
131 eval(contract<0, 2>(x.template core<k>(),
132 std::get<k>(psiphiA)))));
133
134 // calculate psiphib(k+1)
135 std::get<k + 1>(psiphib) =
136 contract<0, 1, 2, 0>(conj(x.template core<k>()),
137 eval(contract<0, 1>(b.template core<k>(),
138 std::get<k>(psiphib))));
139 });
140
141 // backward sweep
142 pRC::range<pRC::Context::CompileTime, 1, D,
143 pRC::Direction::Backwards>(
144 [&](auto const k)
145 {
146 // local A and b
147 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
148 contract<1, 0>(std::get<k>(psiphiA),
149 eval(contract<3, 1>(op.template core<k>(),
150 std::get<k + 1>(psiphiA))))));
151 auto const bk =
152 linearize(contract<1, 0>(std::get<k>(psiphib),
153 eval(contract<2, 1>(b.template core<k>(),
154 std::get<k + 1>(psiphib)))));
155
156 // solve local system
157 pRC::RemoveConstReference<decltype(x.template core<k>())>
158 sol;
159 linearize(sol) = pRC::Solver::GMRES()(Ak, bk,
160 linearize(x.template core<k>()), toleranceLocal);
161
162 // right orthogonalization
163 auto const [lambda, l, q] =
164 orthogonalize<pRC::Position::Right>(sol);
165 x.template core<k>() = q;
166 x.template core<k - 1>() =
167 lambda * contract<2, 0>(x.template core<k - 1>(), l);
168
169 // k-1 in Alg. (shift due to dimensions of psiphi objects
170 // calculate psiphiA(k)
171 std::get<k>(psiphiA) =
172 contract<1, 2, 1, 3>(conj(x.template core<k>()),
173 contract<2, 3, 1, 3>(op.template core<k>(),
174 eval(contract<2, 2>(x.template core<k>(),
175 std::get<k + 1>(psiphiA)))));
176
177 // calculate psiphib(k)
178 std::get<k>(psiphib) =
179 contract<1, 2, 1, 2>(conj(x.template core<k>()),
180 eval(contract<2, 1>(b.template core<k>(),
181 std::get<k + 1>(psiphib))));
182 });
183 }
184 return x;
185 }
186
187 template<pRC::Size R,
188 pRC::Operator::Transform OT = pRC::Operator::Transform::None,
189 pRC::Size D, class T, class Tb>
190 auto ALS(MHNOperator<T, D> const &MHNop, Tb const &b,
191 T const &tolerance)
192 {
193 using ModeSizes = decltype(getModeSizes<D>());
194 using Ranks = decltype(getRanks<D, R>());
195
196 // start from a random TT with norm 1
197 pRC::SeedSequence seq(8, 16);
198 pRC::RandomEngine rng(seq);
199 pRC::GaussianDistribution<pRC::Float<>> dist;
200 auto x = round<Ranks>(
201 pRC::random<pRC::TensorTrain::Tensor<T, ModeSizes, Ranks>>(rng,
202 dist));
203 x = x / norm(x);
204
205 return ALS<R, OT>(MHNop, b, tolerance, x);
206 }
207}
208
209#endif // cMHN_TT_ALS_H
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition: mhn_operator.hpp:23
pRC::Float<> T
Definition: externs_nonTT.hpp:1
Definition: als.hpp:12
auto ALS(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition: als.hpp:16