14 pRC::Operator::Transform OT = pRC::Operator::Transform::None,
15 pRC::Size
D,
class T,
class Tb,
class X>
17 T const &tolerance, X
const &x0)
19 using Ranks =
decltype(getRanks<D, R>());
20 using BRanks =
typename Tb::Ranks;
22 auto const toleranceLocal =
23 tolerance / pRC::sqrt(pRC::identity<pRC::Float<>>(
D));
26 auto const op = transform<OT>(expand(pRC::makeSeries<pRC::Index, D>(),
27 [&](
auto const... seq)
29 return pRC::TensorTrain::fromCores(
30 MHNop.template core<seq>()...);
36 auto psiphiA = expand(pRC::makeSeries<pRC::Index, D + 1>(),
39 return std::tuple(pRC::Tensor<
T,
40 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
42 decltype(pRC::Sizes<1>(),
43 pRC::Sizes(pRC::makeConstantSequence<pRC::Size,
D - 1,
45 pRC::Sizes<1>())::size(seq),
46 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
51 reshape<1, 1>(std::get<0>(psiphiA)) =
52 pRC::identity<pRC::Tensor<T, 1, 1>>();
53 reshape<1, 1>(std::get<D>(psiphiA)) =
54 pRC::identity<pRC::Tensor<T, 1, 1>>();
57 auto psiphib = expand(pRC::makeSeries<pRC::Index, D + 1>(),
60 return std::tuple(pRC::Tensor<
T,
61 decltype(pRC::Sizes<1>(), Ranks(), pRC::Sizes<1>())::size(
63 decltype(pRC::Sizes<1>(), BRanks(), pRC::Sizes<1>())::size(
68 reshape<1, 1>(std::get<0>(psiphib)) =
69 pRC::identity<pRC::Tensor<T, 1, 1>>();
70 reshape<1, 1>(std::get<D>(psiphib)) =
71 pRC::identity<pRC::Tensor<T, 1, 1>>();
74 pRC::range<pRC::Context::CompileTime, 1, D, pRC::Direction::Backwards>(
77 auto const [lambda, l, q] =
78 orthogonalize<pRC::Position::Right>(x.template core<k>());
80 x.template core<k>() = q;
81 x.template core<k - 1>() =
82 lambda * contract<2, 0>(x.template core<k - 1>(), l);
85 std::get<k>(psiphiA) =
86 contract<1, 2, 1, 3>(conj(x.template core<k>()),
87 contract<2, 3, 1, 3>(op.template core<k>(),
88 eval(contract<2, 2>(x.template core<k>(),
89 std::get<k + 1>(psiphiA)))));
91 std::get<k>(psiphib) =
92 contract<1, 2, 1, 2>(conj(x.template core<k>()),
93 eval(contract<2, 1>(b.template core<k>(),
94 std::get<k + 1>(psiphib))));
98 for(pRC::Index i = 0; i < 2; ++i)
101 pRC::range<pRC::Context::CompileTime,
D - 1>(
105 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
106 contract<1, 0>(std::get<k>(psiphiA),
107 eval(contract<3, 1>(op.template core<k>(),
108 std::get<k + 1>(psiphiA))))));
110 linearize(contract<1, 0>(std::get<k>(psiphib),
111 eval(contract<2, 1>(b.template core<k>(),
112 std::get<k + 1>(psiphib)))));
115 pRC::RemoveConstReference<
decltype(x.template core<k>())>
117 linearize(sol) = pRC::Solver::GMRES()(Ak, bk,
118 linearize(x.template core<k>()), toleranceLocal);
121 auto const [lambda, q, r] =
122 orthogonalize<pRC::Position::Left>(sol);
123 x.template core<k>() = q;
124 x.template core<k + 1>() =
125 lambda * contract<1, 0>(r, x.template core<k + 1>());
128 std::get<k + 1>(psiphiA) =
129 contract<1, 0, 0, 3>(conj(x.template core<k>()),
130 contract<2, 0, 0, 3>(op.template core<k>(),
131 eval(contract<0, 2>(x.template core<k>(),
132 std::get<k>(psiphiA)))));
135 std::get<k + 1>(psiphib) =
136 contract<0, 1, 2, 0>(conj(x.template core<k>()),
137 eval(contract<0, 1>(b.template core<k>(),
138 std::get<k>(psiphib))));
142 pRC::range<pRC::Context::CompileTime, 1,
D,
143 pRC::Direction::Backwards>(
147 auto const Ak = matricize(permute<0, 2, 4, 1, 3, 5>(
148 contract<1, 0>(std::get<k>(psiphiA),
149 eval(contract<3, 1>(op.template core<k>(),
150 std::get<k + 1>(psiphiA))))));
152 linearize(contract<1, 0>(std::get<k>(psiphib),
153 eval(contract<2, 1>(b.template core<k>(),
154 std::get<k + 1>(psiphib)))));
157 pRC::RemoveConstReference<
decltype(x.template core<k>())>
159 linearize(sol) = pRC::Solver::GMRES()(Ak, bk,
160 linearize(x.template core<k>()), toleranceLocal);
163 auto const [lambda, l, q] =
164 orthogonalize<pRC::Position::Right>(sol);
165 x.template core<k>() = q;
166 x.template core<k - 1>() =
167 lambda * contract<2, 0>(x.template core<k - 1>(), l);
171 std::get<k>(psiphiA) =
172 contract<1, 2, 1, 3>(conj(x.template core<k>()),
173 contract<2, 3, 1, 3>(op.template core<k>(),
174 eval(contract<2, 2>(x.template core<k>(),
175 std::get<k + 1>(psiphiA)))));
178 std::get<k>(psiphib) =
179 contract<1, 2, 1, 2>(conj(x.template core<k>()),
180 eval(contract<2, 1>(b.template core<k>(),
181 std::get<k + 1>(psiphib))));
187 template<pRC::Size R,
188 pRC::Operator::Transform OT = pRC::Operator::Transform::None,
189 pRC::Size
D,
class T,
class Tb>
193 using ModeSizes =
decltype(getModeSizes<D>());
194 using Ranks =
decltype(getRanks<D, R>());
197 pRC::SeedSequence seq(8, 16);
198 pRC::RandomEngine rng(seq);
199 pRC::GaussianDistribution<pRC::Float<>> dist;
200 auto x = round<Ranks>(
201 pRC::random<pRC::TensorTrain::Tensor<T, ModeSizes, Ranks>>(rng,
205 return ALS<R, OT>(MHNop, b, tolerance, x);
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition: mhn_operator.hpp:23
pRC::Float<> T
Definition: externs_nonTT.hpp:1
auto ALS(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition: als.hpp:16