Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #include <optional>
7 : #include <sstream>
8 : #include <stdexcept>
9 : #include <tuple>
10 : #include <vector>
11 :
12 : #include "DataStructures/DataVector.hpp"
13 : #include "Utilities/ErrorHandling/Error.hpp"
14 : #include "Utilities/Gsl.hpp"
15 :
16 : namespace RootFinder {
17 :
18 : namespace bracketing_detail {
19 : // Brackets a root, given a functor f(x) that returns a
20 : // std::optional<double> and given two arrays x and y (with y=f(x))
21 : // containing points that have already been tried for bracketing.
22 : //
23 : // Returns a std::tuple<double, double, double, double>
24 : // containing {{x1,x2,y1,y2}} where
25 : // x1 and x2 bracket the root, and y1=f(x1) and y2=f(x2).
26 : //
27 : // Note that y might be undefined (i.e. an invalid std::optional)
28 : // for values of x at or near the endpoints of the interval. We
29 : // assume that if f(x1) and f(x2) are valid for some x1 and x2, then
30 : // f(x) is valid for all x between x1 and x2.
31 : //
32 : // Assumes that there is a root between the first and last points of
33 : // the array.
34 : // We also assume that there is only one root.
35 : //
36 : // So this means we have only 2 possibilities for the validity of the
37 : // points in the input x,y arrays:
38 : // 1) All points are invalid, e.g. "X X X X X X X X".
39 : // (here X represents an invalid point)
40 : // 2) All valid points are adjacent, with the same sign, e.g. "X X o o X X"
41 : // or "o o X X X" or "X X X o o".
42 : // (here o represents a valid point)
43 : // Note that we assume that all valid points have the same sign; otherwise
44 : // the caller would have known that the root was bracketed and the caller would
45 : // not have called bracket_by_contracting.
46 : //
47 : // Also note that we exclude the case "o o o o o" (no roots), since the
48 : // caller would have known that too. If such a case is found, an error is
49 : // thrown. An error is also thrown if the size of the region where the
50 : // sign changes is so small that the number of iterations is exceeded.
51 : //
52 : // For case 1) above, we bisect each pair of points, and call
53 : // bracket_by_contracting recursively until we find a valid point.
54 : // For case 2) above, it is sufficent to check for a bracket only
55 : // between valid and invalid points. That is, for "X X + + X X" we
56 : // check only between points 1 and 2 and between points 3 and 4 (where
57 : // points are numbered starting from zero). For "+ + X X X" we check
58 : // only between points 1 and 2.
59 : template <typename Functor>
60 : auto bracket_by_contracting(const std::vector<double>& x,
61 : const std::vector<std::optional<double>>& y,
62 : const Functor& f, const size_t level = 0)
63 : -> std::tuple<double, double, double, double> {
64 : constexpr size_t max_level = 27;
65 : if (level > max_level) {
66 : std::stringstream ss;
67 : ss << "Too many iterations in bracket_by_contracting. Either the "
68 : "region where the root changes sign is so small that we cannot "
69 : "find it, or the given interval does not actually bracket a "
70 : "root. The points we are checking are "
71 : << x.size() << " almost-evenly-spaced points from " << x.front()
72 : << " to " << x.back() << " with difference " << x.back()-x.front();
73 : throw std::runtime_error(ss.str());
74 : }
75 :
76 : // First check if we have any valid points.
77 : size_t last_valid_index = y.size();
78 : for (size_t i = y.size(); i >= 1; --i) {
79 : if (y[i - 1].has_value()) {
80 : last_valid_index = i - 1;
81 : break;
82 : }
83 : }
84 :
85 : if (last_valid_index == y.size()) {
86 : // No valid points!
87 :
88 : // Create larger arrays with one point between each of the already
89 : // computed points.
90 : std::vector<double> bisected_x(x.size() * 2 - 1);
91 : std::vector<std::optional<double>> bisected_y(y.size() * 2 - 1);
92 :
93 : // Copy all even-numbered points in the range.
94 : for (size_t i = 0; i < x.size(); ++i) {
95 : bisected_x[2 * i] = x[i];
96 : bisected_y[2 * i] = y[i];
97 : }
98 :
99 : // Fill midpoints and check for bracket on each one.
100 : for (size_t i = 0; i < x.size() - 1; ++i) {
101 : bisected_x[2 * i + 1] = x[i] + 0.5 * (x[i + 1] - x[i]);
102 : bisected_y[2 * i + 1] = f(bisected_x[2 * i + 1]);
103 : if (bisected_y[2 * i + 1].has_value()) {
104 : // Valid point! We know that all the other points are
105 : // invalid, so we need to check only 3 points in the next
106 : // iteration: the new valid point and its neighbors.
107 : return bracket_by_contracting({{x[i], bisected_x[2 * i + 1], x[i + 1]}},
108 : {{y[i], bisected_y[2 * i + 1], y[i + 1]}},
109 : f, level + 1);
110 : }
111 : }
112 : // We still have no valid points. So recurse, using all points.
113 : // The next iteration will bisect all the points.
114 : return bracket_by_contracting(bisected_x, bisected_y, f, level + 1);
115 : }
116 :
117 : // If we get here, we have found a valid point; in particular we have
118 : // found the last valid point in the array.
119 :
120 : // Find the first valid point in the array.
121 : size_t first_valid_index = 0;
122 : for (size_t i = 0; i < y.size(); ++i) {
123 : if (y[i].has_value()) {
124 : first_valid_index = i;
125 : break;
126 : }
127 : }
128 :
129 : // Make a new set of points that includes only the points that
130 : // neighbor the boundary between valid and invalid points.
131 : std::vector<double> x_near_valid_point;
132 : std::vector<std::optional<double>> y_near_valid_point;
133 :
134 : if (first_valid_index == 0 and last_valid_index == y.size() - 1) {
135 : ERROR(
136 : "bracket_while_contracting: found a case where all points are valid,"
137 : "which should not happen under our assumptions.");
138 : }
139 :
140 : if (first_valid_index > 0) {
141 : // Check for a root between first_valid_index-1 and first_valid_index.
142 : const double x_test =
143 : x[first_valid_index - 1] +
144 : 0.5 * (x[first_valid_index] - x[first_valid_index - 1]);
145 : const auto y_test = f(x_test);
146 : if (y_test.has_value() and
147 : y[first_valid_index].value() * y_test.value() <= 0.0) {
148 : // Bracketed!
149 : return {
150 : x_test, x[first_valid_index], y_test.value(),
151 : y[first_valid_index].value()};
152 : } else {
153 : x_near_valid_point.push_back(x[first_valid_index - 1]);
154 : y_near_valid_point.push_back(y[first_valid_index - 1]);
155 : x_near_valid_point.push_back(x_test);
156 : y_near_valid_point.push_back(y_test);
157 : x_near_valid_point.push_back(x[first_valid_index]);
158 : y_near_valid_point.push_back(y[first_valid_index]);
159 : }
160 : }
161 : if (last_valid_index < y.size() - 1) {
162 : // Check for a root between last_valid_index and last_valid_index+1.
163 : const double x_test = x[last_valid_index] +
164 : 0.5 * (x[last_valid_index + 1] - x[last_valid_index]);
165 : const auto y_test = f(x_test);
166 : if (y_test.has_value() and
167 : y[last_valid_index].value() * y_test.value() <= 0.0) {
168 : // Bracketed!
169 : return {x[last_valid_index], x_test, y[last_valid_index].value(),
170 : y_test.value()};
171 : } else {
172 : if (first_valid_index != last_valid_index or first_valid_index == 0) {
173 : x_near_valid_point.push_back(x[last_valid_index]);
174 : y_near_valid_point.push_back(y[last_valid_index]);
175 : } // else we already pushed back last_valid_index (==first_valid_index).
176 : x_near_valid_point.push_back(x_test);
177 : y_near_valid_point.push_back(y_test);
178 : x_near_valid_point.push_back(x[last_valid_index + 1]);
179 : y_near_valid_point.push_back(y[last_valid_index + 1]);
180 : }
181 : }
182 :
183 : // We have one or more valid points but we didn't find a bracket.
184 : // That is, we have something like "X X o o X X" or "X X o o" or "o o X X".
185 : // So recurse, zooming in to the boundary (either one boundary or two
186 : // boundaries) between valid and invalid points.
187 : // Note that "o o o o" is prohibited by our assumptions, and checked for
188 : // above just in case it occurs by mistake.
189 : return bracket_by_contracting(x_near_valid_point, y_near_valid_point, f,
190 : level + 1);
191 : }
192 : } // namespace bracketing_detail
193 :
194 : /*!
195 : * \ingroup NumericalAlgorithmsGroup
196 : * \brief Brackets the root of the function `f`, assuming a single
197 : * root in a given interval \f$f[x_\mathrm{lo},x_\mathrm{up}]\f$
198 : * and assuming that `f` is defined only in an unknown smaller
199 : * interval \f$f[x_a,x_b]\f$ where
200 : * \f$x_\mathrm{lo} \leq x_a \leq x_b \leq x_\mathrm{hi}\f$.
201 : *
202 : * `f` is a unary invokable that takes a `double` which is the current value at
203 : * which to evaluate `f`. `f` returns a `std::optional<double>` which
204 : * evaluates to false if the function is undefined at the supplied point.
205 : *
206 : * Assumes that there is only one root in the interval.
207 : *
208 : * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
209 : * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
210 : * between \f$x_1\f$ and \f$x_2\f$.
211 : *
212 : * On input, assumes that the root lies in the interval
213 : * [`lower_bound`,`upper_bound`]. Optionally takes a `guess` for the
214 : * location of the root. If `guess` is supplied, then evaluates the
215 : * function first at `guess` and `upper_bound` before trying
216 : * `lower_bound`: this means that it would be optimal if `guess`
217 : * underestimates the actual root and if `upper_bound` was less likely
218 : * to be undefined than `lower_bound`.
219 : *
220 : * On return, `lower_bound` and `upper_bound` are replaced with values that
221 : * bracket the root and for which the function is defined, and
222 : * `f_at_lower_bound` and `f_at_upper_bound` are replaced with
223 : * `f` evaluated at those bracketing points.
224 : *
225 : * `bracket_possibly_undefined_function_in_interval` throws an error if
226 : * all points are valid but of the same sign (because that would indicate
227 : * multiple roots but we assume only one root), if no root exists, or
228 : * if the range of a sign change is sufficently small relative to the
229 : * given interval that the number of iterations to find the root is exceeded.
230 : *
231 : */
232 : template <typename Functor>
233 1 : void bracket_possibly_undefined_function_in_interval(
234 : const gsl::not_null<double*> lower_bound,
235 : const gsl::not_null<double*> upper_bound,
236 : const gsl::not_null<double*> f_at_lower_bound,
237 : const gsl::not_null<double*> f_at_upper_bound, const Functor& f,
238 : const double guess) {
239 : // Initial values of x1,x2,y1,y2. Use `guess` and `upper_bound`,
240 : // because in typical usage `guess` underestimates the actual
241 : // root, and `lower_bound` is more likely than `upper_bound` to be
242 : // invalid.
243 : double x1 = guess;
244 : double x2 = *upper_bound;
245 : auto y1 = f(x1);
246 : auto y2 = f(x2);
247 : const bool y1_defined = y1.has_value();
248 : const bool y2_defined = y2.has_value();
249 : if (not(y1_defined and y2_defined and y1.value() * y2.value() <= 0.0)) {
250 : // Root is not bracketed.
251 : // Before moving to the general algorithm, try the remaining
252 : // input point that was supplied.
253 : const double x3 = *lower_bound;
254 : const auto y3 = f(x3);
255 : const bool y3_defined = y3.has_value();
256 : if (y1_defined and y3_defined and y1.value() * y3.value() <= 0.0) {
257 : // Bracketed! Throw out x2,y2. Rename variables to keep x1 < x2.
258 : x2 = x1;
259 : y2 = y1;
260 : x1 = x3;
261 : y1 = y3;
262 : } else {
263 : // Our simple checks didn't work, so call the more general method.
264 : // There are 8 cases:
265 : //
266 : // y3 y1 y2
267 : // --------
268 : // X X X
269 : // o X X
270 : // X o X
271 : // o o X
272 : // X o o
273 : // o o o
274 : // X X o
275 : // o X o
276 : //
277 : // where X means an invalid point, o means a valid point.
278 : // All valid points have the same sign, or we would have found a
279 : // bracket already.
280 : //
281 : // Before calling the general case, error on "o o o" and "o X o".
282 : // Both of these are prohibited by our assumptions (we
283 : // assume the root is in the interval so no "o o o", and we
284 : // assume that all invalid points are at the end of interval, so no
285 : // "o X o").
286 : if (y2_defined and y3_defined) {
287 : ERROR(
288 : "bracket_possibly_undefined_function_in_interval: found "
289 : "case that should not happen under our assumptions.");
290 : }
291 : try {
292 : std::tie(x1, x2, y1, y2) = bracketing_detail::bracket_by_contracting(
293 : {{x3, x1, x2}}, {{y3, y1, y2}}, f);
294 : } catch (std::runtime_error& e) {
295 : std::stringstream ss;
296 : ss << "bracket_by_contracting: Cannot bracket root between "
297 : << *lower_bound << " and " << *upper_bound
298 : << ". Internal error message is: '" << e.what() << "'";
299 : throw std::runtime_error(ss.str());
300 : }
301 : }
302 : }
303 : *f_at_lower_bound = y1.value();
304 : *f_at_upper_bound = y2.value();
305 : *lower_bound = x1;
306 : *upper_bound = x2;
307 : }
308 :
309 : /*!
310 : * \ingroup NumericalAlgorithmsGroup
311 : * \brief Brackets the single root of the
312 : * function `f` for each element in a `DataVector`, assuming the root
313 : * lies in the given interval and that `f` may be undefined at some
314 : * points in the interval.
315 : *
316 : * `f` is a binary invokable that takes a `double` and a `size_t` as
317 : * arguments. The `double` is the current value at which to evaluate
318 : * `f`, and the `size_t` is the index into the `DataVector`s. `f`
319 : * returns a `std::optional<double>` which evaluates to false if the
320 : * function is undefined at the supplied point.
321 : *
322 : * Assumes that there is only one root in the interval.
323 : *
324 : * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
325 : * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
326 : * between \f$x_1\f$ and \f$x_2\f$.
327 : *
328 : * On input, assumes that the root lies in the interval
329 : * [`lower_bound`,`upper_bound`]. Optionally takes a `guess` for the
330 : * location of the root.
331 : *
332 : * On return, `lower_bound` and `upper_bound` are replaced with values that
333 : * bracket the root and for which the function is defined, and
334 : * `f_at_lower_bound` and `f_at_upper_bound` are replaced with
335 : * `f` evaluated at those bracketing points.
336 : *
337 : */
338 : template <typename Functor>
339 1 : void bracket_possibly_undefined_function_in_interval(
340 : const gsl::not_null<DataVector*> lower_bound,
341 : const gsl::not_null<DataVector*> upper_bound,
342 : const gsl::not_null<DataVector*> f_at_lower_bound,
343 : const gsl::not_null<DataVector*> f_at_upper_bound, const Functor& f,
344 : const DataVector& guess) {
345 : for (size_t s = 0; s < lower_bound->size(); ++s) {
346 : bracket_possibly_undefined_function_in_interval(
347 : &((*lower_bound)[s]), &((*upper_bound)[s]), &((*f_at_lower_bound)[s]),
348 : &((*f_at_upper_bound)[s]), [&f, &s](const double x) { return f(x, s); },
349 : guess[s]);
350 : }
351 : }
352 :
353 : /*
354 : * Version of `bracket_possibly_undefined_function_in_interval`
355 : * without a supplied initial guess; uses the mean of `lower_bound` and
356 : * `upper_bound` as the guess.
357 : */
358 : template <typename Functor>
359 0 : void bracket_possibly_undefined_function_in_interval(
360 : const gsl::not_null<double*> lower_bound,
361 : const gsl::not_null<double*> upper_bound,
362 : const gsl::not_null<double*> f_at_lower_bound,
363 : const gsl::not_null<double*> f_at_upper_bound, const Functor& f) {
364 : bracket_possibly_undefined_function_in_interval(
365 : lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
366 : *lower_bound + 0.5 * (*upper_bound - *lower_bound));
367 : }
368 :
369 : /*
370 : * Version of `bracket_possibly_undefined_function_in_interval`
371 : * without a supplied initial guess; uses the mean of `lower_bound` and
372 : * `upper_bound` as the guess.
373 : */
374 : template <typename Functor>
375 0 : void bracket_possibly_undefined_function_in_interval(
376 : const gsl::not_null<DataVector*> lower_bound,
377 : const gsl::not_null<DataVector*> upper_bound,
378 : const gsl::not_null<DataVector*> f_at_lower_bound,
379 : const gsl::not_null<DataVector*> f_at_upper_bound, const Functor& f) {
380 : bracket_possibly_undefined_function_in_interval(
381 : lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
382 : *lower_bound + 0.5 * (*upper_bound - *lower_bound));
383 : }
384 : } // namespace RootFinder
|