cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
calculate_pTheta.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_COMMON_CALCULATE_PTHETA_H
4#define cMHN_COMMON_CALCULATE_PTHETA_H
5
8#include <cmhn/tt/als.hpp>
9#include <cmhn/tt/mamen.hpp>
11#include <cmhn/tt/utility.hpp>
12
13#include <prc.hpp>
14
15namespace cMHN
16{
29 template<class T, pRC::Size D>
31 [](auto const... Ns)
32 {
33 return pRC::Tensor<T, Ns...>{};
34 }))
36 {
37 auto const p0 =
39 [](auto const... Ns)
40 {
43 }));
44
45 return nonTT::jacobi(op, p0);
46 }
47
65 template<pRC::Size R, class T, class X, pRC::Size D>
66 X calculatePTheta(TT::MHNOperator<T, D> const &op, X const &pInit,
67 T const &toleranceSolver)
68 {
69 using ModeSizes = decltype(TT::getModeSizes<D>());
70
71 auto const p0 = expand(pRC::makeConstantSequence<pRC::Index, D, 0>(),
72 [](auto const... seq)
73 {
74 return pRC::TensorTrain::Tensor<T, ModeSizes,
76 D - 1, 1>()))>::Single(pRC::identity<T>(), seq...);
77 });
78
79 auto pTheta = TT::ALS<R>(op, p0, toleranceSolver, eval(pInit));
80
81 pTheta /= scalarProduct(pTheta,
83
84 return pTheta;
85 }
86
102 template<pRC::Size R, class T, pRC::Size D>
104 T const &toleranceSolver)
105 {
106 using ModeSizes = decltype(TT::getModeSizes<D>());
107 using Ranks = decltype(TT::getRanks<D, R>());
108
109 // use a random pInit
110 pRC::SeedSequence seq(8, 16);
111 pRC::RandomEngine rng(seq);
113 auto pInit = round<Ranks>(
115 dist));
116 pInit /= scalarProduct(pInit,
118
119 return calculatePTheta<R>(op, pInit, toleranceSolver);
120 }
121}
122
123#endif // cMHN_COMMON_CALCULATE_PTHETA_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Class storing an MHN operator represented by a theta matrix (for non TT calculations)
Definition mhn_operator.hpp:24
Definition type_traits.hpp:57
Definition seq.hpp:13
Definition subscripts.hpp:20
Definition type_traits.hpp:17
Definition tensor.hpp:28
static constexpr auto Single(X &&value, Is const ... indices)
Definition tensor.hpp:78
Definition threefry.hpp:24
pRC::Float<> T
Definition externs_nonTT.hpp:1
X jacobi(nonTT::MHNOperator< T, D > const &op, X const &b)
Solves the linear system (1-Q)x=b or (1-Q)^Tx=b.
Definition jacobi.hpp:26
Definition calculate_pTheta.hpp:16
decltype(expand(pRC::makeConstantSequence< pRC::Size, D, 2 >(), [](auto const ... Ns) { return pRC::Tensor< T, Ns... >{};})) calculatePTheta(nonTT::MHNOperator< T, D > const &op)
Calculates the vector pTheta given a nonTT MHN Operator.
Definition calculate_pTheta.hpp:35
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
static constexpr auto random(RandomEngine &rng, D &distribution)
Definition random.hpp:12
std::size_t Size
Definition type_traits.hpp:20
Sequence< Size, Ns... > Sizes
Definition type_traits.hpp:238
static constexpr auto unit()
Definition unit.hpp:12