cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
gradient_descent.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_GRADIENT_DESCENT_H
4#define pRC_ALGORITHMS_OPTIMIZER_GRADIENT_DESCENT_H
5
6#include <prc/config.hpp>
15
16namespace pRC::Optimizer
17{
18 template<class LS = LineSearch::Bracketing>
20 {
21 private:
22 static constexpr Size defaultMaxIterations()
23 {
24 return 1000;
25 }
26
27 template<class G, class T>
28 static constexpr auto projectedGradientConverged(G const &g,
29 T const &tolerance)
30 {
31 auto const infNorm = norm<2, 0>(g)();
32
34 }
35
36 template<class F, class T>
37 static constexpr auto valueConverged(F const &f0, F const &f,
38 T const &tolerance)
39 {
40 auto const scale = max(abs(f0), abs(f), identity<T>());
41
42 return delta(f0, f) <= tolerance * scale;
43 }
44
45 template<class F>
46 static constexpr auto valueDiverged(F const &f)
47 {
48 return !isFinite(f);
49 }
50
51 template<class F>
52 static constexpr auto valueIncreased(F const &f0, F const &f)
53 {
54 return f > f0;
55 }
56
57 public:
58 constexpr GradientDescent(LS const &lineSearch,
59 Size const maxIterations = defaultMaxIterations())
60 : mLineSearch(lineSearch)
61 , mMaxIterations(maxIterations)
62 {
63 }
64
65 constexpr GradientDescent(
66 Size const maxIterations = defaultMaxIterations())
67 : mMaxIterations(maxIterations)
68 {
69 }
70
71 constexpr auto &lineSearch() const
72 {
73 return mLineSearch;
74 }
75
76 constexpr auto maxIterations() const
77 {
78 return mMaxIterations;
79 }
80
81 template<class XX, class RX = RemoveReference<XX>,
82 class TX = typename RX::Type, class VX = typename TX::Value,
83 If<IsTensorish<RX>> = 0,
84 class RXE = RemoveConstReference<ResultOf<Eval, XX>>, class FF,
85 If<IsInvocable<FF, RXE const &, RXE &>> = 0,
86 If<IsFloat<ResultOf<FF, RXE const &, RXE &>>> = 0, class FC,
87 If<IsInvocable<FC, RXE &>> = 0, class VT = VX,
88 If<All<IsFloat<VX>, IsFloat<VT>>> = 0,
89 If<IsInvocable<LS, RXE &, ResultOf<FF, RXE const &, RXE &> &, RXE &,
90 VX &, FF, RXE const &>> = 0>
91 inline constexpr auto operator()(XX &&x0, FF &&function, FC &&callback,
92 VT const &tolerance = NumericLimits<VT>::tolerance()) const
93 {
94 decltype(auto) x =
95 copy<!(!IsReference<XX>() && !IsConst<RX>())>(eval(x0));
96
97 RXE g;
98 auto f = function(x, g);
99
100 Logging::info("Gradient Descent initial f(x) =", f);
101
102 if(projectedGradientConverged(g, tolerance))
103 {
104 return x;
105 }
106
107 auto alpha = identity<TX>();
108 for(Index iteration = 0;;)
109 {
110 auto const f0 = f;
111 auto const g0 = g;
112 auto const x0 = x;
113
114 Tensor p = -g;
115 auto d = -norm<2, 1>(g)();
116
117 alpha = lineSearch()(x, f, g, d, function, p, alpha);
118
119 callback(x);
120
121 if(valueDiverged(f))
122 {
123 Logging::info("Gradient Descent diverged at f(x) =", f);
124 x = x0;
125 break;
126 }
127
128 /*if(valueIncreased(f0, f))
129 {
130 Logging::info("Gradient Descent made no further progress at
131 f(x) =", f); x = x0; break;
132 }*/
133
134 if(++iteration; !(iteration < maxIterations()))
135 {
137 "Gradient Descent max iterations reached at f(x) =", f);
138 break;
139 }
140
141 if(valueConverged(f0, f, tolerance))
142 {
143 Logging::info("Gradient Descent converged at f(x) =", f);
144 break;
145 }
146
147 if(projectedGradientConverged(g, tolerance))
148 {
149 Logging::info("Gradient Descent converged at f(x) =", f);
150 break;
151 }
152
153 Logging::info("Gradient Descent current f(x) =", f);
154 }
155
156 callback(x);
157
158 if constexpr(IsReference<decltype(x)>())
159 {
160 return forward<XX>(x0);
161 }
162 else
163 {
164 return x;
165 }
166 }
167
168 private:
169 LS const mLineSearch;
170 Size const mMaxIterations;
171 };
172}
173#endif // pRC_ALGORITHMS_OPTIMIZER_GRADIENT_DESCENT_H
Definition gradient_descent.hpp:20
constexpr GradientDescent(LS const &lineSearch, Size const maxIterations=defaultMaxIterations())
Definition gradient_descent.hpp:58
constexpr auto & lineSearch() const
Definition gradient_descent.hpp:71
constexpr GradientDescent(Size const maxIterations=defaultMaxIterations())
Definition gradient_descent.hpp:65
constexpr auto operator()(XX &&x0, FF &&function, FC &&callback, VT const &tolerance=NumericLimits< VT >::tolerance()) const
Definition gradient_descent.hpp:91
constexpr auto maxIterations() const
Definition gradient_descent.hpp:76
Definition tensor.hpp:28
static void info(Xs &&...args)
Definition log.hpp:27
Definition bfgs.hpp:17
static constexpr auto isFinite(T const &a)
Definition is_finite.hpp:13
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
std::size_t Size
Definition type_traits.hpp:20
std::is_reference< T > IsReference
Definition type_traits.hpp:47
static constexpr auto delta(Complex< TA > const &a, Complex< TB > const &b)
Definition delta.hpp:12
static constexpr auto abs(Complex< T > const &a)
Definition abs.hpp:12
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
static constexpr X max(X &&a)
Definition max.hpp:13
Definition limits.hpp:13