pRC
multi-purpose Tensor Train library for C++
Loading...
Searching...
No Matches
lbfgs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_LBFGS_H
4#define pRC_ALGORITHMS_OPTIMIZER_LBFGS_H
5
6#include <prc/config.hpp>
14
15namespace pRC::Optimizer
16{
17 template<class LS = LineSearch::Bracketing, Size M = 5>
18 class LBFGS
19 {
20 private:
21 static constexpr Size defaultMaxIterations()
22 {
23 return 1000;
24 }
25
26 template<class G, class T>
27 static constexpr auto projectedGradientConverged(G const &g,
28 T const &tolerance)
29 {
30 auto const infNorm = norm<2, 0>(g)();
31
33 }
34
35 template<class F, class T>
36 static constexpr auto valueConverged(F const &f0, F const &f,
37 T const &tolerance)
38 {
39 auto const scale = max(abs(f0), abs(f), identity<T>());
40
41 return delta(f0, f) <= tolerance * scale;
42 }
43
44 template<class S, class Y, class R, class H>
45 static constexpr auto resetHistory(S &s, Y &y, R &rho, H &H0)
46 {
47 s.clear();
48 y.clear();
49 rho.clear();
50 H0 = identity();
51
52 return;
53 }
54
55 template<class S, class Y, class R, class H, class G>
56 static constexpr auto applyHessianMatrix(S const &s, Y const &y,
57 R const &rho, H const &H0, Size const size, G const &g)
58 {
59 using T = ResultOf<Mul, typename R::Type,
62
63 Tensor q = g;
64 for(Index i = 0; i < size; ++i)
65 {
66 alpha.pushFront(rho.back(i) * scalarProduct(s.back(i), q)());
67 q -= alpha.back(i) * y.back(i);
68 }
69
70 q *= H0;
71
72 for(Index i = 0; i < size; ++i)
73 {
74 auto const beta = rho.front(i) * scalarProduct(y.front(i), q)();
75 q += s.front(i) * (alpha.front(i) - beta);
76 }
77
78 return q;
79 }
80
81 public:
82 constexpr LBFGS(LS const &lineSearch,
83 Size const maxIterations = defaultMaxIterations())
84 : mLineSearch(lineSearch)
85 , mMaxIterations(maxIterations)
86 {
87 }
88
89 constexpr LBFGS(Size const maxIterations = defaultMaxIterations())
90 : mMaxIterations(maxIterations)
91 {
92 }
93
94 constexpr auto &lineSearch() const
95 {
96 return mLineSearch;
97 }
98
99 constexpr auto maxIterations() const
100 {
101 return mMaxIterations;
102 }
103
104 template<class XX, class RX = RemoveReference<XX>,
105 class TX = typename RX::Type, class VX = typename TX::Value,
106 If<IsTensorish<RX>> = 0,
107 class RXE = RemoveConstReference<ResultOf<Eval, XX>>, class FF,
108 If<IsInvocable<FF, RXE const &, RXE &>> = 0,
109 If<IsFloat<ResultOf<FF, RXE const &, RXE &>>> = 0, class FC,
110 If<IsInvocable<FC, RXE>> = 0, class VT = VX,
111 If<All<IsFloat<VX>, IsFloat<VT>>> = 0,
112 If<IsInvocable<LS, RXE &, ResultOf<FF, RXE const &, RXE &> &, RXE &,
113 VX &, FF, RXE const &>> = 0>
114 inline constexpr auto operator()(XX &&x0, FF &&function, FC &&callback,
115 VT const &tolerance = NumericLimits<VT>::tolerance()) const
116 {
117 decltype(auto) x =
118 copy<!(!IsReference<XX>() && !IsConst<RX>())>(eval(x0));
119
120 RXE g;
121 auto f = function(x, g);
122
123 Logging::info("L-BFGS initial f(x) =", f);
124
125 if(projectedGradientConverged(g, tolerance))
126 {
127 return x;
128 }
129
133 TX H0 = identity();
134
135 auto alpha = identity<TX>();
136 for(Index iteration = 0;;)
137 {
138 Tensor p = -applyHessianMatrix(s, y, rho, H0, s.size(), g);
139
140 auto d = scalarProduct(p, g)();
141 if(d > zero())
142 {
143 resetHistory(s, y, rho, H0);
144 p = -g;
145 d = -norm<2, 1>(g)();
146 }
147
148 auto const f0 = f;
149 auto const g0 = g;
150
151 alpha = lineSearch()(x, f, g, d, function, p, alpha);
152
153 callback(x);
154
155 if(++iteration; !(iteration < maxIterations()))
156 {
157 Logging::info("L-BFGS max iterations reached at f(x) =", f);
158 break;
159 }
160
161 if(valueConverged(f0, f, tolerance))
162 {
163 Logging::info("L-BFGS converged at f(x) =", f);
164 break;
165 }
166
167 if(projectedGradientConverged(g, tolerance))
168 {
169 Logging::info("L-BFGS converged at f(x) =", f);
170 break;
171 }
172
173 Logging::info("L-BFGS current f(x) =", f);
174
175 s.pushBack(alpha * p);
176 y.pushBack(g - g0);
177 rho.pushBack(rcp(scalarProduct(y.back(), s.back()))());
178 H0 = rcp(rho.back() * norm<2, 1>(y.back())());
179 }
180
181 if constexpr(IsReference<decltype(x)>())
182 {
183 return forward<XX>(x0);
184 }
185 else
186 {
187 return x;
188 }
189 }
190
191 private:
192 LS const mLineSearch;
193 Size const mMaxIterations;
194 };
195}
196#endif // pRC_ALGORITHMS_OPTIMIZER_LBFGS_H
Definition deque.hpp:15
constexpr auto size() const
Definition deque.hpp:28
constexpr auto pushBack(R const &element) &&
Definition deque.hpp:133
constexpr decltype(auto) back(Index const position=0) &&
Definition deque.hpp:53
Definition lbfgs.hpp:19
constexpr LBFGS(LS const &lineSearch, Size const maxIterations=defaultMaxIterations())
Definition lbfgs.hpp:82
constexpr auto maxIterations() const
Definition lbfgs.hpp:99
constexpr auto & lineSearch() const
Definition lbfgs.hpp:94
constexpr auto operator()(XX &&x0, FF &&function, FC &&callback, VT const &tolerance=NumericLimits< VT >::tolerance()) const
Definition lbfgs.hpp:114
constexpr LBFGS(Size const maxIterations=defaultMaxIterations())
Definition lbfgs.hpp:89
Class storing tensors.
Definition tensor.hpp:44
static void info(Xs &&...args)
Definition log.hpp:27
Definition bfgs.hpp:16
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr auto rcp(Complex< T > const &b)
Definition rcp.hpp:13
std::invoke_result_t< F, Args... > ResultOf
Definition type_traits.hpp:140
static constexpr auto zero()
Definition zero.hpp:12
std::size_t Size
Definition type_traits.hpp:20
static constexpr auto delta(Complex< TA > const &a, Complex< TB > const &b)
Definition delta.hpp:12
static constexpr auto abs(Complex< T > const &a)
Definition abs.hpp:12
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
std::is_reference< T > IsReference
Definition type_traits.hpp:47
static constexpr auto scalarProduct(Complex< TA > const &a, Complex< TB > const &b)
Definition scalar_product.hpp:13
Size Index
Definition type_traits.hpp:21
static constexpr auto identity()
Definition identity.hpp:12
static constexpr X max(X &&a)
Definition max.hpp:13
Definition mul.hpp:12
Definition limits.hpp:13