SpECTRE Documentation Coverage Report
Current view: top level - NumericalAlgorithms/RootFinding - TOMS748.hpp Hit Total Coverage
Commit: 817e13c5144619b701c7cd870655d8dbf94ab8ce Lines: 5 5 100.0 %
Date: 2024-07-19 22:17:05
Legend: Lines: hit not hit

          Line data    Source code
       1           1 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : /// \file
       5             : /// Declares function RootFinder::toms748
       6             : 
       7             : #pragma once
       8             : 
       9             : #include <functional>
      10             : #include <iomanip>
      11             : #include <ios>
      12             : #include <limits>
      13             : #include <string>
      14             : #include <type_traits>
      15             : 
      16             : #include "DataStructures/DataVector.hpp"
      17             : #include "Utilities/ErrorHandling/Error.hpp"
      18             : #include "Utilities/ErrorHandling/Exceptions.hpp"
      19             : #include "Utilities/GetOutput.hpp"
      20             : #include "Utilities/MakeString.hpp"
      21             : #include "Utilities/Simd/Simd.hpp"
      22             : 
      23             : namespace RootFinder {
      24             : namespace toms748_detail {
      25             : // Original implementation of TOMS748 is from Boost:
      26             : //  (C) Copyright John Maddock 2006.
      27             : //  Use, modification and distribution are subject to the
      28             : //  Boost Software License, Version 1.0. (See accompanying file
      29             : //  LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
      30             : //
      31             : // Significant changes made to pretty much all of it to support SIMD by Nils
      32             : // Deppe. Changes are copyrighted by SXS Collaboration under the MIT License.
      33             : 
      34             : template <typename T>
      35             : T safe_div(const T& num, const T& denom, const T& r) {
      36             :   if constexpr (std::is_floating_point_v<T>) {
      37             :     using std::abs;
      38             :     if (abs(denom) < static_cast<T>(1)) {
      39             :       if (abs(denom * std::numeric_limits<T>::max()) <= abs(num)) {
      40             :         return r;
      41             :       }
      42             :     }
      43             :     return num / denom;
      44             :   } else {
      45             :     // return num / denom without overflow, return r if overflow would occur.
      46             :     const auto mask0 = fabs(denom) < static_cast<T>(1);
      47             :     // Note: if denom >= 1.0 you get an FPE because of overflow from
      48             :     // `max() * (1. + value)`
      49             :     const auto mask = fabs(simd::select(mask0, denom, static_cast<T>(0)) *
      50             :                            std::numeric_limits<T>::max()) <= fabs(num);
      51             :     return simd::select(mask0 and mask, r,
      52             :                         num / simd::select(mask, denom, static_cast<T>(1)));
      53             :   }
      54             : }
      55             : 
      56             : template <typename T>
      57             : T secant_interpolate(const T& a, const T& b, const T& fa, const T& fb) {
      58             :   //
      59             :   // Performs standard secant interpolation of [a,b] given
      60             :   // function evaluations f(a) and f(b).  Performs a bisection
      61             :   // if secant interpolation would leave us very close to either
      62             :   // a or b.  Rationale: we only call this function when at least
      63             :   // one other form of interpolation has already failed, so we know
      64             :   // that the function is unlikely to be smooth with a root very
      65             :   // close to a or b.
      66             :   //
      67             :   const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(5);
      68             :   // WARNING: There are several different ways to implement the interpolation
      69             :   // that all have different rounding properties. Unfortunately this means that
      70             :   // tolerances even at 1e-14 can be difficult to achieve generically.
      71             :   //
      72             :   // `g` below is:
      73             :   // const T g = (fa / (fb - fa));
      74             :   //
      75             :   // const T c = simd::fma((fa / (fb - fa)), (a - b), a); // fails
      76             :   //
      77             :   // const T c = a * (static_cast<T>(1) - g) + g * b; // works
      78             :   //
      79             :   // const T c = simd::fma(a, (static_cast<T>(1) - g), g * b); // works
      80             :   //
      81             :   // Original Boost code:
      82             :   // const T c = a - (fa / (fb - fa)) * (b - a); // works
      83             : 
      84             :   const T c = a - (fa / (fb - fa)) * (b - a);
      85             :   return simd::select((c <= simd::fma(fabs(a), tol_batch, a)) or
      86             :                           (c >= simd::fnma(fabs(b), tol_batch, b)),
      87             :                       static_cast<T>(0.5) * (a + b), c);
      88             : }
      89             : 
      90             : template <bool AssumeFinite, typename T>
      91             : T quadratic_interpolate(const T& a, const T& b, const T& d, const T& fa,
      92             :                         const T& fb, const T& fd,
      93             :                         const simd::mask_type_t<T>& incomplete_mask,
      94             :                         const unsigned count) {
      95             :   // Performs quadratic interpolation to determine the next point,
      96             :   // takes count Newton steps to find the location of the
      97             :   // quadratic polynomial.
      98             :   //
      99             :   // Point d must lie outside of the interval [a,b], it is the third
     100             :   // best approximation to the root, after a and b.
     101             :   //
     102             :   // Note: this does not guarantee to find a root
     103             :   // inside [a, b], so we fall back to a secant step should
     104             :   // the result be out of range.
     105             :   //
     106             :   // Start by obtaining the coefficients of the quadratic polynomial:
     107             :   const T B = safe_div(fb - fa, b - a, std::numeric_limits<T>::max());
     108             :   T A = safe_div(fd - fb, d - b, std::numeric_limits<T>::max());
     109             :   A = safe_div(A - B, d - a, static_cast<T>(0));
     110             : 
     111             :   const auto secant_failure_mask = A == static_cast<T>(0) and incomplete_mask;
     112             :   T result_secant{};
     113             :   if (UNLIKELY(simd::any(secant_failure_mask))) {
     114             :     // failure to determine coefficients, try a secant step:
     115             :     result_secant = secant_interpolate(a, b, fa, fb);
     116             :     if (UNLIKELY(simd::all(secant_failure_mask or (not incomplete_mask)))) {
     117             :       return result_secant;
     118             :     }
     119             :   }
     120             : 
     121             :   // Determine the starting point of the Newton steps:
     122             :   //
     123             :   // Note: unlike Boost, we assume A*fa doesn't overflow. This speeds up the
     124             :   // code quite a bit.
     125             :   T c = AssumeFinite
     126             :             ? simd::select(A * fa > static_cast<T>(0), a, b)
     127             :             : simd::select(simd::sign(A) * simd::sign(fa) > static_cast<T>(0),
     128             :                            a, b);
     129             : 
     130             :   // Take the Newton steps:
     131             :   const T two_A = static_cast<T>(2) * A;
     132             :   const T half_a_plus_b = 0.5 * (a + b);
     133             :   const T one_minus_a = static_cast<T>(1) - a;
     134             :   const T B_minus_A_times_b = B - A * b;
     135             :   for (unsigned i = 1; i <= count; ++i) {
     136             :     c -= safe_div(simd::fma(simd::fma(A, c, B_minus_A_times_b), c - a, fa),
     137             :                   simd::fma(two_A, c - half_a_plus_b, B), one_minus_a + c);
     138             :   }
     139             :   if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
     140             :       simd::any(mask)) {
     141             :     // Failure, try a secant step:
     142             :     c = simd::select(mask, secant_interpolate(a, b, fa, fb), c);
     143             :   }
     144             :   return simd::select(secant_failure_mask, result_secant, c);
     145             : }
     146             : 
     147             : template <bool AssumeFinite, typename T>
     148             : T cubic_interpolate(const T& a, const T& b, const T& d, const T& e, const T& fa,
     149             :                     const T& fb, const T& fd, const T& fe,
     150             :                     const simd::mask_type_t<T>& incomplete_mask) {
     151             :   // Uses inverse cubic interpolation of f(x) at points
     152             :   // [a,b,d,e] to obtain an approximate root of f(x).
     153             :   // Points d and e lie outside the interval [a,b]
     154             :   // and are the third and forth best approximations
     155             :   // to the root that we have found so far.
     156             :   //
     157             :   // Note: this does not guarantee to find a root
     158             :   // inside [a, b], so we fall back to quadratic
     159             :   // interpolation in case of an erroneous result.
     160             :   //
     161             :   // This commented chunk is the original Boost implementation translated into
     162             :   // simd. The actual code below is a heavily optimized version.
     163             :   //
     164             :   // const T q11 = (d - e) * fd / (fe - fd);
     165             :   // const T q21 = (b - d) * fb / (fd - fb);
     166             :   // const T q31 = (a - b) * fa / (fb - fa);
     167             :   // const T d21 = (b - d) * fd / (fd - fb);
     168             :   // const T d31 = (a - b) * fb / (fb - fa);
     169             :   //
     170             :   // const T q22 = (d21 - q11) * fb / (fe - fb);
     171             :   // const T q32 = (d31 - q21) * fa / (fd - fa);
     172             :   // const T d32 = (d31 - q21) * fd / (fd - fa);
     173             :   // const T q33 = (d32 - q22) * fa / (fe - fa);
     174             :   //
     175             :   // T c = q31 + q32 + q33 + a;
     176             : 
     177             :   // The optimized implementation here is L1-cache bound. That is, we aren't
     178             :   // able to completely saturate the FP units because we are waiting on the L1
     179             :   // cache. While not ideal, that's okay and just part of the algorithm.
     180             :   const T denom_fb_fa = fb - fa;
     181             :   const T denom_fd_fb = fd - fb;
     182             :   const T denom_fd_fa = fd - fa;
     183             :   const T denom_fe_fd = fe - fd;
     184             :   const T denom_fe_fb = fe - fb;
     185             :   const T denom =
     186             :       denom_fe_fb * denom_fe_fd * denom_fd_fa * denom_fd_fb * (fe - fa);
     187             : 
     188             :   // Avoid division by zero with mask.
     189             :   const T fa_by_denom = fa / simd::select(incomplete_mask, denom_fb_fa * denom,
     190             :                                           static_cast<T>(1));
     191             : 
     192             :   const T d31 = (a - b);
     193             :   const T q21 = (b - d);
     194             :   const T q32 = simd::fms(denom_fd_fb, d31, denom_fb_fa * q21) * denom_fe_fb *
     195             :                 denom_fe_fd;
     196             :   const T q22 = simd::fms(denom_fe_fd, q21, (d - e) * denom_fd_fb) * fd *
     197             :                 denom_fb_fa * denom_fd_fa;
     198             : 
     199             :   // Note: the reduction in rounding error that comes from the improvement by
     200             :   // Stoer & Bulirsch to Neville's algorithm is adding `a` at the very end as
     201             :   // we do below. Alternative ways of evaluating polynomials do not delay this
     202             :   // inclusion of `a`, and so then when the correction to `a` is small,
     203             :   // floating point errors decrease the accuracy of the result.
     204             :   T c = simd::fma(
     205             :       fa_by_denom,
     206             :       simd::fma(fb, simd::fms(q32, (fe + denom_fd_fa), q22), d31 * denom), a);
     207             : 
     208             :   if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
     209             :       simd::any(mask)) {
     210             :     // Out of bounds step, fall back to quadratic interpolation:
     211             :     //
     212             :     // Note: we only apply quadratic interpolation at points where cubic
     213             :     // failed and that aren't already at a root.
     214             :     c = simd::select(
     215             :         mask, quadratic_interpolate<AssumeFinite>(a, b, d, fa, fb, fd, mask, 3),
     216             :         c);
     217             :   }
     218             : 
     219             :   return c;
     220             : }
     221             : 
     222             : template <bool AssumeFinite, typename F, typename T>
     223             : void bracket(F f, T& a, T& b, T c, T& fa, T& fb, T& d, T& fd,
     224             :              const simd::mask_type_t<T>& incomplete_mask) {
     225             :   // Given a point c inside the existing enclosing interval
     226             :   // [a, b] sets a = c if f(c) == 0, otherwise finds the new
     227             :   // enclosing interval: either [a, c] or [c, b] and sets
     228             :   // d and fd to the point that has just been removed from
     229             :   // the interval.  In other words d is the third best guess
     230             :   // to the root.
     231             :   //
     232             :   // Note: `bracket` will only modify slots marked as `true` in
     233             :   //       `incomplete_mask`
     234             :   const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(2);
     235             : 
     236             :   // If the interval [a,b] is very small, or if c is too close
     237             :   // to one end of the interval then we need to adjust the
     238             :   // location of c accordingly. This is:
     239             :   //
     240             :   //   if ((b - a) < 2 * tol * a) {
     241             :   //     c = a + (b - a) / 2;
     242             :   //   } else if (c <= a + fabs(a) * tol) {
     243             :   //     c = a + fabs(a) * tol;
     244             :   //   } else if (c >= b - fabs(b) * tol) {
     245             :   //     c = b - fabs(b) * tol;
     246             :   //   }
     247             :   const T a_filt = simd::fma(fabs(a), tol_batch, a);
     248             :   const T b_filt = simd::fnma(fabs(b), tol_batch, b);
     249             :   const T b_minus_a = b - a;
     250             :   c = simd::select(
     251             :       (static_cast<T>(2) * tol_batch * a > b_minus_a) and incomplete_mask,
     252             :       simd::fma(b_minus_a, static_cast<T>(0.5), a),
     253             :       simd::select(c <= a_filt, a_filt, simd::select(c >= b_filt, b_filt, c)));
     254             : 
     255             :   // Invoke f(c):
     256             :   T fc = f(c);
     257             : 
     258             :   // if we have a zero then we have an exact solution to the root:
     259             :   const auto fc_is_zero_mask = (fc == static_cast<T>(0));
     260             :   if (const auto mask = fc_is_zero_mask and incomplete_mask;
     261             :       UNLIKELY(simd::any(mask))) {
     262             :     a = simd::select(mask, c, a);
     263             :     fa = simd::select(mask, static_cast<T>(0), fa);
     264             :     d = simd::select(mask, static_cast<T>(0), d);
     265             :     fd = simd::select(mask, static_cast<T>(0), fd);
     266             :     if (UNLIKELY(simd::all(mask or not incomplete_mask))) {
     267             :       return;
     268             :     }
     269             :   }
     270             : 
     271             :   // Non-zero fc, update the interval:
     272             :   //
     273             :   // Note: unlike Boost, we assume fa*fc doesn't overflow. This speeds up the
     274             :   // code quite a bit.
     275             :   //
     276             :   // Boost code is:
     277             :   // if (boost::math::sign(fa) * boost::math::sign(fc) < 0) {...} else {...}
     278             :   using simd::sign;
     279             :   const auto sign_mask = AssumeFinite
     280             :                              ? (fa * fc < static_cast<T>(0))
     281             :                              : (sign(fa) * sign(fc) < static_cast<T>(0));
     282             :   const auto mask_if =
     283             :       (sign_mask and (not fc_is_zero_mask)) and incomplete_mask;
     284             :   d = simd::select(mask_if, b, d);
     285             :   fd = simd::select(mask_if, fb, fd);
     286             :   b = simd::select(mask_if, c, b);
     287             :   fb = simd::select(mask_if, fc, fb);
     288             : 
     289             :   const auto mask_else =
     290             :       ((not sign_mask) and (not fc_is_zero_mask)) and incomplete_mask;
     291             :   d = simd::select(mask_else, a, d);
     292             :   fd = simd::select(mask_else, fa, fd);
     293             :   a = simd::select(mask_else, c, a);
     294             :   fa = simd::select(mask_else, fc, fa);
     295             : }
     296             : 
     297             : template <bool AssumeFinite, class F, class T, class Tol>
     298             : std::pair<T, T> toms748_solve(F f, const T& ax, const T& bx, const T& fax,
     299             :                               const T& fbx, Tol tol,
     300             :                               const simd::mask_type_t<T>& ignore_filter,
     301             :                               size_t& max_iter) {
     302             :   // Main entry point and logic for Toms Algorithm 748
     303             :   // root finder.
     304             :   if (UNLIKELY(simd::any(ax >= bx))) {
     305             :     throw std::domain_error("Lower bound is larger than upper bound");
     306             :   }
     307             : 
     308             :   // Sanity check - are we allowed to iterate at all?
     309             :   if (UNLIKELY(max_iter == 0)) {
     310             :     return std::pair{ax, bx};
     311             :   }
     312             : 
     313             :   size_t count = max_iter;
     314             :   // mu is a parameter in the algorithm that must be between (0, 1).
     315             :   static const T mu = 0.5f;
     316             : 
     317             :   // initialise a, b and fa, fb:
     318             :   T a = ax;
     319             :   T b = bx;
     320             :   T fa = fax;
     321             :   T fb = fbx;
     322             : 
     323             :   const auto fa_is_zero_mask = (fa == static_cast<T>(0));
     324             :   const auto fb_is_zero_mask = (fb == static_cast<T>(0));
     325             :   auto completion_mask =
     326             :       tol(a, b) or fa_is_zero_mask or fb_is_zero_mask or ignore_filter;
     327             :   auto incomplete_mask = not completion_mask;
     328             :   if (UNLIKELY(simd::all(completion_mask))) {
     329             :     max_iter = 0;
     330             :     return std::pair{simd::select(fb_is_zero_mask, b, a),
     331             :                      simd::select(fa_is_zero_mask, a, b)};
     332             :   }
     333             : 
     334             :   // Note: unlike Boost, we can assume fa*fb doesn't overflow when possible.
     335             :   // This speeds up the code quite a bit.
     336             :   if (UNLIKELY(simd::any((AssumeFinite ? (fa * fb > static_cast<T>(0))
     337             :                                        : (simd::sign(fa) * simd::sign(fb) >
     338             :                                           static_cast<T>(0))) and
     339             :                          (not fa_is_zero_mask) and (not fb_is_zero_mask)))) {
     340             :     throw std::domain_error(
     341             :         "Parameters lower and upper bounds do not bracket a root");
     342             :   }
     343             :   // dummy value for fd, e and fe:
     344             :   T fe(static_cast<T>(1e5F));
     345             :   T e(static_cast<T>(1e5F));
     346             :   T fd(static_cast<T>(1e5F));
     347             : 
     348             :   T c(std::numeric_limits<T>::signaling_NaN());
     349             :   T d(std::numeric_limits<T>::signaling_NaN());
     350             : 
     351             :   const T nan(std::numeric_limits<T>::signaling_NaN());
     352             :   auto completed_a = simd::select(completion_mask, a, nan);
     353             :   auto completed_b = simd::select(completion_mask, b, nan);
     354             :   auto completed_fa = simd::select(completion_mask, fa, nan);
     355             :   auto completed_fb = simd::select(completion_mask, fb, nan);
     356             :   const auto update_completed = [&fa, &fb, &completion_mask, &incomplete_mask,
     357             :                                  &completed_a, &completed_b, &a, &b,
     358             :                                  &completed_fa, &completed_fb, &tol]() {
     359             :     const auto new_completed =
     360             :         (fa == static_cast<T>(0) or tol(a, b)) and (not completion_mask);
     361             :     completed_a = simd::select(new_completed, a, completed_a);
     362             :     completed_b = simd::select(new_completed, b, completed_b);
     363             :     completed_fa = simd::select(new_completed, fa, completed_fa);
     364             :     completed_fb = simd::select(new_completed, fb, completed_fb);
     365             :     completion_mask = new_completed or completion_mask;
     366             :     incomplete_mask = not completion_mask;
     367             :     // returns true if _all_ simd registers have been completed
     368             :     return simd::all(completion_mask);
     369             :   };
     370             : 
     371             :   if (simd::any(fa != static_cast<T>(0))) {
     372             :     // On the first step we take a secant step:
     373             :     c = toms748_detail::secant_interpolate(a, b, fa, fb);
     374             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     375             :                                           incomplete_mask);
     376             :     --count;
     377             : 
     378             :     // Note: The Boost fa!=0 check is handled with the completion_mask.
     379             :     if (count and not update_completed()) {
     380             :       // On the second step we take a quadratic interpolation:
     381             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     382             :           a, b, d, fa, fb, fd, incomplete_mask, 2);
     383             :       e = d;
     384             :       fe = fd;
     385             :       toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     386             :                                             incomplete_mask);
     387             :       --count;
     388             :       update_completed();
     389             :     }
     390             :   }
     391             : 
     392             :   T u(std::numeric_limits<T>::signaling_NaN());
     393             :   T fu(std::numeric_limits<T>::signaling_NaN());
     394             :   T a0(std::numeric_limits<T>::signaling_NaN());
     395             :   T b0(std::numeric_limits<T>::signaling_NaN());
     396             : 
     397             :   // Note: The Boost fa!=0 check is handled with the completion_mask.
     398             :   while (count and not simd::all(completion_mask)) {
     399             :     // save our brackets:
     400             :     a0 = a;
     401             :     b0 = b;
     402             :     // Starting with the third step taken
     403             :     // we can use either quadratic or cubic interpolation.
     404             :     // Cubic interpolation requires that all four function values
     405             :     // fa, fb, fd, and fe are distinct, should that not be the case
     406             :     // then variable prof will get set to true, and we'll end up
     407             :     // taking a quadratic step instead.
     408             :     static const T min_diff = std::numeric_limits<T>::min() * 32;
     409             :     bool prof =
     410             :         simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
     411             :                    (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
     412             :                    (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
     413             :                   incomplete_mask);
     414             :     if (prof) {
     415             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     416             :           a, b, d, fa, fb, fd, incomplete_mask, 2);
     417             :     } else {
     418             :       c = toms748_detail::cubic_interpolate<AssumeFinite>(
     419             :           a, b, d, e, fa, fb, fd, fe, incomplete_mask);
     420             :     }
     421             :     // re-bracket, and check for termination:
     422             :     e = d;
     423             :     fe = fd;
     424             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     425             :                                           incomplete_mask);
     426             :     if ((0 == --count) or update_completed()) {
     427             :       break;
     428             :     }
     429             :     // Now another interpolated step:
     430             :     prof =
     431             :         simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
     432             :                    (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
     433             :                    (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
     434             :                   incomplete_mask);
     435             :     if (prof) {
     436             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     437             :           a, b, d, fa, fb, fd, incomplete_mask, 3);
     438             :     } else {
     439             :       c = toms748_detail::cubic_interpolate<AssumeFinite>(
     440             :           a, b, d, e, fa, fb, fd, fe, incomplete_mask);
     441             :     }
     442             :     // Bracket again, and check termination condition, update e:
     443             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     444             :                                           incomplete_mask);
     445             :     if ((0 == --count) or update_completed()) {
     446             :       break;
     447             :     }
     448             : 
     449             :     // Now we take a double-length secant step:
     450             :     const auto fabs_fa_less_fabs_fb_mask =
     451             :         (fabs(fa) < fabs(fb)) and incomplete_mask;
     452             :     u = simd::select(fabs_fa_less_fabs_fb_mask, a, b);
     453             :     fu = simd::select(fabs_fa_less_fabs_fb_mask, fa, fb);
     454             :     const T b_minus_a = b - a;
     455             :     // Assumes that bounds a & b are not so close that fa == fb. Boost makes
     456             :     // this assumption too. If this is violated then the algorithm doesn't
     457             :     // work since the function at least appears constant.
     458             :     c = simd::fnma(static_cast<T>(2) * (fu / (fb - fa)), b_minus_a, u);
     459             :     c = simd::select(static_cast<T>(2) * fabs(c - u) > b_minus_a,
     460             :                      simd::fma(static_cast<T>(0.5), b_minus_a, a), c);
     461             : 
     462             :     // Bracket again, and check termination condition:
     463             :     e = d;
     464             :     fe = fd;
     465             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     466             :                                           incomplete_mask);
     467             :     if ((0 == --count) or update_completed()) {
     468             :       break;
     469             :     }
     470             : 
     471             :     // And finally... check to see if an additional bisection step is
     472             :     // to be taken, we do this if we're not converging fast enough:
     473             :     const auto bisection_mask = (b - a) >= mu * (b0 - a0) and incomplete_mask;
     474             :     if (LIKELY(simd::none(bisection_mask))) {
     475             :       continue;
     476             :     }
     477             :     // bracket again on a bisection:
     478             :     //
     479             :     // Note: the mask ensures we only ever modify the slots that the mask has
     480             :     // identified as needing to be modified.
     481             :     e = simd::select(bisection_mask, d, e);
     482             :     fe = simd::select(bisection_mask, fd, fe);
     483             :     toms748_detail::bracket<AssumeFinite>(
     484             :         f, a, b, simd::fma((b - a), static_cast<T>(0.5), a), fa, fb, d, fd,
     485             :         bisection_mask);
     486             :     --count;
     487             :     if (update_completed()) {
     488             :       break;
     489             :     }
     490             :   }  // while loop
     491             : 
     492             :   max_iter -= count;
     493             :   completed_b =
     494             :       simd::select(completed_fa == static_cast<T>(0), completed_a, completed_b);
     495             :   completed_a =
     496             :       simd::select(completed_fb == static_cast<T>(0), completed_b, completed_a);
     497             :   return std::pair{completed_a, completed_b};
     498             : }
     499             : }  // namespace toms748_detail
     500             : 
     501             : /*!
     502             :  * \ingroup NumericalAlgorithmsGroup
     503             :  * \brief Finds the root of the function `f` with the TOMS_748 method.
     504             :  *
     505             :  * `f` is a unary invokable that takes a `double` which is the current value at
     506             :  * which to evaluate `f`. An example is below.
     507             :  *
     508             :  * \snippet Test_TOMS748.cpp double_root_find
     509             :  *
     510             :  * The TOMS_748 algorithm searches for a root in the interval [`lower_bound`,
     511             :  * `upper_bound`], and will throw if this interval does not bracket a root,
     512             :  * i.e. if `f(lower_bound) * f(upper_bound) > 0`.
     513             :  *
     514             :  * The arguments `f_at_lower_bound` and `f_at_upper_bound` are optional, and
     515             :  * are the function values at `lower_bound` and `upper_bound`. These function
     516             :  * values are often known because the user typically checks if a root is
     517             :  * bracketed before calling `toms748`; passing the function values here saves
     518             :  * two function evaluations.
     519             :  *
     520             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     521             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     522             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     523             :  * that products like `fa * fb` are also finite.
     524             :  *
     525             :  * \requires Function `f` is invokable with a `double`
     526             :  *
     527             :  * \throws `convergence_error` if the requested tolerance is not met after
     528             :  *                            `max_iterations` iterations.
     529             :  */
     530             : template <bool AssumeFinite = false, typename Function, typename T>
     531           1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
     532             :           const T f_at_lower_bound, const T f_at_upper_bound,
     533             :           const simd::scalar_type_t<T> absolute_tolerance,
     534             :           const simd::scalar_type_t<T> relative_tolerance,
     535             :           const size_t max_iterations = 100,
     536             :           const simd::mask_type_t<T> ignore_filter =
     537             :               static_cast<simd::mask_type_t<T>>(0)) {
     538             :   ASSERT(relative_tolerance >
     539             :              std::numeric_limits<simd::scalar_type_t<T>>::epsilon(),
     540             :          "The relative tolerance is too small. Got "
     541             :              << relative_tolerance << " but must be at least "
     542             :              << std::numeric_limits<simd::scalar_type_t<T>>::epsilon());
     543             :   if (simd::any(f_at_lower_bound * f_at_upper_bound > 0.0)) {
     544             :     ERROR("Root not bracketed: f(" << lower_bound << ") = " << f_at_lower_bound
     545             :                                    << ", f(" << upper_bound
     546             :                                    << ") = " << f_at_upper_bound);
     547             :   }
     548             : 
     549             :   std::size_t max_iters = max_iterations;
     550             : 
     551             :   // This solver requires tol to be passed as a termination condition. This
     552             :   // termination condition is equivalent to the convergence criteria used by the
     553             :   // GSL
     554             :   const auto tol = [absolute_tolerance, relative_tolerance](const T& lhs,
     555             :                                                             const T& rhs) {
     556             :     return simd::abs(lhs - rhs) <=
     557             :            simd::fma(T(relative_tolerance),
     558             :                      simd::min(simd::abs(lhs), simd::abs(rhs)),
     559             :                      T(absolute_tolerance));
     560             :   };
     561             :   auto result = toms748_detail::toms748_solve<AssumeFinite>(
     562             :       f, lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, tol,
     563             :       ignore_filter, max_iters);
     564             :   if (max_iters >= max_iterations) {
     565             :     throw convergence_error(
     566             :         MakeString{}
     567             :         << std::setprecision(8) << std::scientific
     568             :         << "toms748 reached max iterations without converging.\nAbsolute "
     569             :            "tolerance: "
     570             :         << absolute_tolerance << "\nRelative tolerance: " << relative_tolerance
     571             :         << "\nResult: " << get_output(result.first) << " "
     572             :         << get_output(result.second));
     573             :   }
     574             :   return simd::fma(static_cast<T>(0.5), (result.second - result.first),
     575             :                    result.first);
     576             : }
     577             : 
     578             : /*!
     579             :  * \ingroup NumericalAlgorithmsGroup
     580             :  * \brief Finds the root of the function `f` with the TOMS_748 method, where
     581             :  * function values are not supplied at the lower and upper bounds.
     582             :  *
     583             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     584             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     585             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     586             :  * that products like `fa * fb` are also finite.
     587             :  */
     588             : template <bool AssumeFinite = false, typename Function, typename T>
     589           1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
     590             :           const simd::scalar_type_t<T> absolute_tolerance,
     591             :           const simd::scalar_type_t<T> relative_tolerance,
     592             :           const size_t max_iterations = 100,
     593             :           const simd::mask_type_t<T> ignore_filter =
     594             :               static_cast<simd::mask_type_t<T>>(0)) {
     595             :   return toms748<AssumeFinite>(
     596             :       f, lower_bound, upper_bound, f(lower_bound), f(upper_bound),
     597             :       absolute_tolerance, relative_tolerance, max_iterations, ignore_filter);
     598             : }
     599             : 
     600             : /*!
     601             :  * \ingroup NumericalAlgorithmsGroup
     602             :  * \brief Finds the root of the function `f` with the TOMS_748 method on each
     603             :  * element in a `DataVector`.
     604             :  *
     605             :  * `f` is a binary invokable that takes a `double` as its first argument and a
     606             :  * `size_t` as its second. The `double` is the current value at which to
     607             :  * evaluate `f`, and the `size_t` is the current index into the `DataVector`s.
     608             :  * Below is an example of how to root find different functions by indexing into
     609             :  * a lambda-captured `DataVector` using the `size_t` passed to `f`.
     610             :  *
     611             :  * \snippet Test_TOMS748.cpp datavector_root_find
     612             :  *
     613             :  * For each index `i` into the DataVector, the TOMS_748 algorithm searches for a
     614             :  * root in the interval [`lower_bound[i]`, `upper_bound[i]`], and will throw if
     615             :  * this interval does not bracket a root,
     616             :  * i.e. if `f(lower_bound[i], i) * f(upper_bound[i], i) > 0`.
     617             :  *
     618             :  * See the [Boost](http://www.boost.org/) documentation for more details.
     619             :  *
     620             :  * \requires Function `f` be callable with a `double` and a `size_t`
     621             :  *
     622             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     623             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     624             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     625             :  * that products like `fa * fb` are also finite.
     626             :  *
     627             :  * \throws `convergence_error` if, for any index, the requested tolerance is not
     628             :  * met after `max_iterations` iterations.
     629             :  */
     630             : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
     631           1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
     632             :                    const DataVector& upper_bound,
     633             :                    const double absolute_tolerance,
     634             :                    const double relative_tolerance,
     635             :                    const size_t max_iterations = 100) {
     636             :   DataVector result_vector{lower_bound.size()};
     637             :   if constexpr (UseSimd) {
     638             :     constexpr size_t simd_width{simd::size<
     639             :         std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
     640             :     const size_t vectorized_size =
     641             :         lower_bound.size() - lower_bound.size() % simd_width;
     642             :     for (size_t i = 0; i < vectorized_size; i += simd_width) {
     643             :       simd::store_unaligned(
     644             :           &result_vector[i],
     645             :           toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
     646             :                                 simd::load_unaligned(&lower_bound[i]),
     647             :                                 simd::load_unaligned(&upper_bound[i]),
     648             :                                 absolute_tolerance, relative_tolerance,
     649             :                                 max_iterations));
     650             :     }
     651             :     for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
     652             :       result_vector[i] = toms748<AssumeFinite>(
     653             :           [&f, i](const auto x) { return f(x, i); }, lower_bound[i],
     654             :           upper_bound[i], absolute_tolerance, relative_tolerance,
     655             :           max_iterations);
     656             :     }
     657             :   } else {
     658             :     for (size_t i = 0; i < result_vector.size(); ++i) {
     659             :       result_vector[i] = toms748<AssumeFinite>(
     660             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     661             :           absolute_tolerance, relative_tolerance, max_iterations);
     662             :     }
     663             :   }
     664             :   return result_vector;
     665             : }
     666             : 
     667             : /*!
     668             :  * \ingroup NumericalAlgorithmsGroup
     669             :  * \brief Finds the root of the function `f` with the TOMS_748 method on each
     670             :  * element in a `DataVector`, where function values are supplied at the lower
     671             :  * and upper bounds.
     672             :  *
     673             :  * Supplying function values is an optimization that saves two
     674             :  * function calls per point.  The function values are often available
     675             :  * because one often checks if the root is bracketed before calling `toms748`.
     676             :  *
     677             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     678             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     679             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     680             :  * that products like `fa * fb` are also finite.
     681             :  */
     682             : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
     683           1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
     684             :                    const DataVector& upper_bound,
     685             :                    const DataVector& f_at_lower_bound,
     686             :                    const DataVector& f_at_upper_bound,
     687             :                    const double absolute_tolerance,
     688             :                    const double relative_tolerance,
     689             :                    const size_t max_iterations = 100) {
     690             :   DataVector result_vector{lower_bound.size()};
     691             :   if constexpr (UseSimd) {
     692             :     constexpr size_t simd_width{simd::size<
     693             :         std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
     694             :     const size_t vectorized_size =
     695             :         lower_bound.size() - lower_bound.size() % simd_width;
     696             :     for (size_t i = 0; i < vectorized_size; i += simd_width) {
     697             :       simd::store_unaligned(
     698             :           &result_vector[i],
     699             :           toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
     700             :                                 simd::load_unaligned(&lower_bound[i]),
     701             :                                 simd::load_unaligned(&upper_bound[i]),
     702             :                                 simd::load_unaligned(&f_at_lower_bound[i]),
     703             :                                 simd::load_unaligned(&f_at_upper_bound[i]),
     704             :                                 absolute_tolerance, relative_tolerance,
     705             :                                 max_iterations));
     706             :     }
     707             :     for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
     708             :       result_vector[i] = toms748<AssumeFinite>(
     709             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     710             :           f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
     711             :           relative_tolerance, max_iterations);
     712             :     }
     713             :   } else {
     714             :     for (size_t i = 0; i < lower_bound.size(); ++i) {
     715             :       result_vector[i] = toms748<AssumeFinite>(
     716             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     717             :           f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
     718             :           relative_tolerance, max_iterations);
     719             :     }
     720             :   }
     721             :   return result_vector;
     722             : }
     723             : }  // namespace RootFinder

Generated by: LCOV version 1.14