cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
generate_data.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_UTILITY_GENERATE_DATA_H
4#define cMHN_UTILITY_GENERATE_DATA_H
5
8
9#include <fstream>
10#include <iomanip>
11#include <string>
12
13#include <prc.hpp>
14
15namespace cMHN
16{
34 template<class T, pRC::Size D, class F>
35 static inline auto generateData(pRC::RandomEngine<F> &rng,
36 pRC::Tensor<T, D, D> const &smallThetaGT, pRC::Size const &size,
37 std::string const &header, std::string const &filename)
38 {
39
40 using Subscripts =
42 [](auto const... seq)
43 {
44 return pRC::Subscripts<seq...>();
45 }));
46
47 // rng setup
49
50 pRC::Bool worked = true;
51 pRC::Index at_try = 0;
52 do
53 {
54 std::ofstream file(filename,
55 std::ofstream::out | std::ofstream::trunc);
56
57 if(!file.is_open())
58 {
59 pRC::Logging::error("Unable to open output file!");
60 }
61
62 // print header to file
63 file << header << std::endl;
64
65 ++at_try;
66 Subscripts check{};
67 for(pRC::Index i = 0; i < size; ++i)
68 {
69 Subscripts sample{};
70
71 pRC::Tensor<T, D> transitionRates =
72 exp(extractDiagonal(smallThetaGT));
73
74 // simulate until we terminate
75 while(true)
76 {
77 // simulate one step
78 T const rateSum =
79 reduce<pRC::Add>(transitionRates)() + pRC::unit<T>();
80 T const rand = pRC::random<T>(rng, dist);
81 T sumRejected = pRC::zero();
82 pRC::Index newEvent = 0;
83 while(sumRejected + transitionRates(newEvent) <
84 rand * rateSum)
85 {
86 sumRejected += transitionRates(newEvent);
87 ++newEvent;
88 if(newEvent == D)
89 break;
90 }
91 if(newEvent == D)
92 break;
93 sample[newEvent] = 1;
94 check[newEvent] = 1;
95
96 // update transitionRates
97 transitionRates[newEvent] = pRC::zero();
98 for(pRC::Index j = 0; j < D; ++j)
99 {
100 transitionRates[j] *= exp(smallThetaGT(j, newEvent));
101 }
102 }
103 for(pRC::Index i = 0; i < D; ++i)
104 {
105 file << sample[i];
106 if(i != D - 1)
107 {
108 file << " ";
109 }
110 }
111 file << std::endl;
112 }
113
114 // check if there is a zero column
115 worked = true;
116 for(pRC::Index i = 0; i < D; ++i)
117 {
118 if(check[i] == 0)
119 {
121 "Zero column detected in generated dataset, "
122 "regenerating!");
123 worked = false;
124 }
125 }
126 file.close();
127 }
128 while(!worked && at_try < 10);
129
130 return;
131 }
132
146 template<class T, pRC::Size D>
147 static inline auto generateData(pRC::Tensor<T, D, D> const &smallThetaGT,
148 pRC::Size const &size, std::string const &header,
149 std::string const &filename)
150 {
151 // rng setup
152 pRC::SeedSequence seq(8, 16);
154
155 return generateData(rng, smallThetaGT, size, header, filename);
156 }
157}
158
159#endif // cMHN_UTILITY_GENERATE_DATA_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition value.hpp:12
Definition engine.hpp:13
Definition seq.hpp:13
Definition subscripts.hpp:21
Definition tensor.hpp:25
Definition threefry.hpp:22
Definition uniform.hpp:18
int i
Definition gmock-matchers-comparisons_test.cc:603
Definition calculate_pTheta.hpp:20
static auto generateData(pRC::RandomEngine< F > &rng, pRC::Tensor< T, D, D > const &smallThetaGT, pRC::Size const &size, std::string const &header, std::string const &filename)
Generates a data file from a given ground truth model.
Definition generate_data.hpp:35
static void warning(Xs &&...args)
Definition log.hpp:21
static void error(Xs &&...args)
Definition log.hpp:14
static constexpr auto unit()
Definition unit.hpp:13
bool Bool
Definition basics.hpp:29
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto random(URNG &rng, D &distribution)
Definition random.hpp:13