SpECTRE Documentation Coverage Report
Current view: top level - NumericalAlgorithms/RootFinding - RootBracketing.hpp Hit Total Coverage
Commit: 2db722c93a8e9b106e406b439b79c8e05c2057fb Lines: 2 5 40.0 %
Date: 2021-03-03 22:01:00
Legend: Lines: hit not hit

          Line data    Source code
       1           0 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : #pragma once
       5             : 
       6             : #include <optional>
       7             : #include <vector>
       8             : 
       9             : #include "DataStructures/DataVector.hpp"
      10             : #include "Utilities/ErrorHandling/Error.hpp"
      11             : #include "Utilities/Gsl.hpp"
      12             : 
      13             : namespace RootFinder {
      14             : namespace bracketing_detail {
      15             : // Brackets a root, given a functor f(x) that returns a
      16             : // std::optional<double> and given two arrays x and y (with y=f(x))
      17             : // containing points that have already been tried for bracketing.
      18             : //
      19             : // Returns a std::array<double,4> containing {{x1,x2,y1,y2}} where
      20             : // x1 and x2 bracket the root, and y1=f(x1) and y2=f(x2).
      21             : //
      22             : // Note that y might be undefined (i.e. an invalid std::optional)
      23             : // for values of x at or near the endpoints of the interval.  We
      24             : // assume that if f(x1) and f(x2) are valid for some x1 and x2, then
      25             : // f(x) is valid for all x between x1 and x2.
      26             : //
      27             : // Assumes that there is a root between the first and last points of
      28             : // the array.
      29             : // We also assume that there is only one root.
      30             : //
      31             : // So this means we have only 2 possibilities for the validity of the
      32             : // points in the input x,y arrays:
      33             : // 1) All points are invalid, e.g. "X X X X X X X X".
      34             : //    (here X represents an invalid point)
      35             : // 2) All valid points are adjacent, with the same sign, e.g. "X X o o X X"
      36             : //    or "o o X X X" or "X X X o o".
      37             : //    (here o represents a valid point)
      38             : // Note that we assume that all valid points have the same sign; otherwise
      39             : // the caller would have known that the root was bracketed and the caller would
      40             : // not have called bracket_by_contracting.
      41             : //
      42             : // Also note that we exclude the case "o o o o o" (no roots), since the
      43             : // caller would have known that too.  If such a case is found, an error is
      44             : // thrown.  An error is also thrown if the size of the region where the
      45             : // sign changes is so small that the number of iterations is exceeded.
      46             : //
      47             : // For case 1) above, we bisect each pair of points, and call
      48             : // bracket_by_contracting recursively until we find a valid point.
      49             : // For case 2) above, it is sufficent to check for a bracket only
      50             : // between valid and invalid points.  That is, for "X X + + X X" we
      51             : // check only between points 1 and 2 and between points 3 and 4 (where
      52             : // points are numbered starting from zero).  For "+ + X X X" we check
      53             : // only between points 1 and 2.
      54             : template <typename Functor>
      55             : std::array<double, 4> bracket_by_contracting(
      56             :     const std::vector<double>& x, const std::vector<std::optional<double>>& y,
      57             :     const Functor& f, const size_t level = 0) noexcept {
      58             :   constexpr size_t max_level = 6;
      59             :   if (level > max_level) {
      60             :     ERROR("Too many iterations in bracket_by_contracting. Either refine the "
      61             :           "initial range/guess or increase max_level.");
      62             :   }
      63             : 
      64             :   // First check if we have any valid points.
      65             :   size_t last_valid_index = y.size();
      66             :   for (size_t i = y.size(); i >= 1; --i) {
      67             :     if (y[i - 1].has_value()) {
      68             :       last_valid_index = i - 1;
      69             :       break;
      70             :     }
      71             :   }
      72             : 
      73             :   if (last_valid_index == y.size()) {
      74             :     // No valid points!
      75             : 
      76             :     // Create larger arrays with one point between each of the already
      77             :     // computed points.
      78             :     std::vector<double> bisected_x(x.size() * 2 - 1);
      79             :     std::vector<std::optional<double>> bisected_y(y.size() * 2 - 1);
      80             : 
      81             :     // Copy all even-numbered points in the range.
      82             :     for (size_t i = 0; i < x.size(); ++i) {
      83             :       bisected_x[2 * i] = x[i];
      84             :       bisected_y[2 * i] = y[i];
      85             :     }
      86             : 
      87             :     // Fill midpoints and check for bracket on each one.
      88             :     for (size_t i = 0; i < x.size() - 1; ++i) {
      89             :       bisected_x[2 * i + 1] = x[i] + 0.5 * (x[i + 1] - x[i]);
      90             :       bisected_y[2 * i + 1] = f(bisected_x[2 * i + 1]);
      91             :       if (bisected_y[2 * i + 1].has_value()) {
      92             :         // Valid point! We know that all the other points are
      93             :         // invalid, so we need to check only 3 points in the next
      94             :         // iteration: the new valid point and its neighbors.
      95             :         return bracket_by_contracting({{x[i], bisected_x[2 * i + 1], x[i + 1]}},
      96             :                                       {{y[i], bisected_y[2 * i + 1], y[i + 1]}},
      97             :                                       f, level + 1);
      98             :       }
      99             :     }
     100             :     // We still have no valid points. So recurse, using all points.
     101             :     // The next iteration will bisect all the points.
     102             :     return bracket_by_contracting(bisected_x, bisected_y, f, level + 1);
     103             :   }
     104             : 
     105             :   // If we get here, we have found a valid point; in particular we have
     106             :   // found the last valid point in the array.
     107             : 
     108             :   // Find the first valid point in the array.
     109             :   size_t first_valid_index = 0;
     110             :   for (size_t i = 0; i < y.size(); ++i) {
     111             :     if (y[i].has_value()) {
     112             :       first_valid_index = i;
     113             :       break;
     114             :     }
     115             :   }
     116             : 
     117             :   // Make a new set of points that includes only the points that
     118             :   // neighbor the boundary between valid and invalid points.
     119             :   std::vector<double> x_near_valid_point;
     120             :   std::vector<std::optional<double>> y_near_valid_point;
     121             : 
     122             :   if (first_valid_index == 0 and last_valid_index == y.size() - 1) {
     123             :     ERROR(
     124             :         "bracket_while_contracting: found a case where all points are valid,"
     125             :         "which should not happen under our assumptions.");
     126             :   }
     127             : 
     128             :   if (first_valid_index > 0) {
     129             :     // Check for a root between first_valid_index-1 and first_valid_index.
     130             :     const double x_test =
     131             :         x[first_valid_index - 1] +
     132             :         0.5 * (x[first_valid_index] - x[first_valid_index - 1]);
     133             :     const auto y_test = f(x_test);
     134             :     if (y_test.has_value() and
     135             :         y[first_valid_index].value() * y_test.value() <= 0.0) {
     136             :       // Bracketed!
     137             :       return std::array<double, 4>{{x_test, x[first_valid_index],
     138             :                                     y_test.value(),
     139             :                                     y[first_valid_index].value()}};
     140             :     } else {
     141             :       x_near_valid_point.push_back(x[first_valid_index - 1]);
     142             :       y_near_valid_point.push_back(y[first_valid_index - 1]);
     143             :       x_near_valid_point.push_back(x_test);
     144             :       y_near_valid_point.push_back(y_test);
     145             :       x_near_valid_point.push_back(x[first_valid_index]);
     146             :       y_near_valid_point.push_back(y[first_valid_index]);
     147             :     }
     148             :   }
     149             :   if (last_valid_index < y.size() - 1) {
     150             :     // Check for a root between last_valid_index and last_valid_index+1.
     151             :     const double x_test = x[last_valid_index] +
     152             :                           0.5 * (x[last_valid_index + 1] - x[last_valid_index]);
     153             :     const auto y_test = f(x_test);
     154             :     if (y_test.has_value() and
     155             :         y[last_valid_index].value() * y_test.value() <= 0.0) {
     156             :       // Bracketed!
     157             :       return std::array<double, 4>{{x[last_valid_index], x_test,
     158             :                                     y[last_valid_index].value(),
     159             :                                     y_test.value()}};
     160             :     } else {
     161             :       if (first_valid_index != last_valid_index or first_valid_index == 0) {
     162             :         x_near_valid_point.push_back(x[last_valid_index]);
     163             :         y_near_valid_point.push_back(y[last_valid_index]);
     164             :       }  // else we already pushed back last_valid_index (==first_valid_index).
     165             :       x_near_valid_point.push_back(x_test);
     166             :       y_near_valid_point.push_back(y_test);
     167             :       x_near_valid_point.push_back(x[last_valid_index + 1]);
     168             :       y_near_valid_point.push_back(y[last_valid_index + 1]);
     169             :     }
     170             :   }
     171             : 
     172             :   // We have one or more valid points but we didn't find a bracket.
     173             :   // That is, we have something like "X X o o X X" or "X X o o" or "o o X X".
     174             :   // So recurse, zooming in to the boundary (either one boundary or two
     175             :   // boundaries) between valid and invalid points.
     176             :   // Note that "o o o o" is prohibited by our assumptions, and checked for
     177             :   // above just in case it occurs by mistake.
     178             :   return bracket_by_contracting(x_near_valid_point, y_near_valid_point, f,
     179             :                                 level + 1);
     180             : }
     181             : }  // namespace bracketing_detail
     182             : 
     183             : /*!
     184             :  * \ingroup NumericalAlgorithmsGroup
     185             :  * \brief Brackets the root of the function `f`, assuming a single
     186             :  * root in a given interval \f$f[x_\mathrm{lo},x_\mathrm{up}]\f$
     187             :  * and assuming that `f` is defined only in an unknown smaller
     188             :  * interval \f$f[x_a,x_b]\f$ where
     189             :  * \f$x_\mathrm{lo} \leq x_a \leq x_b \leq x_\mathrm{hi}\f$.
     190             :  *
     191             :  * `f` is a unary invokable that takes a `double` which is the current value at
     192             :  * which to evaluate `f`.  `f` returns a `std::optional<double>` which
     193             :  * evaluates to false if the function is undefined at the supplied point.
     194             :  *
     195             :  * Assumes that there is only one root in the interval.
     196             :  *
     197             :  * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
     198             :  * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
     199             :  * between \f$x_1\f$ and \f$x_2\f$.
     200             :  *
     201             :  * On input, assumes that the root lies in the interval
     202             :  * [`lower_bound`,`upper_bound`].  Optionally takes a `guess` for the
     203             :  * location of the root.  If `guess` is supplied, then evaluates the
     204             :  * function first at `guess` and `upper_bound` before trying
     205             :  * `lower_bound`: this means that it would be optimal if `guess`
     206             :  * underestimates the actual root and if `upper_bound` was less likely
     207             :  * to be undefined than `lower_bound`.
     208             :  *
     209             :  * On return, `lower_bound` and `upper_bound` are replaced with values that
     210             :  * bracket the root and for which the function is defined, and
     211             :  * `f_at_lower_bound` and `f_at_upper_bound` are replaced with
     212             :  * `f` evaluated at those bracketing points.
     213             :  *
     214             :  * `bracket_possibly_undefined_function_in_interval` throws an error if
     215             :  *  all points are valid but of the same sign (because that would indicate
     216             :  *  multiple roots but we assume only one root), if no root exists, or
     217             :  *  if the range of a sign change is sufficently small relative to the
     218             :  *  given interval that the number of iterations to find the root is exceeded.
     219             :  *
     220             :  */
     221             : template <typename Functor>
     222           1 : void bracket_possibly_undefined_function_in_interval(
     223             :     const gsl::not_null<double*> lower_bound,
     224             :     const gsl::not_null<double*> upper_bound,
     225             :     const gsl::not_null<double*> f_at_lower_bound,
     226             :     const gsl::not_null<double*> f_at_upper_bound, const Functor& f,
     227             :     const double guess) noexcept {
     228             :   // Initial values of x1,x2,y1,y2.  Use `guess` and `upper_bound`,
     229             :   // because in typical usage `guess` underestimates the actual
     230             :   // root, and `lower_bound` is more likely than `upper_bound` to be
     231             :   // invalid.
     232             :   double x1 = guess;
     233             :   double x2 = *upper_bound;
     234             :   auto y1 = f(x1);
     235             :   auto y2 = f(x2);
     236             :   const bool y1_defined = y1.has_value();
     237             :   const bool y2_defined = y2.has_value();
     238             :   if (not(y1_defined and y2_defined and y1.value() * y2.value() <= 0.0)) {
     239             :     // Root is not bracketed.
     240             :     // Before moving to the general algorithm, try the remaining
     241             :     // input point that was supplied.
     242             :     const double x3 = *lower_bound;
     243             :     const auto y3 = f(x3);
     244             :     const bool y3_defined = y3.has_value();
     245             :     if (y1_defined and y3_defined and y1.value() * y3.value() <= 0.0) {
     246             :       // Bracketed! Throw out x2,y2.  Rename variables to keep x1 < x2.
     247             :       x2 = x1;
     248             :       y2 = y1;
     249             :       x1 = x3;
     250             :       y1 = y3;
     251             :     } else {
     252             :       // Our simple checks didn't work, so call the more general method.
     253             :       // There are 8 cases:
     254             :       //
     255             :       // y3 y1 y2
     256             :       // --------
     257             :       // X  X  X
     258             :       // o  X  X
     259             :       // X  o  X
     260             :       // o  o  X
     261             :       // X  o  o
     262             :       // o  o  o
     263             :       // X  X  o
     264             :       // o  X  o
     265             :       //
     266             :       // where X means an invalid point, o means a valid point.
     267             :       // All valid points have the same sign, or we would have found a
     268             :       // bracket already.
     269             :       //
     270             :       // Before calling the general case, error on "o o o" and "o X o".
     271             :       // Both of these are prohibited by our assumptions (we
     272             :       // assume the root is in the interval so no "o o o", and we
     273             :       // assume that all invalid points are at the end of interval, so no
     274             :       // "o X o").
     275             :       if (y2_defined and y3_defined) {
     276             :         ERROR(
     277             :             "bracket_possibly_undefined_function_in_interval: found "
     278             :             "case that should not happen under our assumptions.");
     279             :       }
     280             :       std::array<double, 4> tmp = bracketing_detail::bracket_by_contracting(
     281             :           {{x3, x1, x2}}, {{y3, y1, y2}}, f);
     282             :       x1 = tmp[0];
     283             :       x2 = tmp[1];
     284             :       y1 = tmp[2];
     285             :       y2 = tmp[3];
     286             :     }
     287             :   }
     288             :   *f_at_lower_bound = y1.value();
     289             :   *f_at_upper_bound = y2.value();
     290             :   *lower_bound = x1;
     291             :   *upper_bound = x2;
     292             : }
     293             : 
     294             : /*!
     295             :  * \ingroup NumericalAlgorithmsGroup
     296             :  * \brief Brackets the single root of the
     297             :  * function `f` for each element in a `DataVector`, assuming the root
     298             :  * lies in the given interval and that `f` may be undefined at some
     299             :  * points in the interval.
     300             :  *
     301             :  * `f` is a binary invokable that takes a `double` and a `size_t` as
     302             :  * arguments.  The `double` is the current value at which to evaluate
     303             :  * `f`, and the `size_t` is the index into the `DataVector`s.  `f`
     304             :  * returns a `std::optional<double>` which evaluates to false if the
     305             :  * function is undefined at the supplied point.
     306             :  *
     307             :  * Assumes that there is only one root in the interval.
     308             :  *
     309             :  * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
     310             :  * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
     311             :  * between \f$x_1\f$ and \f$x_2\f$.
     312             :  *
     313             :  * On input, assumes that the root lies in the interval
     314             :  * [`lower_bound`,`upper_bound`].  Optionally takes a `guess` for the
     315             :  * location of the root.
     316             :  *
     317             :  * On return, `lower_bound` and `upper_bound` are replaced with values that
     318             :  * bracket the root and for which the function is defined, and
     319             :  * `f_at_lower_bound` and `f_at_upper_bound` are replaced with
     320             :  * `f` evaluated at those bracketing points.
     321             :  *
     322             :  */
     323             : template <typename Functor>
     324           1 : void bracket_possibly_undefined_function_in_interval(
     325             :     const gsl::not_null<DataVector*> lower_bound,
     326             :     const gsl::not_null<DataVector*> upper_bound,
     327             :     const gsl::not_null<DataVector*> f_at_lower_bound,
     328             :     const gsl::not_null<DataVector*> f_at_upper_bound, const Functor& f,
     329             :     const DataVector& guess) noexcept {
     330             :   for (size_t s = 0; s < lower_bound->size(); ++s) {
     331             :     bracket_possibly_undefined_function_in_interval(
     332             :         &((*lower_bound)[s]), &((*upper_bound)[s]), &((*f_at_lower_bound)[s]),
     333             :         &((*f_at_upper_bound)[s]),
     334             :         [&f, &s ](const double x) noexcept { return f(x, s); }, guess[s]);
     335             :   }
     336             : }
     337             : 
     338             : /*
     339             :  * Version of `bracket_possibly_undefined_function_in_interval`
     340             :  * without a supplied initial guess; uses the mean of `lower_bound` and
     341             :  * `upper_bound` as the guess.
     342             :  */
     343             : template <typename Functor>
     344           0 : void bracket_possibly_undefined_function_in_interval(
     345             :     const gsl::not_null<double*> lower_bound,
     346             :     const gsl::not_null<double*> upper_bound,
     347             :     const gsl::not_null<double*> f_at_lower_bound,
     348             :     const gsl::not_null<double*> f_at_upper_bound, const Functor& f) noexcept {
     349             :   bracket_possibly_undefined_function_in_interval(
     350             :       lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
     351             :       *lower_bound + 0.5 * (*upper_bound - *lower_bound));
     352             : }
     353             : 
     354             : /*
     355             :  * Version of `bracket_possibly_undefined_function_in_interval`
     356             :  * without a supplied initial guess; uses the mean of `lower_bound` and
     357             :  * `upper_bound` as the guess.
     358             :  */
     359             : template <typename Functor>
     360           0 : void bracket_possibly_undefined_function_in_interval(
     361             :     const gsl::not_null<DataVector*> lower_bound,
     362             :     const gsl::not_null<DataVector*> upper_bound,
     363             :     const gsl::not_null<DataVector*> f_at_lower_bound,
     364             :     const gsl::not_null<DataVector*> f_at_upper_bound,
     365             :     const Functor& f) noexcept {
     366             :   bracket_possibly_undefined_function_in_interval(
     367             :       lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
     368             :       *lower_bound + 0.5 * (*upper_bound - *lower_bound));
     369             : }
     370             : }  // namespace RootFinder

Generated by: LCOV version 1.14