cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
mhn_operator.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_MHN_OPERATOR_H
4#define cMHN_TT_MHN_OPERATOR_H
5
6#include <prc.hpp>
7
8namespace cMHN::TT
9{
22 template<class T, pRC::Size D>
24 {
25 public:
26 using Type = T;
27
28 public:
30 : mBigTheta(exp(smallTheta))
31 {
32 }
33
34 constexpr auto &bigTheta(pRC::Index const i, pRC::Index const j) const
35 {
36 return mBigTheta(i, j);
37 }
38
39 constexpr auto &bigTheta() const
40 {
41 return mBigTheta;
42 }
43
44 auto &bigTheta()
45 {
46 return mBigTheta;
47 }
48
49 constexpr auto smallTheta(pRC::Index const i, pRC::Index const j) const
50 {
51 return log(mBigTheta(i, j));
52 }
53
54 constexpr auto smallTheta() const
55 {
56 return eval(log(mBigTheta));
57 }
58
59 // calculates one TT term in eq. (2.2) of Peter's Thesis
60 // -> Q=sum_i term(i)
61 constexpr auto term(pRC::Index const i) const
62 {
63 TermType ret;
65 [this, &ret, i](auto const j)
66 {
67 ret.template core<j>() = this->core(i, j);
68 });
69
70 return ret;
71 }
72
73 // calculates a core of the TT (1-Q)
74 template<pRC::Index i>
75 constexpr auto core() const
76 {
77 return expand(pRC::makeSeries<pRC::Index, D>(),
78 [&](auto const... seq)
79 {
80 return (pRC::identity<TermType>() - (term(seq) + ...))
81 .template core<i>();
82 });
83 }
84
85 // calculates dQ/d(theta_ij) - note that the derivative is wrt small
86 // thetas
87 constexpr auto derivative(pRC::Index const i, pRC::Index const j) const
88 {
89 TermType ret;
91 [this, &ret, i, j](auto const m)
92 {
93 ret.template core<m>() = this->derivative(i, j, m);
94 });
95
96 return ret;
97 }
98
99 private:
100 template<pRC::Index... Ms, pRC::Index... Ns, pRC::Index... Rs>
101 static constexpr auto termType(pRC::Sequence<pRC::Index, Ms...> const,
104 {
105 using M = pRC::Sizes<Ms...>;
106 using N = pRC::Sizes<Ns...>;
107 using R = pRC::Sizes<Rs...>;
108
110 }
111
112 using TermType =
116
117 private:
118 // cores for calculation of Q
119 constexpr auto core(pRC::Size const i, pRC::Size const j) const
120 {
122 if(i == j)
123 {
124 c(0, 0, 0, 0) = -mBigTheta(i, i);
125 c(0, 1, 0, 0) = mBigTheta(i, i);
126 }
127 else
128 {
129 c(0, 0, 0, 0) = pRC::identity<T>();
130 c(0, 1, 1, 0) = mBigTheta(i, j);
131 }
132 return c;
133 }
134
135 constexpr auto derivative(pRC::Size const i, pRC::Size const j,
136 pRC::Size const m) const
137 {
139 if(j == m)
140 {
141 if(i == j)
142 {
143 c(0, 0, 0, 0) = -mBigTheta(i, m);
144 c(0, 1, 0, 0) = mBigTheta(i, m);
145 }
146 else
147 {
148 c(0, 1, 1, 0) = mBigTheta(i, m);
149 }
150 }
151 else
152 {
153 if(i == m)
154 {
155 c(0, 0, 0, 0) = -mBigTheta(i, m);
156 c(0, 1, 0, 0) = mBigTheta(i, m);
157 }
158 else
159 {
160 c(0, 0, 0, 0) = pRC::identity<T>();
161 c(0, 1, 1, 0) = mBigTheta(i, m);
162 }
163 }
164
165 return c;
166 }
167
168 private:
169 pRC::Tensor<T, D, D> mBigTheta;
170 };
171}
172
173#endif // cMHN_TT_MHN_OPERATOR_H
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition mhn_operator.hpp:24
constexpr auto term(pRC::Index const i) const
Definition mhn_operator.hpp:61
constexpr auto smallTheta() const
Definition mhn_operator.hpp:54
auto & bigTheta()
Definition mhn_operator.hpp:44
constexpr auto & bigTheta() const
Definition mhn_operator.hpp:39
constexpr MHNOperator(pRC::Tensor< T, D, D > const &smallTheta)
Definition mhn_operator.hpp:29
constexpr auto & bigTheta(pRC::Index const i, pRC::Index const j) const
Definition mhn_operator.hpp:34
constexpr auto derivative(pRC::Index const i, pRC::Index const j) const
Definition mhn_operator.hpp:87
constexpr auto core() const
Definition mhn_operator.hpp:75
constexpr auto smallTheta(pRC::Index const i, pRC::Index const j) const
Definition mhn_operator.hpp:49
Definition value.hpp:12
Definition sequence.hpp:29
Definition tensor.hpp:25
pRC::Float<> T
Definition externs_nonTT.hpp:1
int i
Definition gmock-matchers-comparisons_test.cc:603
Definition als.hpp:12
Operator(OperatorViews::View< T, M, N, Ranks, F > const &) -> Operator< T, M, N, Ranks >
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
static constexpr auto makeSeries()
Definition sequence.hpp:390
static constexpr auto range(F &&f, Xs &&...args)
Definition range.hpp:18
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12