cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
bfgs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_BFGS_H
4#define pRC_ALGORITHMS_OPTIMIZER_BFGS_H
5
9
11{
12 template<class LS = LineSearch::Bracketing>
13 class BFGS
14 {
15 private:
16 static constexpr Size defaultMaxIterations()
17 {
18 return 1000;
19 }
20
21 template<class G, class T>
22 static constexpr auto projectedGradientConverged(G const &g,
23 T const &tolerance)
24 {
25 auto const infNorm = norm<2, 0>(g)();
26
27 return infNorm <= tolerance * identity<T>(1e-3);
28 }
29
30 template<class F, class T>
31 static constexpr auto valueConverged(F const &f0, F const &f,
32 T const &tolerance)
33 {
34 auto const scale = max(abs(f0), abs(f), identity<T>());
35
36 return delta(f0, f) <= tolerance * scale;
37 }
38
39 template<class F>
40 static constexpr auto valueDiverged(F const &f)
41 {
42 return !isFinite(f);
43 }
44
45 template<class F>
46 static constexpr auto valueIncreased(F const &f0, F const &f)
47 {
48 return f > f0;
49 }
50
51 public:
52 constexpr BFGS(LS const &lineSearch,
53 Size const maxIterations = defaultMaxIterations())
54 : mLineSearch(lineSearch)
55 , mMaxIterations(maxIterations)
56 {
57 }
58
59 constexpr BFGS(Size const maxIterations = defaultMaxIterations())
60 : mMaxIterations(maxIterations)
61 {
62 }
63
64 constexpr auto &lineSearch() const
65 {
66 return mLineSearch;
67 }
68
69 constexpr auto maxIterations() const
70 {
71 return mMaxIterations;
72 }
73
74 template<class XX, class FF, class FC,
77 IsFloat VX = Value<RX>, IsFloat VT = VX>
83 RXE const &>
84 inline constexpr auto operator()(XX &&x0, FF &&function, FC &&callback,
85 VT const &tolerance = NumericLimits<VT>::tolerance()) const
86 {
87 using TX = typename RX::Type;
88
89 decltype(auto) x =
91
92 RXE g;
93 auto f = function(x, g);
94
95 Logging::debug("BFGS initial f(x) =", f);
96
97 if(projectedGradientConverged(g, tolerance))
98 {
99 return x;
100 }
101
103 [&](auto const... seq)
104 {
105 return identity<
106 Tensor<TX, RXE::size(seq)..., RXE::size(seq)...>>();
107 });
108
109 auto alpha = identity<TX>();
110 for(Index iteration = 0;;)
111 {
112 Tensor p = -H * g;
113
114 auto d = scalarProduct(p, g)();
115 if(d > zero())
116 {
117 H = identity();
118 p = -g;
119 d = -norm<2, 1>(g)();
120 }
121
122 auto const f0 = f;
123 auto const g0 = g;
124 auto const x0 = x;
125
126 alpha = lineSearch()(x, f, g, d, function, p, alpha);
127
128 callback(x);
129
130 if(valueDiverged(f))
131 {
132 Logging::info("BFGS diverged at f(x) =", f);
133 x = x0;
134 break;
135 }
136
137 if(valueIncreased(f0, f))
138 {
139 Logging::info("BFGS made no further progress at f(x) =", f);
140 x = x0;
141 break;
142 }
143
144 if(++iteration; !(iteration < maxIterations()))
145 {
146 Logging::info("BFGS max iterations reached at f(x) =", f);
147 break;
148 }
149
150 if(valueConverged(f0, f, tolerance))
151 {
152 Logging::info("BFGS converged at f(x) =", f);
153 break;
154 }
155
156 if(projectedGradientConverged(g, tolerance))
157 {
158 Logging::info("BFGS converged at f(x) =", f);
159 break;
160 }
161
162 Logging::debug("BFGS current f(x) =", f);
163
164 Tensor const s = alpha * p;
165 Tensor const y = g - g0;
166 Tensor const rho = rcp(scalarProduct(y, s));
167
168 Tensor const V =
170 H = eval(transpose(V) * H) * V + rho * tensorProduct(s, s);
171 }
172
173 callback(x);
174
175 if constexpr(IsReference<decltype(x)>)
176 {
177 return forward<XX>(x0);
178 }
179 else
180 {
181 return x;
182 }
183 }
184
185 private:
186 LS const mLineSearch;
187 Size const mMaxIterations;
188 };
189}
190#endif // pRC_ALGORITHMS_OPTIMIZER_BFGS_H
Definition value.hpp:12
Definition bfgs.hpp:14
constexpr BFGS(Size const maxIterations=defaultMaxIterations())
Definition bfgs.hpp:59
constexpr auto operator()(XX &&x0, FF &&function, FC &&callback, VT const &tolerance=NumericLimits< VT >::tolerance()) const
Definition bfgs.hpp:84
constexpr BFGS(LS const &lineSearch, Size const maxIterations=defaultMaxIterations())
Definition bfgs.hpp:52
constexpr auto maxIterations() const
Definition bfgs.hpp:69
constexpr auto & lineSearch() const
Definition bfgs.hpp:64
Definition tensor.hpp:25
Definition concepts.hpp:25
Definition value.hpp:24
Definition concepts.hpp:31
Definition concepts.hpp:19
Definition declarations.hpp:27
Definition declarations.hpp:45
const double y
Definition gmock-matchers-containers_test.cc:377
int x
Definition gmock-matchers-containers_test.cc:376
const char * p
Definition gmock-matchers-containers_test.cc:379
static void info(Xs &&...args)
Definition log.hpp:27
static void debug(Xs &&...args)
Definition log.hpp:33
Definition bfgs.hpp:11
static constexpr auto isFinite(T const &a)
Definition is_finite.hpp:13
static constexpr Conditional< C, RemoveConstReference< X >, RemoveConst< X > > copy(X &&a)
Definition copy.hpp:13
static constexpr auto rcp(T const &b)
Definition rcp.hpp:12
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
std::invoke_result_t< F, Args... > ResultOf
Definition basics.hpp:59
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr auto transpose(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:306
static constexpr auto abs(T const &a)
Definition abs.hpp:11
typename ValueType< T >::Type Value
Definition value.hpp:72
static constexpr auto scalarProduct(TA const &a, TB const &b)
Definition scalar_product.hpp:11
static constexpr auto makeSeries()
Definition sequence.hpp:390
static constexpr auto delta(TA const &a, TB const &b)
Definition delta.hpp:11
static constexpr auto tensorProduct(XA &&a, XB &&b)
Definition tensor_product.hpp:17
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:383
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
static constexpr auto norm(T const &a)
Definition norm.hpp:12
static constexpr decltype(auto) max(X &&a)
Definition max.hpp:13
Definition limits.hpp:13