3#ifndef cMHN_COMMON_CALCULATE_SCORE_AND_GRADIENT_H
4#define cMHN_COMMON_CALCULATE_SCORE_AND_GRADIENT_H
34 template<
class T, pRC::Size D,
class S>
38 T const &toleranceSolverQ = 1e-8)
45 for(
auto const &[k, v] : pD)
57 for(
auto const &[k, v] : pD)
64 eval(
pRC::zero<
decltype(rhs)>()), toleranceSolverQ);
88 return std::make_tuple(score, g);
115 template<pRC::Size RP, pRC::Size RQ,
class T, pRC::Size D,
class S,
class X>
119 X &pInit,
T const &toleranceSolverP = 1e-4,
120 T const &toleranceSolverQ = 1e-4)
132 #pragma omp declare reduction(+ : T, \
133 pRC::Tensor<T, D, D> : omp_out = omp_in + omp_out) \
134 initializer(omp_priv(pRC::Zero()))
135 #pragma omp parallel for schedule(dynamic, 10) reduction(+ : scoreT, gT)
139 auto it = pD.cbegin();
141 auto const [k, v] = *it;
142 auto const pThetaE = pTheta(k);
151 rhs, toleranceSolverQ);
172 return std::make_tuple(score, g);
194 template<pRC::Size RP, pRC::Size RQ,
class T, pRC::Size D,
class S>
198 T const &toleranceSolverP = 1e-4,
T const &toleranceSolverQ = 1e-4)
207 auto pInit = round<Ranks>(
210 pInit /= scalarProduct(pInit,
214 pInit, toleranceSolverP, toleranceSolverQ);
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Class storing all relevant information for a regulator.
Definition regulator.hpp:30
auto grad(pRC::Tensor< T, D, D > const &theta) const
Definition regulator.hpp:53
auto score(pRC::Tensor< T, D, D > const &theta) const
Definition regulator.hpp:48
Class storing all relevant information for a score.
Definition score.hpp:28
auto pointwiseScore(T const &pDE, T const &pThetaE) const
Definition score.hpp:45
auto pointwiseDSDP(T const &pDE, T const &pThetaE) const
Definition score.hpp:50
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
constexpr auto & bigTheta(pRC::Index const i, pRC::Index const j) const
Definition mhn_operator.hpp:34
constexpr auto derivative(pRC::Index const i, pRC::Index const j) const
Definition mhn_operator.hpp:87
Class storing an MHN operator represented by a theta matrix (for non TT calculations)
Definition mhn_operator.hpp:24
constexpr auto & bigTheta(pRC::Index const i, pRC::Index const j) const
Definition mhn_operator.hpp:34
Definition gaussian.hpp:14
Definition declarations.hpp:16
Definition threefry.hpp:22
pRC::Float<> T
Definition externs_nonTT.hpp:1
int i
Definition gmock-matchers-comparisons_test.cc:603
constexpr auto getRanks()
Definition utility.hpp:17
auto als(MHNOperator< T, D > const &MHNop, Tb const &b, T const &tolerance, X const &x0)
Definition als.hpp:16
constexpr auto getModeSizes()
Definition utility.hpp:11
static constexpr auto applyDerivative(MHNOperator< T1, D > const &op, pRC::Tensor< T2, Ns... > const &x, pRC::Index const &i)
apply the derivative of an MHN Q wrt to theta_ii to a vector x
Definition mhn_operator.hpp:82
Definition calculate_pTheta.hpp:20
decltype(expand(pRC::makeConstantSequence< pRC::Size, D, 2 >(), [](auto const ... Ns) { return pRC::Tensor< T, Ns... >{};})) calculatePTheta(nonTT::MHNOperator< T, D > const &op)
Calculates the vector pTheta given a nonTT MHN Operator.
Definition calculate_pTheta.hpp:39
std::tuple< T, pRC::Tensor< T, D, D > > calculateScoreAndGradient(nonTT::MHNOperator< T, D > const &op, std::map< S, T > const &pD, cMHN::Score< T > const &Score, cMHN::Regulator< T, D > const &Regulator, T const &toleranceSolverQ=1e-8)
Calculate score and gradient of a theta matrix given some data distribution pD.
Definition calculate_score_and_gradient.hpp:35
static constexpr decltype(auto) solve(Solver &&solver, XA &&A, XB &&b)
Definition solve.hpp:18
static constexpr auto unit()
Definition unit.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
Size Index
Definition basics.hpp:32
static constexpr auto reduce(Sequence< T, I1, I2, Is... > const)
Definition sequence.hpp:458
static constexpr auto chip(Sequence< T, Is... > const)
Definition sequence.hpp:584
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13