cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
gradient_descent.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_GRADIENT_DESCENT_H
4#define pRC_ALGORITHMS_OPTIMIZER_GRADIENT_DESCENT_H
5
9
10namespace pRC::Optimizer
11{
12 template<class LS = LineSearch::Fixed>
14 {
15 private:
16 static constexpr Size defaultMaxIterations()
17 {
18 return 1000;
19 }
20
21 template<class G, class T>
22 static constexpr auto projectedGradientConverged(G const &g,
23 T const &tolerance)
24 {
25 auto const infNorm = norm<2, 0>(g)();
26
27 return infNorm <= tolerance * identity<T>(1e-3);
28 }
29
30 template<class F, class T>
31 static constexpr auto valueConverged(F const &f0, F const &f,
32 T const &tolerance)
33 {
34 auto const scale = max(abs(f0), abs(f), identity<T>());
35
36 return delta(f0, f) <= tolerance * scale;
37 }
38
39 template<class F>
40 static constexpr auto valueDiverged(F const &f)
41 {
42 return !isFinite(f);
43 }
44
45 template<class F>
46 static constexpr auto valueIncreased(F const &f0, F const &f)
47 {
48 return f > f0;
49 }
50
51 public:
52 constexpr GradientDescent(LS const &lineSearch,
53 Size const maxIterations = defaultMaxIterations())
54 : mLineSearch(lineSearch)
55 , mMaxIterations(maxIterations)
56 {
57 }
58
59 constexpr GradientDescent(
60 Size const maxIterations = defaultMaxIterations())
61 : mMaxIterations(maxIterations)
62 {
63 }
64
65 constexpr auto &lineSearch() const
66 {
67 return mLineSearch;
68 }
69
70 constexpr auto maxIterations() const
71 {
72 return mMaxIterations;
73 }
74
75 template<class XX, class FF, class FC,
78 IsFloat VX = Value<RX>, IsFloat VT = VX>
84 RXE const &>
85 inline constexpr auto operator()(XX &&x0, FF &&function, FC &&callback,
86 VT const &tolerance = NumericLimits<VT>::tolerance()) const
87 {
88 using TX = typename RX::Type;
89
90 decltype(auto) x =
92
93 RXE g;
94 auto f = function(x, g);
95
96 Logging::info("Gradient Descent initial f(x) =", f);
97
98 if(projectedGradientConverged(g, tolerance))
99 {
100 return x;
101 }
102
103 auto alpha = identity<TX>();
104 for(Index iteration = 0;;)
105 {
106 auto const f0 = f;
107 auto const g0 = g;
108 auto const x0 = x;
109
110 Tensor p = -g;
111 auto d = -norm<2, 1>(g)();
112
113 alpha = lineSearch()(x, f, g, d, function, p, alpha);
114
115 callback(x);
116
117 if(valueDiverged(f))
118 {
119 Logging::info("Gradient Descent diverged at f(x) =", f);
120 x = x0;
121 break;
122 }
123
124 /*if(valueIncreased(f0, f))
125 {
126 Logging::info("Gradient Descent made no further progress at
127 f(x) =", f); x = x0; break;
128 }*/
129
130 if(++iteration; !(iteration < maxIterations()))
131 {
133 "Gradient Descent max iterations reached at f(x) =", f);
134 break;
135 }
136
137 if(valueConverged(f0, f, tolerance))
138 {
139 Logging::info("Gradient Descent converged at f(x) =", f);
140 break;
141 }
142
143 if(projectedGradientConverged(g, tolerance))
144 {
145 Logging::info("Gradient Descent converged at f(x) =", f);
146 break;
147 }
148
149 Logging::info("Gradient Descent current f(x) =", f);
150 }
151
152 callback(x);
153
154 if constexpr(IsReference<decltype(x)>)
155 {
156 return forward<XX>(x0);
157 }
158 else
159 {
160 return x;
161 }
162 }
163
164 private:
165 LS const mLineSearch;
166 Size const mMaxIterations;
167 };
168}
169#endif // pRC_ALGORITHMS_OPTIMIZER_GRADIENT_DESCENT_H
Definition value.hpp:12
Definition gradient_descent.hpp:14
constexpr auto operator()(XX &&x0, FF &&function, FC &&callback, VT const &tolerance=NumericLimits< VT >::tolerance()) const
Definition gradient_descent.hpp:85
constexpr GradientDescent(LS const &lineSearch, Size const maxIterations=defaultMaxIterations())
Definition gradient_descent.hpp:52
constexpr auto & lineSearch() const
Definition gradient_descent.hpp:65
constexpr GradientDescent(Size const maxIterations=defaultMaxIterations())
Definition gradient_descent.hpp:59
constexpr auto maxIterations() const
Definition gradient_descent.hpp:70
Definition tensor.hpp:25
Definition concepts.hpp:25
Definition value.hpp:24
Definition concepts.hpp:31
Definition concepts.hpp:19
Definition declarations.hpp:27
Definition declarations.hpp:45
int x
Definition gmock-matchers-containers_test.cc:376
const char * p
Definition gmock-matchers-containers_test.cc:379
static void info(Xs &&...args)
Definition log.hpp:27
Definition bfgs.hpp:11
static constexpr auto isFinite(T const &a)
Definition is_finite.hpp:13
static constexpr Conditional< C, RemoveConstReference< X >, RemoveConst< X > > copy(X &&a)
Definition copy.hpp:13
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
std::invoke_result_t< F, Args... > ResultOf
Definition basics.hpp:59
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr auto abs(T const &a)
Definition abs.hpp:11
typename ValueType< T >::Type Value
Definition value.hpp:72
static constexpr auto delta(TA const &a, TB const &b)
Definition delta.hpp:11
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto identity()
Definition identity.hpp:13
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
static constexpr auto norm(T const &a)
Definition norm.hpp:12
static constexpr decltype(auto) max(X &&a)
Definition max.hpp:13
Definition limits.hpp:13