cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
generate_theta.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_UTILITY_GENERATE_THETA_H
4#define cMHN_UTILITY_GENERATE_THETA_H
5
6#include <prc.hpp>
7
8namespace cMHN
9{
36 template<pRC::Size D, class T = pRC::Float<>>
37 static inline auto generateTheta(pRC::RandomEngine &rng,
38 pRC::Float<> const &fullness = 0.5, T const &diagonalMean = -2,
39 T const &diagonalStd = 2, T const &offDiagonalMean = 0,
40 T const &offDiagonalB = 0.75)
41 {
42 // diagonal
43 pRC::GaussianDistribution<T> diagonalDist(diagonalMean, diagonalStd);
44 auto diagonalTheta =
45 pRC::random<pRC::Tensor<T, D, D>>(rng, diagonalDist);
46
47 // off diagonal
48 pRC::LaplaceDistribution<T> offDiagonalDist(offDiagonalMean,
49 offDiagonalB);
50 auto offDiagonalTheta =
51 pRC::random<pRC::Tensor<T, D, D>>(rng, offDiagonalDist);
52
53 // combine both
54 auto theta =
55 eval(diagonal(diagonalTheta) + offDiagonal(offDiagonalTheta));
56
57 // add sparsity
59 auto allowedPlaces = offDiagonal(
60 pRC::random<pRC::Tensor<pRC::Float<>, D, D>>(rng, sparsityDist));
61 for(pRC::Index i = 0; i < D; ++i)
62 {
63 for(pRC::Index j = 0; j < D; ++j)
64 {
65 if(allowedPlaces(i, j) > fullness)
66 {
67 theta(i, j) = pRC::zero();
68 }
69 }
70 }
71
72 return theta;
73 }
74
101 template<pRC::Size D, class T = pRC::Float<>>
102 static inline auto generateTheta(pRC::Float<> const &fullness = 0.5,
103 T const &diagonalMean = -2, T const &diagonalStd = 2,
104 T const &offDiagonalMean = 0, T const &offDiagonalB = 0.75)
105 {
106 // rng setup
107 pRC::SeedSequence seq(8, 16);
108 pRC::RandomEngine rng(seq);
109
110 return generateTheta<D>(rng, fullness, diagonalMean, diagonalStd,
111 offDiagonalMean, offDiagonalB);
112 }
113}
114
115#endif // cMHN_UTILITY_GENERATE_THETA_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition type_traits.hpp:57
Definition type_traits.hpp:41
Definition seq.hpp:13
Definition tensor.hpp:28
Definition threefry.hpp:24
Definition type_traits.hpp:49
Definition calculate_pTheta.hpp:16
static auto generateTheta(pRC::RandomEngine &rng, pRC::Float<> const &fullness=0.5, T const &diagonalMean=-2, T const &diagonalStd=2, T const &offDiagonalMean=0, T const &offDiagonalB=0.75)
Generates a random theta matrix according to given distributions, and with given fullness.
Definition generate_theta.hpp:37
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
static constexpr auto random(RandomEngine &rng, D &distribution)
Definition random.hpp:12
Size Index
Definition type_traits.hpp:21
static constexpr auto zero()
Definition zero.hpp:12