pRC
multi-purpose Tensor Train library for C++
Loading...
Searching...
No Matches
bracketing.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_BRACKETING_H
4#define pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_BRACKETING_H
5
6#include <prc/config.hpp>
13
15{
17 {
18 private:
19 static constexpr Bool defaultStrongWolfeConditions()
20 {
21 return true;
22 }
23
24 static constexpr Float<> defaultDecreasingStepScale()
25 {
26 return 0.5;
27 }
28
29 static constexpr Float<> defaultIncreasingStepScale()
30 {
31 return 2.1;
32 }
33
34 static constexpr Size defaultMaxIterations()
35 {
36 return 5;
37 }
38
39 static constexpr Float<> defaultC1()
40 {
41 return 1e-4;
42 }
43
44 static constexpr Float<> defaultC2()
45 {
46 return 0.9;
47 }
48
49 public:
50 constexpr Bracketing(
51 Bool const strongWolfeConditions = defaultStrongWolfeConditions(),
52 Float<> const decreasingStepScale = defaultDecreasingStepScale(),
53 Float<> const increasingStepScale = defaultIncreasingStepScale(),
54 Size const maxIterations = defaultMaxIterations(),
55 Float<> const c1 = defaultC1(), Float<> const c2 = defaultC2())
56 : mStrongWolfeConditions(strongWolfeConditions)
57 , mDecreasingStepScale(decreasingStepScale)
58 , mIncreasingStepScale(increasingStepScale)
59 , mMaxIterations(maxIterations)
60 , mC1(c1)
61 , mC2(c2)
62 {
63 }
64
65 constexpr auto strongWolfeConditions() const
66 {
67 return mStrongWolfeConditions;
68 }
69
70 template<class T = Float<>>
71 constexpr decltype(auto) decreasingStepScale() const
72 {
73 return cast<T>(mDecreasingStepScale);
74 }
75
76 template<class T = Float<>>
77 constexpr decltype(auto) increasingStepScale() const
78 {
79 return cast<T>(mIncreasingStepScale);
80 }
81
82 constexpr auto maxIterations() const
83 {
84 return mMaxIterations;
85 }
86
87 template<class T = Float<>>
88 constexpr decltype(auto) c1() const
89 {
90 return cast<T>(mC1);
91 }
92
93 template<class T = Float<>>
94 constexpr decltype(auto) c2() const
95 {
96 return cast<T>(mC2);
97 }
98
99 template<class X, If<IsTensor<X>> = 0,
100 class T = typename X::Type::Value, If<IsFloat<T>> = 0, class F,
101 If<IsInvocable<F, X const &, X &>> = 0,
102 If<IsFloat<ResultOf<F, X const &, X &>>> = 0, class FC,
103 If<IsInvocable<FC, X const &>> = 0,
104 If<IsConvertible<ResultOf<FC, X const &>, X>> = 0>
107 FC &&constraint, X const &p, T alpha = identity<T>(),
108 T const alphaMin = zero<T>(),
110 {
111 if constexpr(cDebugLevel >= DebugLevel::High)
112 {
113 if(alphaMin < zero<T>())
114 {
115 Logging::error("LS-Bracketing: Minimum alpha < 0.");
116 }
117
118 if(alphaMax < zero<T>())
119 {
120 Logging::error("LS-Bracketing: Maximum alpha < 0.");
121 }
122
123 if(alphaMax < alphaMin)
124 {
126 "LS-Bracketing: Minimum alpha > Maximum alpha.");
127 }
128
130 {
131 Logging::error("Initial alpha not in range (min, max]");
132 }
133 }
134
135 Tensor const x0 = x;
136 auto const f0 = f;
137 auto const d0 = d;
138
139 auto low = alphaMin;
140 auto high = alphaMax;
141
142 for(Index iteration = 0;; ++iteration,
145 {
147 {
148 Logging::debug("Line search: Max iterations reached.");
149 break;
150 }
151
152 x = constraint(x0 + alpha * p);
153 f = function(x, g);
154
155 if(f > f0 + c1<T>() * alpha * d0)
156 {
157 high = alpha;
158 Logging::info("First Wolfe condition not satisfied, trying again");
159 continue;
160 }
161
162 d = scalarProduct(g, p)();
163 if(d < c2<T>() * d0)
164 {
165 low = alpha;
166 Logging::info("Second Wolfe condition not satisfied, trying again");
167 continue;
168 }
169
171 {
172 if(abs(d) > abs(c2<T>() * d0))
173 {
174 high = alpha;
175 Logging::info("Strong Wolfe condition not satisfied, trying again");
176 continue;
177 }
178 }
179
180 break;
181 }
182
183 return alpha;
184 }
185
186 template<class X, If<IsTensor<X>> = 0,
187 class T = typename X::Type::Value, If<IsFloat<T>> = 0, class F,
188 If<IsInvocable<F, X const &, X &>> = 0,
189 If<IsFloat<ResultOf<F, X const &, X &>>> = 0>
192 X const &p, T alpha = identity<T>(), T const alphaMin = zero<T>(),
194 {
195 return operator()(
196 x, f, g, d, forward<F>(function),
197 [](auto &&x) -> decltype(auto)
198 {
199 return forward<decltype(x)>(x);
200 },
202 }
203
204 private:
205 Bool const mStrongWolfeConditions;
206 Float<> const mDecreasingStepScale;
207 Float<> const mIncreasingStepScale;
208 Size const mMaxIterations;
209 Float<> const mC1;
210 Float<> const mC2;
211 };
212}
213#endif // pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_BRACKETING_H
Top-level class storing a floating point number.
Definition float.hpp:35
Class storing tensors.
Definition tensor.hpp:44
static void info(Xs &&...args)
Definition log.hpp:27
static void debug(Xs &&...args)
Definition log.hpp:33
static void error(Xs &&...args)
Definition log.hpp:14
Definition bracketing.hpp:15
bool Bool
Definition type_traits.hpp:18
static constexpr X min(X &&a)
Definition min.hpp:13
std::invoke_result_t< F, Args... > ResultOf
Definition type_traits.hpp:140
std::size_t Size
Definition type_traits.hpp:20
static constexpr auto abs(Complex< T > const &a)
Definition abs.hpp:12
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
constexpr auto cDebugLevel
Definition config.hpp:46
static constexpr auto scalarProduct(Complex< TA > const &a, Complex< TB > const &b)
Definition scalar_product.hpp:13
Size Index
Definition type_traits.hpp:21
Definition limits.hpp:13
Definition bracketing.hpp:17
constexpr decltype(auto) decreasingStepScale() const
Definition bracketing.hpp:71
constexpr auto strongWolfeConditions() const
Definition bracketing.hpp:65
constexpr decltype(auto) c2() const
Definition bracketing.hpp:94
constexpr decltype(auto) increasingStepScale() const
Definition bracketing.hpp:77
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition bracketing.hpp:190
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, FC &&constraint, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition bracketing.hpp:105
constexpr Bracketing(Bool const strongWolfeConditions=defaultStrongWolfeConditions(), Float<> const decreasingStepScale=defaultDecreasingStepScale(), Float<> const increasingStepScale=defaultIncreasingStepScale(), Size const maxIterations=defaultMaxIterations(), Float<> const c1=defaultC1(), Float<> const c2=defaultC2())
Definition bracketing.hpp:50
constexpr auto maxIterations() const
Definition bracketing.hpp:82
constexpr decltype(auto) c1() const
Definition bracketing.hpp:88