cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
bfgs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_BFGS_H
4#define pRC_ALGORITHMS_OPTIMIZER_BFGS_H
5
6#include <prc/config.hpp>
15
17{
18 template<class LS = LineSearch::Bracketing>
19 class BFGS
20 {
21 private:
22 static constexpr Size defaultMaxIterations()
23 {
24 return 1000;
25 }
26
27 template<class G, class T>
28 static constexpr auto projectedGradientConverged(G const &g,
29 T const &tolerance)
30 {
31 auto const infNorm = norm<2, 0>(g)();
32
34 }
35
36 template<class F, class T>
37 static constexpr auto valueConverged(F const &f0, F const &f,
38 T const &tolerance)
39 {
40 auto const scale = max(abs(f0), abs(f), identity<T>());
41
42 return delta(f0, f) <= tolerance * scale;
43 }
44
45 template<class F>
46 static constexpr auto valueDiverged(F const &f)
47 {
48 return !isFinite(f);
49 }
50
51 template<class F>
52 static constexpr auto valueIncreased(F const &f0, F const &f)
53 {
54 return f > f0;
55 }
56
57 public:
58 constexpr BFGS(LS const &lineSearch,
59 Size const maxIterations = defaultMaxIterations())
60 : mLineSearch(lineSearch)
61 , mMaxIterations(maxIterations)
62 {
63 }
64
65 constexpr BFGS(Size const maxIterations = defaultMaxIterations())
66 : mMaxIterations(maxIterations)
67 {
68 }
69
70 constexpr auto &lineSearch() const
71 {
72 return mLineSearch;
73 }
74
75 constexpr auto maxIterations() const
76 {
77 return mMaxIterations;
78 }
79
80 template<class XX, class RX = RemoveReference<XX>,
81 class TX = typename RX::Type, class VX = typename TX::Value,
82 If<IsTensorish<RX>> = 0,
83 class RXE = RemoveConstReference<ResultOf<Eval, XX>>, class FF,
84 If<IsInvocable<FF, RXE const &, RXE &>> = 0,
85 If<IsFloat<ResultOf<FF, RXE const &, RXE &>>> = 0, class FC,
86 If<IsInvocable<FC, RXE &>> = 0, class VT = VX,
87 If<All<IsFloat<VX>, IsFloat<VT>>> = 0,
88 If<IsInvocable<LS, RXE &, ResultOf<FF, RXE const &, RXE &> &, RXE &,
89 VX &, FF, RXE const &>> = 0>
90 inline constexpr auto operator()(XX &&x0, FF &&function, FC &&callback,
91 VT const &tolerance = NumericLimits<VT>::tolerance()) const
92 {
93 decltype(auto) x =
94 copy<!(!IsReference<XX>() && !IsConst<RX>())>(eval(x0));
95
96 RXE g;
97 auto f = function(x, g);
98
99 Logging::debug("BFGS initial f(x) =", f);
100
101 if(projectedGradientConverged(g, tolerance))
102 {
103 return x;
104 }
105
106 Tensor H = expand(makeSeries<Index, typename RXE::Dimension{}>(),
107 [&](auto const... seq)
108 {
109 return identity<
110 Tensor<TX, RXE::size(seq)..., RXE::size(seq)...>>();
111 });
112
113 auto alpha = identity<TX>();
114 for(Index iteration = 0;;)
115 {
116 Tensor p = -H * g;
117
118 auto d = scalarProduct(p, g)();
119 if(d > zero())
120 {
121 H = identity();
122 p = -g;
123 d = -norm<2, 1>(g)();
124 }
125
126 auto const f0 = f;
127 auto const g0 = g;
128 auto const x0 = x;
129
130 alpha = lineSearch()(x, f, g, d, function, p, alpha);
131
132 callback(x);
133
134 if(valueDiverged(f))
135 {
136 Logging::info("BFGS diverged at f(x) =", f);
137 x = x0;
138 break;
139 }
140
141 if(valueIncreased(f0, f))
142 {
143 Logging::info("BFGS made no further progress at f(x) =", f);
144 x = x0;
145 break;
146 }
147
148 if(++iteration; !(iteration < maxIterations()))
149 {
150 Logging::info("BFGS max iterations reached at f(x) =", f);
151 break;
152 }
153
154 if(valueConverged(f0, f, tolerance))
155 {
156 Logging::info("BFGS converged at f(x) =", f);
157 break;
158 }
159
160 if(projectedGradientConverged(g, tolerance))
161 {
162 Logging::info("BFGS converged at f(x) =", f);
163 break;
164 }
165
166 Logging::debug("BFGS current f(x) =", f);
167
168 Tensor const s = alpha * p;
169 Tensor const y = g - g0;
170 Tensor const rho = rcp(scalarProduct(y, s));
171
172 Tensor const V =
174 H = eval(transpose(V) * H) * V + rho * tensorProduct(s, s);
175 }
176
177 callback(x);
178
179 if constexpr(IsReference<decltype(x)>())
180 {
181 return forward<XX>(x0);
182 }
183 else
184 {
185 return x;
186 }
187 }
188
189 private:
190 LS const mLineSearch;
191 Size const mMaxIterations;
192 };
193}
194#endif // pRC_ALGORITHMS_OPTIMIZER_BFGS_H
Definition bfgs.hpp:20
constexpr auto operator()(XX &&x0, FF &&function, FC &&callback, VT const &tolerance=NumericLimits< VT >::tolerance()) const
Definition bfgs.hpp:90
constexpr BFGS(Size const maxIterations=defaultMaxIterations())
Definition bfgs.hpp:65
constexpr BFGS(LS const &lineSearch, Size const maxIterations=defaultMaxIterations())
Definition bfgs.hpp:58
constexpr auto maxIterations() const
Definition bfgs.hpp:75
constexpr auto & lineSearch() const
Definition bfgs.hpp:70
Definition tensor.hpp:28
static void info(Xs &&...args)
Definition log.hpp:27
static void debug(Xs &&...args)
Definition log.hpp:33
Definition bfgs.hpp:17
static constexpr auto isFinite(T const &a)
Definition is_finite.hpp:13
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
static constexpr auto rcp(Complex< T > const &b)
Definition rcp.hpp:13
std::size_t Size
Definition type_traits.hpp:20
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto transpose(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:319
static constexpr auto makeSeries()
Definition sequence.hpp:351
std::is_reference< T > IsReference
Definition type_traits.hpp:47
static constexpr auto delta(Complex< TA > const &a, Complex< TB > const &b)
Definition delta.hpp:12
static constexpr auto abs(Complex< T > const &a)
Definition abs.hpp:12
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
static constexpr auto scalarProduct(Complex< TA > const &a, Complex< TB > const &b)
Definition scalar_product.hpp:13
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
Definition sequence.hpp:344
static constexpr auto tensorProduct(XA &&a, XB &&b)
Definition tensor_product.hpp:19
static constexpr auto identity()
Definition identity.hpp:12
static constexpr X max(X &&a)
Definition max.hpp:13
Definition limits.hpp:13