cMHN 1.0
C++ library for learning MHNs with pRC
calculate_pTheta.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_COMMON_CALCULATE_PTHETA_H
4#define cMHN_COMMON_CALCULATE_PTHETA_H
5
7#include <cmhn/tt/als.hpp>
8#include <cmhn/tt/mamen.hpp>
10#include <cmhn/tt/utility.hpp>
11
12#include <prc.hpp>
13
14namespace cMHN
15{
32 template<class T, class X, pRC::Size D>
34 nonTT::MHNOperator<T, D> const &op, X const &pInit,
35 T const &toleranceSolver)
36 {
37 auto const p0 =
38 eval(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
39 [](auto const... Ns)
40 {
41 return pRC::Tensor<T, Ns...>::Single(pRC::identity<T>(),
42 pRC::Subscripts<Ns...>(0));
43 }));
44
45 auto pTheta = pRC::solve<pRC::Solver::GMRES<>>(op, p0, pInit,
46 toleranceSolver);
47
48 pTheta /= pRC::norm<1>(pTheta)();
49
50 return pTheta;
51 }
52
67 template<class T, pRC::Size D>
69 nonTT::MHNOperator<T, D> const &op, T const &toleranceSolver)
70 {
71 // use a random pInit
72 pRC::SeedSequence seq(8, 16);
73 pRC::RandomEngine rng(seq);
74 pRC::GaussianDistribution<pRC::Float<>> dist;
75 auto pInit = eval(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
76 [&](auto const... Ns)
77 {
78 return pRC::random<pRC::Tensor<T, Ns...>>(rng, dist);
79 }));
80 pInit /= norm<1>(pInit)();
81
82 return calculatePTheta(op, pInit, toleranceSolver);
83 }
84
102 template<pRC::Size R, class T, class X, pRC::Size D>
104 TT::MHNOperator<T, D> const &op, X const &pInit,
105 T const &toleranceSolver)
106 {
107 using ModeSizes = decltype(TT::getModeSizes<D>());
108
109 auto const p0 = expand(pRC::makeConstantSequence<pRC::Index, D, 0>(),
110 [](auto const... seq)
111 {
112 return pRC::TensorTrain::Tensor<T, ModeSizes,
113 decltype(pRC::Sizes(pRC::makeConstantSequence<pRC::Size,
114 D - 1, 1>()))>::Single(pRC::identity<T>(), seq...);
115 });
116
117 auto pTheta = TT::ALS<R>(op, p0, toleranceSolver, eval(pInit));
118
119 pTheta /= scalarProduct(pTheta,
120 pRC::unit<pRC::TensorTrain::Tensor<T, ModeSizes>>())();
121
122 return pTheta;
123 }
124
140 template<pRC::Size R, class T, pRC::Size D>
142 TT::MHNOperator<T, D> const &op, T const &toleranceSolver)
143 {
144 using ModeSizes = decltype(TT::getModeSizes<D>());
145 using Ranks = decltype(TT::getRanks<D,R>());
146
147 // use a random pInit
148 pRC::SeedSequence seq(8, 16);
149 pRC::RandomEngine rng(seq);
150 pRC::GaussianDistribution<pRC::Float<>> dist;
151 auto pInit = round<Ranks>(
152 pRC::random<pRC::TensorTrain::Tensor<T, ModeSizes, Ranks>>
153 (rng, dist));
154 pInit /= scalarProduct(pInit,
155 pRC::unit<pRC::TensorTrain::Tensor<T, ModeSizes>>())();
156
157 return calculatePTheta<R>(op, pInit, toleranceSolver);
158 }
159}
160
161#endif // cMHN_COMMON_CALCULATE_PTHETA_H
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition: mhn_operator.hpp:23
Class storing an MHN operator represented by a theta matrix (for non TT calculations)
Definition: mhn_operator.hpp:23
pRC::Float<> T
Definition: externs_nonTT.hpp:1
Definition: calculate_pTheta.hpp:15
X calculatePTheta(nonTT::MHNOperator< T, D > const &op, X const &pInit, T const &toleranceSolver)
Calculates the vector pTheta given a nonTT MHN Operator and a tolerance.
Definition: calculate_pTheta.hpp:33