pRC
multi-purpose Tensor Train library for C++
Loading...
Searching...
No Matches
bfgs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_BFGS_H
4#define pRC_ALGORITHMS_OPTIMIZER_BFGS_H
5
6#include <prc/config.hpp>
14
16{
17 template<class LS = LineSearch::Bracketing>
18 class BFGS
19 {
20 private:
21 static constexpr Size defaultMaxIterations()
22 {
23 return 1000;
24 }
25
26 template<class G, class T>
27 static constexpr auto projectedGradientConverged(G const &g,
28 T const &tolerance)
29 {
30 auto const infNorm = norm<2, 0>(g)();
31
33 }
34
35 template<class F, class T>
36 static constexpr auto valueConverged(F const &f0, F const &f,
37 T const &tolerance)
38 {
39 auto const scale = max(abs(f0), abs(f), identity<T>());
40
41 return delta(f0, f) <= tolerance * scale;
42 }
43
44 public:
45 constexpr BFGS(LS const &lineSearch,
46 Size const maxIterations = defaultMaxIterations())
47 : mLineSearch(lineSearch)
48 , mMaxIterations(maxIterations)
49 {
50 }
51
52 constexpr BFGS(Size const maxIterations = defaultMaxIterations())
53 : mMaxIterations(maxIterations)
54 {
55 }
56
57 constexpr auto &lineSearch() const
58 {
59 return mLineSearch;
60 }
61
62 constexpr auto maxIterations() const
63 {
64 return mMaxIterations;
65 }
66
67 template<class XX, class RX = RemoveReference<XX>,
68 class TX = typename RX::Type, class VX = typename TX::Value,
69 If<IsTensorish<RX>> = 0,
70 class RXE = RemoveConstReference<ResultOf<Eval, XX>>, class FF,
71 If<IsInvocable<FF, RXE const &, RXE &>> = 0,
72 If<IsFloat<ResultOf<FF, RXE const &, RXE &>>> = 0, class FC,
73 If<IsInvocable<FC, RXE &>> = 0, class VT = VX,
74 If<All<IsFloat<VX>, IsFloat<VT>>> = 0,
75 If<IsInvocable<LS, RXE &, ResultOf<FF, RXE const &, RXE &> &, RXE &,
76 VX &, FF, RXE const &>> = 0>
77 inline constexpr auto operator()(XX &&x0, FF &&function, FC &&callback,
78 VT const &tolerance = NumericLimits<VT>::tolerance()) const
79 {
80 decltype(auto) x =
81 copy<!(!IsReference<XX>() && !IsConst<RX>())>(eval(x0));
82
83 RXE g;
84 auto f = function(x, g);
85
86 Logging::debug("BFGS initial f(x) =", f);
87
88 if(projectedGradientConverged(g, tolerance))
89 {
90 return x;
91 }
92
93 Tensor H = expand(makeSeries<Index, typename RXE::Dimension{}>(),
94 [&](auto const... seq)
95 {
96 return identity<
97 Tensor<TX, RXE::size(seq)..., RXE::size(seq)...>>();
98 });
99
100 auto alpha = identity<TX>();
101 for(Index iteration = 0;;)
102 {
103 Tensor p = -H * g;
104
105 auto d = scalarProduct(p, g)();
106 if(d > zero())
107 {
108 H = identity();
109 p = -g;
110 d = -norm<2, 1>(g)();
111 }
112
113 auto const f0 = f;
114 auto const g0 = g;
115
116 alpha = lineSearch()(x, f, g, d, function, p, alpha);
117
118 callback(x);
119
120 if(++iteration; !(iteration < maxIterations()))
121 {
122 Logging::info("BFGS max iterations reached at f(x) =", f);
123 break;
124 }
125
126 if(valueConverged(f0, f, tolerance))
127 {
128 Logging::info("BFGS converged at f(x) =", f);
129 break;
130 }
131
132 if(projectedGradientConverged(g, tolerance))
133 {
134 Logging::info("BFGS converged at f(x) =", f);
135 break;
136 }
137
138 Logging::debug("BFGS current f(x) =", f);
139
140 Tensor const s = alpha * p;
141 Tensor const y = g - g0;
142 Tensor const rho = rcp(scalarProduct(y, s));
143
144 Tensor const V =
146 H = eval(transpose(V) * H) * V + rho * tensorProduct(s, s);
147 }
148
149 if constexpr(IsReference<decltype(x)>())
150 {
151 return forward<XX>(x0);
152 }
153 else
154 {
155 return x;
156 }
157 }
158
159 private:
160 LS const mLineSearch;
161 Size const mMaxIterations;
162 };
163}
164#endif // pRC_ALGORITHMS_OPTIMIZER_BFGS_H
Definition bfgs.hpp:19
constexpr auto operator()(XX &&x0, FF &&function, FC &&callback, VT const &tolerance=NumericLimits< VT >::tolerance()) const
Definition bfgs.hpp:77
constexpr BFGS(Size const maxIterations=defaultMaxIterations())
Definition bfgs.hpp:52
constexpr BFGS(LS const &lineSearch, Size const maxIterations=defaultMaxIterations())
Definition bfgs.hpp:45
constexpr auto maxIterations() const
Definition bfgs.hpp:62
constexpr auto & lineSearch() const
Definition bfgs.hpp:57
Class storing tensors.
Definition tensor.hpp:44
static void info(Xs &&...args)
Definition log.hpp:27
static void debug(Xs &&...args)
Definition log.hpp:33
Definition bfgs.hpp:16
static constexpr X eval(X &&a)
Definition eval.hpp:11
static constexpr auto rcp(Complex< T > const &b)
Definition rcp.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
std::size_t Size
Definition type_traits.hpp:20
static constexpr auto transpose(JacobiRotation< T > const &a)
Definition jacobi_rotation.hpp:319
static constexpr auto makeSeries()
Definition sequence.hpp:361
static constexpr auto delta(Complex< TA > const &a, Complex< TB > const &b)
Definition delta.hpp:12
static constexpr auto abs(Complex< T > const &a)
Definition abs.hpp:12
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
std::is_reference< T > IsReference
Definition type_traits.hpp:47
static constexpr auto scalarProduct(Complex< TA > const &a, Complex< TB > const &b)
Definition scalar_product.hpp:13
static constexpr decltype(auto) expand(Sequence< T, Seq... > const, F &&f, Xs &&...args)
forwards the values in a pRC::Sequence to a function as parameters
Definition sequence.hpp:354
static constexpr auto tensorProduct(XA &&a, XB &&b)
Calculates the tensor product of two Tensors.
Definition tensor_product.hpp:32
Size Index
Definition type_traits.hpp:21
static constexpr auto identity()
Definition identity.hpp:12
static constexpr X max(X &&a)
Definition max.hpp:13
Definition limits.hpp:13