cMHN 1.0
C++ library for learning MHNs with pRC
calculate_score_and_gradient.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_COMMON_CALCULATE_SCORE_AND_GRADIENT_H
4#define cMHN_COMMON_CALCULATE_SCORE_AND_GRADIENT_H
5
6#include <map>
7#include <tuple>
8
11#include <cmhn/common/score.hpp>
14
15#include <prc.hpp>
16
17namespace cMHN
18{
33 template<class T, pRC::Size D, class S>
34 std::tuple<T, pRC::Tensor<T, D, D>>
36 std::map<S, T> const &pD, cMHN::Score<T> const &Score,
37 cMHN::Regulator<T, D> const &Regulator, T const &toleranceSolverP = 1e-4,
38 T const &toleranceSolverQ = 1e-4)
39 {
40 T score = pRC::zero();
41 pRC::Tensor<T, D, D> g = pRC::zero();
42
43 // use a random pInit
44 pRC::SeedSequence seq(8, 16);
45 pRC::RandomEngine rng(seq);
46 pRC::GaussianDistribution<pRC::Float<>> dist;
47 auto pInit = eval(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
48 [&](auto const... Ns)
49 {
50 return pRC::random<pRC::Tensor<T, Ns...>>(rng, dist);
51 }));
52 pInit = pInit / norm(pInit);
53
54 auto const pTheta = calculatePTheta(op, pInit, toleranceSolverP);
55
56 for(auto const &[k, v] : pD)
57 {
58 score += Score.pointwiseScore(v, pTheta(k));
59 }
60
61 // calculate g = dS/dtheta
62 auto rhs =
63 eval(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
64 [](auto const... Ns)
65 {
66 return pRC::zero<pRC::Tensor<T, Ns...>>();
67 }));
68
69 for(auto const &[k, v] : pD)
70 {
71 rhs(k) = Score.pointwiseDSDP(v, pTheta(k));
72 }
73
74 auto const q = pRC::solve<pRC::Solver::GMRES<>,
75 pRC::Operator::Transform::Transpose>(op, rhs,
76 pRC::zero<decltype(rhs)>(), toleranceSolverQ);
77
78 // this follows appendix C of Rudi's Thesis
79 for(pRC::Index i = 0; i < D; ++i)
80 {
81 auto r = hadamardProduct(q,
82 nonTT::applyDerivative(op, pTheta, i));
83
84 pRC::range<pRC::Context::CompileTime, D>(
85 [&](auto const j)
86 {
87 if(i == j)
88 {
89 g(i, j) = pRC::reduce<pRC::Add>(r)();
90 }
91 else
92 {
93 g(i, j) = pRC::reduce<pRC::Add>(pRC::chip<j>(r, 1))();
94 }
95 });
96 }
97
98 score -= Regulator.score(log(op.theta()));
99 g -= Regulator.grad(log(op.theta()));
100
101 return std::make_tuple(score, g);
102 }
103
120 template<pRC::Size RP, pRC::Size RQ, class T, pRC::Size D, class S>
121 std::tuple<T, pRC::Tensor<T, D, D>>
123 std::map<S, T> const &pD, cMHN::Score<T> const &Score,
124 cMHN::Regulator<T, D> const &Regulator, T const &toleranceSolverP = 1e-4,
125 T const &toleranceSolverQ = 1e-4)
126 {
127 using ModeSizes = decltype(TT::getModeSizes<D>());
128 using Ranks = decltype(TT::getRanks<D, RP>());
129
130 T score = pRC::zero();
131 pRC::Tensor<T, D, D> g = pRC::zero();
132
133 // use a random pInit
134 pRC::SeedSequence seq(8, 16);
135 pRC::RandomEngine rng(seq);
136 pRC::GaussianDistribution<pRC::Float<>> dist;
137 auto pInit = round<Ranks>(
138 pRC::random<pRC::TensorTrain::Tensor<T, ModeSizes, Ranks>>
139 (rng, dist));
140 pInit /= scalarProduct(pInit,
141 pRC::unit<pRC::TensorTrain::Tensor<T, ModeSizes>>())();
142
143 auto const pTheta =
144 calculatePTheta<RP>(op, pInit, toleranceSolverP);
145
146 T scoreT = pRC::zero();
147 pRC::Tensor<T, D, D> gT = pRC::zero();
148#if defined(_OPENMP)
149 #pragma omp declare reduction(+: T, pRC::Tensor<T, D, D>: omp_out = omp_in + omp_out) \
150 initializer (omp_priv(pRC::Zero()))
151 #pragma omp parallel for schedule(dynamic, 10) reduction(+ : scoreT, gT)
152#endif
153 for(pRC::Index s = 0; s < pD.size(); ++s)
154 {
155 auto it = pD.cbegin();
156 std::advance(it, s);
157 auto const [k, v] = *it;
158 auto const pThetaE = pTheta(k);
159
160 scoreT += Score.pointwiseScore(v, pThetaE);
161
162 auto const rhs =
163 pRC::TensorTrain::Tensor<T, ModeSizes>::Single(
164 pRC::identity<T>(), k);
165
166 // solve (1-Q)T * q = rhs
167 auto const q = TT::ALS<RQ, pRC::Operator::Transform::Transpose>(
168 op, rhs, toleranceSolverQ);
169
170 pRC::Tensor<T, D, D> tmp = pRC::zero();
171 for(pRC::Index i = 0; i < D; ++i)
172 {
173 for(pRC::Index j = 0; j < D; ++j)
174 {
175 gT(i, j) = -Score.pointwiseDSDP(v, pThetaE) *
176 scalarProduct(q, op.derivative(i, j) * pTheta)();
177 }
178 }
179 }
180
181 score += scoreT - Regulator.score(log(op.theta()));
182 g += gT - Regulator.grad(log(op.theta()));
183
184 return std::make_tuple(score, g);
185 }
186}
187
188#endif
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Class storing all relevant information for a regulator.
Definition: regulator.hpp:30
auto grad(pRC::Tensor< T, D, D > const &theta) const
Definition: regulator.hpp:53
auto score(pRC::Tensor< T, D, D > const &theta) const
Definition: regulator.hpp:48
Class storing all relevant information for a score.
Definition: score.hpp:27
auto pointwiseScore(T const &pDE, T const &pThetaE) const
Definition: score.hpp:44
auto pointwiseDSDP(T const &pDE, T const &pThetaE) const
Definition: score.hpp:49
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition: mhn_operator.hpp:23
constexpr auto & theta(pRC::Index const i, pRC::Index const j) const
Definition: mhn_operator.hpp:33
constexpr auto derivative(pRC::Index const i, pRC::Index const j) const
Definition: mhn_operator.hpp:76
Class storing an MHN operator represented by a theta matrix (for non TT calculations)
Definition: mhn_operator.hpp:23
constexpr auto & theta(pRC::Index const i, pRC::Index const j) const
Definition: mhn_operator.hpp:33
pRC::Float<> T
Definition: externs_nonTT.hpp:1
static constexpr auto applyDerivative(MHNOperator< T1, D > const &op, pRC::Tensor< T2, Ns... > const &x, pRC::Index const &i)
apply the derivative of an MHN Q wrt to theta_ii to a vector x
Definition: mhn_operator.hpp:71
Definition: calculate_pTheta.hpp:15
std::tuple< T, pRC::Tensor< T, D, D > > calculateScoreAndGradient(nonTT::MHNOperator< T, D > const &op, std::map< S, T > const &pD, cMHN::Score< T > const &Score, cMHN::Regulator< T, D > const &Regulator, T const &toleranceSolverP=1e-4, T const &toleranceSolverQ=1e-4)
Calculate score and gradient of a theta matrix given some data distribution pD.
Definition: calculate_score_and_gradient.hpp:35
X calculatePTheta(nonTT::MHNOperator< T, D > const &op, X const &pInit, T const &toleranceSolver)
Calculates the vector pTheta given a nonTT MHN Operator and a tolerance.
Definition: calculate_pTheta.hpp:33