cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
lbfgs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_LBFGS_H
4#define pRC_ALGORITHMS_OPTIMIZER_LBFGS_H
5
6#include <prc/config.hpp>
15
16namespace pRC::Optimizer
17{
18 template<class LS = LineSearch::Bracketing, Size M = 5>
19 class LBFGS
20 {
21 private:
22 static constexpr Size defaultMaxIterations()
23 {
24 return 1000;
25 }
26
27 template<class G, class T>
28 static constexpr auto projectedGradientConverged(G const &g,
29 T const &tolerance)
30 {
31 auto const infNorm = norm<2, 0>(g)();
32
34 }
35
36 template<class F, class T>
37 static constexpr auto valueConverged(F const &f0, F const &f,
38 T const &tolerance)
39 {
40 auto const scale = max(abs(f0), abs(f), identity<T>());
41
42 return delta(f0, f) <= tolerance * scale;
43 }
44
45 template<class F>
46 static constexpr auto valueDiverged(F const &f)
47 {
48 return !isFinite(f);
49 }
50
51 template<class F>
52 static constexpr auto valueIncreased(F const &f0, F const &f)
53 {
54 return f > f0;
55 }
56
57 template<class S, class Y, class R, class H>
58 static constexpr auto resetHistory(S &s, Y &y, R &rho, H &H0)
59 {
60 s.clear();
61 y.clear();
62 rho.clear();
63 H0 = identity();
64
65 return;
66 }
67
68 template<class S, class Y, class R, class H, class G>
69 static constexpr auto applyHessianMatrix(S const &s, Y const &y,
70 R const &rho, H const &H0, Size const size, G const &g)
71 {
72 using T = ResultOf<Mul, typename R::Type,
75
76 Tensor q = g;
77 for(Index i = 0; i < size; ++i)
78 {
79 alpha.pushFront(rho.back(i) * scalarProduct(s.back(i), q)());
80 q -= alpha.back(i) * y.back(i);
81 }
82
83 q *= H0;
84
85 for(Index i = 0; i < size; ++i)
86 {
87 auto const beta = rho.front(i) * scalarProduct(y.front(i), q)();
88 q += s.front(i) * (alpha.front(i) - beta);
89 }
90
91 return q;
92 }
93
94 public:
95 constexpr LBFGS(LS const &lineSearch,
96 Size const maxIterations = defaultMaxIterations())
97 : mLineSearch(lineSearch)
98 , mMaxIterations(maxIterations)
99 {
100 }
101
102 constexpr LBFGS(Size const maxIterations = defaultMaxIterations())
103 : mMaxIterations(maxIterations)
104 {
105 }
106
107 constexpr auto &lineSearch() const
108 {
109 return mLineSearch;
110 }
111
112 constexpr auto maxIterations() const
113 {
114 return mMaxIterations;
115 }
116
117 template<class XX, class RX = RemoveReference<XX>,
118 class TX = typename RX::Type, class VX = typename TX::Value,
119 If<IsTensorish<RX>> = 0,
120 class RXE = RemoveConstReference<ResultOf<Eval, XX>>, class FF,
121 If<IsInvocable<FF, RXE const &, RXE &>> = 0,
122 If<IsFloat<ResultOf<FF, RXE const &, RXE &>>> = 0, class FC,
123 If<IsInvocable<FC, RXE>> = 0, class VT = VX,
124 If<All<IsFloat<VX>, IsFloat<VT>>> = 0,
125 If<IsInvocable<LS, RXE &, ResultOf<FF, RXE const &, RXE &> &, RXE &,
126 VX &, FF, RXE const &>> = 0>
127 inline constexpr auto operator()(XX &&x0, FF &&function, FC &&callback,
128 VT const &tolerance = NumericLimits<VT>::tolerance()) const
129 {
130 decltype(auto) x =
131 copy<!(!IsReference<XX>() && !IsConst<RX>())>(eval(x0));
132
133 RXE g;
134 auto f = function(x, g);
135
136 Logging::info("L-BFGS initial f(x) =", f);
137
138 if(projectedGradientConverged(g, tolerance))
139 {
140 return x;
141 }
142
146 TX H0 = identity();
147
148 auto alpha = identity<TX>();
149 for(Index iteration = 0;;)
150 {
151 Tensor p = -applyHessianMatrix(s, y, rho, H0, s.size(), g);
152
153 auto d = scalarProduct(p, g)();
154 if(d > zero())
155 {
156 resetHistory(s, y, rho, H0);
157 p = -g;
158 d = -norm<2, 1>(g)();
159 }
160
161 auto const f0 = f;
162 auto const g0 = g;
163 auto const x0 = x;
164
165 alpha = lineSearch()(x, f, g, d, function, p, alpha);
166
167 callback(x);
168
169 if(valueDiverged(f))
170 {
171 Logging::info("L-BFGS diverged at f(x) =", f);
172 x = x0;
173 break;
174 }
175
176 if(valueIncreased(f0, f))
177 {
178 Logging::info("L-BFGS made no further progress at f(x) =",
179 f);
180 x = x0;
181 break;
182 }
183
184 if(++iteration; !(iteration < maxIterations()))
185 {
186 Logging::info("L-BFGS max iterations reached at f(x) =", f);
187 break;
188 }
189
190 if(valueConverged(f0, f, tolerance))
191 {
192 Logging::info("L-BFGS converged at f(x) =", f);
193 break;
194 }
195
196 if(projectedGradientConverged(g, tolerance))
197 {
198 Logging::info("L-BFGS converged at f(x) =", f);
199 break;
200 }
201
202 Logging::info("L-BFGS current f(x) =", f);
203
204 s.pushBack(alpha * p);
205 y.pushBack(g - g0);
206 rho.pushBack(rcp(scalarProduct(y.back(), s.back()))());
207 H0 = rcp(rho.back() * norm<2, 1>(y.back())());
208 }
209
210 callback(x);
211
212 if constexpr(IsReference<decltype(x)>())
213 {
214 return forward<XX>(x0);
215 }
216 else
217 {
218 return x;
219 }
220 }
221
222 private:
223 LS const mLineSearch;
224 Size const mMaxIterations;
225 };
226}
227#endif // pRC_ALGORITHMS_OPTIMIZER_LBFGS_H
Definition deque.hpp:15
constexpr auto size() const
Definition deque.hpp:28
constexpr auto pushBack(R const &element) &&
Definition deque.hpp:133
constexpr decltype(auto) back(Index const position=0) &&
Definition deque.hpp:53
Definition lbfgs.hpp:20
constexpr LBFGS(LS const &lineSearch, Size const maxIterations=defaultMaxIterations())
Definition lbfgs.hpp:95
constexpr auto maxIterations() const
Definition lbfgs.hpp:112
constexpr auto & lineSearch() const
Definition lbfgs.hpp:107
constexpr auto operator()(XX &&x0, FF &&function, FC &&callback, VT const &tolerance=NumericLimits< VT >::tolerance()) const
Definition lbfgs.hpp:127
constexpr LBFGS(Size const maxIterations=defaultMaxIterations())
Definition lbfgs.hpp:102
Definition tensor.hpp:28
TN::Subscripts S
Definition externs_nonTT.hpp:9
static void info(Xs &&...args)
Definition log.hpp:27
Definition bfgs.hpp:17
static constexpr auto isFinite(T const &a)
Definition is_finite.hpp:13
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr auto rcp(Complex< T > const &b)
Definition rcp.hpp:13
std::size_t Size
Definition type_traits.hpp:20
std::invoke_result_t< F, Args... > ResultOf
Definition type_traits.hpp:140
static constexpr auto zero()
Definition zero.hpp:12
std::is_reference< T > IsReference
Definition type_traits.hpp:47
static constexpr auto delta(Complex< TA > const &a, Complex< TB > const &b)
Definition delta.hpp:12
static constexpr auto abs(Complex< T > const &a)
Definition abs.hpp:12
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
static constexpr auto scalarProduct(Complex< TA > const &a, Complex< TB > const &b)
Definition scalar_product.hpp:13
static constexpr auto identity()
Definition identity.hpp:12
static constexpr X max(X &&a)
Definition max.hpp:13
Definition mul.hpp:12
Definition limits.hpp:13