SpECTRE Documentation Coverage Report
Current view: top level - NumericalAlgorithms/RootFinding - RootBracketing.hpp Hit Total Coverage
Commit: a8efe75339f4781ca06d43fed14c40144d5e8a08 Lines: 2 5 40.0 %
Date: 2024-10-17 21:19:21
Legend: Lines: hit not hit

          Line data    Source code
       1           0 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : #pragma once
       5             : 
       6             : #include <optional>
       7             : #include <sstream>
       8             : #include <stdexcept>
       9             : #include <tuple>
      10             : #include <vector>
      11             : 
      12             : #include "DataStructures/DataVector.hpp"
      13             : #include "Utilities/ErrorHandling/Error.hpp"
      14             : #include "Utilities/Gsl.hpp"
      15             : 
      16             : namespace RootFinder {
      17             : 
      18             : namespace bracketing_detail {
      19             : // Brackets a root, given a functor f(x) that returns a
      20             : // std::optional<double> and given two arrays x and y (with y=f(x))
      21             : // containing points that have already been tried for bracketing.
      22             : //
      23             : // Returns a std::tuple<double, double, double, double>
      24             : // containing {{x1,x2,y1,y2}} where
      25             : // x1 and x2 bracket the root, and y1=f(x1) and y2=f(x2).
      26             : //
      27             : // Note that y might be undefined (i.e. an invalid std::optional)
      28             : // for values of x at or near the endpoints of the interval.  We
      29             : // assume that if f(x1) and f(x2) are valid for some x1 and x2, then
      30             : // f(x) is valid for all x between x1 and x2.
      31             : //
      32             : // Assumes that there is a root between the first and last points of
      33             : // the array.
      34             : // We also assume that there is only one root.
      35             : //
      36             : // So this means we have only 2 possibilities for the validity of the
      37             : // points in the input x,y arrays:
      38             : // 1) All points are invalid, e.g. "X X X X X X X X".
      39             : //    (here X represents an invalid point)
      40             : // 2) All valid points are adjacent, with the same sign, e.g. "X X o o X X"
      41             : //    or "o o X X X" or "X X X o o".
      42             : //    (here o represents a valid point)
      43             : // Note that we assume that all valid points have the same sign; otherwise
      44             : // the caller would have known that the root was bracketed and the caller would
      45             : // not have called bracket_by_contracting.
      46             : //
      47             : // Also note that we exclude the case "o o o o o" (no roots), since the
      48             : // caller would have known that too.  If such a case is found, an error is
      49             : // thrown.  An error is also thrown if the size of the region where the
      50             : // sign changes is so small that the number of iterations is exceeded.
      51             : //
      52             : // For case 1) above, we bisect each pair of points, and call
      53             : // bracket_by_contracting recursively until we find a valid point.
      54             : // For case 2) above, it is sufficent to check for a bracket only
      55             : // between valid and invalid points.  That is, for "X X + + X X" we
      56             : // check only between points 1 and 2 and between points 3 and 4 (where
      57             : // points are numbered starting from zero).  For "+ + X X X" we check
      58             : // only between points 1 and 2.
      59             : template <typename Functor>
      60             : auto bracket_by_contracting(const std::vector<double>& x,
      61             :                             const std::vector<std::optional<double>>& y,
      62             :                             const Functor& f, const size_t level = 0)
      63             :     -> std::tuple<double, double, double, double> {
      64             :   constexpr size_t max_level = 27;
      65             :   if (level > max_level) {
      66             :     std::stringstream ss;
      67             :     ss << "Too many iterations in bracket_by_contracting.  Either the "
      68             :           "region where the root changes sign is so small that we cannot "
      69             :           "find it, or the given interval does not actually bracket a "
      70             :           "root.  The points we are checking are "
      71             :        << x.size() << " almost-evenly-spaced points from " << x.front()
      72             :        << " to " << x.back() << " with difference " << x.back()-x.front();
      73             :     throw std::runtime_error(ss.str());
      74             :   }
      75             : 
      76             :   // First check if we have any valid points.
      77             :   size_t last_valid_index = y.size();
      78             :   for (size_t i = y.size(); i >= 1; --i) {
      79             :     if (y[i - 1].has_value()) {
      80             :       last_valid_index = i - 1;
      81             :       break;
      82             :     }
      83             :   }
      84             : 
      85             :   if (last_valid_index == y.size()) {
      86             :     // No valid points!
      87             : 
      88             :     // Create larger arrays with one point between each of the already
      89             :     // computed points.
      90             :     std::vector<double> bisected_x(x.size() * 2 - 1);
      91             :     std::vector<std::optional<double>> bisected_y(y.size() * 2 - 1);
      92             : 
      93             :     // Copy all even-numbered points in the range.
      94             :     for (size_t i = 0; i < x.size(); ++i) {
      95             :       bisected_x[2 * i] = x[i];
      96             :       bisected_y[2 * i] = y[i];
      97             :     }
      98             : 
      99             :     // Fill midpoints and check for bracket on each one.
     100             :     for (size_t i = 0; i < x.size() - 1; ++i) {
     101             :       bisected_x[2 * i + 1] = x[i] + 0.5 * (x[i + 1] - x[i]);
     102             :       bisected_y[2 * i + 1] = f(bisected_x[2 * i + 1]);
     103             :       if (bisected_y[2 * i + 1].has_value()) {
     104             :         // Valid point! We know that all the other points are
     105             :         // invalid, so we need to check only 3 points in the next
     106             :         // iteration: the new valid point and its neighbors.
     107             :         return bracket_by_contracting({{x[i], bisected_x[2 * i + 1], x[i + 1]}},
     108             :                                       {{y[i], bisected_y[2 * i + 1], y[i + 1]}},
     109             :                                       f, level + 1);
     110             :       }
     111             :     }
     112             :     // We still have no valid points. So recurse, using all points.
     113             :     // The next iteration will bisect all the points.
     114             :     return bracket_by_contracting(bisected_x, bisected_y, f, level + 1);
     115             :   }
     116             : 
     117             :   // If we get here, we have found a valid point; in particular we have
     118             :   // found the last valid point in the array.
     119             : 
     120             :   // Find the first valid point in the array.
     121             :   size_t first_valid_index = 0;
     122             :   for (size_t i = 0; i < y.size(); ++i) {
     123             :     if (y[i].has_value()) {
     124             :       first_valid_index = i;
     125             :       break;
     126             :     }
     127             :   }
     128             : 
     129             :   // Make a new set of points that includes only the points that
     130             :   // neighbor the boundary between valid and invalid points.
     131             :   std::vector<double> x_near_valid_point;
     132             :   std::vector<std::optional<double>> y_near_valid_point;
     133             : 
     134             :   if (first_valid_index == 0 and last_valid_index == y.size() - 1) {
     135             :     ERROR(
     136             :         "bracket_while_contracting: found a case where all points are valid,"
     137             :         "which should not happen under our assumptions.");
     138             :   }
     139             : 
     140             :   if (first_valid_index > 0) {
     141             :     // Check for a root between first_valid_index-1 and first_valid_index.
     142             :     const double x_test =
     143             :         x[first_valid_index - 1] +
     144             :         0.5 * (x[first_valid_index] - x[first_valid_index - 1]);
     145             :     const auto y_test = f(x_test);
     146             :     if (y_test.has_value() and
     147             :         y[first_valid_index].value() * y_test.value() <= 0.0) {
     148             :       // Bracketed!
     149             :       return {
     150             :           x_test, x[first_valid_index], y_test.value(),
     151             :           y[first_valid_index].value()};
     152             :     } else {
     153             :       x_near_valid_point.push_back(x[first_valid_index - 1]);
     154             :       y_near_valid_point.push_back(y[first_valid_index - 1]);
     155             :       x_near_valid_point.push_back(x_test);
     156             :       y_near_valid_point.push_back(y_test);
     157             :       x_near_valid_point.push_back(x[first_valid_index]);
     158             :       y_near_valid_point.push_back(y[first_valid_index]);
     159             :     }
     160             :   }
     161             :   if (last_valid_index < y.size() - 1) {
     162             :     // Check for a root between last_valid_index and last_valid_index+1.
     163             :     const double x_test = x[last_valid_index] +
     164             :                           0.5 * (x[last_valid_index + 1] - x[last_valid_index]);
     165             :     const auto y_test = f(x_test);
     166             :     if (y_test.has_value() and
     167             :         y[last_valid_index].value() * y_test.value() <= 0.0) {
     168             :       // Bracketed!
     169             :       return {x[last_valid_index], x_test, y[last_valid_index].value(),
     170             :               y_test.value()};
     171             :     } else {
     172             :       if (first_valid_index != last_valid_index or first_valid_index == 0) {
     173             :         x_near_valid_point.push_back(x[last_valid_index]);
     174             :         y_near_valid_point.push_back(y[last_valid_index]);
     175             :       }  // else we already pushed back last_valid_index (==first_valid_index).
     176             :       x_near_valid_point.push_back(x_test);
     177             :       y_near_valid_point.push_back(y_test);
     178             :       x_near_valid_point.push_back(x[last_valid_index + 1]);
     179             :       y_near_valid_point.push_back(y[last_valid_index + 1]);
     180             :     }
     181             :   }
     182             : 
     183             :   // We have one or more valid points but we didn't find a bracket.
     184             :   // That is, we have something like "X X o o X X" or "X X o o" or "o o X X".
     185             :   // So recurse, zooming in to the boundary (either one boundary or two
     186             :   // boundaries) between valid and invalid points.
     187             :   // Note that "o o o o" is prohibited by our assumptions, and checked for
     188             :   // above just in case it occurs by mistake.
     189             :   return bracket_by_contracting(x_near_valid_point, y_near_valid_point, f,
     190             :                                 level + 1);
     191             : }
     192             : }  // namespace bracketing_detail
     193             : 
     194             : /*!
     195             :  * \ingroup NumericalAlgorithmsGroup
     196             :  * \brief Brackets the root of the function `f`, assuming a single
     197             :  * root in a given interval \f$f[x_\mathrm{lo},x_\mathrm{up}]\f$
     198             :  * and assuming that `f` is defined only in an unknown smaller
     199             :  * interval \f$f[x_a,x_b]\f$ where
     200             :  * \f$x_\mathrm{lo} \leq x_a \leq x_b \leq x_\mathrm{hi}\f$.
     201             :  *
     202             :  * `f` is a unary invokable that takes a `double` which is the current value at
     203             :  * which to evaluate `f`.  `f` returns a `std::optional<double>` which
     204             :  * evaluates to false if the function is undefined at the supplied point.
     205             :  *
     206             :  * Assumes that there is only one root in the interval.
     207             :  *
     208             :  * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
     209             :  * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
     210             :  * between \f$x_1\f$ and \f$x_2\f$.
     211             :  *
     212             :  * On input, assumes that the root lies in the interval
     213             :  * [`lower_bound`,`upper_bound`].  Optionally takes a `guess` for the
     214             :  * location of the root.  If `guess` is supplied, then evaluates the
     215             :  * function first at `guess` and `upper_bound` before trying
     216             :  * `lower_bound`: this means that it would be optimal if `guess`
     217             :  * underestimates the actual root and if `upper_bound` was less likely
     218             :  * to be undefined than `lower_bound`.
     219             :  *
     220             :  * On return, `lower_bound` and `upper_bound` are replaced with values that
     221             :  * bracket the root and for which the function is defined, and
     222             :  * `f_at_lower_bound` and `f_at_upper_bound` are replaced with
     223             :  * `f` evaluated at those bracketing points.
     224             :  *
     225             :  * `bracket_possibly_undefined_function_in_interval` throws an error if
     226             :  *  all points are valid but of the same sign (because that would indicate
     227             :  *  multiple roots but we assume only one root), if no root exists, or
     228             :  *  if the range of a sign change is sufficently small relative to the
     229             :  *  given interval that the number of iterations to find the root is exceeded.
     230             :  *
     231             :  */
     232             : template <typename Functor>
     233           1 : void bracket_possibly_undefined_function_in_interval(
     234             :     const gsl::not_null<double*> lower_bound,
     235             :     const gsl::not_null<double*> upper_bound,
     236             :     const gsl::not_null<double*> f_at_lower_bound,
     237             :     const gsl::not_null<double*> f_at_upper_bound, const Functor& f,
     238             :     const double guess) {
     239             :   // Initial values of x1,x2,y1,y2.  Use `guess` and `upper_bound`,
     240             :   // because in typical usage `guess` underestimates the actual
     241             :   // root, and `lower_bound` is more likely than `upper_bound` to be
     242             :   // invalid.
     243             :   double x1 = guess;
     244             :   double x2 = *upper_bound;
     245             :   auto y1 = f(x1);
     246             :   auto y2 = f(x2);
     247             :   const bool y1_defined = y1.has_value();
     248             :   const bool y2_defined = y2.has_value();
     249             :   if (not(y1_defined and y2_defined and y1.value() * y2.value() <= 0.0)) {
     250             :     // Root is not bracketed.
     251             :     // Before moving to the general algorithm, try the remaining
     252             :     // input point that was supplied.
     253             :     const double x3 = *lower_bound;
     254             :     const auto y3 = f(x3);
     255             :     const bool y3_defined = y3.has_value();
     256             :     if (y1_defined and y3_defined and y1.value() * y3.value() <= 0.0) {
     257             :       // Bracketed! Throw out x2,y2.  Rename variables to keep x1 < x2.
     258             :       x2 = x1;
     259             :       y2 = y1;
     260             :       x1 = x3;
     261             :       y1 = y3;
     262             :     } else {
     263             :       // Our simple checks didn't work, so call the more general method.
     264             :       // There are 8 cases:
     265             :       //
     266             :       // y3 y1 y2
     267             :       // --------
     268             :       // X  X  X
     269             :       // o  X  X
     270             :       // X  o  X
     271             :       // o  o  X
     272             :       // X  o  o
     273             :       // o  o  o
     274             :       // X  X  o
     275             :       // o  X  o
     276             :       //
     277             :       // where X means an invalid point, o means a valid point.
     278             :       // All valid points have the same sign, or we would have found a
     279             :       // bracket already.
     280             :       //
     281             :       // Before calling the general case, error on "o o o" and "o X o".
     282             :       // Both of these are prohibited by our assumptions (we
     283             :       // assume the root is in the interval so no "o o o", and we
     284             :       // assume that all invalid points are at the end of interval, so no
     285             :       // "o X o").
     286             :       if (y2_defined and y3_defined) {
     287             :         ERROR(
     288             :             "bracket_possibly_undefined_function_in_interval: found "
     289             :             "case that should not happen under our assumptions.");
     290             :       }
     291             :       try {
     292             :         std::tie(x1, x2, y1, y2) = bracketing_detail::bracket_by_contracting(
     293             :             {{x3, x1, x2}}, {{y3, y1, y2}}, f);
     294             :       } catch (std::runtime_error& e) {
     295             :         std::stringstream ss;
     296             :         ss << "bracket_by_contracting: Cannot bracket root between "
     297             :            << *lower_bound << " and " << *upper_bound
     298             :            << ". Internal error message is: '" << e.what() << "'";
     299             :         throw std::runtime_error(ss.str());
     300             :       }
     301             :     }
     302             :   }
     303             :   *f_at_lower_bound = y1.value();
     304             :   *f_at_upper_bound = y2.value();
     305             :   *lower_bound = x1;
     306             :   *upper_bound = x2;
     307             : }
     308             : 
     309             : /*!
     310             :  * \ingroup NumericalAlgorithmsGroup
     311             :  * \brief Brackets the single root of the
     312             :  * function `f` for each element in a `DataVector`, assuming the root
     313             :  * lies in the given interval and that `f` may be undefined at some
     314             :  * points in the interval.
     315             :  *
     316             :  * `f` is a binary invokable that takes a `double` and a `size_t` as
     317             :  * arguments.  The `double` is the current value at which to evaluate
     318             :  * `f`, and the `size_t` is the index into the `DataVector`s.  `f`
     319             :  * returns a `std::optional<double>` which evaluates to false if the
     320             :  * function is undefined at the supplied point.
     321             :  *
     322             :  * Assumes that there is only one root in the interval.
     323             :  *
     324             :  * Assumes that if \f$f(x_1)\f$ and \f$f(x_2)\f$ are both defined for
     325             :  * some \f$(x_1,x_2)\f$, then \f$f(x)\f$ is defined for all \f$x\f$
     326             :  * between \f$x_1\f$ and \f$x_2\f$.
     327             :  *
     328             :  * On input, assumes that the root lies in the interval
     329             :  * [`lower_bound`,`upper_bound`].  Optionally takes a `guess` for the
     330             :  * location of the root.
     331             :  *
     332             :  * On return, `lower_bound` and `upper_bound` are replaced with values that
     333             :  * bracket the root and for which the function is defined, and
     334             :  * `f_at_lower_bound` and `f_at_upper_bound` are replaced with
     335             :  * `f` evaluated at those bracketing points.
     336             :  *
     337             :  */
     338             : template <typename Functor>
     339           1 : void bracket_possibly_undefined_function_in_interval(
     340             :     const gsl::not_null<DataVector*> lower_bound,
     341             :     const gsl::not_null<DataVector*> upper_bound,
     342             :     const gsl::not_null<DataVector*> f_at_lower_bound,
     343             :     const gsl::not_null<DataVector*> f_at_upper_bound, const Functor& f,
     344             :     const DataVector& guess) {
     345             :   for (size_t s = 0; s < lower_bound->size(); ++s) {
     346             :     bracket_possibly_undefined_function_in_interval(
     347             :         &((*lower_bound)[s]), &((*upper_bound)[s]), &((*f_at_lower_bound)[s]),
     348             :         &((*f_at_upper_bound)[s]), [&f, &s](const double x) { return f(x, s); },
     349             :         guess[s]);
     350             :   }
     351             : }
     352             : 
     353             : /*
     354             :  * Version of `bracket_possibly_undefined_function_in_interval`
     355             :  * without a supplied initial guess; uses the mean of `lower_bound` and
     356             :  * `upper_bound` as the guess.
     357             :  */
     358             : template <typename Functor>
     359           0 : void bracket_possibly_undefined_function_in_interval(
     360             :     const gsl::not_null<double*> lower_bound,
     361             :     const gsl::not_null<double*> upper_bound,
     362             :     const gsl::not_null<double*> f_at_lower_bound,
     363             :     const gsl::not_null<double*> f_at_upper_bound, const Functor& f) {
     364             :   bracket_possibly_undefined_function_in_interval(
     365             :       lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
     366             :       *lower_bound + 0.5 * (*upper_bound - *lower_bound));
     367             : }
     368             : 
     369             : /*
     370             :  * Version of `bracket_possibly_undefined_function_in_interval`
     371             :  * without a supplied initial guess; uses the mean of `lower_bound` and
     372             :  * `upper_bound` as the guess.
     373             :  */
     374             : template <typename Functor>
     375           0 : void bracket_possibly_undefined_function_in_interval(
     376             :     const gsl::not_null<DataVector*> lower_bound,
     377             :     const gsl::not_null<DataVector*> upper_bound,
     378             :     const gsl::not_null<DataVector*> f_at_lower_bound,
     379             :     const gsl::not_null<DataVector*> f_at_upper_bound, const Functor& f) {
     380             :   bracket_possibly_undefined_function_in_interval(
     381             :       lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, f,
     382             :       *lower_bound + 0.5 * (*upper_bound - *lower_bound));
     383             : }
     384             : }  // namespace RootFinder

Generated by: LCOV version 1.14