cMHN 1.0
C++ library for learning MHNs with pRC
mhn_operator.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_NONTT_MHN_OPERATOR_H
4#define cMHN_NONTT_MHN_OPERATOR_H
5
6#include <prc.hpp>
7
8namespace cMHN::nonTT
9{
21 template<class T, pRC::Size D>
23 {
24 public:
25 using Type = T;
26
27 public:
28 constexpr MHNOperator(pRC::Tensor<T, D, D> const &theta)
29 : mTheta(exp(theta))
30 {
31 }
32
33 constexpr auto &theta(pRC::Index const i, pRC::Index const j) const
34 {
35 return mTheta(i, j);
36 }
37
38 constexpr auto &theta() const
39 {
40 return mTheta;
41 }
42
43 auto &theta()
44 {
45 return mTheta;
46 }
47
48 private:
49 pRC::Tensor<T, D, D> mTheta;
50 };
51
64 // apply dQ/dtheta_ii to a vector x
65 template<pRC::Operator::Transform OT = pRC::Operator::Transform::None,
66 pRC::Operator::Restrict OR = pRC::Operator::Restrict::None,
67 pRC::Operator::Hint OH = pRC::Operator::Hint::None, class T1,
68 pRC::Size D, class T2, pRC::Size... Ns,
69 pRC::If<pRC::IsSatisfied<(OR == pRC::Operator::Restrict::None)>> = 0,
70 pRC::If<pRC::IsSatisfied<(OT == pRC::Operator::Transform::None)>> = 0>
71 static inline constexpr auto applyDerivative(MHNOperator<T1, D> const &op,
72 pRC::Tensor<T2, Ns...> const &x, pRC::Index const &i)
73
74 {
75 auto r = x;
76 pRC::range<pRC::Context::CompileTime, D>(
77 [&](auto const j)
78 {
79 if(i == j)
80 {
81 pRC::chip<j>(r, 0) *= -op.theta(i, i);
82 pRC::chip<j>(r, 1) = -pRC::chip<j>(r, 0);
83 }
84 else
85 {
86 pRC::chip<j>(r, 1) *= op.theta(i, j);
87 }
88 });
89 return r;
90 }
91
105 // apply (1-Q) or the transposed to a vector x
106 template<pRC::Operator::Transform OT = pRC::Operator::Transform::None,
107 pRC::Operator::Restrict OR = pRC::Operator::Restrict::None,
108 pRC::Operator::Hint OH = pRC::Operator::Hint::None, class T1,
109 pRC::Size D, class T2, pRC::Size... Ns,
110 pRC::If<pRC::IsSatisfied<(OR == pRC::Operator::Restrict::None)>> = 0>
111 static inline constexpr auto apply(MHNOperator<T1, D> const &op,
112 pRC::Tensor<T2, Ns...> const &x)
113
114 {
115 auto b = x;
116 for(pRC::Index i = 0; i < D; ++i)
117 {
118 auto v = x;
119 pRC::range<pRC::Context::CompileTime, D>(
120 [&](auto const j)
121 {
122 if(i == j)
123 {
124 if constexpr(OT == pRC::Operator::Transform::Transpose)
125 {
126 pRC::chip<j>(v, 0) =
127 -op.theta(i, i) * pRC::chip<j>(v, 0) +
128 op.theta(i, i) * pRC::chip<j>(v, 1);
129 pRC::chip<j>(v, 1) = pRC::zero();
130 }
131 else
132 {
133 pRC::chip<j>(v, 0) *= -op.theta(i, i);
134 pRC::chip<j>(v, 1) = -pRC::chip<j>(v, 0);
135 }
136 }
137 else
138 {
139 pRC::chip<j>(v, 1) *= op.theta(i, j);
140 }
141 });
142 b -= v;
143 }
144 return b;
145 }
146
147}
148
149#endif // cMHN_NONTT_MHN_OPERATOR_H
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Class storing an MHN operator represented by a theta matrix (for non TT calculations)
Definition: mhn_operator.hpp:23
T Type
Definition: mhn_operator.hpp:25
constexpr auto & theta(pRC::Index const i, pRC::Index const j) const
Definition: mhn_operator.hpp:33
auto & theta()
Definition: mhn_operator.hpp:43
constexpr MHNOperator(pRC::Tensor< T, D, D > const &theta)
Definition: mhn_operator.hpp:28
constexpr auto & theta() const
Definition: mhn_operator.hpp:38
pRC::Float<> T
Definition: externs_nonTT.hpp:1
Definition: learn_theta.hpp:20
static constexpr auto apply(MHNOperator< T1, D > const &op, pRC::Tensor< T2, Ns... > const &x)
apply (1-Q) or its transposed to a vector x, given an MHN Q
Definition: mhn_operator.hpp:111
static constexpr auto applyDerivative(MHNOperator< T1, D > const &op, pRC::Tensor< T2, Ns... > const &x, pRC::Index const &i)
apply the derivative of an MHN Q wrt to theta_ii to a vector x
Definition: mhn_operator.hpp:71