SpECTRE Documentation Coverage Report
Current view: top level - NumericalAlgorithms/RootFinding - TOMS748.hpp Hit Total Coverage
Commit: 2068747df712b64688243d3254666212942d85f2 Lines: 5 5 100.0 %
Date: 2026-05-22 23:35:16
Legend: Lines: hit not hit

          Line data    Source code
       1           1 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : /// \file
       5             : /// Declares function RootFinder::toms748
       6             : 
       7             : #pragma once
       8             : 
       9             : #include <functional>
      10             : #include <iomanip>
      11             : #include <ios>
      12             : #include <limits>
      13             : #include <string>
      14             : #include <type_traits>
      15             : 
      16             : #include "DataStructures/DataVector.hpp"
      17             : #include "Utilities/ErrorHandling/CaptureForError.hpp"
      18             : #include "Utilities/ErrorHandling/Error.hpp"
      19             : #include "Utilities/ErrorHandling/Exceptions.hpp"
      20             : #include "Utilities/GetOutput.hpp"
      21             : #include "Utilities/MakeString.hpp"
      22             : #include "Utilities/Simd/Simd.hpp"
      23             : 
      24             : namespace RootFinder {
      25             : namespace toms748_detail {
      26             : // Original implementation of TOMS748 is from Boost:
      27             : //  (C) Copyright John Maddock 2006.
      28             : //  Use, modification and distribution are subject to the
      29             : //  Boost Software License, Version 1.0. (See accompanying file
      30             : //  LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
      31             : //
      32             : // Significant changes made to pretty much all of it to support SIMD by Nils
      33             : // Deppe. Changes are copyrighted by SXS Collaboration under the MIT License.
      34             : 
      35             : template <typename T>
      36             : T safe_div(const T& num, const T& denom, const T& r) {
      37             :   if constexpr (std::is_floating_point_v<T>) {
      38             :     using std::abs;
      39             :     if (abs(denom) < static_cast<T>(1)) {
      40             :       if (abs(denom * std::numeric_limits<T>::max()) <= abs(num)) {
      41             :         return r;
      42             :       }
      43             :     }
      44             :     return num / denom;
      45             :   } else {
      46             :     // return num / denom without overflow, return r if overflow would occur.
      47             :     //
      48             :     // To do this we need to handle the following cases:
      49             :     // 1. fabs(denom) < 1 AND fabs(denom * max) <= fabs(num):
      50             :     //    return r
      51             :     // 2. return num / denom
      52             :     //
      53             :     // We do this by creating 2 masks.
      54             :     // 1. `mask0` selects where fabs(denom) < 1. This is where we _may_ have
      55             :     //    issues with division by zero or overflow.
      56             :     // 2. `mask` selects where fabs(denom) < 1 AND fabs(denom * max)<=fabs(num)
      57             :     //    * Note: the edge case of fabs(num)==max could be problematic, but if
      58             :     //            you're dealing with numbers like that you are likely in
      59             :     //            trouble anyway.
      60             :     // The second mask is where we must avoid division by zero and instead
      61             :     // return `r`.
      62             :     const auto mask0 = fabs(denom) < static_cast<T>(1);
      63             :     // Note: if denom >= 1.0 you get an FPE because of overflow from
      64             :     // `max() * (1. + value)`, which is why `mask0` is necessary.
      65             :     const auto mask = fabs(simd::select(mask0, denom, static_cast<T>(1.0)) *
      66             :                            std::numeric_limits<T>::max()) <= fabs(num);
      67             :     const auto new_denom = simd::select(mask, static_cast<T>(1), denom);
      68             :     return simd::select(mask, r, num / new_denom);
      69             :   }
      70             : }
      71             : 
      72             : template <typename T>
      73             : T secant_interpolate(const T& a, const T& b, const T& fa, const T& fb,
      74             :                      const simd::mask_type_t<T>& incomplete_mask) {
      75             :   //
      76             :   // Performs standard secant interpolation of [a,b] given
      77             :   // function evaluations f(a) and f(b).  Performs a bisection
      78             :   // if secant interpolation would leave us very close to either
      79             :   // a or b.  Rationale: we only call this function when at least
      80             :   // one other form of interpolation has already failed, so we know
      81             :   // that the function is unlikely to be smooth with a root very
      82             :   // close to a or b.
      83             :   //
      84             :   // For lanes where `incomplete_mask` is false (already converged), the
      85             :   // denominator `fb - fa` may be stale or zero. We substitute 1 to avoid
      86             :   // undefined behavior (e.g. division by zero); the returned value for those
      87             :   // lanes is meaningless and should not be used by the caller. See
      88             :   // `cubic_interpolate` for the same pattern.
      89             :   //
      90             :   const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(5);
      91             :   // WARNING: There are several different ways to implement the interpolation
      92             :   // that all have different rounding properties. Unfortunately this means that
      93             :   // tolerances even at 1e-14 can be difficult to achieve generically.
      94             :   //
      95             :   // `g` below is:
      96             :   // const T g = (fa / (fb - fa));
      97             :   //
      98             :   // const T c = simd::fma((fa / (fb - fa)), (a - b), a); // fails
      99             :   //
     100             :   // const T c = a * (static_cast<T>(1) - g) + g * b; // works
     101             :   //
     102             :   // const T c = simd::fma(a, (static_cast<T>(1) - g), g * b); // works
     103             :   //
     104             :   // Original Boost code:
     105             :   // const T c = a - (fa / (fb - fa)) * (b - a); // works
     106             : 
     107             :   const T c =
     108             :       a - (fa / simd::select(incomplete_mask, fb - fa, static_cast<T>(1))) *
     109             :               (b - a);
     110             :   return simd::select((c <= simd::fma(fabs(a), tol_batch, a)) or
     111             :                           (c >= simd::fnma(fabs(b), tol_batch, b)),
     112             :                       static_cast<T>(0.5) * (a + b), c);
     113             : }
     114             : 
     115             : template <bool AssumeFinite, typename T>
     116             : T quadratic_interpolate(const T& a, const T& b, const T& d, const T& fa,
     117             :                         const T& fb, const T& fd,
     118             :                         const simd::mask_type_t<T>& incomplete_mask,
     119             :                         const unsigned count) {
     120             :   // Performs quadratic interpolation to determine the next point,
     121             :   // takes count Newton steps to find the location of the
     122             :   // quadratic polynomial.
     123             :   //
     124             :   // Point d must lie outside of the interval [a,b], it is the third
     125             :   // best approximation to the root, after a and b.
     126             :   //
     127             :   // Note: this does not guarantee to find a root
     128             :   // inside [a, b], so we fall back to a secant step should
     129             :   // the result be out of range.
     130             :   //
     131             :   // Start by obtaining the coefficients of the quadratic polynomial:
     132             :   const T B = safe_div(fb - fa, b - a, std::numeric_limits<T>::max());
     133             :   T A = safe_div(fd - fb, d - b, std::numeric_limits<T>::max());
     134             :   A = safe_div(A - B, d - a, static_cast<T>(0));
     135             : 
     136             :   const auto secant_failure_mask = A == static_cast<T>(0) and incomplete_mask;
     137             :   T result_secant{};
     138             :   if (UNLIKELY(simd::any(secant_failure_mask))) {
     139             :     // failure to determine coefficients, try a secant step:
     140             :     result_secant = secant_interpolate(a, b, fa, fb, secant_failure_mask);
     141             :     if (UNLIKELY(simd::all(secant_failure_mask or (not incomplete_mask)))) {
     142             :       return result_secant;
     143             :     }
     144             :   }
     145             : 
     146             :   // Determine the starting point of the Newton steps:
     147             :   //
     148             :   // Note: unlike Boost, we assume A*fa doesn't overflow. This speeds up the
     149             :   // code quite a bit.
     150             :   T c = AssumeFinite
     151             :             ? simd::select(A * fa > static_cast<T>(0), a, b)
     152             :             : simd::select(simd::sign(A) * simd::sign(fa) > static_cast<T>(0),
     153             :                            a, b);
     154             : 
     155             :   // Take the Newton steps:
     156             :   const T two_A = static_cast<T>(2) * A;
     157             :   const T half_a_plus_b = 0.5 * (a + b);
     158             :   const T one_minus_a = static_cast<T>(1) - a;
     159             :   const T B_minus_A_times_b = B - A * b;
     160             :   for (unsigned i = 1; i <= count; ++i) {
     161             :     c -= safe_div(simd::fma(simd::fma(A, c, B_minus_A_times_b), c - a, fa),
     162             :                   simd::fma(two_A, c - half_a_plus_b, B), one_minus_a + c);
     163             :   }
     164             :   if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
     165             :       simd::any(mask)) {
     166             :     // Failure, try a secant step:
     167             :     c = simd::select(mask, secant_interpolate(a, b, fa, fb, mask), c);
     168             :   }
     169             :   return simd::select(secant_failure_mask, result_secant, c);
     170             : }
     171             : 
     172             : template <bool AssumeFinite, typename T>
     173             : T cubic_interpolate(const T& a, const T& b, const T& d, const T& e, const T& fa,
     174             :                     const T& fb, const T& fd, const T& fe,
     175             :                     const simd::mask_type_t<T>& incomplete_mask) {
     176             :   // Uses inverse cubic interpolation of f(x) at points
     177             :   // [a,b,d,e] to obtain an approximate root of f(x).
     178             :   // Points d and e lie outside the interval [a,b]
     179             :   // and are the third and forth best approximations
     180             :   // to the root that we have found so far.
     181             :   //
     182             :   // Note: this does not guarantee to find a root
     183             :   // inside [a, b], so we fall back to quadratic
     184             :   // interpolation in case of an erroneous result.
     185             :   //
     186             :   // This commented chunk is the original Boost implementation translated into
     187             :   // simd. The actual code below is a heavily optimized version.
     188             :   //
     189             :   // const T q11 = (d - e) * fd / (fe - fd);
     190             :   // const T q21 = (b - d) * fb / (fd - fb);
     191             :   // const T q31 = (a - b) * fa / (fb - fa);
     192             :   // const T d21 = (b - d) * fd / (fd - fb);
     193             :   // const T d31 = (a - b) * fb / (fb - fa);
     194             :   //
     195             :   // const T q22 = (d21 - q11) * fb / (fe - fb);
     196             :   // const T q32 = (d31 - q21) * fa / (fd - fa);
     197             :   // const T d32 = (d31 - q21) * fd / (fd - fa);
     198             :   // const T q33 = (d32 - q22) * fa / (fe - fa);
     199             :   //
     200             :   // T c = q31 + q32 + q33 + a;
     201             : 
     202             :   // The optimized implementation here is L1-cache bound. That is, we aren't
     203             :   // able to completely saturate the FP units because we are waiting on the L1
     204             :   // cache. While not ideal, that's okay and just part of the algorithm.
     205             :   const T denom_fb_fa = fb - fa;
     206             :   const T denom_fd_fb = fd - fb;
     207             :   const T denom_fd_fa = fd - fa;
     208             :   const T denom_fe_fd = fe - fd;
     209             :   const T denom_fe_fb = fe - fb;
     210             :   const T denom =
     211             :       denom_fe_fb * denom_fe_fd * denom_fd_fa * denom_fd_fb * (fe - fa);
     212             : 
     213             :   // Avoid division by zero with mask.
     214             :   const T fa_by_denom = fa / simd::select(incomplete_mask, denom_fb_fa * denom,
     215             :                                           static_cast<T>(1));
     216             : 
     217             :   const T d31 = (a - b);
     218             :   const T q21 = (b - d);
     219             :   const T q32 = simd::fms(denom_fd_fb, d31, denom_fb_fa * q21) * denom_fe_fb *
     220             :                 denom_fe_fd;
     221             :   const T q22 = simd::fms(denom_fe_fd, q21, (d - e) * denom_fd_fb) * fd *
     222             :                 denom_fb_fa * denom_fd_fa;
     223             : 
     224             :   // Note: the reduction in rounding error that comes from the improvement by
     225             :   // Stoer & Bulirsch to Neville's algorithm is adding `a` at the very end as
     226             :   // we do below. Alternative ways of evaluating polynomials do not delay this
     227             :   // inclusion of `a`, and so then when the correction to `a` is small,
     228             :   // floating point errors decrease the accuracy of the result.
     229             :   T c = simd::fma(
     230             :       fa_by_denom,
     231             :       simd::fma(fb, simd::fms(q32, (fe + denom_fd_fa), q22), d31 * denom), a);
     232             : 
     233             :   if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
     234             :       simd::any(mask)) {
     235             :     // Out of bounds step, fall back to quadratic interpolation:
     236             :     //
     237             :     // Note: we only apply quadratic interpolation at points where cubic
     238             :     // failed and that aren't already at a root.
     239             :     c = simd::select(
     240             :         mask, quadratic_interpolate<AssumeFinite>(a, b, d, fa, fb, fd, mask, 3),
     241             :         c);
     242             :   }
     243             : 
     244             :   return c;
     245             : }
     246             : 
     247             : template <bool AssumeFinite, typename F, typename T>
     248             : void bracket(F f, T& a, T& b, T c, T& fa, T& fb, T& d, T& fd,
     249             :              const simd::mask_type_t<T>& incomplete_mask) {
     250             :   // Given a point c inside the existing enclosing interval
     251             :   // [a, b] sets a = c if f(c) == 0, otherwise finds the new
     252             :   // enclosing interval: either [a, c] or [c, b] and sets
     253             :   // d and fd to the point that has just been removed from
     254             :   // the interval.  In other words d is the third best guess
     255             :   // to the root.
     256             :   //
     257             :   // Note: `bracket` will only modify slots marked as `true` in
     258             :   //       `incomplete_mask`
     259             :   const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(2);
     260             : 
     261             :   // If the interval [a,b] is very small, or if c is too close
     262             :   // to one end of the interval then we need to adjust the
     263             :   // location of c accordingly. This is:
     264             :   //
     265             :   //   if ((b - a) < 2 * tol * a) {
     266             :   //     c = a + (b - a) / 2;
     267             :   //   } else if (c <= a + fabs(a) * tol) {
     268             :   //     c = a + fabs(a) * tol;
     269             :   //   } else if (c >= b - fabs(b) * tol) {
     270             :   //     c = b - fabs(b) * tol;
     271             :   //   }
     272             :   const T a_filt = simd::fma(fabs(a), tol_batch, a);
     273             :   const T b_filt = simd::fnma(fabs(b), tol_batch, b);
     274             :   const T b_minus_a = b - a;
     275             :   c = simd::select(
     276             :       (static_cast<T>(2) * tol_batch * a > b_minus_a) and incomplete_mask,
     277             :       simd::fma(b_minus_a, static_cast<T>(0.5), a),
     278             :       simd::select(c <= a_filt, a_filt, simd::select(c >= b_filt, b_filt, c)));
     279             : 
     280             :   // Invoke f(c):
     281             :   T fc = f(c);
     282             : 
     283             :   // if we have a zero then we have an exact solution to the root:
     284             :   const auto fc_is_zero_mask = (fc == static_cast<T>(0));
     285             :   if (const auto mask = fc_is_zero_mask and incomplete_mask;
     286             :       UNLIKELY(simd::any(mask))) {
     287             :     a = simd::select(mask, c, a);
     288             :     fa = simd::select(mask, static_cast<T>(0), fa);
     289             :     d = simd::select(mask, static_cast<T>(0), d);
     290             :     fd = simd::select(mask, static_cast<T>(0), fd);
     291             :     if (UNLIKELY(simd::all(mask or not incomplete_mask))) {
     292             :       return;
     293             :     }
     294             :   }
     295             : 
     296             :   // Non-zero fc, update the interval:
     297             :   //
     298             :   // Note: unlike Boost, we assume fa*fc doesn't overflow. This speeds up the
     299             :   // code quite a bit.
     300             :   //
     301             :   // Boost code is:
     302             :   // if (boost::math::sign(fa) * boost::math::sign(fc) < 0) {...} else {...}
     303             :   using simd::sign;
     304             :   const auto sign_mask = AssumeFinite
     305             :                              ? (fa * fc < static_cast<T>(0))
     306             :                              : (sign(fa) * sign(fc) < static_cast<T>(0));
     307             :   const auto mask_if =
     308             :       (sign_mask and (not fc_is_zero_mask)) and incomplete_mask;
     309             :   d = simd::select(mask_if, b, d);
     310             :   fd = simd::select(mask_if, fb, fd);
     311             :   b = simd::select(mask_if, c, b);
     312             :   fb = simd::select(mask_if, fc, fb);
     313             : 
     314             :   const auto mask_else =
     315             :       ((not sign_mask) and (not fc_is_zero_mask)) and incomplete_mask;
     316             :   d = simd::select(mask_else, a, d);
     317             :   fd = simd::select(mask_else, fa, fd);
     318             :   a = simd::select(mask_else, c, a);
     319             :   fa = simd::select(mask_else, fc, fa);
     320             : }
     321             : 
     322             : template <bool AssumeFinite, class F, class T, class Tol>
     323             : std::pair<T, T> toms748_solve(F f, const T& ax, const T& bx, const T& fax,
     324             :                               const T& fbx, Tol tol,
     325             :                               const simd::mask_type_t<T>& ignore_filter,
     326             :                               size_t& max_iter) {
     327             :   // Main entry point and logic for Toms Algorithm 748
     328             :   // root finder.
     329             :   if (UNLIKELY(simd::any(ax > bx))) {
     330             :     ERROR_AS("Lower bound is larger than upper bound", std::domain_error);
     331             :   }
     332             : 
     333             :   // Sanity check - are we allowed to iterate at all?
     334             :   if (UNLIKELY(max_iter == 0)) {
     335             :     return std::pair{ax, bx};
     336             :   }
     337             : 
     338             :   size_t count = max_iter;
     339             :   // mu is a parameter in the algorithm that must be between (0, 1).
     340             :   static const T mu = 0.5f;
     341             : 
     342             :   // initialise a, b and fa, fb:
     343             :   T a = ax;
     344             :   T b = bx;
     345             :   T fa = fax;
     346             :   T fb = fbx;
     347             :   CAPTURE_FOR_ERROR(ax);
     348             :   CAPTURE_FOR_ERROR(bx);
     349             :   CAPTURE_FOR_ERROR(fax);
     350             :   CAPTURE_FOR_ERROR(fbx);
     351             : 
     352             :   const auto fa_is_zero_mask = (fa == static_cast<T>(0));
     353             :   const auto fb_is_zero_mask = (fb == static_cast<T>(0));
     354             :   auto completion_mask =
     355             :       tol(a, b) or fa_is_zero_mask or fb_is_zero_mask or ignore_filter;
     356             :   auto incomplete_mask = not completion_mask;
     357             :   if (UNLIKELY(simd::all(completion_mask))) {
     358             :     max_iter = 0;
     359             :     return std::pair{simd::select(fb_is_zero_mask, b, a),
     360             :                      simd::select(fa_is_zero_mask, a, b)};
     361             :   }
     362             : 
     363             :   // Note: unlike Boost, we can assume fa*fb doesn't overflow when possible.
     364             :   // This speeds up the code quite a bit.
     365             :   if (UNLIKELY(simd::any((AssumeFinite ? (fa * fb > static_cast<T>(0))
     366             :                                        : (simd::sign(fa) * simd::sign(fb) >
     367             :                                           static_cast<T>(0))) and
     368             :                          (not fa_is_zero_mask) and (not fb_is_zero_mask)))) {
     369             :     ERROR_AS("Parameters lower and upper bounds do not bracket a root.",
     370             :              std::domain_error);
     371             :   }
     372             :   // dummy value for fd, e and fe:
     373             :   T fe(static_cast<T>(1e5F));
     374             :   T e(static_cast<T>(1e5F));
     375             :   T fd(static_cast<T>(1e5F));
     376             : 
     377             :   T c(0.0);
     378             :   T d(0.0);
     379             : 
     380             :   // We can't use signaling_NaN() because we do a max comparison at the end to
     381             :   // 0, and signaling_NaN when compared to 0 raises an FPE.
     382             :   const T nan(std::numeric_limits<T>::quiet_NaN());
     383             :   auto completed_a = simd::select(completion_mask, a, nan);
     384             :   auto completed_b = simd::select(completion_mask, b, nan);
     385             :   auto completed_fa = simd::select(completion_mask, fa, nan);
     386             :   auto completed_fb = simd::select(completion_mask, fb, nan);
     387             :   const auto update_completed = [&fa, &fb, &completion_mask, &incomplete_mask,
     388             :                                  &completed_a, &completed_b, &a, &b,
     389             :                                  &completed_fa, &completed_fb, &tol]() {
     390             :     const auto new_completed =
     391             :         (fa == static_cast<T>(0) or tol(a, b)) and (not completion_mask);
     392             :     completed_a = simd::select(new_completed, a, completed_a);
     393             :     completed_b = simd::select(new_completed, b, completed_b);
     394             :     completed_fa = simd::select(new_completed, fa, completed_fa);
     395             :     completed_fb = simd::select(new_completed, fb, completed_fb);
     396             :     completion_mask = new_completed or completion_mask;
     397             :     incomplete_mask = not completion_mask;
     398             :     // returns true if _all_ simd registers have been completed
     399             :     return simd::all(completion_mask);
     400             :   };
     401             : 
     402             :   if (simd::any(fa != static_cast<T>(0))) {
     403             :     // On the first step we take a secant step:
     404             :     c = toms748_detail::secant_interpolate(a, b, fa, fb, incomplete_mask);
     405             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     406             :                                           incomplete_mask);
     407             :     --count;
     408             : 
     409             :     // Note: The Boost fa!=0 check is handled with the completion_mask.
     410             :     if (count and not update_completed()) {
     411             :       // On the second step we take a quadratic interpolation:
     412             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     413             :           a, b, d, fa, fb, fd, incomplete_mask, 2);
     414             :       e = d;
     415             :       fe = fd;
     416             :       toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     417             :                                             incomplete_mask);
     418             :       --count;
     419             :       update_completed();
     420             :     }
     421             :   }
     422             : 
     423             :   T u(std::numeric_limits<T>::signaling_NaN());
     424             :   T fu(std::numeric_limits<T>::signaling_NaN());
     425             :   T a0(std::numeric_limits<T>::signaling_NaN());
     426             :   T b0(std::numeric_limits<T>::signaling_NaN());
     427             : 
     428             :   // Note: The Boost fa!=0 check is handled with the completion_mask.
     429             :   while (count and not simd::all(completion_mask)) {
     430             :     // save our brackets:
     431             :     a0 = a;
     432             :     b0 = b;
     433             :     // Starting with the third step taken
     434             :     // we can use either quadratic or cubic interpolation.
     435             :     // Cubic interpolation requires that all four function values
     436             :     // fa, fb, fd, and fe are distinct, should that not be the case
     437             :     // then variable prof will get set to true, and we'll end up
     438             :     // taking a quadratic step instead.
     439             :     static const T min_diff = std::numeric_limits<T>::min() * 32;
     440             :     bool prof =
     441             :         simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
     442             :                    (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
     443             :                    (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
     444             :                   incomplete_mask);
     445             :     if (prof) {
     446             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     447             :           a, b, d, fa, fb, fd, incomplete_mask, 2);
     448             :     } else {
     449             :       c = toms748_detail::cubic_interpolate<AssumeFinite>(
     450             :           a, b, d, e, fa, fb, fd, fe, incomplete_mask);
     451             :     }
     452             :     // re-bracket, and check for termination:
     453             :     e = d;
     454             :     fe = fd;
     455             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     456             :                                           incomplete_mask);
     457             :     if ((0 == --count) or update_completed()) {
     458             :       break;
     459             :     }
     460             :     // Now another interpolated step:
     461             :     prof =
     462             :         simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
     463             :                    (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
     464             :                    (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
     465             :                   incomplete_mask);
     466             :     if (prof) {
     467             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     468             :           a, b, d, fa, fb, fd, incomplete_mask, 3);
     469             :     } else {
     470             :       c = toms748_detail::cubic_interpolate<AssumeFinite>(
     471             :           a, b, d, e, fa, fb, fd, fe, incomplete_mask);
     472             :     }
     473             :     // Bracket again, and check termination condition, update e:
     474             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     475             :                                           incomplete_mask);
     476             :     if ((0 == --count) or update_completed()) {
     477             :       break;
     478             :     }
     479             : 
     480             :     // Now we take a double-length secant step:
     481             :     const auto fabs_fa_less_fabs_fb_mask =
     482             :         (fabs(fa) < fabs(fb)) and incomplete_mask;
     483             :     u = simd::select(fabs_fa_less_fabs_fb_mask, a, b);
     484             :     fu = simd::select(fabs_fa_less_fabs_fb_mask, fa, fb);
     485             :     const T b_minus_a = b - a;
     486             :     // Assumes that bounds a & b are not so close that fa == fb. Boost makes
     487             :     // this assumption too. If this is violated then the algorithm doesn't
     488             :     // work since the function at least appears constant.
     489             : 
     490             :     // Safeguard complete lanes: `fb - fa` may be stale or zero for lanes
     491             :     // that have already converged. Use 1 as a harmless dummy denominator;
     492             :     // the result for complete lanes is discarded downstream.
     493             :     c = simd::fnma(
     494             :         static_cast<T>(2) *
     495             :             (fu / simd::select(incomplete_mask, fb - fa, static_cast<T>(1))),
     496             :         b_minus_a, u);
     497             :     c = simd::select(static_cast<T>(2) * fabs(c - u) > b_minus_a,
     498             :                      simd::fma(static_cast<T>(0.5), b_minus_a, a), c);
     499             : 
     500             :     // Bracket again, and check termination condition:
     501             :     e = d;
     502             :     fe = fd;
     503             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     504             :                                           incomplete_mask);
     505             :     if ((0 == --count) or update_completed()) {
     506             :       break;
     507             :     }
     508             : 
     509             :     // And finally... check to see if an additional bisection step is
     510             :     // to be taken, we do this if we're not converging fast enough:
     511             :     const auto bisection_mask = (b - a) >= mu * (b0 - a0) and incomplete_mask;
     512             :     if (LIKELY(simd::none(bisection_mask))) {
     513             :       continue;
     514             :     }
     515             :     // bracket again on a bisection:
     516             :     //
     517             :     // Note: the mask ensures we only ever modify the slots that the mask has
     518             :     // identified as needing to be modified.
     519             :     e = simd::select(bisection_mask, d, e);
     520             :     fe = simd::select(bisection_mask, fd, fe);
     521             :     toms748_detail::bracket<AssumeFinite>(
     522             :         f, a, b, simd::fma((b - a), static_cast<T>(0.5), a), fa, fb, d, fd,
     523             :         bisection_mask);
     524             :     --count;
     525             :     if (update_completed()) {
     526             :       break;
     527             :     }
     528             :   }  // while loop
     529             : 
     530             :   max_iter -= count;
     531             :   completed_b =
     532             :       simd::select(completed_fa == static_cast<T>(0), completed_a, completed_b);
     533             :   completed_a =
     534             :       simd::select(completed_fb == static_cast<T>(0), completed_b, completed_a);
     535             :   return std::pair{completed_a, completed_b};
     536             : }
     537             : }  // namespace toms748_detail
     538             : 
     539             : /*!
     540             :  * \ingroup NumericalAlgorithmsGroup
     541             :  * \brief Finds the root of the function `f` with the TOMS_748 method.
     542             :  *
     543             :  * `f` is a unary invokable that takes a `double` which is the current value at
     544             :  * which to evaluate `f`. An example is below.
     545             :  *
     546             :  * \snippet Test_TOMS748.cpp double_root_find
     547             :  *
     548             :  * The TOMS_748 algorithm searches for a root in the interval [`lower_bound`,
     549             :  * `upper_bound`], and will throw if this interval does not bracket a root,
     550             :  * i.e. if `f(lower_bound) * f(upper_bound) > 0`.
     551             :  *
     552             :  * The arguments `f_at_lower_bound` and `f_at_upper_bound` are optional, and
     553             :  * are the function values at `lower_bound` and `upper_bound`. These function
     554             :  * values are often known because the user typically checks if a root is
     555             :  * bracketed before calling `toms748`; passing the function values here saves
     556             :  * two function evaluations.
     557             :  *
     558             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     559             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     560             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     561             :  * that products like `fa * fb` are also finite.
     562             :  *
     563             :  * \requires Function `f` is invokable with a `double`
     564             :  *
     565             :  * \throws `convergence_error` if the requested tolerance is not met after
     566             :  *                            `max_iterations` iterations.
     567             :  */
     568             : template <bool AssumeFinite = false, typename Function, typename T>
     569           1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
     570             :           const T f_at_lower_bound, const T f_at_upper_bound,
     571             :           const simd::scalar_type_t<T> absolute_tolerance,
     572             :           const simd::scalar_type_t<T> relative_tolerance,
     573             :           const size_t max_iterations = 100,
     574             :           const simd::mask_type_t<T> ignore_filter =
     575             :               static_cast<simd::mask_type_t<T>>(0)) {
     576             :   ASSERT(relative_tolerance >
     577             :              std::numeric_limits<simd::scalar_type_t<T>>::epsilon(),
     578             :          "The relative tolerance is too small. Got "
     579             :              << relative_tolerance << " but must be at least "
     580             :              << std::numeric_limits<simd::scalar_type_t<T>>::epsilon());
     581             :   if (simd::any(f_at_lower_bound * f_at_upper_bound > 0.0)) {
     582             :     ERROR("Root not bracketed: f(" << lower_bound << ") = " << f_at_lower_bound
     583             :                                    << ", f(" << upper_bound
     584             :                                    << ") = " << f_at_upper_bound);
     585             :   }
     586             : 
     587             :   std::size_t max_iters = max_iterations;
     588             : 
     589             :   // This solver requires tol to be passed as a termination condition. This
     590             :   // termination condition is equivalent to the convergence criteria used by the
     591             :   // GSL
     592             :   const auto tol = [absolute_tolerance, relative_tolerance](const T& lhs,
     593             :                                                             const T& rhs) {
     594             :     return simd::abs(lhs - rhs) <=
     595             :            simd::fma(T(relative_tolerance),
     596             :                      simd::min(simd::abs(lhs), simd::abs(rhs)),
     597             :                      T(absolute_tolerance));
     598             :   };
     599             :   auto result = toms748_detail::toms748_solve<AssumeFinite>(
     600             :       f, lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, tol,
     601             :       ignore_filter, max_iters);
     602             :   if (max_iters >= max_iterations) {
     603             :     ERROR_AS(
     604             :         "toms748 reached max iterations without converging.\nAbsolute "
     605             :         "tolerance: "
     606             :             << absolute_tolerance << "\nRelative tolerance: "
     607             :             << relative_tolerance << "\nResult: " << get_output(result.first)
     608             :             << " " << get_output(result.second),
     609             :         convergence_error);
     610             :   }
     611             :   return simd::fma(static_cast<T>(0.5), (result.second - result.first),
     612             :                    result.first);
     613             : }
     614             : 
     615             : /*!
     616             :  * \ingroup NumericalAlgorithmsGroup
     617             :  * \brief Finds the root of the function `f` with the TOMS_748 method, where
     618             :  * function values are not supplied at the lower and upper bounds.
     619             :  *
     620             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     621             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     622             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     623             :  * that products like `fa * fb` are also finite.
     624             :  */
     625             : template <bool AssumeFinite = false, typename Function, typename T>
     626           1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
     627             :           const simd::scalar_type_t<T> absolute_tolerance,
     628             :           const simd::scalar_type_t<T> relative_tolerance,
     629             :           const size_t max_iterations = 100,
     630             :           const simd::mask_type_t<T> ignore_filter =
     631             :               static_cast<simd::mask_type_t<T>>(0)) {
     632             :   return toms748<AssumeFinite>(
     633             :       f, lower_bound, upper_bound, f(lower_bound), f(upper_bound),
     634             :       absolute_tolerance, relative_tolerance, max_iterations, ignore_filter);
     635             : }
     636             : 
     637             : /*!
     638             :  * \ingroup NumericalAlgorithmsGroup
     639             :  * \brief Finds the root of the function `f` with the TOMS_748 method on each
     640             :  * element in a `DataVector`.
     641             :  *
     642             :  * `f` is a binary invokable that takes a `double` as its first argument and a
     643             :  * `size_t` as its second. The `double` is the current value at which to
     644             :  * evaluate `f`, and the `size_t` is the current index into the `DataVector`s.
     645             :  * Below is an example of how to root find different functions by indexing into
     646             :  * a lambda-captured `DataVector` using the `size_t` passed to `f`.
     647             :  *
     648             :  * \snippet Test_TOMS748.cpp datavector_root_find
     649             :  *
     650             :  * For each index `i` into the DataVector, the TOMS_748 algorithm searches for a
     651             :  * root in the interval [`lower_bound[i]`, `upper_bound[i]`], and will throw if
     652             :  * this interval does not bracket a root,
     653             :  * i.e. if `f(lower_bound[i], i) * f(upper_bound[i], i) > 0`.
     654             :  *
     655             :  * See the [Boost](http://www.boost.org/) documentation for more details.
     656             :  *
     657             :  * \requires Function `f` be callable with a `double` and a `size_t`
     658             :  *
     659             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     660             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     661             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     662             :  * that products like `fa * fb` are also finite.
     663             :  *
     664             :  * \throws `convergence_error` if, for any index, the requested tolerance is not
     665             :  * met after `max_iterations` iterations.
     666             :  */
     667             : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
     668           1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
     669             :                    const DataVector& upper_bound,
     670             :                    const double absolute_tolerance,
     671             :                    const double relative_tolerance,
     672             :                    const size_t max_iterations = 100) {
     673             :   DataVector result_vector{lower_bound.size()};
     674             :   if constexpr (UseSimd) {
     675             :     constexpr size_t simd_width{simd::size<
     676             :         std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
     677             :     const size_t vectorized_size =
     678             :         lower_bound.size() - lower_bound.size() % simd_width;
     679             :     for (size_t i = 0; i < vectorized_size; i += simd_width) {
     680             :       simd::store_unaligned(
     681             :           &result_vector[i],
     682             :           toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
     683             :                                 simd::load_unaligned(&lower_bound[i]),
     684             :                                 simd::load_unaligned(&upper_bound[i]),
     685             :                                 absolute_tolerance, relative_tolerance,
     686             :                                 max_iterations));
     687             :     }
     688             :     for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
     689             :       result_vector[i] = toms748<AssumeFinite>(
     690             :           [&f, i](const auto x) { return f(x, i); }, lower_bound[i],
     691             :           upper_bound[i], absolute_tolerance, relative_tolerance,
     692             :           max_iterations);
     693             :     }
     694             :   } else {
     695             :     for (size_t i = 0; i < result_vector.size(); ++i) {
     696             :       result_vector[i] = toms748<AssumeFinite>(
     697             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     698             :           absolute_tolerance, relative_tolerance, max_iterations);
     699             :     }
     700             :   }
     701             :   return result_vector;
     702             : }
     703             : 
     704             : /*!
     705             :  * \ingroup NumericalAlgorithmsGroup
     706             :  * \brief Finds the root of the function `f` with the TOMS_748 method on each
     707             :  * element in a `DataVector`, where function values are supplied at the lower
     708             :  * and upper bounds.
     709             :  *
     710             :  * Supplying function values is an optimization that saves two
     711             :  * function calls per point.  The function values are often available
     712             :  * because one often checks if the root is bracketed before calling `toms748`.
     713             :  *
     714             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     715             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     716             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     717             :  * that products like `fa * fb` are also finite.
     718             :  */
     719             : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
     720           1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
     721             :                    const DataVector& upper_bound,
     722             :                    const DataVector& f_at_lower_bound,
     723             :                    const DataVector& f_at_upper_bound,
     724             :                    const double absolute_tolerance,
     725             :                    const double relative_tolerance,
     726             :                    const size_t max_iterations = 100) {
     727             :   DataVector result_vector{lower_bound.size()};
     728             :   if constexpr (UseSimd) {
     729             :     constexpr size_t simd_width{simd::size<
     730             :         std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
     731             :     const size_t vectorized_size =
     732             :         lower_bound.size() - lower_bound.size() % simd_width;
     733             :     for (size_t i = 0; i < vectorized_size; i += simd_width) {
     734             :       simd::store_unaligned(
     735             :           &result_vector[i],
     736             :           toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
     737             :                                 simd::load_unaligned(&lower_bound[i]),
     738             :                                 simd::load_unaligned(&upper_bound[i]),
     739             :                                 simd::load_unaligned(&f_at_lower_bound[i]),
     740             :                                 simd::load_unaligned(&f_at_upper_bound[i]),
     741             :                                 absolute_tolerance, relative_tolerance,
     742             :                                 max_iterations));
     743             :     }
     744             :     for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
     745             :       result_vector[i] = toms748<AssumeFinite>(
     746             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     747             :           f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
     748             :           relative_tolerance, max_iterations);
     749             :     }
     750             :   } else {
     751             :     for (size_t i = 0; i < lower_bound.size(); ++i) {
     752             :       result_vector[i] = toms748<AssumeFinite>(
     753             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     754             :           f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
     755             :           relative_tolerance, max_iterations);
     756             :     }
     757             :   }
     758             :   return result_vector;
     759             : }
     760             : }  // namespace RootFinder

Generated by: LCOV version 1.14