cMHN 1.0
C++ library for learning MHNs with pRC
regulator.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef cMHN_COMMON_REGULATOR_H
4#define cMHN_COMMON_REGULATOR_H
5
6#include <functional>
7#include <string>
8
9#include <prc.hpp>
10
11namespace cMHN
12{
28 template<class T, pRC::Size D>
30 {
31 public:
32 ~Regulator() = default;
33 Regulator(Regulator const &) = default;
34 Regulator(Regulator &&) = default;
35 Regulator &operator=(Regulator const &) & = default;
36 Regulator &operator=(Regulator &&) & = default;
37 Regulator() = delete;
38
39 explicit Regulator(auto const &score, auto const &grad,
40 auto const &name, auto const &lambda = T(1e-2))
41 : mScore(score)
42 , mGrad(grad)
43 , mLambda(lambda)
44 , mName(name)
45 {
46 }
47
48 auto score(pRC::Tensor<T, D, D> const &theta) const
49 {
50 return mLambda * mScore(theta);
51 }
52
53 auto grad(pRC::Tensor<T, D, D> const &theta) const
54 {
55 return mLambda * mGrad(theta);
56 }
57
58 auto &lambda()
59 {
60 return mLambda;
61 }
62
63 auto const lambda() const
64 {
65 return mLambda;
66 }
67
68 auto name() const
69 {
70 return mName;
71 }
72
73 private:
74 std::function<T(pRC::Tensor<T, D, D>)> mScore;
75 std::function<pRC::Tensor<T, D, D>(pRC::Tensor<T, D, D>)> mGrad;
76 T mLambda;
77 std::string mName;
78 };
79
92 template<class T, pRC::Size D>
93 class L1Regulator : public Regulator<T, D>
94 {
95 public:
96 ~L1Regulator() = default;
97 L1Regulator(L1Regulator const &) = default;
98 L1Regulator(L1Regulator &&) = default;
99 L1Regulator &operator=(L1Regulator const &) & = default;
101
102 explicit L1Regulator(T const &lambda)
103 : Regulator<T, D>(
104 [](pRC::Tensor<T, D, D> const &theta)
105 {
106 return pRC::norm<1>(pRC::offDiagonal(theta))();
107 },
108 [](pRC::Tensor<T, D, D> const &theta)
109 {
110 return pRC::offDiagonal(pRC::sign(theta));
111 },
112 "L1", lambda)
113 {
114 }
115 };
116
129 template<class T, pRC::Size D>
130 class L2Regulator : public Regulator<T, D>
131 {
132 public:
133 ~L2Regulator() = default;
134 L2Regulator(L2Regulator const &) = default;
136 L2Regulator &operator=(L2Regulator const &) & = default;
138
139 explicit L2Regulator(T const &lambda)
140 : Regulator<T, D>(
141 [](pRC::Tensor<T, D, D> const &theta)
142 {
143 return contract<0, 1, 0, 1>(offDiagonal(theta),
144 offDiagonal(theta))();
145 },
146 [](pRC::Tensor<T, D, D> const &theta)
147 {
148 return pRC::offDiagonal(T(2) * theta);
149 },
150 "L2", lambda)
151 {
152 }
153 };
154}
155
156#endif // cMHN_COMMON_REGULATOR_H
pRC::Size const D
Definition: CalculatePThetaTests.cpp:9
Class storing an L1 Regulator, specializes the Regulator class.
Definition: regulator.hpp:94
L1Regulator(L1Regulator &&)=default
L1Regulator & operator=(L1Regulator &&) &=default
~L1Regulator()=default
L1Regulator & operator=(L1Regulator const &) &=default
L1Regulator(T const &lambda)
Definition: regulator.hpp:102
L1Regulator(L1Regulator const &)=default
Class storing an L2 Regulator, specializes the Regulator class.
Definition: regulator.hpp:131
L2Regulator(L2Regulator &&)=default
L2Regulator(L2Regulator const &)=default
~L2Regulator()=default
L2Regulator & operator=(L2Regulator const &) &=default
L2Regulator(T const &lambda)
Definition: regulator.hpp:139
L2Regulator & operator=(L2Regulator &&) &=default
Class storing all relevant information for a regulator.
Definition: regulator.hpp:30
~Regulator()=default
auto & lambda()
Definition: regulator.hpp:58
Regulator()=delete
Regulator(auto const &score, auto const &grad, auto const &name, auto const &lambda=T(1e-2))
Definition: regulator.hpp:39
auto const lambda() const
Definition: regulator.hpp:63
auto grad(pRC::Tensor< T, D, D > const &theta) const
Definition: regulator.hpp:53
Regulator(Regulator const &)=default
Regulator(Regulator &&)=default
auto name() const
Definition: regulator.hpp:68
Regulator & operator=(Regulator const &) &=default
Regulator & operator=(Regulator &&) &=default
auto score(pRC::Tensor< T, D, D > const &theta) const
Definition: regulator.hpp:48
pRC::Float<> T
Definition: externs_nonTT.hpp:1
Definition: calculate_pTheta.hpp:15