cMHN 1.2
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
more_thuente.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3// References:
4// Authors: Jorge J. More, David J. Thuente
5// Title: Line Search Algorithms with Guaranteed Sufficient Decrease
6// Year: 1994
7// URL: https://doi.org/10.1145/192115.192132
8
9#ifndef pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_MORE_THUENTE_H
10#define pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_MORE_THUENTE_H
11
16
18{
20 {
21 private:
22 static constexpr Size defaultMaxIterations()
23 {
24 return 20;
25 }
26
27 static constexpr Float<> defaultC1()
28 {
29 return 1e-4;
30 }
31
32 static constexpr Float<> defaultC2()
33 {
34 return 0.9;
35 }
36
37 static constexpr Float<> defaultTrapLower()
38 {
39 return 1.1;
40 }
41
42 static constexpr Float<> defaultTrapUpper()
43 {
44 return 4.0;
45 }
46
47 static constexpr Float<> defaultDelta()
48 {
49 return 2.0 / 3.0;
50 }
51
52 public:
53 constexpr MoreThuente(Size const maxIterations = defaultMaxIterations(),
54 Float<> const c1 = defaultC1(), Float<> const c2 = defaultC2(),
55 Float<> const trapLower = defaultTrapLower(),
56 Float<> const trapUpper = defaultTrapUpper(),
57 Float<> const delta = defaultDelta())
58 : mMaxIterations(maxIterations)
59 , mC1(c1)
60 , mC2(c2)
61 , mTrapLower(trapLower)
62 , mTrapUpper(trapUpper)
63 , mDelta(delta)
64 {
65 }
66
67 constexpr auto maxIterations() const
68 {
69 return mMaxIterations;
70 }
71
72 template<class T = Float<>>
73 constexpr decltype(auto) c1() const
74 {
75 return cast<T>(mC1);
76 }
77
78 template<class T = Float<>>
79 constexpr decltype(auto) c2() const
80 {
81 return cast<T>(mC2);
82 }
83
84 template<class T = Float<>>
85 constexpr decltype(auto) trapLower() const
86 {
87 return cast<T>(mTrapLower);
88 }
89
90 template<class T = Float<>>
91 constexpr decltype(auto) trapUpper() const
92 {
93 return cast<T>(mTrapUpper);
94 }
95
96 template<class T = Float<>>
97 constexpr decltype(auto) delta() const
98 {
99 return cast<T>(mDelta);
100 }
101
102 template<IsTensor X, IsFloat T = Value<X>, class F, class FC>
107 constexpr auto operator()(X &x, ResultOf<F, X const &, X &> &f, X &g,
108 typename ResultOf<ScalarProduct, X, X>::Type &d, F &&function,
109 FC &&constraint, X const &p, T alpha = identity<T>(),
110 T const alphaMin = zero<T>(),
111 T const alphaMax = identity<T>(NumericLimits<T>::max())) const
112 {
113 if constexpr(cDebugLevel >= DebugLevel::High)
114 {
115 if(alphaMin < zero<T>())
116 {
117 Logging::error("LS-MoreThuente: Minimum alpha < 0.");
118 }
119
120 if(alphaMax < zero<T>())
121 {
122 Logging::error("LS-MoreThuente: Maximum alpha < 0.");
123 }
124
125 if(alphaMax < alphaMin)
126 {
128 "LS-MoreThuente: Minimum alpha > Maximum alpha.");
129 }
130
131 if(alpha <= alphaMin || alpha > alphaMax)
132 {
133 Logging::error("Initial alpha not in range (min, max]");
134 }
135 }
136
137 Tensor const x0 = x;
138 auto const f0 = f;
139 auto const d0 = d;
140
141 auto const curvatureTest = c1<T>() * d0;
142
143 T alphaLower = zero();
144 auto fLower = f0;
145 auto dLower = d0;
146
147 T alphaUpper = zero();
148 auto fUpper = f0;
149 auto dUpper = d0;
150
151 auto alphaLowerBound = zero<T>();
152 auto alphaUpperBound = alpha + trapUpper<T>() * alpha;
153
154 auto nextWidth = alphaMax - alphaMin;
155 auto width = identity<T>(2) * nextWidth;
156
157 Bool firstStage = true;
158 Bool bracketed = false;
159 for(Index iteration = 0;; ++iteration)
160 {
161 if(iteration > maxIterations())
162 {
163 Logging::debug("Line search: Max iterations reached.");
164 break;
165 }
166
167 if(bracketed)
168 {
169 if(alpha <= alphaLowerBound || alpha >= alphaUpperBound)
170 {
172 "Line search: Rounding errors prevent progress.");
173 break;
174 }
175
176 if(isZero(alphaUpperBound - alphaLowerBound))
177 {
178 Logging::debug("Line search: Tolerance is satisfied.");
179 break;
180 }
181 }
182
183 x = constraint(x0 + alpha * p);
184 f = function(x, g);
185 d = scalarProduct(p, g)();
186
187 auto const sufficientDecreaseTest = f0 + alpha * curvatureTest;
188
189 if(alpha == alphaMax && f <= sufficientDecreaseTest &&
190 d <= curvatureTest)
191 {
192 Logging::debug("Line search: Max alpha reached.");
193 break;
194 }
195
196 if(alpha == alphaMin &&
197 (f > sufficientDecreaseTest || d >= curvatureTest))
198 {
199 Logging::debug("Line search: Min alpha reached.");
200 break;
201 }
202
203 if(f <= sufficientDecreaseTest && abs(d) <= abs(c2<T>() * d0))
204 {
205 break;
206 }
207
208 if(firstStage && f <= sufficientDecreaseTest && d >= zero())
209 {
210 firstStage = false;
211 }
212
213 if(firstStage && f <= fLower && f > sufficientDecreaseTest)
214 {
215 auto const fMod = f - alpha * curvatureTest;
216 auto fLowerMod = fLower - alphaLower * curvatureTest;
217 auto fUpperMod = fUpper - alphaUpper * curvatureTest;
218 auto const dMod = d - curvatureTest;
219 auto dLowerMod = dLower - curvatureTest;
220 auto dUpperMod = dUpper - curvatureTest;
221
222 alpha = computeAlpha(alphaLower, fLowerMod, dLowerMod,
223 alphaUpper, fUpperMod, dUpperMod, alpha, fMod, dMod,
224 bracketed, alphaLowerBound, alphaUpperBound);
225
226 fLower = fLowerMod + alphaLower * curvatureTest;
227 fUpper = fUpperMod + alphaUpper * curvatureTest;
228 dLower = dLowerMod + curvatureTest;
229 dUpper = dUpperMod + curvatureTest;
230 }
231 else
232 {
233 alpha = computeAlpha(alphaLower, fLower, dLower, alphaUpper,
234 fUpper, dUpper, alpha, f, d, bracketed, alphaLowerBound,
235 alphaUpperBound);
236 }
237
238 if(bracketed)
239 {
240 auto const alphaDelta = pRC::delta(alphaLower, alphaUpper);
241
242 if(alphaDelta >= delta<T>() * width)
243 {
244 alpha = mean(alphaLower, alphaUpper);
245 }
246
247 width = nextWidth;
248 nextWidth = alphaDelta;
249 }
250
251 if(bracketed)
252 {
253 alphaLowerBound = min(alphaLower, alphaUpper);
254 alphaUpperBound = max(alphaLower, alphaUpper);
255 }
256 else
257 {
258 alphaLowerBound =
259 alpha + trapLower<T>() * (alpha - alphaLower);
260 alphaUpperBound =
261 alpha + trapUpper<T>() * (alpha - alphaLower);
262 }
263
264 alpha = max(alpha, alphaMin);
265 alpha = min(alpha, alphaMax);
266
267 if(bracketed &&
268 ((alpha <= alphaLowerBound || alpha >= alphaUpperBound) ||
269 isZero(alphaUpperBound - alphaLowerBound)))
270 {
271 alpha = alphaLower;
272 }
273 }
274
275 return alpha;
276 }
277
278 template<IsTensor X, IsFloat T = Value<X>, class F>
281 constexpr auto operator()(X &x, ResultOf<F, X const &, X &> &f, X &g,
282 typename ResultOf<ScalarProduct, X, X>::Type &d, F &&function,
283 X const &p, T alpha = identity<T>(), T const alphaMin = zero<T>(),
284 T const alphaMax = identity<T>(NumericLimits<T>::max())) const
285 {
286 return operator()(
287 x, f, g, d, forward<F>(function),
288 [](auto &&x) -> decltype(auto)
289 {
290 return forward<decltype(x)>(x);
291 },
292 p, alpha, alphaMin, alphaMax);
293 }
294
295 private:
296 template<typename T>
297 static constexpr auto secantMinimizer(T const &x, T const &dX,
298 T const &y, T const &dY)
299 {
300 return y + dY / (dY - dX) * (x - y);
301 }
302
303 template<typename T>
304 static constexpr auto quadraticMinimizer(T const &x, T const &fX,
305 T const &dX, T const &y, T const &fY)
306 {
307 return x +
308 dX / ((fX - fY) / (y - x) + dX) / identity<T>(2) * (y - x);
309 }
310
311 template<typename T>
312 static constexpr auto cubicMinimizer(T const &x, T const &fX,
313 T const &dX, T const &y, T const &fY, T const &dY)
314 {
315 auto const theta = identity<T>(3) * (fX - fY) / (y - x) + dX + dY;
316
317 auto const s = max(abs(theta), abs(dX), abs(dY));
318
319#ifdef __FAST_MATH__
320 auto gamma = sqrt(square(theta) - dX * dY);
321#else
322 auto gamma = s * sqrt(square(theta / s) - (dX / s) * (dY / s));
323#endif // __FAST_MATH__
324
325 if(y < x)
326 {
327 gamma = -gamma;
328 }
329
330 auto const p = (gamma - dX) + theta;
331 auto const q = (gamma - dX) + gamma + dY;
332 auto const r = p / q;
333
334 return x + r * (y - x);
335 }
336
337 template<typename T>
338 static constexpr auto cubicMinimizer(T const &x, T const &fX,
339 T const &dX, T const &y, T const &fY, T const &dY,
340 T const &lowerBound, T const &upperBound)
341 {
342 auto const theta = identity<T>(3) * (fX - fY) / (y - x) + dX + dY;
343
344 auto const s = max(abs(theta), abs(dX), abs(dY));
345
346#ifdef __FAST_MATH__
347 auto gamma = sqrt(max(zero<T>(), square(theta) - dX * dY));
348#else
349 auto gamma = s *
350 sqrt(max(zero<T>(), square(theta / s) - (dX / s) * (dY / s)));
351#endif // __FAST_MATH__
352
353 if(y < x)
354 {
355 gamma = -gamma;
356 }
357
358 auto const p = (gamma - dX) + theta;
359 auto const q = (gamma - dX) + gamma + dY;
360 auto const r = p / q;
361
362 if((r < zero()) && (gamma != zero()))
363 {
364 return x + r * (y - x);
365 }
366 else if(y < x)
367 {
368 return upperBound;
369 }
370 else
371 {
372 return lowerBound;
373 }
374 }
375
376 private:
377 template<typename T>
378 constexpr auto computeAlpha(T &alphaLower, T &fLower, T &dLower,
379 T &alphaUpper, T &fUpper, T &dUpper, T const &alpha, T const &f,
380 T const &d, Bool &bracketed, T const &alphaLowerBound,
381 T const &alphaUpperBound) const
382 {
383 if(f > fLower)
384 {
385 auto const alphaC =
386 cubicMinimizer(alphaLower, fLower, dLower, alpha, f, d);
387
388 auto const alphaQ =
389 quadraticMinimizer(alphaLower, fLower, dLower, alpha, f);
390
391 bracketed = true;
392 alphaUpper = alpha;
393 fUpper = f;
394 dUpper = d;
395
396 if(pRC::delta(alphaC, alphaLower) <
397 pRC::delta(alphaQ, alphaLower))
398 {
399 return alphaC;
400 }
401 else
402 {
403 return mean(alphaQ, alphaC);
404 }
405 }
406
407 RemoveConstReference<decltype(alpha)> nextAlpha;
408 if(d * dLower < zero())
409 {
410 auto const alphaC =
411 cubicMinimizer(alpha, f, d, alphaLower, fLower, dLower);
412
413 auto const alphaS =
414 secantMinimizer(alphaLower, dLower, alpha, d);
415
416 if(pRC::delta(alphaC, alpha) > pRC::delta(alphaS, alpha))
417 {
418 nextAlpha = alphaC;
419 }
420 else
421 {
422 nextAlpha = alphaS;
423 }
424
425 bracketed = true;
426 alphaUpper = alphaLower;
427 fUpper = fLower;
428 dUpper = dLower;
429 }
430 else if(abs(d) < abs(dLower))
431 {
432 auto const alphaC = cubicMinimizer(alpha, f, d, alphaLower,
433 fLower, dLower, alphaLowerBound, alphaUpperBound);
434
435 auto const alphaS =
436 secantMinimizer(alphaLower, dLower, alpha, d);
437
438 if(bracketed)
439 {
440 auto const trap = alpha + delta<T>() * (alphaUpper - alpha);
441 if(pRC::delta(alphaC, alpha) < pRC::delta(alphaS, alpha))
442 {
443 if(alpha > alphaLower)
444 {
445 nextAlpha = min(trap, alphaC);
446 }
447 else
448 {
449 nextAlpha = max(trap, alphaC);
450 }
451 }
452 else
453 {
454 if(alpha > alphaLower)
455 {
456 nextAlpha = min(trap, alphaS);
457 }
458 else
459 {
460 nextAlpha = max(trap, alphaS);
461 }
462 }
463 }
464 else
465 {
466 if(pRC::delta(alphaC, alpha) > pRC::delta(alphaS, alpha))
467 {
468 nextAlpha = alphaC;
469 }
470 else
471 {
472 nextAlpha = alphaS;
473 }
474
475 nextAlpha = min(alphaUpperBound, nextAlpha);
476 nextAlpha = max(alphaLowerBound, nextAlpha);
477 }
478 }
479 else
480 {
481 if(bracketed)
482 {
483 nextAlpha =
484 cubicMinimizer(alpha, f, d, alphaUpper, fUpper, dUpper);
485 }
486 else if(alpha > alphaLower)
487 {
488 nextAlpha = alphaUpperBound;
489 }
490 else
491 {
492 nextAlpha = alphaLowerBound;
493 }
494 }
495
496 alphaLower = alpha;
497 fLower = f;
498 dLower = d;
499
500 return nextAlpha;
501 }
502
503 private:
504 Size const mMaxIterations;
505 Float<> const mC1;
506 Float<> const mC2;
507 Float<> const mTrapLower;
508 Float<> const mTrapUpper;
509 Float<> const mDelta;
510 };
511}
512#endif // pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_MORE_THUENTE_H
Definition value.hpp:12
Definition tensor.hpp:25
Definition concepts.hpp:43
Definition value.hpp:24
Definition concepts.hpp:31
const double y
Definition gmock-matchers-containers_test.cc:377
int x
Definition gmock-matchers-containers_test.cc:376
const char * p
Definition gmock-matchers-containers_test.cc:379
static void debug(Xs &&...args)
Definition log.hpp:33
static void error(Xs &&...args)
Definition log.hpp:14
Definition bracketing.hpp:11
Float(Float< 16 >::Fundamental const) -> Float< 16 >
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
static constexpr auto cast(T const &a)
Definition cast.hpp:11
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
std::invoke_result_t< F, Args... > ResultOf
Definition basics.hpp:59
static constexpr auto abs(T const &a)
Definition abs.hpp:11
static constexpr decltype(auto) min(X &&a)
Definition min.hpp:13
static constexpr auto scalarProduct(TA const &a, TB const &b)
Definition scalar_product.hpp:11
static constexpr auto delta(TA const &a, TB const &b)
Definition delta.hpp:11
constexpr auto cDebugLevel
Definition config.hpp:48
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto isZero(T const a)
Definition is_zero.hpp:11
static constexpr auto square(T const &a)
Definition square.hpp:11
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto mean(Xs &&...args)
Definition mean.hpp:16
static constexpr decltype(auto) max(X &&a)
Definition max.hpp:13
Definition gtest_pred_impl_unittest.cc:54
Definition limits.hpp:13
Definition more_thuente.hpp:20
constexpr decltype(auto) delta() const
Definition more_thuente.hpp:97
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition more_thuente.hpp:281
constexpr MoreThuente(Size const maxIterations=defaultMaxIterations(), Float<> const c1=defaultC1(), Float<> const c2=defaultC2(), Float<> const trapLower=defaultTrapLower(), Float<> const trapUpper=defaultTrapUpper(), Float<> const delta=defaultDelta())
Definition more_thuente.hpp:53
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, FC &&constraint, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition more_thuente.hpp:107
constexpr decltype(auto) c2() const
Definition more_thuente.hpp:79
constexpr decltype(auto) trapLower() const
Definition more_thuente.hpp:85
constexpr decltype(auto) c1() const
Definition more_thuente.hpp:73
constexpr auto maxIterations() const
Definition more_thuente.hpp:67
constexpr decltype(auto) trapUpper() const
Definition more_thuente.hpp:91