cMHN 1.0
C++ library for learning MHNs with pRC
generate_pD.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_UTILITY_GENERATE_PD_H
4#define cMHN_UTILITY_GENERATE_PD_H
5
8
9#include <prc.hpp>
10
11namespace cMHN
12{
27 template<class T, pRC::Size D>
28 static inline auto generatePD(pRC::RandomEngine &rng,
29 nonTT::MHNOperator<T, D> const &op, pRC::Size const &size,
30 T const &toleranceSolverP)
31 {
32 using Subscripts =
33 decltype(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
34 [](auto const... seq)
35 {
36 return pRC::Subscripts<seq...>();
37 }));
38
39 std::map<Subscripts, T> pD;
40 pRC::UnsignedInteger<64> pDSum = 0;
41
42 auto const pTheta = calculatePTheta(op, toleranceSolverP);
43
44 // rng setup
45 pRC::UniformDistribution<T> dist;
46
47 for(pRC::Index i = 0; i < size; ++i)
48 {
49 T sum = pRC::zero();
50 T const r = pRC::random<T>(rng, dist);
51 pRC::Index index = 0;
52 while(index < Subscripts::size()-1)
53 {
54 auto pThetaE = pTheta(Subscripts(index));
55 sum += pThetaE;
56 if(sum >= r)
57 {
58 break;
59 }
60 ++index;
61 }
62 pD.try_emplace(Subscripts(index), pRC::zero<T>());
63 pD[Subscripts(index)] += pRC::unit<T>();
64
65 pDSum += pRC::unit<decltype(pDSum)>();
66 }
67
68 // normalize pD
69 for(auto &[k, v] : pD)
70 {
71 v /= pDSum;
72 }
73
74 return pD;
75 }
76
89 template<class T, pRC::Size D>
90 static inline auto generatePD(nonTT::MHNOperator<T, D> const &op,
91 pRC::Size const &size, T const &toleranceSolverP)
92 {
93 // rng setup
94 pRC::SeedSequence seq(8, 16);
95 pRC::RandomEngine rng(seq);
96
97 return generatePD(rng, op, size, toleranceSolverP);
98 }
99
115 template<pRC::Size RP, class T, pRC::Size D>
116 static inline auto generatePD(pRC::RandomEngine &rng,
117 TT::MHNOperator<T, D> const &op, pRC::Size const &size,
118 T const &toleranceSolverP)
119 {
120 using Subscripts =
121 decltype(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
122 [](auto const... seq)
123 {
124 return pRC::Subscripts<seq...>();
125 }));
126
127 auto const pTheta = calculatePTheta<RP>(op, toleranceSolverP);
128
129 // get heuristic for maximum entry of pTheta;
130 T maxHeuristic = pRC::zero();
131 for(pRC::Index i = 0; i < 1000; ++i)
132 {
133 auto const subscripts = getRandomSubscripts<Subscripts>();
134 auto const value = pTheta(subscripts);
135 if(value > maxHeuristic)
136 {
137 maxHeuristic = value;
138 }
139 }
140
141 // choose prefactor according to heuristic
142 T prefactor = T(1) / maxHeuristic;
143
144 std::map<Subscripts, T> pD;
145 pRC::UnsignedInteger<64> pDSum = 0;
146
147 // rng setup
148 pRC::UniformDistribution<T> dist;
149
150 pRC::Index i = 0;
151 while(i < size)
152 {
153 auto const subscripts = getRandomSubscripts<Subscripts>();
154 T const r = pRC::random<T>(rng, dist);
155 if(r < prefactor * pTheta(subscripts))
156 {
157 pD.try_emplace(subscripts, pRC::zero<T>());
158 pD[subscripts] += pRC::unit<T>();
159 pDSum += pRC::unit<decltype(pDSum)>();
160 ++i;
161 }
162 }
163
164 // normalize pD
165 for(auto &[k, v] : pD)
166 {
167 v /= pDSum;
168 }
169
170 return pD;
171 }
172
186 template<pRC::Size RP, class T, pRC::Size D>
187 static inline auto generatePD(TT::MHNOperator<T, D> const &op,
188 pRC::Size const &size, T const &toleranceSolverP)
189 {
190 // rng setup
191 pRC::SeedSequence seq(8, 16);
192 pRC::RandomEngine rng(seq);
193
194 return generatePD<RP>(rng, op, size, toleranceSolverP);
195 }
196}
197
198#endif // cMHN_UTILITY_GENERATE_PD_H
Class storing an MHN operator represented by a theta matrix (for TT calculations)
Definition: mhn_operator.hpp:23
Class storing an MHN operator represented by a theta matrix (for non TT calculations)
Definition: mhn_operator.hpp:23
pRC::Float<> T
Definition: externs_nonTT.hpp:1
Definition: calculate_pTheta.hpp:15
static auto generatePD(pRC::RandomEngine &rng, nonTT::MHNOperator< T, D > const &op, pRC::Size const &size, T const &toleranceSolverP)
Generates a data distribution from a given nonTT MHNOperator.
Definition: generate_pD.hpp:28
X calculatePTheta(nonTT::MHNOperator< T, D > const &op, X const &pInit, T const &toleranceSolver)
Calculates the vector pTheta given a nonTT MHN Operator and a tolerance.
Definition: calculate_pTheta.hpp:33