cMHN 1.0
C++ library for learning MHNs with pRC
get_cross_val_splits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_UTILITY_GET_CROSS_VAL_SPLITS_H
4#define cMHN_UTILITY_GET_CROSS_VAL_SPLITS_H
5
6#include <fstream>
7#include <map>
8#include <random>
9#include <sstream>
10#include <string>
11#include <tuple>
12#include <vector>
13
14#include <prc.hpp>
15
16namespace cMHN
17{
29 template<class T, pRC::Size D>
30 static inline auto getCrossValSplits(std::string const &filename,
31 pRC::Index const &k)
32 {
33 std::ifstream file(filename);
34
35 using Subscripts =
36 decltype(expand(pRC::makeConstantSequence<pRC::Size, D, 2>(),
37 [](auto const... seq)
38 {
39 return pRC::Subscripts<seq...>();
40 }));
41
42 if(!file.is_open())
43 {
44 pRC::Logging::error("Unable to open input file!");
45 }
46
47 std::vector<std::string> lines;
48
49 // number of samples
50 pRC::UnsignedInteger<64> totalSum = 0;
51
52 std::string line;
53
54 // first line contains header (event names)
55 std::getline(file, line);
56
57 // write samples to lines
58 while(std::getline(file, line))
59 {
60 lines.push_back(line);
61
62 totalSum += pRC::unit<decltype(totalSum)>();
63 }
64
65 // shuffle data
66 std::random_device rd;
67 std::mt19937 g(rd());
68 std::shuffle(lines.begin(), lines.end(), g);
69
70 std::vector<std::map<Subscripts, T>> pDs;
71 pRC::Index length = totalSum() / k;
72 pRC::Index remainder = totalSum() % k;
73
74 std::vector<pRC::Index> lengths(k, length);
75
76 while(remainder > 0)
77 {
78 ++lengths[remainder - 1];
79 --remainder;
80 }
81
82 pRC::Index minInd = 0;
83
84 // write the maps in pDs
85 for(pRC::Index ind = 0; ind < k; ++ind)
86 {
87 std::map<Subscripts, T> pD;
88 pRC::UnsignedInteger<64> sum = 0;
89
90 std::string line;
91 for(pRC::Index innerInd = 0; innerInd < lengths[ind]; ++innerInd)
92 {
93 line = lines[minInd + innerInd];
94
95 // turn commas into spaces
96 std::replace(line.begin(), line.end(), ',', ' ');
97
98 std::istringstream iss(line);
99
100 // store event data for current sample in 'bits'
101 Subscripts bits;
102 std::size_t i = 0;
103 unsigned v;
104 while(iss >> v)
105 {
106 bits[i++] = v;
107 }
108 if constexpr(pRC::cDebugLevel >= pRC::DebugLevel::Low)
109 {
110 if(i != D)
111 {
112 pRC::Logging::error(
113 "Number of events differs for input file and "
114 "binary. "
115 "File:",
116 i, "Binary:", D);
117 }
118 }
119
120 pD.try_emplace(bits, pRC::zero<T>());
121 pD[bits] += pRC::unit<T>();
122
123 sum += pRC::unit<decltype(sum)>();
124 }
125
126 minInd += lengths[ind];
127
128 // normalize pD
129 for(auto &[k, v] : pD)
130 {
131 v /= sum;
132 }
133 pDs.push_back(pD);
134 }
135
136 return std::make_tuple(pDs, lengths);
137 }
138} // namespace cMHN
139
140#endif // cMHN_UTILITY_GET_CROSS_VAL_SPLITS_H
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Definition: calculate_pTheta.hpp:15
static auto getCrossValSplits(std::string const &filename, pRC::Index const &k)
Splits the samples in a dataset into <it>k</it> as equal as possible sized sets that can be used for ...
Definition: get_cross_val_splits.hpp:30