3#ifndef cMHN_COMMON_CALCULATE_PTHETA_H
4#define cMHN_COMMON_CALCULATE_PTHETA_H
32 template<
class T,
class X, pRC::Size D>
35 T const &toleranceSolver)
38 eval(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
41 return pRC::Tensor<T, Ns...>::Single(pRC::identity<T>(),
42 pRC::Subscripts<Ns...>(0));
45 auto pTheta = pRC::solve<pRC::Solver::GMRES<>>(op, p0, pInit,
48 pTheta /= pRC::norm<1>(pTheta)();
67 template<
class T, pRC::Size D>
72 pRC::SeedSequence seq(8, 16);
73 pRC::RandomEngine rng(seq);
74 pRC::GaussianDistribution<pRC::Float<>> dist;
75 auto pInit = eval(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
78 return pRC::random<pRC::Tensor<
T, Ns...>>(rng, dist);
80 pInit /= norm<1>(pInit)();
102 template<pRC::Size R,
class T,
class X, pRC::Size D>
105 T const &toleranceSolver)
107 using ModeSizes =
decltype(TT::getModeSizes<D>());
109 auto const p0 = expand(pRC::makeConstantSequence<pRC::Index, D, 0>(),
110 [](
auto const... seq)
112 return pRC::TensorTrain::Tensor<
T, ModeSizes,
113 decltype(pRC::Sizes(pRC::makeConstantSequence<pRC::Size,
114 D - 1, 1>()))>::Single(pRC::identity<T>(), seq...);
117 auto pTheta = TT::ALS<R>(op, p0, toleranceSolver, eval(pInit));
119 pTheta /= scalarProduct(pTheta,
120 pRC::unit<pRC::TensorTrain::Tensor<T, ModeSizes>>())();
140 template<pRC::Size R,
class T, pRC::Size D>
144 using ModeSizes =
decltype(TT::getModeSizes<D>());
145 using Ranks =
decltype(TT::getRanks<D,R>());
148 pRC::SeedSequence seq(8, 16);
149 pRC::RandomEngine rng(seq);
150 pRC::GaussianDistribution<pRC::Float<>> dist;
151 auto pInit = round<Ranks>(
152 pRC::random<pRC::TensorTrain::Tensor<T, ModeSizes, Ranks>>
154 pInit /= scalarProduct(pInit,
155 pRC::unit<pRC::TensorTrain::Tensor<T, ModeSizes>>())();
157 return calculatePTheta<R>(op, pInit, toleranceSolver);
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition: mhn_operator.hpp:23
Class storing an MHN operator represented by a theta matrix (for non TT calculations)
Definition: mhn_operator.hpp:23
pRC::Float<> T
Definition: externs_nonTT.hpp:1
Definition: calculate_pTheta.hpp:15
X calculatePTheta(nonTT::MHNOperator< T, D > const &op, X const &pInit, T const &toleranceSolver)
Calculates the vector pTheta given a nonTT MHN Operator and a tolerance.
Definition: calculate_pTheta.hpp:33