cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
calculate_pTheta.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_COMMON_CALCULATE_PTHETA_H
4#define cMHN_COMMON_CALCULATE_PTHETA_H
5
8#include <cmhn/tt/als.hpp>
9#include <cmhn/tt/amen.hpp>
10#include <cmhn/tt/mals.hpp>
11#include <cmhn/tt/mamen.hpp>
13#include <cmhn/tt/rals.hpp>
14#include <cmhn/tt/rmals.hpp>
15#include <cmhn/tt/utility.hpp>
16
17#include <prc.hpp>
18
19namespace cMHN
20{
33 template<class T, pRC::Size D>
35 [](auto const... Ns)
36 {
37 return pRC::Tensor<T, Ns...>{};
38 }))
40 {
41 auto const p0 =
43 [](auto const... Ns)
44 {
47 }));
48
49 return nonTT::jacobi(op, p0);
50 }
51
69 template<pRC::Size R, class T, class X, pRC::Size D>
70 X calculatePTheta(TT::MHNOperator<T, D> const &op, X const &pInit,
71 T const &toleranceSolver)
72 {
73 using ModeSizes = decltype(TT::getModeSizes<D>());
74
75 auto const p0 = expand(pRC::makeConstantSequence<pRC::Index, D, 0>(),
76 [](auto const... seq)
77 {
78 return pRC::TensorTrain::Tensor<T, ModeSizes,
80 1>())>::Single(pRC::identity<T>(), seq...);
81 });
82
83 auto pTheta = TT::als<R>(op, p0, toleranceSolver, eval(pInit));
84
85 pTheta /= scalarProduct(pTheta,
87
88 return pTheta;
89 }
90
106 template<pRC::Size R, class T, pRC::Size D>
108 T const &toleranceSolver)
109 {
110 using ModeSizes = decltype(TT::getModeSizes<D>());
111 using Ranks = decltype(TT::getRanks<D, R>());
112
113 // use a random pInit
114 pRC::SeedSequence seq(8, 16);
117 auto pInit = round<Ranks>(
119 dist));
120 pInit /= scalarProduct(pInit,
122
123 return calculatePTheta<R>(op, pInit, toleranceSolver);
124 }
125}
126
127#endif // cMHN_COMMON_CALCULATE_PTHETA_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Class storing an MHN operator represented by a theta matrix (for non TT calculations)
Definition mhn_operator.hpp:24
Definition value.hpp:12
Definition gaussian.hpp:14
Definition seq.hpp:13
Definition subscripts.hpp:21
Definition declarations.hpp:16
Definition tensor.hpp:25
static constexpr auto Single(X &&value, Is const ... indices)
Definition tensor.hpp:51
Definition threefry.hpp:22
pRC::Float<> T
Definition externs_nonTT.hpp:1
constexpr auto getRanks()
Definition utility.hpp:17
auto als(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition als.hpp:16
constexpr auto getModeSizes()
Definition utility.hpp:11
X jacobi(nonTT::MHNOperator< T, D > const &op, X const &b)
Solves the linear system (1-Q)x=b or (1-Q)^Tx=b.
Definition jacobi.hpp:26
Definition calculate_pTheta.hpp:20
decltype(expand(pRC::makeConstantSequence< pRC::Size, D, 2 >(), [](auto const ... Ns) { return pRC::Tensor< T, Ns... >{};})) calculatePTheta(nonTT::MHNOperator< T, D > const &op)
Calculates the vector pTheta given a nonTT MHN Operator.
Definition calculate_pTheta.hpp:39
static constexpr auto unit()
Definition unit.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
std::size_t Size
Definition basics.hpp:31
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13