cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
calculate_score_and_gradient.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_COMMON_CALCULATE_SCORE_AND_GRADIENT_H
4#define cMHN_COMMON_CALCULATE_SCORE_AND_GRADIENT_H
5
6#include <map>
7#include <tuple>
8
11#include <cmhn/common/score.hpp>
14
15#include <prc.hpp>
16
17namespace cMHN
18{
34 template<class T, pRC::Size D, class S>
35 std::tuple<T, pRC::Tensor<T, D, D>> calculateScoreAndGradient(
36 nonTT::MHNOperator<T, D> const &op, std::map<S, T> const &pD,
38 T const &toleranceSolverQ = 1e-4)
39 {
40 T score = pRC::zero();
42
43 auto const pTheta = calculatePTheta(op);
44
45 for(auto const &[k, v] : pD)
46 {
47 score += Score.pointwiseScore(v, pTheta(k));
48 }
49
50 // calculate g = dS/dtheta
51 auto rhs = eval(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
52 [](auto const... Ns)
53 {
54 return pRC::zero<pRC::Tensor<T, Ns...>>();
55 }));
56
57 for(auto const &[k, v] : pD)
58 {
59 rhs(k) = Score.pointwiseDSDP(v, pTheta(k));
60 }
61
64 pRC::zero<decltype(rhs)>(), toleranceSolverQ);
65
66 // this follows appendix C of Rudi's Thesis
67 for(pRC::Index i = 0; i < D; ++i)
68 {
69 auto r = hadamardProduct(q, nonTT::applyDerivative(op, pTheta, i));
70
72 [&](auto const j)
73 {
74 if(i == j)
75 {
76 g(i, j) = pRC::reduce<pRC::Add>(r)();
77 }
78 else
79 {
80 g(i, j) = pRC::reduce<pRC::Add>(pRC::chip<j>(r, 1))();
81 }
82 });
83 }
84
85 score -= Regulator.score(log(op.bigTheta()));
86 g -= Regulator.grad(log(op.bigTheta()));
87
88 return std::make_tuple(score, g);
89 }
90
115 template<pRC::Size RP, pRC::Size RQ, class T, pRC::Size D, class S, class X>
116 std::tuple<T, pRC::Tensor<T, D, D>> calculateScoreAndGradient(
117 TT::MHNOperator<T, D> const &op, std::map<S, T> const &pD,
119 X &pInit, T const &toleranceSolverP = 1e-4,
120 T const &toleranceSolverQ = 1e-4)
121 {
122 using ModeSizes = decltype(TT::getModeSizes<D>());
123
124 T score = pRC::zero();
126
127 auto const pTheta = calculatePTheta<RP>(op, pInit, toleranceSolverP);
128
129 T scoreT = pRC::zero();
131#if defined(_OPENMP)
132 #pragma omp declare reduction(+ : T, \
133 pRC::Tensor<T, D, D> : omp_out = omp_in + omp_out) \
134 initializer(omp_priv(pRC::Zero()))
135 #pragma omp parallel for schedule(dynamic, 10) reduction(+ : scoreT, gT)
136#endif
137 for(pRC::Index s = 0; s < pD.size(); ++s)
138 {
139 auto it = pD.cbegin();
140 std::advance(it, s);
141 auto const [k, v] = *it;
142 auto const pThetaE = pTheta(k);
143
144 scoreT += Score.pointwiseScore(v, pThetaE);
145
147 pRC::identity<T>(), k);
148
149 // solve (1-Q)T * q = rhs
150 auto const q = TT::ALS<RQ, pRC::Operator::Transform::Transpose>(op,
151 rhs, toleranceSolverQ);
152
154 for(pRC::Index i = 0; i < D; ++i)
155 {
156 for(pRC::Index j = 0; j < D; ++j)
157 {
158 tmp(i, j) = -Score.pointwiseDSDP(v, pThetaE) *
159 scalarProduct(q, op.derivative(i, j) * pTheta)();
160 }
161 }
162
163 gT += tmp;
164 }
165
166 score += scoreT - Regulator.score(log(op.bigTheta()));
167 g += gT - Regulator.grad(log(op.bigTheta()));
168
169 // update pInit
170 pInit = pTheta;
171
172 return std::make_tuple(score, g);
173 }
174
194 template<pRC::Size RP, pRC::Size RQ, class T, pRC::Size D, class S>
195 std::tuple<T, pRC::Tensor<T, D, D>> calculateScoreAndGradient(
196 TT::MHNOperator<T, D> const &op, std::map<S, T> const &pD,
198 T const &toleranceSolverP = 1e-4, T const &toleranceSolverQ = 1e-4)
199 {
200 using ModeSizes = decltype(TT::getModeSizes<D>());
201 using Ranks = decltype(TT::getRanks<D, RP>());
202
203 // use a random pInit
204 pRC::SeedSequence seq(8, 16);
205 pRC::RandomEngine rng(seq);
207 auto pInit = round<Ranks>(
209 dist));
210 pInit /= scalarProduct(pInit,
212
213 return calculateScoreAndGradient<RP, RQ>(op, pD, Score, Regulator,
214 pInit, toleranceSolverP, toleranceSolverQ);
215 }
216}
217
218#endif
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing all relevant information for a regulator.
Definition regulator.hpp:30
auto grad(pRC::Tensor< T, D, D > const &theta) const
Definition regulator.hpp:53
auto score(pRC::Tensor< T, D, D > const &theta) const
Definition regulator.hpp:48
Class storing all relevant information for a score.
Definition score.hpp:27
auto pointwiseScore(T const &pDE, T const &pThetaE) const
Definition score.hpp:44
auto pointwiseDSDP(T const &pDE, T const &pThetaE) const
Definition score.hpp:49
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
constexpr auto & bigTheta(pRC::Index const i, pRC::Index const j) const
Definition mhn_operator.hpp:34
constexpr auto derivative(pRC::Index const i, pRC::Index const j) const
Definition mhn_operator.hpp:87
Class storing an MHN operator represented by a theta matrix (for non TT calculations)
Definition mhn_operator.hpp:24
constexpr auto & bigTheta(pRC::Index const i, pRC::Index const j) const
Definition mhn_operator.hpp:34
Definition type_traits.hpp:57
Definition seq.hpp:13
Definition type_traits.hpp:17
Definition tensor.hpp:28
Definition threefry.hpp:24
pRC::Float<> T
Definition externs_nonTT.hpp:1
static constexpr auto applyDerivative(MHNOperator< T1, D > const &op, pRC::Tensor< T2, Ns... > const &x, pRC::Index const &i)
apply the derivative of an MHN Q wrt to theta_ii to a vector x
Definition mhn_operator.hpp:82
Definition calculate_pTheta.hpp:16
decltype(expand(pRC::makeConstantSequence< pRC::Size, D, 2 >(), [](auto const ... Ns) { return pRC::Tensor< T, Ns... >{};})) calculatePTheta(nonTT::MHNOperator< T, D > const &op)
Calculates the vector pTheta given a nonTT MHN Operator.
Definition calculate_pTheta.hpp:35
std::tuple< T, pRC::Tensor< T, D, D > > calculateScoreAndGradient(nonTT::MHNOperator< T, D > const &op, std::map< S, T > const &pD, cMHN::Score< T > const &Score, cMHN::Regulator< T, D > const &Regulator, T const &toleranceSolverQ=1e-4)
Calculate score and gradient of a theta matrix given some data distribution pD.
Definition calculate_score_and_gradient.hpp:35
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
static constexpr auto random(RandomEngine &rng, D &distribution)
Definition random.hpp:12
Size Index
Definition type_traits.hpp:21
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto unit()
Definition unit.hpp:12