RootBracketing.hpp
2 // See LICENSE.txt for details.
3
4 #pragma once
5
6 #include <boost/none.hpp>
7 #include <boost/optional.hpp>
8 #include <vector>
9
10 #include "DataStructures/DataVector.hpp"
11 #include "ErrorHandling/Error.hpp"
12 #include "Utilities/Gsl.hpp"
13
14 namespace RootFinder {
15 namespace bracketing_detail {
16 // Brackets a root, given a functor f(x) that returns a
17 // boost::optional<double> and given two arrays x and y (with y=f(x))
18 // containing points that have already been tried for bracketing.
19 //
20 // Returns a std::array<double,4> containing {{x1,x2,y1,y2}} where
21 // x1 and x2 bracket the root, and y1=f(x1) and y2=f(x2).
22 //
23 // Note that y might be undefined (i.e. an invalid boost::optional)
24 // for values of x at or near the endpoints of the interval. We
25 // assume that if f(x1) and f(x2) are valid for some x1 and x2, then
26 // f(x) is valid for all x between x1 and x2.
27 //
28 // Assumes that there is a root between the first and last points of
29 // the array.
30 // We also assume that there is only one root.
31 //
32 // So this means we have only 2 possibilities for the validity of the
33 // points in the input x,y arrays:
34 // 1) All points are invalid, e.g. "X X X X X X X X".
35 // (here X represents an invalid point)
36 // 2) All valid points are adjacent, with the same sign, e.g. "X X o o X X"
37 // or "o o X X X" or "X X X o o".
38 // (here o represents a valid point)
39 // Note that we assume that all valid points have the same sign; otherwise
40 // the caller would have known that the root was bracketed and the caller would
41 // not have called bracket_by_contracting.
42 //
43 // Also note that we exclude the case "o o o o o" (no roots), since the
44 // caller would have known that too. If such a case is found, an error is
45 // thrown. An error is also thrown if the size of the region where the
46 // sign changes is so small that the number of iterations is exceeded.
47 //
48 // For case 1) above, we bisect each pair of points, and call
49 // bracket_by_contracting recursively until we find a valid point.
50 // For case 2) above, it is sufficent to check for a bracket only
51 // between valid and invalid points. That is, for "X X + + X X" we
52 // check only between points 1 and 2 and between points 3 and 4 (where
53 // points are numbered starting from zero). For "+ + X X X" we check
54 // only between points 1 and 2.
55 template <typename Functor>
56 std::array<double, 4> bracket_by_contracting(
57  const std::vector<double>& x, const std::vector<boost::optional<double>>& y,
58  const Functor& f, const size_t level = 0) noexcept {
59  constexpr size_t max_level = 6;
60  if (level > max_level) {
61  ERROR("Too many iterations in bracket_by_contracting. Either refine the "
62  "initial range/guess or increase max_level.");
63  }
64
65  // First check if we have any valid points.
66  size_t last_valid_index = y.size();
67  for (size_t i = y.size(); i >= 1; --i) {
68  if (y[i - 1]) {
69  last_valid_index = i - 1;
70  break;
71  }
72  }
73
74  if (last_valid_index == y.size()) {
75  // No valid points!
76
77  // Create larger arrays with one point between each of the already
78  // computed points.
79  std::vector<double> bisected_x(x.size() * 2 - 1);
80  std::vector<boost::optional<double>> bisected_y(y.size() * 2 - 1);
81
82  // Copy all even-numbered points in the range.
83  for (size_t i = 0; i < x.size(); ++i) {
84  bisected_x[2 * i] = x[i];
85  bisected_y[2 * i] = y[i];
86  }
87
88  // Fill midpoints and check for bracket on each one.
89  for (size_t i = 0; i < x.size() - 1; ++i) {
90  bisected_x[2 * i + 1] = x[i] + 0.5 * (x[i + 1] - x[i]);
91  bisected_y[2 * i + 1] = f(bisected_x[2 * i + 1]);
92  if (bisected_y[2 * i + 1]) {
93  // Valid point! We know that all the other points are
94  // invalid, so we need to check only 3 points in the next
95  // iteration: the new valid point and its neighbors.
96  return bracket_by_contracting({{x[i], bisected_x[2 * i + 1], x[i + 1]}},
97  {{y[i], bisected_y[2 * i + 1], y[i + 1]}},
98  f, level + 1);
99  }
100  }
101  // We still have no valid points. So recurse, using all points.
102  // The next iteration will bisect all the points.
103  return bracket_by_contracting(bisected_x, bisected_y, f, level + 1);
104  }
105
106  // If we get here, we have found a valid point; in particular we have
107  // found the last valid point in the array.
108
109  // Find the first valid point in the array.
110  size_t first_valid_index = 0;
111  for (size_t i = 0; i < y.size(); ++i) {
112  if (y[i]) {
113  first_valid_index = i;
114  break;
115  }
116  }
117
118  // Make a new set of points that includes only the points that
119  // neighbor the boundary between valid and invalid points.
120  std::vector<double> x_near_valid_point;
121  std::vector<boost::optional<double>> y_near_valid_point;
122
123  if (first_valid_index == 0 and last_valid_index == y.size() - 1) {
124  ERROR(
125  "bracket_while_contracting: found a case where all points are valid,"
126  "which should not happen under our assumptions.");
127  }
128
129  if (first_valid_index > 0) {
130  // Check for a root between first_valid_index-1 and first_valid_index.
131  const double x_test =
132  x[first_valid_index - 1] +
133  0.5 * (x[first_valid_index] - x[first_valid_index - 1]);
134  const auto y_test = f(x_test);
135  if (y_test and y[first_valid_index].get() * y_test.get() <= 0.0) {
136  // Bracketed!
137  return std::array<double, 4>{{x_test, x[first_valid_index], y_test.get(),
138  y[first_valid_index].get()}};
139  } else {
140  x_near_valid_point.push_back(x[first_valid_index - 1]);
141  y_near_valid_point.push_back(y[first_valid_index - 1]);
142  x_near_valid_point.push_back(x_test);
143  y_near_valid_point.push_back(y_test);
144  x_near_valid_point.push_back(x[first_valid_index]);
145  y_near_valid_point.push_back(y[first_valid_index]);
146  }
147  }
148  if (last_valid_index < y.size() - 1) {
149  // Check for a root between last_valid_index and last_valid_index+1.
150  const double x_test = x[last_valid_index] +
151  0.5 * (x[last_valid_index + 1] - x[last_valid_index]);
152  const auto y_test = f(x_test);
153  if (y_test and y[last_valid_index].get() * y_test.get() <= 0.0) {
154  // Bracketed!
155  return std::array<double, 4>{{x[last_valid_index], x_test,
156  y[last_valid_index].get(), y_test.get()}};
157  } else {
158  if (first_valid_index != last_valid_index or first_valid_index == 0) {
159  x_near_valid_point.push_back(x[last_valid_index]);
160  y_near_valid_point.push_back(y[last_valid_index]);
161  } // else we already pushed back last_valid_index (==first_valid_index).
162  x_near_valid_point.push_back(x_test);
163  y_near_valid_point.push_back(y_test);
164  x_near_valid_point.push_back(x[last_valid_index + 1]);
165  y_near_valid_point.push_back(y[last_valid_index + 1]);
166  }
167  }
168
169  // We have one or more valid points but we didn't find a bracket.
170  // That is, we have something like "X X o o X X" or "X X o o" or "o o X X".
171  // So recurse, zooming in to the boundary (either one boundary or two
172  // boundaries) between valid and invalid points.
173  // Note that "o o o o" is prohibited by our assumptions, and checked for
174  // above just in case it occurs by mistake.
175  return bracket_by_contracting(x_near_valid_point, y_near_valid_point, f,
176  level + 1);
177 }
178 } // namespace bracketing_detail
179
180 /*!
181  * \ingroup NumericalAlgorithmsGroup
182  * \brief Brackets the root of the function f, assuming a single
183  * root in a given interval \f$f[x_\mathrm{lo},x_\mathrm{up}]\f$
184  * and assuming that f is defined only in an unknown smaller
185  * interval \f$f[x_a,x_b]\f$ where
186  * \f$x_\mathrm{lo} \leq x_a \leq x_b \leq x_\mathrm{hi}\f$.
187  *
188  * f is a unary invokable that takes a double which is the current value at
189  * which to evaluate f. f returns a boost::optional<double> which
190  * evaluates to false if the function is undefined at the supplied point.
191  *
192  * Assumes that there is only one root in the interval.
193  *
194  * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
195  * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
196  * between \f$x_1\f$ and \f$x_2\f$.
197  *
198  * On input, assumes that the root lies in the interval
199  * [lower_bound,upper_bound]. Optionally takes a guess for the
200  * location of the root. If guess is supplied, then evaluates the
201  * function first at guess and upper_bound before trying
202  * lower_bound: this means that it would be optimal if guess
203  * underestimates the actual root and if upper_bound was less likely
204  * to be undefined than lower_bound.
205  *
206  * On return, lower_bound and upper_bound are replaced with values that
207  * bracket the root and for which the function is defined, and
208  * f_at_lower_bound and f_at_upper_bound are replaced with
209  * f evaluated at those bracketing points.
210  *
211  * bracket_possibly_undefined_function_in_interval throws an error if
212  * all points are valid but of the same sign (because that would indicate
213  * multiple roots but we assume only one root), if no root exists, or
214  * if the range of a sign change is sufficently small relative to the
215  * given interval that the number of iterations to find the root is exceeded.
216  *
217  */
218 template <typename Functor>
220  const gsl::not_null<double*> lower_bound,
221  const gsl::not_null<double*> upper_bound,
222  const gsl::not_null<double*> f_at_lower_bound,
223  const gsl::not_null<double*> f_at_upper_bound, const Functor& f,
224  const double guess) noexcept {
225  // Initial values of x1,x2,y1,y2. Use guess and upper_bound,
226  // because in typical usage guess underestimates the actual
227  // root, and lower_bound is more likely than upper_bound to be
228  // invalid.
229  double x1 = guess;
230  double x2 = *upper_bound;
231  auto y1 = f(x1);
232  auto y2 = f(x2);
233  const bool y1_defined = static_cast<bool>(y1);
234  const bool y2_defined = static_cast<bool>(y2);
235  if (not(y1_defined and y2_defined and y1.get() * y2.get() <= 0.0)) {
236  // Root is not bracketed.
237  // Before moving to the general algorithm, try the remaining
238  // input point that was supplied.
239  const double x3 = *lower_bound;
240  const auto y3 = f(x3);
241  const bool y3_defined = static_cast<bool>(y3);
242  if (y1_defined and y3_defined and y1.get() * y3.get() <= 0.0) {
243  // Bracketed! Throw out x2,y2. Rename variables to keep x1 < x2.
244  x2 = x1;
245  y2 = y1;
246  x1 = x3;
247  y1 = y3;
248  } else {
249  // Our simple checks didn't work, so call the more general method.
250  // There are 8 cases:
251  //
252  // y3 y1 y2
253  // --------
254  // X X X
255  // o X X
256  // X o X
257  // o o X
258  // X o o
259  // o o o
260  // X X o
261  // o X o
262  //
263  // where X means an invalid point, o means a valid point.
264  // All valid points have the same sign, or we would have found a
265  // bracket already.
266  //
267  // Before calling the general case, error on "o o o" and "o X o".
268  // Both of these are prohibited by our assumptions (we
269  // assume the root is in the interval so no "o o o", and we
270  // assume that all invalid points are at the end of interval, so no
271  // "o X o").
272  if (y2_defined and y3_defined) {
273  ERROR(
274  "bracket_possibly_undefined_function_in_interval: found "
275  "case that should not happen under our assumptions.");
276  }
277  std::array<double, 4> tmp = bracketing_detail::bracket_by_contracting(
278  {{x3, x1, x2}}, {{y3, y1, y2}}, f);
279  x1 = tmp[0];
280  x2 = tmp[1];
281  y1 = tmp[2];
282  y2 = tmp[3];
283  }
284  }
285  *f_at_lower_bound = y1.get();
286  *f_at_upper_bound = y2.get();
287  *lower_bound = x1;
288  *upper_bound = x2;
289 }
290
291 /*!
292  * \ingroup NumericalAlgorithmsGroup
293  * \brief Brackets the single root of the
294  * function f for each element in a DataVector, assuming the root
295  * lies in the given interval and that f may be undefined at some
296  * points in the interval.
297  *
298  * f is a binary invokable that takes a double and a size_t as
299  * arguments. The double is the current value at which to evaluate
300  * f, and the size_t is the index into the DataVectors. f
301  * returns a boost::optional<double> which evaluates to false if the
302  * function is undefined at the supplied point.
303  *
304  * Assumes that there is only one root in the interval.
305  *
306  * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
307  * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
308  * between \f$x_1\f$ and \f$x_2\f$.
309  *
310  * On input, assumes that the root lies in the interval
311  * [lower_bound,upper_bound]. Optionally takes a guess for the
312  * location of the root.
313  *
314  * On return, lower_bound and upper_bound are replaced with values that
315  * bracket the root and for which the function is defined, and
316  * f_at_lower_bound and f_at_upper_bound are replaced with
317  * f evaluated at those bracketing points.
318  *
319  */
320 template <typename Functor>
322  const gsl::not_null<DataVector*> lower_bound,
323  const gsl::not_null<DataVector*> upper_bound,
324  const gsl::not_null<DataVector*> f_at_lower_bound,
325  const gsl::not_null<DataVector*> f_at_upper_bound, const Functor& f,
326  const DataVector& guess) noexcept {
327  for (size_t s = 0; s < lower_bound->size(); ++s) {
329  &((*lower_bound)[s]), &((*upper_bound)[s]), &((*f_at_lower_bound)[s]),
330  &((*f_at_upper_bound)[s]),
331  [&f, &s ](const double x) noexcept { return f(x, s); }, guess[s]);
332  }
333 }
334
335 /*
336  * Version of bracket_possibly_undefined_function_in_interval
337  * without a supplied initial guess; uses the mean of lower_bound and
338  * upper_bound as the guess.
339  */
340 template <typename Functor>
342  const gsl::not_null<double*> lower_bound,
343  const gsl::not_null<double*> upper_bound,
344  const gsl::not_null<double*> f_at_lower_bound,
345  const gsl::not_null<double*> f_at_upper_bound, const Functor& f) noexcept {
347  lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
348  *lower_bound + 0.5 * (*upper_bound - *lower_bound));
349 }
350
351 /*
352  * Version of bracket_possibly_undefined_function_in_interval
353  * without a supplied initial guess; uses the mean of lower_bound and
354  * upper_bound as the guess.
355  */
356 template <typename Functor>
358  const gsl::not_null<DataVector*> lower_bound,
359  const gsl::not_null<DataVector*> upper_bound,
360  const gsl::not_null<DataVector*> f_at_lower_bound,
361  const gsl::not_null<DataVector*> f_at_upper_bound,
362  const Functor& f) noexcept {
364  lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
365  *lower_bound + 0.5 * (*upper_bound - *lower_bound));
366 }
367 } // namespace RootFinder
get
constexpr Tag::type & get(Variables< TagList > &v) noexcept
Return Tag::type pointing into the contiguous array.
Definition: Variables.hpp:639
vector
Error.hpp
ERROR
#define ERROR(m)
prints an error message to the standard error stream and aborts the program.
Definition: Error.hpp:36
RootFinder::bracket_possibly_undefined_function_in_interval
void bracket_possibly_undefined_function_in_interval(const gsl::not_null< DataVector * > lower_bound, const gsl::not_null< DataVector * > upper_bound, const gsl::not_null< DataVector * > f_at_lower_bound, const gsl::not_null< DataVector * > f_at_upper_bound, const Functor &f, const DataVector &guess) noexcept
Brackets the single root of the function f for each element in a DataVector, assuming the root lies i...
Definition: RootBracketing.hpp:321
std::array< double, 4 >
DataVector
Stores a collection of function values.
Definition: DataVector.hpp:42
Gsl.hpp
gsl::not_null
Require a pointer to not be a nullptr
Definition: Gsl.hpp:183