cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
get_cross_val_splits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_UTILITY_GET_CROSS_VAL_SPLITS_H
4#define cMHN_UTILITY_GET_CROSS_VAL_SPLITS_H
5
6#include <fstream>
7#include <map>
8#include <random>
9#include <sstream>
10#include <string>
11#include <tuple>
12#include <vector>
13
14#include <prc.hpp>
15
16namespace cMHN
17{
29 template<class T, pRC::Size D>
30 static inline auto getCrossValSplits(std::string const &filename,
31 pRC::Index const &k)
32 {
33 std::ifstream file(filename);
34
35 using Subscripts =
37 [](auto const... seq)
38 {
39 return pRC::Subscripts<seq...>();
40 }));
41
42 if(!file.is_open())
43 {
44 pRC::Logging::error("Unable to open input file!");
45 }
46
47 std::vector<std::string> lines;
48
49 // number of samples
50 pRC::UnsignedInteger<64> totalSum = 0;
51
52 std::string line;
53
54 // first line contains header (event names)
55 std::getline(file, line);
56
57 // write samples to lines
58 while(std::getline(file, line))
59 {
60 lines.push_back(line);
61
63 }
64
65 // shuffle data
66 std::random_device rd;
67 std::mt19937 g(rd());
68 std::shuffle(lines.begin(), lines.end(), g);
69
70 std::vector<std::map<Subscripts, T>> pDs;
71 pRC::Index length = totalSum() / k;
72 pRC::Index remainder = totalSum() % k;
73
74 std::vector<pRC::Index> lengths(k, length);
75
76 while(remainder > 0)
77 {
78 ++lengths[remainder - 1];
79 --remainder;
80 }
81
82 pRC::Index minInd = 0;
83
84 // write the maps in pDs
85 for(pRC::Index ind = 0; ind < k; ++ind)
86 {
87 std::map<Subscripts, T> pD;
89
90 std::string line;
91 for(pRC::Index innerInd = 0; innerInd < lengths[ind]; ++innerInd)
92 {
93 line = lines[minInd + innerInd];
94
95 // turn commas into spaces
96 std::replace(line.begin(), line.end(), ',', ' ');
97
98 std::istringstream iss(line);
99
100 // store event data for current sample in 'bits'
101 Subscripts bits;
102 std::size_t i = 0;
103 unsigned v;
104 while(iss >> v)
105 {
106 bits[i++] = v;
107 }
109 {
110 if(i != D)
111 {
113 "Number of events differs for input file and "
114 "binary. "
115 "File:",
116 i, "Binary:", D);
117 }
118 }
119
120 pD.try_emplace(bits, pRC::zero<T>());
121 pD[bits] += pRC::unit<T>();
122
124 }
125
126 minInd += lengths[ind];
127
128 // normalize pD
129 for(auto &[k, v] : pD)
130 {
131 v /= sum;
132 }
133 pDs.push_back(pD);
134 }
135
136 return std::make_tuple(pDs, lengths);
137 }
138} // namespace cMHN
139
140#endif // cMHN_UTILITY_GET_CROSS_VAL_SPLITS_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition integer.hpp:22
Definition subscripts.hpp:20
Definition calculate_pTheta.hpp:16
static auto getCrossValSplits(std::string const &filename, pRC::Index const &k)
Splits the samples in a dataset into k as equal as possible sized sets that can be used for k-fold cr...
Definition get_cross_val_splits.hpp:30
static void error(Xs &&...args)
Definition log.hpp:14
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
constexpr auto cDebugLevel
Definition config.hpp:46