cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
generate_data.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_UTILITY_GENERATE_DATA_H
4#define cMHN_UTILITY_GENERATE_DATA_H
5
8
9#include <fstream>
10#include <iomanip>
11#include <string>
12
13#include <prc.hpp>
14
15namespace cMHN
16{
33 template<class T, pRC::Size D>
34 static inline auto generateData(pRC::RandomEngine &rng,
35 pRC::Tensor<T, D, D> const &smallThetaGT, pRC::Size const &size,
36 std::string const &header, std::string const &filename)
37 {
38
39 using Subscripts =
41 [](auto const... seq)
42 {
43 return pRC::Subscripts<seq...>();
44 }));
45
46 // rng setup
48
49 pRC::Bool worked = true;
50 pRC::Index at_try = 0;
51 do
52 {
53 std::ofstream file(filename,
54 std::ofstream::out | std::ofstream::trunc);
55
56 if(!file.is_open())
57 {
58 pRC::Logging::error("Unable to open output file!");
59 }
60
61 // print header to file
62 file << header << std::endl;
63
64 ++at_try;
65 Subscripts check{};
66 for(pRC::Index i = 0; i < size; ++i)
67 {
68 Subscripts sample{};
69
70 pRC::Tensor<T, D> transitionRates =
71 exp(extractDiagonal(smallThetaGT));
72
73 // simulate until we terminate
74 while(true)
75 {
76 // simulate one step
77 T const rateSum =
78 reduce<pRC::Add>(transitionRates)() + pRC::unit<T>();
79 T const rand = pRC::random<T>(rng, dist);
80 T sumRejected = pRC::zero();
81 pRC::Index newEvent = 0;
82 while(sumRejected + transitionRates(newEvent) <
83 rand * rateSum)
84 {
85 sumRejected += transitionRates(newEvent);
86 ++newEvent;
87 if(newEvent == D)
88 break;
89 }
90 if(newEvent == D)
91 break;
92 sample[newEvent] = 1;
93 check[newEvent] = 1;
94
95 // update transitionRates
96 transitionRates[newEvent] = pRC::zero();
97 for(pRC::Index j = 0; j < D; ++j)
98 {
99 transitionRates[j] *= exp(smallThetaGT(j, newEvent));
100 }
101 }
102 for(pRC::Index i = 0; i < D; ++i)
103 {
104 file << sample[i];
105 if(i != D - 1)
106 {
107 file << " ";
108 }
109 }
110 file << std::endl;
111 }
112
113 // check if there is a zero column
114 worked = true;
115 for(pRC::Index i = 0; i < D; ++i)
116 {
117 if(check[i] == 0)
118 {
120 "Zero column detected in generated dataset, "
121 "regenerating!");
122 worked = false;
123 }
124 }
125 file.close();
126 }
127 while(!worked && at_try < 10);
128
129 return;
130 }
131
145 template<class T, pRC::Size D>
146 static inline auto generateData(pRC::Tensor<T, D, D> const &smallThetaGT,
147 pRC::Size const &size, std::string const &header,
148 std::string const &filename)
149 {
150 // rng setup
151 pRC::SeedSequence seq(8, 16);
152 pRC::RandomEngine rng(seq);
153
154 return generateData(rng, smallThetaGT, size, header, filename);
155 }
156}
157
158#endif // cMHN_UTILITY_GENERATE_DATA_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition seq.hpp:13
Definition subscripts.hpp:20
Definition tensor.hpp:28
Definition threefry.hpp:24
Definition type_traits.hpp:49
Definition calculate_pTheta.hpp:16
static auto generateData(pRC::RandomEngine &rng, pRC::Tensor< T, D, D > const &smallThetaGT, pRC::Size const &size, std::string const &header, std::string const &filename)
Generates a data file from a given ground truth model.
Definition generate_data.hpp:34
static void warning(Xs &&...args)
Definition log.hpp:21
static void error(Xs &&...args)
Definition log.hpp:14
bool Bool
Definition type_traits.hpp:18
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
std::size_t Size
Definition type_traits.hpp:20
static constexpr auto zero()
Definition zero.hpp:12