cMHN 1.1
C++ library for learning MHNs with pRC
Loading...
Searching...
No Matches
bracketing.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: BSD-2-Clause
2
3#ifndef pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_BRACKETING_H
4#define pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_BRACKETING_H
5
6#include <prc/config.hpp>
13
15{
17 {
18 private:
19 static constexpr Bool defaultStrongWolfeConditions()
20 {
21 return true;
22 }
23
24 static constexpr Float<> defaultDecreasingStepScale()
25 {
26 return 0.5;
27 }
28
29 static constexpr Float<> defaultIncreasingStepScale()
30 {
31 return 2.1;
32 }
33
34 static constexpr Size defaultMaxIterations()
35 {
36 return 20;
37 }
38
39 static constexpr Float<> defaultC1()
40 {
41 return 1e-4;
42 }
43
44 static constexpr Float<> defaultC2()
45 {
46 return 0.9;
47 }
48
49 public:
50 constexpr Bracketing(
51 Bool const strongWolfeConditions = defaultStrongWolfeConditions(),
52 Float<> const decreasingStepScale = defaultDecreasingStepScale(),
53 Float<> const increasingStepScale = defaultIncreasingStepScale(),
54 Size const maxIterations = defaultMaxIterations(),
55 Float<> const c1 = defaultC1(), Float<> const c2 = defaultC2())
56 : mStrongWolfeConditions(strongWolfeConditions)
57 , mDecreasingStepScale(decreasingStepScale)
58 , mIncreasingStepScale(increasingStepScale)
59 , mMaxIterations(maxIterations)
60 , mC1(c1)
61 , mC2(c2)
62 {
63 }
64
65 constexpr auto strongWolfeConditions() const
66 {
67 return mStrongWolfeConditions;
68 }
69
70 template<class T = Float<>>
71 constexpr decltype(auto) decreasingStepScale() const
72 {
73 return cast<T>(mDecreasingStepScale);
74 }
75
76 template<class T = Float<>>
77 constexpr decltype(auto) increasingStepScale() const
78 {
79 return cast<T>(mIncreasingStepScale);
80 }
81
82 constexpr auto maxIterations() const
83 {
84 return mMaxIterations;
85 }
86
87 template<class T = Float<>>
88 constexpr decltype(auto) c1() const
89 {
90 return cast<T>(mC1);
91 }
92
93 template<class T = Float<>>
94 constexpr decltype(auto) c2() const
95 {
96 return cast<T>(mC2);
97 }
98
99 template<class X, If<IsTensor<X>> = 0,
100 class T = typename X::Type::Value, If<IsFloat<T>> = 0, class F,
101 If<IsInvocable<F, X const &, X &>> = 0,
102 If<IsFloat<ResultOf<F, X const &, X &>>> = 0, class FC,
103 If<IsInvocable<FC, X const &>> = 0,
104 If<IsConvertible<ResultOf<FC, X const &>, X>> = 0>
107 FC &&constraint, X const &p, T alpha = identity<T>(),
108 T const alphaMin = zero<T>(),
110 {
111 if constexpr(cDebugLevel >= DebugLevel::High)
112 {
113 if(alphaMin < zero<T>())
114 {
115 Logging::error("LS-Bracketing: Minimum alpha < 0.");
116 }
117
118 if(alphaMax < zero<T>())
119 {
120 Logging::error("LS-Bracketing: Maximum alpha < 0.");
121 }
122
123 if(alphaMax < alphaMin)
124 {
126 "LS-Bracketing: Minimum alpha > Maximum alpha.");
127 }
128
130 {
131 Logging::error("Initial alpha not in range (min, max]");
132 }
133 }
134
135 Tensor const x0 = x;
136 auto const f0 = f;
137 auto const d0 = d;
138
139 auto low = alphaMin;
140 auto high = alphaMax;
141
142 for(Index iteration = 0;; ++iteration,
145 {
147 {
148 Logging::debug("Line search: Max iterations reached.");
149 break;
150 }
151
152 x = constraint(x0 + alpha * p);
153 f = function(x, g);
154
155 if(f > f0 + c1<T>() * alpha * d0)
156 {
157 high = alpha;
159 "First Wolfe condition not satisfied, trying again");
160 continue;
161 }
162
163 d = scalarProduct(g, p)();
164 if(d < c2<T>() * d0)
165 {
166 low = alpha;
168 "Second Wolfe condition not satisfied, trying again");
169 continue;
170 }
171
173 {
174 if(abs(d) > abs(c2<T>() * d0))
175 {
176 high = alpha;
178 "Strong Wolfe condition not satisfied, trying "
179 "again");
180 continue;
181 }
182 }
183
184 break;
185 }
186
187 return alpha;
188 }
189
190 template<class X, If<IsTensor<X>> = 0,
191 class T = typename X::Type::Value, If<IsFloat<T>> = 0, class F,
192 If<IsInvocable<F, X const &, X &>> = 0,
193 If<IsFloat<ResultOf<F, X const &, X &>>> = 0>
196 X const &p, T alpha = identity<T>(), T const alphaMin = zero<T>(),
198 {
199 return operator()(
200 x, f, g, d, forward<F>(function),
201 [](auto &&x) -> decltype(auto)
202 {
203 return forward<decltype(x)>(x);
204 },
206 }
207
208 private:
209 Bool const mStrongWolfeConditions;
210 Float<> const mDecreasingStepScale;
211 Float<> const mIncreasingStepScale;
212 Size const mMaxIterations;
213 Float<> const mC1;
214 Float<> const mC2;
215 };
216}
217#endif // pRC_ALGORITHMS_OPTIMIZER_LINE_SEARCH_BRACKETING_H
Definition tensor.hpp:28
static void info(Xs &&...args)
Definition log.hpp:27
static void debug(Xs &&...args)
Definition log.hpp:33
static void error(Xs &&...args)
Definition log.hpp:14
Definition bracketing.hpp:15
bool Bool
Definition type_traits.hpp:18
static constexpr X min(X &&a)
Definition min.hpp:13
static constexpr auto makeConstantSequence()
Definition sequence.hpp:402
Size Index
Definition type_traits.hpp:21
std::size_t Size
Definition type_traits.hpp:20
std::invoke_result_t< F, Args... > ResultOf
Definition type_traits.hpp:140
static constexpr auto abs(Complex< T > const &a)
Definition abs.hpp:12
constexpr auto cDebugLevel
Definition config.hpp:46
static constexpr auto scalarProduct(Complex< TA > const &a, Complex< TB > const &b)
Definition scalar_product.hpp:13
Definition limits.hpp:13
Definition bracketing.hpp:17
constexpr decltype(auto) decreasingStepScale() const
Definition bracketing.hpp:71
constexpr auto strongWolfeConditions() const
Definition bracketing.hpp:65
constexpr decltype(auto) c2() const
Definition bracketing.hpp:94
constexpr decltype(auto) increasingStepScale() const
Definition bracketing.hpp:77
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition bracketing.hpp:194
constexpr auto operator()(X &x, ResultOf< F, X const &, X & > &f, X &g, typename ResultOf< ScalarProduct, X, X >::Type &d, F &&function, FC &&constraint, X const &p, T alpha=identity< T >(), T const alphaMin=zero< T >(), T const alphaMax=identity< T >(NumericLimits< T >::max())) const
Definition bracketing.hpp:105
constexpr Bracketing(Bool const strongWolfeConditions=defaultStrongWolfeConditions(), Float<> const decreasingStepScale=defaultDecreasingStepScale(), Float<> const increasingStepScale=defaultIncreasingStepScale(), Size const maxIterations=defaultMaxIterations(), Float<> const c1=defaultC1(), Float<> const c2=defaultC2())
Definition bracketing.hpp:50
constexpr auto maxIterations() const
Definition bracketing.hpp:82
constexpr decltype(auto) c1() const
Definition bracketing.hpp:88