cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
lbfgs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_LBFGS_H
4#define pRC_ALGORITHMS_OPTIMIZER_LBFGS_H
5
10
11namespace pRC::Optimizer
12{
13 template<class LS = LineSearch::Bracketing, Size M = 5>
14 class LBFGS
15 {
16 private:
17 static constexpr Size defaultMaxIterations()
18 {
19 return 1000;
20 }
21
22 template<class G, class T>
23 static constexpr auto projectedGradientConverged(G const &g,
24 T const &tolerance)
25 {
26 auto const infNorm = norm<2, 0>(g)();
27
28 return infNorm <= tolerance * identity<T>(1e-3);
29 }
30
31 template<class F, class T>
32 static constexpr auto valueConverged(F const &f0, F const &f,
33 T const &tolerance)
34 {
35 auto const scale = max(abs(f0), abs(f), identity<T>());
36
37 return delta(f0, f) <= tolerance * scale;
38 }
39
40 template<class F>
41 static constexpr auto valueDiverged(F const &f)
42 {
43 return !isFinite(f);
44 }
45
46 template<class F>
47 static constexpr auto valueIncreased(F const &f0, F const &f)
48 {
49 return f > f0;
50 }
51
52 template<class S, class Y, class R, class H>
53 static constexpr auto resetHistory(S &s, Y &y, R &rho, H &H0)
54 {
55 s.clear();
56 y.clear();
57 rho.clear();
58 H0 = identity();
59
60 return;
61 }
62
63 template<class S, class Y, class R, class H, class G>
64 static constexpr auto applyHessianMatrix(S const &s, Y const &y,
65 R const &rho, H const &H0, Size const size, G const &g)
66 {
67 using T = ResultOf<Mul, typename R::Type,
69 Deque<T, M> alpha;
70
71 Tensor q = g;
72 for(Index i = 0; i < size; ++i)
73 {
74 alpha.pushFront(rho.back(i) * scalarProduct(s.back(i), q)());
75 q -= alpha.back(i) * y.back(i);
76 }
77
78 q *= H0;
79
80 for(Index i = 0; i < size; ++i)
81 {
82 auto const beta = rho.front(i) * scalarProduct(y.front(i), q)();
83 q += s.front(i) * (alpha.front(i) - beta);
84 }
85
86 return q;
87 }
88
89 public:
90 constexpr LBFGS(LS const &lineSearch,
91 Size const maxIterations = defaultMaxIterations())
92 : mLineSearch(lineSearch)
93 , mMaxIterations(maxIterations)
94 {
95 }
96
97 constexpr LBFGS(Size const maxIterations = defaultMaxIterations())
98 : mMaxIterations(maxIterations)
99 {
100 }
101
102 constexpr auto &lineSearch() const
103 {
104 return mLineSearch;
105 }
106
107 constexpr auto maxIterations() const
108 {
109 return mMaxIterations;
110 }
111
112 template<class XX, class FF, class FC,
115 IsFloat VX = Value<RX>, IsFloat VT = VX>
121 RXE const &>
122 inline constexpr auto operator()(XX &&x0, FF &&function, FC &&callback,
123 VT const &tolerance = NumericLimits<VT>::tolerance()) const
124 {
125 using TX = typename RX::Type;
126
127 decltype(auto) x =
128 copy<!(!IsReference<XX> && !IsConst<RX>)>(eval(x0));
129
130 RXE g;
131 auto f = function(x, g);
132
133 Logging::info("L-BFGS initial f(x) =", f);
134
135 if(projectedGradientConverged(g, tolerance))
136 {
137 return x;
138 }
139
142 Deque<TX, M> rho;
143 TX H0 = identity();
144
145 auto alpha = identity<TX>();
146 for(Index iteration = 0;;)
147 {
148 Tensor p = -applyHessianMatrix(s, y, rho, H0, s.size(), g);
149
150 auto d = scalarProduct(p, g)();
151 if(d > zero())
152 {
153 resetHistory(s, y, rho, H0);
154 p = -g;
155 d = -norm<2, 1>(g)();
156 }
157
158 auto const f0 = f;
159 auto const g0 = g;
160 auto const x0 = x;
161
162 alpha = lineSearch()(x, f, g, d, function, p, alpha);
163
164 callback(x);
165
166 if(valueDiverged(f))
167 {
168 Logging::info("L-BFGS diverged at f(x) =", f);
169 x = x0;
170 break;
171 }
172
173 if(valueIncreased(f0, f))
174 {
175 Logging::info("L-BFGS made no further progress at f(x) =",
176 f);
177 x = x0;
178 break;
179 }
180
181 if(++iteration; !(iteration < maxIterations()))
182 {
183 Logging::info("L-BFGS max iterations reached at f(x) =", f);
184 break;
185 }
186
187 if(valueConverged(f0, f, tolerance))
188 {
189 Logging::info("L-BFGS converged at f(x) =", f);
190 break;
191 }
192
193 if(projectedGradientConverged(g, tolerance))
194 {
195 Logging::info("L-BFGS converged at f(x) =", f);
196 break;
197 }
198
199 Logging::info("L-BFGS current f(x) =", f);
200
201 s.pushBack(alpha * p);
202 y.pushBack(g - g0);
203 rho.pushBack(rcp(scalarProduct(y.back(), s.back()))());
204 H0 = rcp(rho.back() * norm<2, 1>(y.back())());
205 }
206
207 callback(x);
208
209 if constexpr(IsReference<decltype(x)>)
210 {
211 return forward<XX>(x0);
212 }
213 else
214 {
215 return x;
216 }
217 }
218
219 private:
220 LS const mLineSearch;
221 Size const mMaxIterations;
222 };
223}
224#endif // pRC_ALGORITHMS_OPTIMIZER_LBFGS_H
Definition deque.hpp:15
constexpr auto size() const
Definition deque.hpp:28
constexpr decltype(auto) front(Index const position=0) &&
Definition deque.hpp:33
constexpr auto pushFront(R const &element) &&
Definition deque.hpp:81
constexpr decltype(auto) back(Index const position=0) &&
Definition deque.hpp:53
constexpr auto pushBack(R const &element) &&
Definition deque.hpp:135
Definition value.hpp:12
Definition lbfgs.hpp:15
constexpr LBFGS(LS const &lineSearch, Size const maxIterations=defaultMaxIterations())
Definition lbfgs.hpp:90
constexpr auto operator()(XX &&x0, FF &&function, FC &&callback, VT const &tolerance=NumericLimits< VT >::tolerance()) const
Definition lbfgs.hpp:122
constexpr auto maxIterations() const
Definition lbfgs.hpp:107
constexpr auto & lineSearch() const
Definition lbfgs.hpp:102
constexpr LBFGS(Size const maxIterations=defaultMaxIterations())
Definition lbfgs.hpp:97
Definition tensor.hpp:25
Definition concepts.hpp:25
Definition value.hpp:24
Definition concepts.hpp:31
Definition concepts.hpp:19
Definition declarations.hpp:27
Definition declarations.hpp:45
TN::Subscripts S
Definition externs_nonTT.hpp:9
int i
Definition gmock-matchers-comparisons_test.cc:603
const double y
Definition gmock-matchers-containers_test.cc:377
int x
Definition gmock-matchers-containers_test.cc:376
const char * p
Definition gmock-matchers-containers_test.cc:379
static void info(Xs &&...args)
Definition log.hpp:27
Definition bfgs.hpp:11
static constexpr auto isFinite(T const &a)
Definition is_finite.hpp:13
static constexpr Conditional< C, RemoveConstReference< X >, RemoveConst< X > > copy(X &&a)
Definition copy.hpp:13
static constexpr auto rcp(T const &b)
Definition rcp.hpp:12
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
std::invoke_result_t< F, Args... > ResultOf
Definition basics.hpp:59
std::remove_reference_t< T > RemoveReference
Definition basics.hpp:41
static constexpr auto abs(T const &a)
Definition abs.hpp:11
typename ValueType< T >::Type Value
Definition value.hpp:72
static constexpr auto scalarProduct(TA const &a, TB const &b)
Definition scalar_product.hpp:11
static constexpr auto delta(TA const &a, TB const &b)
Definition delta.hpp:11
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr decltype(auto) eval(X &&a)
Definition eval.hpp:12
static constexpr auto norm(T const &a)
Definition norm.hpp:12
static constexpr decltype(auto) max(X &&a)
Definition max.hpp:13
Definition mul.hpp:12
Definition limits.hpp:13