cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
get_cross_val_splits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_UTILITY_GET_CROSS_VAL_SPLITS_H
4#define cMHN_UTILITY_GET_CROSS_VAL_SPLITS_H
5
6#include <algorithm>
7#include <fstream>
8#include <map>
9#include <random>
10#include <sstream>
11#include <string>
12#include <tuple>
13#include <vector>
14
15#include <prc.hpp>
16
17namespace cMHN
18{
30 template<class T, pRC::Size D>
31 static inline auto getCrossValSplits(std::string const &filename,
32 pRC::Index const &k)
33 {
34 std::ifstream file(filename);
35
36 using Subscripts =
38 [](auto const... seq)
39 {
40 return pRC::Subscripts<seq...>();
41 }));
42
43 if(!file.is_open())
44 {
45 pRC::Logging::error("Unable to open input file!");
46 }
47
48 std::vector<std::string> lines;
49
50 // number of samples
51 pRC::UnsignedInteger<64> totalSum = 0;
52
53 std::string line;
54
55 // first line contains header (event names)
56 std::getline(file, line);
57
58 // write samples to lines
59 while(std::getline(file, line))
60 {
61 lines.push_back(line);
62
64 }
65
66 // shuffle data
67 std::random_device rd;
68 std::mt19937 g(rd());
69 std::shuffle(lines.begin(), lines.end(), g);
70
71 std::vector<std::map<Subscripts, T>> pDs;
72 pRC::Index length = totalSum() / k;
73 pRC::Index remainder = totalSum() % k;
74
75 std::vector<pRC::Index> lengths(k, length);
76
77 while(remainder > 0)
78 {
79 ++lengths[remainder - 1];
80 --remainder;
81 }
82
83 pRC::Index minInd = 0;
84
85 // write the maps in pDs
86 for(pRC::Index ind = 0; ind < k; ++ind)
87 {
88 std::map<Subscripts, T> pD;
90
91 std::string line;
92 for(pRC::Index innerInd = 0; innerInd < lengths[ind]; ++innerInd)
93 {
94 line = lines[minInd + innerInd];
95
96 // turn commas into spaces
97 std::replace(line.begin(), line.end(), ',', ' ');
98
99 std::istringstream iss(line);
100
101 // store event data for current sample in 'bits'
102 Subscripts bits;
103 std::size_t i = 0;
104 unsigned v;
105 while(iss >> v)
106 {
107 bits[i++] = v;
108 }
110 {
111 if(i != D)
112 {
114 "Number of events differs for input file and "
115 "binary. "
116 "File:",
117 i, "Binary:", D);
118 }
119 }
120
121 pD.try_emplace(bits, pRC::zero<T>());
122 pD[bits] += pRC::unit<T>();
123
125 }
126
127 minInd += lengths[ind];
128
129 // normalize pD
130 for(auto &[k, v] : pD)
131 {
132 v /= sum;
133 }
134 pDs.push_back(pD);
135 }
136
137 return std::make_tuple(pDs, lengths);
138 }
139} // namespace cMHN
140
141#endif // cMHN_UTILITY_GET_CROSS_VAL_SPLITS_H
pRC::Size const D
Definition CalculatePThetaTests.cpp:9
Definition value.hpp:15
Definition subscripts.hpp:21
int i
Definition gmock-matchers-comparisons_test.cc:603
Definition calculate_pTheta.hpp:20
static auto getCrossValSplits(std::string const &filename, pRC::Index const &k)
Splits the samples in a dataset into k as equal as possible sized sets that can be used for k-fold cr...
Definition get_cross_val_splits.hpp:31
static void error(Xs &&...args)
Definition log.hpp:14
static constexpr auto unit()
Definition unit.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:444
Size Index
Definition basics.hpp:32
constexpr auto cDebugLevel
Definition config.hpp:48
static constexpr auto zero()
Definition zero.hpp:12