9#ifndef pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_MORE_THUENTE_H
10#define pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_MORE_THUENTE_H
22 static constexpr Size defaultMaxIterations()
27 static constexpr Float<> defaultC1()
32 static constexpr Float<> defaultC2()
37 static constexpr Float<> defaultTrapLower()
42 static constexpr Float<> defaultTrapUpper()
47 static constexpr Float<> defaultDelta()
69 return mMaxIterations;
72 template<
class T = Float<>>
73 constexpr decltype(
auto)
c1()
const
78 template<
class T = Float<>>
79 constexpr decltype(
auto)
c2()
const
84 template<
class T = Float<>>
90 template<
class T = Float<>>
96 template<
class T = Float<>>
97 constexpr decltype(
auto)
delta()
const
102 template<IsTensor X, IsFloat T = Value<X>,
class F,
class FC>
125 if(alphaMax < alphaMin)
128 "LS-MoreThuente: Minimum alpha > Maximum alpha.");
131 if(alpha <= alphaMin || alpha > alphaMax)
141 auto const curvatureTest =
c1<T>() * d0;
143 T alphaLower =
zero();
147 T alphaUpper =
zero();
151 auto alphaLowerBound =
zero<T>();
154 auto nextWidth = alphaMax - alphaMin;
157 Bool firstStage =
true;
158 Bool bracketed =
false;
159 for(
Index iteration = 0;; ++iteration)
169 if(alpha <= alphaLowerBound || alpha >= alphaUpperBound)
172 "Line search: Rounding errors prevent progress.");
176 if(
isZero(alphaUpperBound - alphaLowerBound))
183 x = constraint(x0 + alpha *
p);
187 auto const sufficientDecreaseTest = f0 + alpha * curvatureTest;
189 if(alpha == alphaMax && f <= sufficientDecreaseTest &&
196 if(alpha == alphaMin &&
197 (f > sufficientDecreaseTest || d >= curvatureTest))
203 if(f <= sufficientDecreaseTest &&
abs(d) <=
abs(
c2<T>() * d0))
208 if(firstStage && f <= sufficientDecreaseTest && d >=
zero())
213 if(firstStage && f <= fLower && f > sufficientDecreaseTest)
215 auto const fMod = f - alpha * curvatureTest;
216 auto fLowerMod = fLower - alphaLower * curvatureTest;
217 auto fUpperMod = fUpper - alphaUpper * curvatureTest;
218 auto const dMod = d - curvatureTest;
219 auto dLowerMod = dLower - curvatureTest;
220 auto dUpperMod = dUpper - curvatureTest;
222 alpha = computeAlpha(alphaLower, fLowerMod, dLowerMod,
223 alphaUpper, fUpperMod, dUpperMod, alpha, fMod, dMod,
224 bracketed, alphaLowerBound, alphaUpperBound);
226 fLower = fLowerMod + alphaLower * curvatureTest;
227 fUpper = fUpperMod + alphaUpper * curvatureTest;
228 dLower = dLowerMod + curvatureTest;
229 dUpper = dUpperMod + curvatureTest;
233 alpha = computeAlpha(alphaLower, fLower, dLower, alphaUpper,
234 fUpper, dUpper, alpha, f, d, bracketed, alphaLowerBound,
240 auto const alphaDelta =
pRC::delta(alphaLower, alphaUpper);
242 if(alphaDelta >=
delta<T>() * width)
244 alpha =
mean(alphaLower, alphaUpper);
248 nextWidth = alphaDelta;
253 alphaLowerBound =
min(alphaLower, alphaUpper);
254 alphaUpperBound =
max(alphaLower, alphaUpper);
264 alpha =
max(alpha, alphaMin);
265 alpha =
min(alpha, alphaMax);
268 ((alpha <= alphaLowerBound || alpha >= alphaUpperBound) ||
269 isZero(alphaUpperBound - alphaLowerBound)))
278 template<IsTensor X, IsFloat T = Value<X>,
class F>
287 x, f, g, d, forward<F>(function),
288 [](
auto &&
x) ->
decltype(
auto)
290 return forward<decltype(x)>(
x);
292 p, alpha, alphaMin, alphaMax);
297 static constexpr auto secantMinimizer(
T const &
x,
T const &dX,
298 T const &
y,
T const &dY)
300 return y + dY / (dY - dX) * (
x -
y);
304 static constexpr auto quadraticMinimizer(
T const &
x,
T const &fX,
305 T const &dX,
T const &
y,
T const &fY)
312 static constexpr auto cubicMinimizer(
T const &
x,
T const &fX,
313 T const &dX,
T const &
y,
T const &fY,
T const &dY)
315 auto const theta =
identity<T>(3) * (fX - fY) / (
y -
x) + dX + dY;
322 auto gamma = s *
sqrt(
square(theta / s) - (dX / s) * (dY / s));
330 auto const p = (gamma - dX) + theta;
331 auto const q = (gamma - dX) + gamma + dY;
332 auto const r =
p / q;
334 return x + r * (
y -
x);
338 static constexpr auto cubicMinimizer(
T const &
x,
T const &fX,
339 T const &dX,
T const &
y,
T const &fY,
T const &dY,
340 T const &lowerBound,
T const &upperBound)
342 auto const theta =
identity<T>(3) * (fX - fY) / (
y -
x) + dX + dY;
358 auto const p = (gamma - dX) + theta;
359 auto const q = (gamma - dX) + gamma + dY;
360 auto const r =
p / q;
362 if((r <
zero()) && (gamma !=
zero()))
364 return x + r * (
y -
x);
378 constexpr auto computeAlpha(
T &alphaLower,
T &fLower,
T &dLower,
379 T &alphaUpper,
T &fUpper,
T &dUpper,
T const &alpha,
T const &f,
380 T const &d,
Bool &bracketed,
T const &alphaLowerBound,
381 T const &alphaUpperBound)
const
386 cubicMinimizer(alphaLower, fLower, dLower, alpha, f, d);
389 quadraticMinimizer(alphaLower, fLower, dLower, alpha, f);
403 return mean(alphaQ, alphaC);
408 if(d * dLower <
zero())
411 cubicMinimizer(alpha, f, d, alphaLower, fLower, dLower);
414 secantMinimizer(alphaLower, dLower, alpha, d);
426 alphaUpper = alphaLower;
430 else if(
abs(d) <
abs(dLower))
432 auto const alphaC = cubicMinimizer(alpha, f, d, alphaLower,
433 fLower, dLower, alphaLowerBound, alphaUpperBound);
436 secantMinimizer(alphaLower, dLower, alpha, d);
440 auto const trap = alpha +
delta<T>() * (alphaUpper - alpha);
443 if(alpha > alphaLower)
445 nextAlpha =
min(trap, alphaC);
449 nextAlpha =
max(trap, alphaC);
454 if(alpha > alphaLower)
456 nextAlpha =
min(trap, alphaS);
460 nextAlpha =
max(trap, alphaS);
475 nextAlpha =
min(alphaUpperBound, nextAlpha);
476 nextAlpha =
max(alphaLowerBound, nextAlpha);
484 cubicMinimizer(alpha, f, d, alphaUpper, fUpper, dUpper);
486 else if(alpha > alphaLower)
488 nextAlpha = alphaUpperBound;
492 nextAlpha = alphaLowerBound;
504 Size const mMaxIterations;
Definition concepts.hpp:43
Definition concepts.hpp:31
const double y
Definition gmock-matchers-containers_test.cc:377
int x
Definition gmock-matchers-containers_test.cc:376
const char * p
Definition gmock-matchers-containers_test.cc:379
static void debug(Xs &&...args)
Definition log.hpp:33
static void error(Xs &&...args)
Definition log.hpp:14
Definition bracketing.hpp:11
Float(Float< 16 >::Fundamental const) -> Float< 16 >
static constexpr auto sqrt(T const &a)
Definition sqrt.hpp:11
static constexpr auto cast(T const &a)
Definition cast.hpp:11
Size Index
Definition basics.hpp:32
std::size_t Size
Definition basics.hpp:31
std::invoke_result_t< F, Args... > ResultOf
Definition basics.hpp:59
static constexpr auto abs(T const &a)
Definition abs.hpp:11
static constexpr decltype(auto) min(X &&a)
Definition min.hpp:13
static constexpr auto scalarProduct(TA const &a, TB const &b)
Definition scalar_product.hpp:11
static constexpr auto delta(TA const &a, TB const &b)
Definition delta.hpp:11
constexpr auto cDebugLevel
Definition config.hpp:48
RemoveConst< RemoveReference< T > > RemoveConstReference
Definition basics.hpp:47
static constexpr auto isZero(T const a)
Definition is_zero.hpp:11
static constexpr auto square(T const &a)
Definition square.hpp:11
static constexpr auto identity()
Definition identity.hpp:13
static constexpr auto zero()
Definition zero.hpp:12
static constexpr auto mean(Xs &&...args)
Definition mean.hpp:16
static constexpr decltype(auto) max(X &&a)
Definition max.hpp:13
Definition gtest_pred_impl_unittest.cc:54
Definition more_thuente.hpp:20
constexpr decltype(auto) delta() const
Definition more_thuente.hpp:97
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition more_thuente.hpp:281
constexpr MoreThuente(Size const maxIterations=defaultMaxIterations(), Float<> const c1=defaultC1(), Float<> const c2=defaultC2(), Float<> const trapLower=defaultTrapLower(), Float<> const trapUpper=defaultTrapUpper(), Float<> const delta=defaultDelta())
Definition more_thuente.hpp:53
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, FC &&constraint, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition more_thuente.hpp:107
constexpr decltype(auto) c2() const
Definition more_thuente.hpp:79
constexpr decltype(auto) trapLower() const
Definition more_thuente.hpp:85
constexpr decltype(auto) c1() const
Definition more_thuente.hpp:73
constexpr auto maxIterations() const
Definition more_thuente.hpp:67
constexpr decltype(auto) trapUpper() const
Definition more_thuente.hpp:91