cMHN 1.0
C++ library for learning MHNs with pRC
mhn_operator.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_TT_MHN_OPERATOR_H
4#define cMHN_TT_MHN_OPERATOR_H
5
6#include <prc.hpp>
7
8namespace cMHN::TT
9{
21 template<class T, pRC::Size D>
23 {
24 public:
25 using Type = T;
26
27 public:
28 constexpr MHNOperator(pRC::Tensor<T, D, D> const &theta)
29 : mTheta(exp(theta))
30 {
31 }
32
33 constexpr auto &theta(pRC::Index const i, pRC::Index const j) const
34 {
35 return mTheta(i, j);
36 }
37
38 constexpr auto &theta() const
39 {
40 return mTheta;
41 }
42
43 auto &theta()
44 {
45 return mTheta;
46 }
47
48 // calculates one TT term in eq. (2.2) of Peter's Thesis
49 // -> Q=sum_i term(i)
50 constexpr auto term(pRC::Index const i) const
51 {
52 TermType ret;
53 pRC::range<pRC::Context::CompileTime, D>(
54 [this, &ret, i](auto const j)
55 {
56 ret.template core<j>() = this->core(i, j);
57 });
58
59 return ret;
60 }
61
62 // calculates a core of the TT (1-Q)
63 template<pRC::Index i>
64 constexpr auto core() const
65 {
66 return expand(pRC::makeSeries<pRC::Index, D>(),
67 [&](auto const... seq)
68 {
69 return (pRC::identity<TermType>() - (term(seq) + ...))
70 .template core<i>();
71 });
72 }
73
74 // calculates dQ/d(theta_ij) - note that the derivative is wrt small
75 // thetas
76 constexpr auto derivative(pRC::Index const i, pRC::Index const j) const
77 {
78 TermType ret;
79 pRC::range<pRC::Context::CompileTime, D>(
80 [this, &ret, i, j](auto const m)
81 {
82 ret.template core<m>() = this->derivative(i, j, m);
83 });
84
85 return ret;
86 }
87
88 private:
89 template<pRC::Index... Ms, pRC::Index... Ns, pRC::Index... Rs>
90 static constexpr auto termType(pRC::Sequence<pRC::Index, Ms...> const,
91 pRC::Sequence<pRC::Index, Ns...> const,
92 pRC::Sequence<pRC::Index, Rs...> const)
93 {
94 using M = pRC::Sizes<Ms...>;
95 using N = pRC::Sizes<Ns...>;
96 using R = pRC::Sizes<Rs...>;
97
98 return pRC::TensorTrain::Operator<T, M, N, R>();
99 }
100
101 using TermType =
102 decltype(termType(pRC::makeConstantSequence<pRC::Index, D, 2>(),
103 pRC::makeConstantSequence<pRC::Index, D, 2>(),
104 pRC::makeConstantSequence<pRC::Index, D - 1, 1>()));
105
106 private:
107 // cores for calculation of Q
108 constexpr auto core(pRC::Size const i, pRC::Size const j) const
109 {
110 pRC::Tensor c = pRC::zero<pRC::Tensor<T, 1, 2, 2, 1>>();
111 if(i == j)
112 {
113 c(0, 0, 0, 0) = -mTheta(i, i);
114 c(0, 1, 0, 0) = mTheta(i, i);
115 }
116 else
117 {
118 c(0, 0, 0, 0) = pRC::identity<T>();
119 c(0, 1, 1, 0) = mTheta(i, j);
120 }
121 return c;
122 }
123
124 constexpr auto derivative(pRC::Size const i, pRC::Size const j,
125 pRC::Size const m) const
126 {
127 pRC::Tensor c = pRC::zero<pRC::Tensor<T, 1, 2, 2, 1>>();
128 if(j == m)
129 {
130 if(i == j)
131 {
132 c(0, 0, 0, 0) = mTheta(i, m);
133 c(0, 1, 0, 0) = -mTheta(i, m);
134 }
135 else
136 {
137 c(0, 1, 1, 0) = mTheta(i, m);
138 }
139 }
140 else
141 {
142 if(i == m)
143 {
144 c(0, 0, 0, 0) = mTheta(i, m);
145 c(0, 1, 0, 0) = -mTheta(i, m);
146 }
147 else
148 {
149 c(0, 0, 0, 0) = pRC::identity<T>();
150 c(0, 1, 1, 0) = mTheta(i, m);
151 }
152 }
153
154 return c;
155 }
156
157 private:
158 pRC::Tensor<T, D, D> mTheta;
159 };
160}
161
162#endif // cMHN_TT_MHN_OPERATOR_H
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition: mhn_operator.hpp:23
constexpr auto & theta() const
Definition: mhn_operator.hpp:38
auto & theta()
Definition: mhn_operator.hpp:43
constexpr auto term(pRC::Index const i) const
Definition: mhn_operator.hpp:50
constexpr auto & theta(pRC::Index const i, pRC::Index const j) const
Definition: mhn_operator.hpp:33
constexpr MHNOperator(pRC::Tensor< T, D, D > const &theta)
Definition: mhn_operator.hpp:28
T Type
Definition: mhn_operator.hpp:25
constexpr auto derivative(pRC::Index const i, pRC::Index const j) const
Definition: mhn_operator.hpp:76
constexpr auto core() const
Definition: mhn_operator.hpp:64
pRC::Float<> T
Definition: externs_nonTT.hpp:1
Definition: als.hpp:12