pRC
multi-purpose Tensor Train library for C++
Loading...
Searching...
No Matches
more_thuente.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3// References:
4// Authors: Jorge J. More, David J. Thuente
5// Title: Line Search Algorithms with Guaranteed Sufficient Decrease
6// Year: 1994
7// URL: https://doi.org/10.1145/192115.192132
8
9#ifndef pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_MORE_THUENTE_H
10#define pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_MORE_THUENTE_H
11
12#include <prc/config.hpp>
20
22{
24 {
25 private:
26 static constexpr Size defaultMaxIterations()
27 {
28 return 20;
29 }
30
31 static constexpr Float<> defaultC1()
32 {
33 return 1e-4;
34 }
35
36 static constexpr Float<> defaultC2()
37 {
38 return 0.9;
39 }
40
41 static constexpr Float<> defaultTrapLower()
42 {
43 return 1.1;
44 }
45
46 static constexpr Float<> defaultTrapUpper()
47 {
48 return 4.0;
49 }
50
51 static constexpr Float<> defaultDelta()
52 {
53 return 2.0 / 3.0;
54 }
55
56 public:
57 constexpr MoreThuente(Size const maxIterations = defaultMaxIterations(),
58 Float<> const c1 = defaultC1(), Float<> const c2 = defaultC2(),
59 Float<> const trapLower = defaultTrapLower(),
60 Float<> const trapUpper = defaultTrapUpper(),
61 Float<> const delta = defaultDelta())
62 : mMaxIterations(maxIterations)
63 , mC1(c1)
64 , mC2(c2)
65 , mTrapLower(trapLower)
66 , mTrapUpper(trapUpper)
67 , mDelta(delta)
68 {
69 }
70
71 constexpr auto maxIterations() const
72 {
73 return mMaxIterations;
74 }
75
76 template<class T = Float<>>
77 constexpr decltype(auto) c1() const
78 {
79 return cast<T>(mC1);
80 }
81
82 template<class T = Float<>>
83 constexpr decltype(auto) c2() const
84 {
85 return cast<T>(mC2);
86 }
87
88 template<class T = Float<>>
89 constexpr decltype(auto) trapLower() const
90 {
91 return cast<T>(mTrapLower);
92 }
93
94 template<class T = Float<>>
95 constexpr decltype(auto) trapUpper() const
96 {
97 return cast<T>(mTrapUpper);
98 }
99
100 template<class T = Float<>>
101 constexpr decltype(auto) delta() const
102 {
103 return cast<T>(mDelta);
104 }
105
106 template<class X, If<IsTensor<X>> = 0,
107 class T = typename X::Type::Value, If<IsFloat<T>> = 0, class F,
108 If<IsInvocable<F, X const &, X &>> = 0,
109 If<IsFloat<ResultOf<F, X const &, X &>>> = 0, class FC,
110 If<IsInvocable<FC, X const &>> = 0,
111 If<IsConvertible<ResultOf<FC, X const &>, X>> = 0>
114 FC &&constraint, X const &p, T alpha = identity<T>(),
115 T const alphaMin = zero<T>(),
117 {
118 if constexpr(cDebugLevel >= DebugLevel::High)
119 {
120 if(alphaMin < zero<T>())
121 {
122 Logging::error("LS-MoreThuente: Minimum alpha < 0.");
123 }
124
125 if(alphaMax < zero<T>())
126 {
127 Logging::error("LS-MoreThuente: Maximum alpha < 0.");
128 }
129
130 if(alphaMax < alphaMin)
131 {
133 "LS-MoreThuente: Minimum alpha > Maximum alpha.");
134 }
135
137 {
138 Logging::error("Initial alpha not in range (min, max]");
139 }
140 }
141
142 Tensor const x0 = x;
143 auto const f0 = f;
144 auto const d0 = d;
145
146 auto const curvatureTest = c1<T>() * d0;
147
148 T alphaLower = zero();
149 auto fLower = f0;
150 auto dLower = d0;
151
152 T alphaUpper = zero();
153 auto fUpper = f0;
154 auto dUpper = d0;
155
156 auto alphaLowerBound = zero<T>();
158
159 auto nextWidth = alphaMax - alphaMin;
160 auto width = identity<T>(2) * nextWidth;
161
162 Bool firstStage = true;
163 Bool bracketed = false;
164 for(Index iteration = 0;; ++iteration)
165 {
167 {
168 Logging::debug("Line search: Max iterations reached.");
169 break;
170 }
171
172 if(bracketed)
173 {
175 {
177 "Line search: Rounding errors prevent progress.");
178 break;
179 }
180
182 {
183 Logging::debug("Line search: Tolerance is satisfied.");
184 break;
185 }
186 }
187
188 x = constraint(x0 + alpha * p);
189 f = function(x, g);
190 d = scalarProduct(p, g)();
191
193
195 d <= curvatureTest)
196 {
197 Logging::debug("Line search: Max alpha reached.");
198 break;
199 }
200
201 if(alpha == alphaMin &&
203 {
204 Logging::debug("Line search: Min alpha reached.");
205 break;
206 }
207
208 if(f <= sufficientDecreaseTest && abs(d) <= abs(c2<T>() * d0))
209 {
210 break;
211 }
212
214 {
215 firstStage = false;
216 }
217
219 {
220 auto const fMod = f - alpha * curvatureTest;
223 auto const dMod = d - curvatureTest;
226
227 alpha = computeAlpha(alphaLower, fLowerMod, dLowerMod,
230
235 }
236 else
237 {
238 alpha = computeAlpha(alphaLower, fLower, dLower, alphaUpper,
241 }
242
243 if(bracketed)
244 {
246
247 if(alphaDelta >= delta<T>() * width)
248 {
250 }
251
254 }
255
256 if(bracketed)
257 {
260 }
261 else
262 {
267 }
268
271
272 if(bracketed &&
275 {
277 }
278 }
279
280 return alpha;
281 }
282
283 template<class X, If<IsTensor<X>> = 0,
284 class T = typename X::Type::Value, If<IsFloat<T>> = 0, class F,
285 If<IsInvocable<F, X const &, X &>> = 0,
286 If<IsFloat<ResultOf<F, X const &, X &>>> = 0>
289 X const &p, T alpha = identity<T>(), T const alphaMin = zero<T>(),
291 {
292 return operator()(
293 x, f, g, d, forward<F>(function),
294 [](auto &&x) -> decltype(auto)
295 {
296 return forward<decltype(x)>(x);
297 },
299 }
300
301 private:
302 template<typename T>
303 static constexpr auto secantMinimizer(T const &x, T const &dX,
304 T const &y, T const &dY)
305 {
306 return y + dY / (dY - dX) * (x - y);
307 }
308
309 template<typename T>
310 static constexpr auto quadraticMinimizer(T const &x, T const &fX,
311 T const &dX, T const &y, T const &fY)
312 {
313 return x +
314 dX / ((fX - fY) / (y - x) + dX) / identity<T>(2) * (y - x);
315 }
316
317 template<typename T>
318 static constexpr auto cubicMinimizer(T const &x, T const &fX,
319 T const &dX, T const &y, T const &fY, T const &dY)
320 {
321 auto const theta = identity<T>(3) * (fX - fY) / (y - x) + dX + dY;
322
323 auto const s = max(abs(theta), abs(dX), abs(dY));
324
325#ifdef __FAST_MATH__
326 auto gamma = sqrt(square(theta) - dX * dY);
327#else
328 auto gamma = s * sqrt(square(theta / s) - (dX / s) * (dY / s));
329#endif // __FAST_MATH__
330
331 if(y < x)
332 {
333 gamma = -gamma;
334 }
335
336 auto const p = (gamma - dX) + theta;
337 auto const q = (gamma - dX) + gamma + dY;
338 auto const r = p / q;
339
340 return x + r * (y - x);
341 }
342
343 template<typename T>
344 static constexpr auto cubicMinimizer(T const &x, T const &fX,
345 T const &dX, T const &y, T const &fY, T const &dY,
346 T const &lowerBound, T const &upperBound)
347 {
348 auto const theta = identity<T>(3) * (fX - fY) / (y - x) + dX + dY;
349
350 auto const s = max(abs(theta), abs(dX), abs(dY));
351
352#ifdef __FAST_MATH__
353 auto gamma = sqrt(max(zero<T>(), square(theta) - dX * dY));
354#else
355 auto gamma = s *
356 sqrt(max(zero<T>(), square(theta / s) - (dX / s) * (dY / s)));
357#endif // __FAST_MATH__
358
359 if(y < x)
360 {
361 gamma = -gamma;
362 }
363
364 auto const p = (gamma - dX) + theta;
365 auto const q = (gamma - dX) + gamma + dY;
366 auto const r = p / q;
367
368 if((r < zero()) && (gamma != zero()))
369 {
370 return x + r * (y - x);
371 }
372 else if(y < x)
373 {
374 return upperBound;
375 }
376 else
377 {
378 return lowerBound;
379 }
380 }
381
382 private:
383 template<typename T>
384 constexpr auto computeAlpha(T &alphaLower, T &fLower, T &dLower,
385 T &alphaUpper, T &fUpper, T &dUpper, T const &alpha, T const &f,
386 T const &d, Bool &bracketed, T const &alphaLowerBound,
387 T const &alphaUpperBound) const
388 {
389 if(f > fLower)
390 {
391 auto const alphaC =
392 cubicMinimizer(alphaLower, fLower, dLower, alpha, f, d);
393
394 auto const alphaQ =
395 quadraticMinimizer(alphaLower, fLower, dLower, alpha, f);
396
397 bracketed = true;
399 fUpper = f;
400 dUpper = d;
401
404 {
405 return alphaC;
406 }
407 else
408 {
409 return mean(alphaQ, alphaC);
410 }
411 }
412
414 if(d * dLower < zero())
415 {
416 auto const alphaC =
417 cubicMinimizer(alpha, f, d, alphaLower, fLower, dLower);
418
419 auto const alphaS =
420 secantMinimizer(alphaLower, dLower, alpha, d);
421
423 {
425 }
426 else
427 {
429 }
430
431 bracketed = true;
433 fUpper = fLower;
434 dUpper = dLower;
435 }
436 else if(abs(d) < abs(dLower))
437 {
438 auto const alphaC = cubicMinimizer(alpha, f, d, alphaLower,
440
441 auto const alphaS =
442 secantMinimizer(alphaLower, dLower, alpha, d);
443
444 if(bracketed)
445 {
446 auto const trap = alpha + delta<T>() * (alphaUpper - alpha);
448 {
449 if(alpha > alphaLower)
450 {
452 }
453 else
454 {
456 }
457 }
458 else
459 {
460 if(alpha > alphaLower)
461 {
463 }
464 else
465 {
467 }
468 }
469 }
470 else
471 {
473 {
475 }
476 else
477 {
479 }
480
483 }
484 }
485 else
486 {
487 if(bracketed)
488 {
489 nextAlpha =
490 cubicMinimizer(alpha, f, d, alphaUpper, fUpper, dUpper);
491 }
492 else if(alpha > alphaLower)
493 {
495 }
496 else
497 {
499 }
500 }
501
503 fLower = f;
504 dLower = d;
505
506 return nextAlpha;
507 }
508
509 private:
510 Size const mMaxIterations;
511 Float<> const mC1;
512 Float<> const mC2;
513 Float<> const mTrapLower;
514 Float<> const mTrapUpper;
515 Float<> const mDelta;
516 };
517}
518#endif // pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_MORE_THUENTE_H
Top-level class storing a floating point number.
Definition float.hpp:35
Class storing tensors.
Definition tensor.hpp:44
static void debug(Xs &&...args)
Definition log.hpp:33
static void error(Xs &&...args)
Definition log.hpp:14
Definition bracketing.hpp:15
static constexpr auto mean(Xs &&...args)
Calculates the mean of a variable ammount of pRC objects.
Definition mean.hpp:22
bool Bool
Definition type_traits.hpp:18
static constexpr X min(X &&a)
Definition min.hpp:13
std::invoke_result_t< F, Args... > ResultOf
Definition type_traits.hpp:140
static constexpr auto zero()
Definition zero.hpp:12
std::size_t Size
Definition type_traits.hpp:20
static constexpr auto square(Complex< T > const &a)
Definition square.hpp:14
static constexpr auto isApprox(XA &&a, XB &&b, TT const &tolerance=NumericLimits< TT >::tolerance())
Checks if two pRC objects agree up to a given tolerance.
Definition is_approx.hpp:44
static constexpr auto delta(Complex< TA > const &a, Complex< TB > const &b)
Definition delta.hpp:12
static constexpr auto abs(Complex< T > const &a)
Definition abs.hpp:12
static constexpr Conditional< IsSatisfied< C >, RemoveConstReference< X >, X > copy(X &&a)
Definition copy.hpp:13
constexpr auto cDebugLevel
Definition config.hpp:46
static constexpr auto sqrt(Complex< T > const &a)
Definition sqrt.hpp:12
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition type_traits.hpp:62
static constexpr auto scalarProduct(Complex< TA > const &a, Complex< TB > const &b)
Definition scalar_product.hpp:13
Size Index
Definition type_traits.hpp:21
static constexpr X max(X &&a)
Definition max.hpp:13
Definition limits.hpp:13
Definition more_thuente.hpp:24
constexpr decltype(auto) delta() const
Definition more_thuente.hpp:101
constexpr MoreThuente(Size const maxIterations=defaultMaxIterations(), Float<> const c1=defaultC1(), Float<> const c2=defaultC2(), Float<> const trapLower=defaultTrapLower(), Float<> const trapUpper=defaultTrapUpper(), Float<> const delta=defaultDelta())
Definition more_thuente.hpp:57
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, FC &&constraint, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition more_thuente.hpp:112
constexpr decltype(auto) c2() const
Definition more_thuente.hpp:83
constexpr decltype(auto) trapLower() const
Definition more_thuente.hpp:89
constexpr decltype(auto) c1() const
Definition more_thuente.hpp:77
constexpr auto maxIterations() const
Definition more_thuente.hpp:71
constexpr decltype(auto) trapUpper() const
Definition more_thuente.hpp:95
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition more_thuente.hpp:287