3#ifndef cMHN_TT_LEARN_THETA_H
4#define cMHN_TT_LEARN_THETA_H
56 template<pRC::Size RP, pRC::Size RQ,
class T, pRC::Size D,
class S>
57 std::tuple<pRC::Tensor<T, D, D>, std::map<std::string, std::string>,
58 std::map<std::string, double>>
60 std::string
const &output, std::map<S, T>
const &pD,
62 T const &toleranceOptimizer,
T const &toleranceSolverP,
63 T const &toleranceSolverQ)
67 auto tempTheta = theta;
75 std::map<std::string, double> logInfoNumbers{{
"Score", score()},
76 {
"Iterations", at_iter},
80 std::map<std::string, std::string> logInfoNames{
83 writeTheta(output, header, tempTheta, logInfoNames, logInfoNumbers);
85 std::cout <<
"cMHN learning started (TT):" << std::endl;
86 std::cout <<
"\tScore Name:\t" << logInfoNames[
"Score Name"]
88 std::cout <<
"\tRegulator Name:\t" << logInfoNames[
"Regulator Name"]
96 auto pInit = round<decltype(getRanks<D, RP>())>(
106 &toleranceOptimizer, &toleranceSolverP, &toleranceSolverQ](
113 Regulator, pInit, toleranceSolverP, toleranceSolverQ);
117 [&output, &header, &score, &at_iter, &startTime, &logInfoNames,
118 &logInfoNumbers](
auto const &tempTheta)
121 logInfoNumbers[
"Iterations"] = at_iter;
122 logInfoNumbers[
"Score"] = score();
123 logInfoNumbers[
"Time"] =
126 std::cout <<
"cMHN learning in progress (TT):" << std::endl;
127 std::cout << std::defaultfloat;
128 std::cout <<
"\tIteration:\t" << logInfoNumbers[
"Iterations"]
130 std::cout << std::scientific;
131 std::cout <<
"\tLambda:\t\t" << logInfoNumbers[
"Lambda"]
133 std::cout <<
"\tScore:\t\t" << logInfoNumbers[
"Score"]
135 std::cout <<
"\tTime:\t\t" << logInfoNumbers[
"Time"]
137 std::cout << std::defaultfloat;
139 writeTheta(output, header, tempTheta, logInfoNames,
144 return std::make_tuple(tempTheta, logInfoNames, logInfoNumbers);
Class storing all relevant information for a regulator.
Definition regulator.hpp:30
auto & lambda()
Definition regulator.hpp:58
auto name() const
Definition regulator.hpp:68
Class storing all relevant information for a score.
Definition score.hpp:28
auto name() const
Definition score.hpp:55
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
Definition gaussian.hpp:14
Definition gradient_descent.hpp:14
Definition declarations.hpp:16
Definition threefry.hpp:22
std::tuple< pRC::Tensor< T, D, D >, std::map< std::string, std::string >, std::map< std::string, double > > learnTheta(pRC::Tensor< T, D, D > const &theta, std::string const &header, std::string const &output, std::map< S, T > const &pD, cMHN::Score< T > const &Score, cMHN::Regulator< T, D > const &Regulator, T const &toleranceOptimizer, T const &toleranceSolverP, T const &toleranceSolverQ)
Optimizes an MHN represented by a theta matrix to best describe a given data distribution using the T...
Definition learn_theta.hpp:59
constexpr auto getRanks()
Definition utility.hpp:17
constexpr auto getModeSizes()
Definition utility.hpp:11
std::tuple< T, pRC::Tensor< T, D, D > > calculateScoreAndGradient(nonTT::MHNOperator< T, D > const &op, std::map< S, T > const &pD, cMHN::Score< T > const &Score, cMHN::Regulator< T, D > const &Regulator, T const &toleranceSolverQ=1e-8)
Calculate score and gradient of a theta matrix given some data distribution pD.
Definition calculate_score_and_gradient.hpp:35
static auto writeTheta(std::string const &filename, std::string const &header, pRC::Tensor< T, D, D > const &theta, std::map< std::string, std::string > const &logInfoNames={}, std::map< std::string, double > const &logInfoNumbers={})
Writes a theta matrix to file, including additional logging information at the bottom.
Definition write_theta.hpp:29
static constexpr auto unit()
Definition unit.hpp:13
Size Index
Definition basics.hpp:32
static constexpr auto optimize(Optimizer &&optimizer, XX &&x, FF &&function, FC &&callback, VT const &tolerance=NumericLimits< Value< RemoveReference< XX > > >::tolerance())
Definition optimize.hpp:15
static Float< 64 > getTimeInSeconds()
Definition stopwatch.hpp:23
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13