RootBracketing.hpp
1 // Distributed under the MIT License.
2 // See LICENSE.txt for details.
3 
4 #pragma once
5 
6 #include <optional>
7 #include <vector>
8 
9 #include "DataStructures/DataVector.hpp"
11 #include "Utilities/Gsl.hpp"
12 
13 namespace RootFinder {
14 namespace bracketing_detail {
15 // Brackets a root, given a functor f(x) that returns a
16 // std::optional<double> and given two arrays x and y (with y=f(x))
17 // containing points that have already been tried for bracketing.
18 //
19 // Returns a std::array<double,4> containing {{x1,x2,y1,y2}} where
20 // x1 and x2 bracket the root, and y1=f(x1) and y2=f(x2).
21 //
22 // Note that y might be undefined (i.e. an invalid std::optional)
23 // for values of x at or near the endpoints of the interval. We
24 // assume that if f(x1) and f(x2) are valid for some x1 and x2, then
25 // f(x) is valid for all x between x1 and x2.
26 //
27 // Assumes that there is a root between the first and last points of
28 // the array.
29 // We also assume that there is only one root.
30 //
31 // So this means we have only 2 possibilities for the validity of the
32 // points in the input x,y arrays:
33 // 1) All points are invalid, e.g. "X X X X X X X X".
34 // (here X represents an invalid point)
35 // 2) All valid points are adjacent, with the same sign, e.g. "X X o o X X"
36 // or "o o X X X" or "X X X o o".
37 // (here o represents a valid point)
38 // Note that we assume that all valid points have the same sign; otherwise
39 // the caller would have known that the root was bracketed and the caller would
40 // not have called bracket_by_contracting.
41 //
42 // Also note that we exclude the case "o o o o o" (no roots), since the
43 // caller would have known that too. If such a case is found, an error is
44 // thrown. An error is also thrown if the size of the region where the
45 // sign changes is so small that the number of iterations is exceeded.
46 //
47 // For case 1) above, we bisect each pair of points, and call
48 // bracket_by_contracting recursively until we find a valid point.
49 // For case 2) above, it is sufficent to check for a bracket only
50 // between valid and invalid points. That is, for "X X + + X X" we
51 // check only between points 1 and 2 and between points 3 and 4 (where
52 // points are numbered starting from zero). For "+ + X X X" we check
53 // only between points 1 and 2.
54 template <typename Functor>
55 std::array<double, 4> bracket_by_contracting(
57  const Functor& f, const size_t level = 0) noexcept {
58  constexpr size_t max_level = 6;
59  if (level > max_level) {
60  ERROR("Too many iterations in bracket_by_contracting. Either refine the "
61  "initial range/guess or increase max_level.");
62  }
63 
64  // First check if we have any valid points.
65  size_t last_valid_index = y.size();
66  for (size_t i = y.size(); i >= 1; --i) {
67  if (y[i - 1].has_value()) {
68  last_valid_index = i - 1;
69  break;
70  }
71  }
72 
73  if (last_valid_index == y.size()) {
74  // No valid points!
75 
76  // Create larger arrays with one point between each of the already
77  // computed points.
78  std::vector<double> bisected_x(x.size() * 2 - 1);
79  std::vector<std::optional<double>> bisected_y(y.size() * 2 - 1);
80 
81  // Copy all even-numbered points in the range.
82  for (size_t i = 0; i < x.size(); ++i) {
83  bisected_x[2 * i] = x[i];
84  bisected_y[2 * i] = y[i];
85  }
86 
87  // Fill midpoints and check for bracket on each one.
88  for (size_t i = 0; i < x.size() - 1; ++i) {
89  bisected_x[2 * i + 1] = x[i] + 0.5 * (x[i + 1] - x[i]);
90  bisected_y[2 * i + 1] = f(bisected_x[2 * i + 1]);
91  if (bisected_y[2 * i + 1].has_value()) {
92  // Valid point! We know that all the other points are
93  // invalid, so we need to check only 3 points in the next
94  // iteration: the new valid point and its neighbors.
95  return bracket_by_contracting({{x[i], bisected_x[2 * i + 1], x[i + 1]}},
96  {{y[i], bisected_y[2 * i + 1], y[i + 1]}},
97  f, level + 1);
98  }
99  }
100  // We still have no valid points. So recurse, using all points.
101  // The next iteration will bisect all the points.
102  return bracket_by_contracting(bisected_x, bisected_y, f, level + 1);
103  }
104 
105  // If we get here, we have found a valid point; in particular we have
106  // found the last valid point in the array.
107 
108  // Find the first valid point in the array.
109  size_t first_valid_index = 0;
110  for (size_t i = 0; i < y.size(); ++i) {
111  if (y[i].has_value()) {
112  first_valid_index = i;
113  break;
114  }
115  }
116 
117  // Make a new set of points that includes only the points that
118  // neighbor the boundary between valid and invalid points.
119  std::vector<double> x_near_valid_point;
120  std::vector<std::optional<double>> y_near_valid_point;
121 
122  if (first_valid_index == 0 and last_valid_index == y.size() - 1) {
123  ERROR(
124  "bracket_while_contracting: found a case where all points are valid,"
125  "which should not happen under our assumptions.");
126  }
127 
128  if (first_valid_index > 0) {
129  // Check for a root between first_valid_index-1 and first_valid_index.
130  const double x_test =
131  x[first_valid_index - 1] +
132  0.5 * (x[first_valid_index] - x[first_valid_index - 1]);
133  const auto y_test = f(x_test);
134  if (y_test.has_value() and
135  y[first_valid_index].value() * y_test.value() <= 0.0) {
136  // Bracketed!
137  return std::array<double, 4>{{x_test, x[first_valid_index],
138  y_test.value(),
139  y[first_valid_index].value()}};
140  } else {
141  x_near_valid_point.push_back(x[first_valid_index - 1]);
142  y_near_valid_point.push_back(y[first_valid_index - 1]);
143  x_near_valid_point.push_back(x_test);
144  y_near_valid_point.push_back(y_test);
145  x_near_valid_point.push_back(x[first_valid_index]);
146  y_near_valid_point.push_back(y[first_valid_index]);
147  }
148  }
149  if (last_valid_index < y.size() - 1) {
150  // Check for a root between last_valid_index and last_valid_index+1.
151  const double x_test = x[last_valid_index] +
152  0.5 * (x[last_valid_index + 1] - x[last_valid_index]);
153  const auto y_test = f(x_test);
154  if (y_test.has_value() and
155  y[last_valid_index].value() * y_test.value() <= 0.0) {
156  // Bracketed!
157  return std::array<double, 4>{{x[last_valid_index], x_test,
158  y[last_valid_index].value(),
159  y_test.value()}};
160  } else {
161  if (first_valid_index != last_valid_index or first_valid_index == 0) {
162  x_near_valid_point.push_back(x[last_valid_index]);
163  y_near_valid_point.push_back(y[last_valid_index]);
164  } // else we already pushed back last_valid_index (==first_valid_index).
165  x_near_valid_point.push_back(x_test);
166  y_near_valid_point.push_back(y_test);
167  x_near_valid_point.push_back(x[last_valid_index + 1]);
168  y_near_valid_point.push_back(y[last_valid_index + 1]);
169  }
170  }
171 
172  // We have one or more valid points but we didn't find a bracket.
173  // That is, we have something like "X X o o X X" or "X X o o" or "o o X X".
174  // So recurse, zooming in to the boundary (either one boundary or two
175  // boundaries) between valid and invalid points.
176  // Note that "o o o o" is prohibited by our assumptions, and checked for
177  // above just in case it occurs by mistake.
178  return bracket_by_contracting(x_near_valid_point, y_near_valid_point, f,
179  level + 1);
180 }
181 } // namespace bracketing_detail
182 
183 /*!
184  * \ingroup NumericalAlgorithmsGroup
185  * \brief Brackets the root of the function `f`, assuming a single
186  * root in a given interval \f$f[x_\mathrm{lo},x_\mathrm{up}]\f$
187  * and assuming that `f` is defined only in an unknown smaller
188  * interval \f$f[x_a,x_b]\f$ where
189  * \f$x_\mathrm{lo} \leq x_a \leq x_b \leq x_\mathrm{hi}\f$.
190  *
191  * `f` is a unary invokable that takes a `double` which is the current value at
192  * which to evaluate `f`. `f` returns a `std::optional<double>` which
193  * evaluates to false if the function is undefined at the supplied point.
194  *
195  * Assumes that there is only one root in the interval.
196  *
197  * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
198  * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
199  * between \f$x_1\f$ and \f$x_2\f$.
200  *
201  * On input, assumes that the root lies in the interval
202  * [`lower_bound`,`upper_bound`]. Optionally takes a `guess` for the
203  * location of the root. If `guess` is supplied, then evaluates the
204  * function first at `guess` and `upper_bound` before trying
205  * `lower_bound`: this means that it would be optimal if `guess`
206  * underestimates the actual root and if `upper_bound` was less likely
207  * to be undefined than `lower_bound`.
208  *
209  * On return, `lower_bound` and `upper_bound` are replaced with values that
210  * bracket the root and for which the function is defined, and
211  * `f_at_lower_bound` and `f_at_upper_bound` are replaced with
212  * `f` evaluated at those bracketing points.
213  *
214  * `bracket_possibly_undefined_function_in_interval` throws an error if
215  * all points are valid but of the same sign (because that would indicate
216  * multiple roots but we assume only one root), if no root exists, or
217  * if the range of a sign change is sufficently small relative to the
218  * given interval that the number of iterations to find the root is exceeded.
219  *
220  */
221 template <typename Functor>
223  const gsl::not_null<double*> lower_bound,
224  const gsl::not_null<double*> upper_bound,
225  const gsl::not_null<double*> f_at_lower_bound,
226  const gsl::not_null<double*> f_at_upper_bound, const Functor& f,
227  const double guess) noexcept {
228  // Initial values of x1,x2,y1,y2. Use `guess` and `upper_bound`,
229  // because in typical usage `guess` underestimates the actual
230  // root, and `lower_bound` is more likely than `upper_bound` to be
231  // invalid.
232  double x1 = guess;
233  double x2 = *upper_bound;
234  auto y1 = f(x1);
235  auto y2 = f(x2);
236  const bool y1_defined = y1.has_value();
237  const bool y2_defined = y2.has_value();
238  if (not(y1_defined and y2_defined and y1.value() * y2.value() <= 0.0)) {
239  // Root is not bracketed.
240  // Before moving to the general algorithm, try the remaining
241  // input point that was supplied.
242  const double x3 = *lower_bound;
243  const auto y3 = f(x3);
244  const bool y3_defined = y3.has_value();
245  if (y1_defined and y3_defined and y1.value() * y3.value() <= 0.0) {
246  // Bracketed! Throw out x2,y2. Rename variables to keep x1 < x2.
247  x2 = x1;
248  y2 = y1;
249  x1 = x3;
250  y1 = y3;
251  } else {
252  // Our simple checks didn't work, so call the more general method.
253  // There are 8 cases:
254  //
255  // y3 y1 y2
256  // --------
257  // X X X
258  // o X X
259  // X o X
260  // o o X
261  // X o o
262  // o o o
263  // X X o
264  // o X o
265  //
266  // where X means an invalid point, o means a valid point.
267  // All valid points have the same sign, or we would have found a
268  // bracket already.
269  //
270  // Before calling the general case, error on "o o o" and "o X o".
271  // Both of these are prohibited by our assumptions (we
272  // assume the root is in the interval so no "o o o", and we
273  // assume that all invalid points are at the end of interval, so no
274  // "o X o").
275  if (y2_defined and y3_defined) {
276  ERROR(
277  "bracket_possibly_undefined_function_in_interval: found "
278  "case that should not happen under our assumptions.");
279  }
280  std::array<double, 4> tmp = bracketing_detail::bracket_by_contracting(
281  {{x3, x1, x2}}, {{y3, y1, y2}}, f);
282  x1 = tmp[0];
283  x2 = tmp[1];
284  y1 = tmp[2];
285  y2 = tmp[3];
286  }
287  }
288  *f_at_lower_bound = y1.value();
289  *f_at_upper_bound = y2.value();
290  *lower_bound = x1;
291  *upper_bound = x2;
292 }
293 
294 /*!
295  * \ingroup NumericalAlgorithmsGroup
296  * \brief Brackets the single root of the
297  * function `f` for each element in a `DataVector`, assuming the root
298  * lies in the given interval and that `f` may be undefined at some
299  * points in the interval.
300  *
301  * `f` is a binary invokable that takes a `double` and a `size_t` as
302  * arguments. The `double` is the current value at which to evaluate
303  * `f`, and the `size_t` is the index into the `DataVector`s. `f`
304  * returns a `std::optional<double>` which evaluates to false if the
305  * function is undefined at the supplied point.
306  *
307  * Assumes that there is only one root in the interval.
308  *
309  * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
310  * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
311  * between \f$x_1\f$ and \f$x_2\f$.
312  *
313  * On input, assumes that the root lies in the interval
314  * [`lower_bound`,`upper_bound`]. Optionally takes a `guess` for the
315  * location of the root.
316  *
317  * On return, `lower_bound` and `upper_bound` are replaced with values that
318  * bracket the root and for which the function is defined, and
319  * `f_at_lower_bound` and `f_at_upper_bound` are replaced with
320  * `f` evaluated at those bracketing points.
321  *
322  */
323 template <typename Functor>
325  const gsl::not_null<DataVector*> lower_bound,
326  const gsl::not_null<DataVector*> upper_bound,
327  const gsl::not_null<DataVector*> f_at_lower_bound,
328  const gsl::not_null<DataVector*> f_at_upper_bound, const Functor& f,
329  const DataVector& guess) noexcept {
330  for (size_t s = 0; s < lower_bound->size(); ++s) {
332  &((*lower_bound)[s]), &((*upper_bound)[s]), &((*f_at_lower_bound)[s]),
333  &((*f_at_upper_bound)[s]),
334  [&f, &s ](const double x) noexcept { return f(x, s); }, guess[s]);
335  }
336 }
337 
338 /*
339  * Version of `bracket_possibly_undefined_function_in_interval`
340  * without a supplied initial guess; uses the mean of `lower_bound` and
341  * `upper_bound` as the guess.
342  */
343 template <typename Functor>
345  const gsl::not_null<double*> lower_bound,
346  const gsl::not_null<double*> upper_bound,
347  const gsl::not_null<double*> f_at_lower_bound,
348  const gsl::not_null<double*> f_at_upper_bound, const Functor& f) noexcept {
350  lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
351  *lower_bound + 0.5 * (*upper_bound - *lower_bound));
352 }
353 
354 /*
355  * Version of `bracket_possibly_undefined_function_in_interval`
356  * without a supplied initial guess; uses the mean of `lower_bound` and
357  * `upper_bound` as the guess.
358  */
359 template <typename Functor>
361  const gsl::not_null<DataVector*> lower_bound,
362  const gsl::not_null<DataVector*> upper_bound,
363  const gsl::not_null<DataVector*> f_at_lower_bound,
364  const gsl::not_null<DataVector*> f_at_upper_bound,
365  const Functor& f) noexcept {
367  lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
368  *lower_bound + 0.5 * (*upper_bound - *lower_bound));
369 }
370 } // namespace RootFinder
vector
Error.hpp
ERROR
#define ERROR(m)
prints an error message to the standard error stream and aborts the program.
Definition: Error.hpp:36
RootFinder::bracket_possibly_undefined_function_in_interval
void bracket_possibly_undefined_function_in_interval(const gsl::not_null< DataVector * > lower_bound, const gsl::not_null< DataVector * > upper_bound, const gsl::not_null< DataVector * > f_at_lower_bound, const gsl::not_null< DataVector * > f_at_upper_bound, const Functor &f, const DataVector &guess) noexcept
Brackets the single root of the function f for each element in a DataVector, assuming the root lies i...
Definition: RootBracketing.hpp:324
std::array< double, 4 >
DataVector
Stores a collection of function values.
Definition: DataVector.hpp:42
Gsl.hpp
optional
gsl::not_null
Require a pointer to not be a nullptr
Definition: ReadSpecThirdOrderPiecewisePolynomial.hpp:13