SpECTRE Documentation Coverage Report
Current view: top level - NumericalAlgorithms/RootFinding - TOMS748.hpp Hit Total Coverage
Commit: 1f2210958b4f38fdc0400907ee7c6d5af5111418 Lines: 5 5 100.0 %
Date: 2025-12-05 05:03:31
Legend: Lines: hit not hit

          Line data    Source code
       1           1 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : /// \file
       5             : /// Declares function RootFinder::toms748
       6             : 
       7             : #pragma once
       8             : 
       9             : #include <functional>
      10             : #include <iomanip>
      11             : #include <ios>
      12             : #include <limits>
      13             : #include <string>
      14             : #include <type_traits>
      15             : 
      16             : #include "DataStructures/DataVector.hpp"
      17             : #include "Utilities/ErrorHandling/CaptureForError.hpp"
      18             : #include "Utilities/ErrorHandling/Error.hpp"
      19             : #include "Utilities/ErrorHandling/Exceptions.hpp"
      20             : #include "Utilities/GetOutput.hpp"
      21             : #include "Utilities/MakeString.hpp"
      22             : #include "Utilities/Simd/Simd.hpp"
      23             : 
      24             : namespace RootFinder {
      25             : namespace toms748_detail {
      26             : // Original implementation of TOMS748 is from Boost:
      27             : //  (C) Copyright John Maddock 2006.
      28             : //  Use, modification and distribution are subject to the
      29             : //  Boost Software License, Version 1.0. (See accompanying file
      30             : //  LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
      31             : //
      32             : // Significant changes made to pretty much all of it to support SIMD by Nils
      33             : // Deppe. Changes are copyrighted by SXS Collaboration under the MIT License.
      34             : 
      35             : template <typename T>
      36             : T safe_div(const T& num, const T& denom, const T& r) {
      37             :   if constexpr (std::is_floating_point_v<T>) {
      38             :     using std::abs;
      39             :     if (abs(denom) < static_cast<T>(1)) {
      40             :       if (abs(denom * std::numeric_limits<T>::max()) <= abs(num)) {
      41             :         return r;
      42             :       }
      43             :     }
      44             :     return num / denom;
      45             :   } else {
      46             :     // return num / denom without overflow, return r if overflow would occur.
      47             :     //
      48             :     // To do this we need to handle the following cases:
      49             :     // 1. fabs(denom) < 1 AND fabs(denom * max) <= fabs(num):
      50             :     //    return r
      51             :     // 2. return num / denom
      52             :     //
      53             :     // We do this by creating 2 masks.
      54             :     // 1. `mask0` selects where fabs(denom) < 1. This is where we _may_ have
      55             :     //    issues with division by zero or overflow.
      56             :     // 2. `mask` selects where fabs(denom) < 1 AND fabs(denom * max)<=fabs(num)
      57             :     //    * Note: the edge case of fabs(num)==max could be problematic, but if
      58             :     //            you're dealing with numbers like that you are likely in
      59             :     //            trouble anyway.
      60             :     // The second mask is where we must avoid division by zero and instead
      61             :     // return `r`.
      62             :     const auto mask0 = fabs(denom) < static_cast<T>(1);
      63             :     // Note: if denom >= 1.0 you get an FPE because of overflow from
      64             :     // `max() * (1. + value)`, which is why `mask0` is necessary.
      65             :     const auto mask = fabs(simd::select(mask0, denom, static_cast<T>(1.0)) *
      66             :                            std::numeric_limits<T>::max()) <= fabs(num);
      67             :     const auto new_denom = simd::select(mask, static_cast<T>(1), denom);
      68             :     return simd::select(mask, r, num / new_denom);
      69             :   }
      70             : }
      71             : 
      72             : template <typename T>
      73             : T secant_interpolate(const T& a, const T& b, const T& fa, const T& fb) {
      74             :   //
      75             :   // Performs standard secant interpolation of [a,b] given
      76             :   // function evaluations f(a) and f(b).  Performs a bisection
      77             :   // if secant interpolation would leave us very close to either
      78             :   // a or b.  Rationale: we only call this function when at least
      79             :   // one other form of interpolation has already failed, so we know
      80             :   // that the function is unlikely to be smooth with a root very
      81             :   // close to a or b.
      82             :   //
      83             :   const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(5);
      84             :   // WARNING: There are several different ways to implement the interpolation
      85             :   // that all have different rounding properties. Unfortunately this means that
      86             :   // tolerances even at 1e-14 can be difficult to achieve generically.
      87             :   //
      88             :   // `g` below is:
      89             :   // const T g = (fa / (fb - fa));
      90             :   //
      91             :   // const T c = simd::fma((fa / (fb - fa)), (a - b), a); // fails
      92             :   //
      93             :   // const T c = a * (static_cast<T>(1) - g) + g * b; // works
      94             :   //
      95             :   // const T c = simd::fma(a, (static_cast<T>(1) - g), g * b); // works
      96             :   //
      97             :   // Original Boost code:
      98             :   // const T c = a - (fa / (fb - fa)) * (b - a); // works
      99             : 
     100             :   const T c = a - (fa / (fb - fa)) * (b - a);
     101             :   return simd::select((c <= simd::fma(fabs(a), tol_batch, a)) or
     102             :                           (c >= simd::fnma(fabs(b), tol_batch, b)),
     103             :                       static_cast<T>(0.5) * (a + b), c);
     104             : }
     105             : 
     106             : template <bool AssumeFinite, typename T>
     107             : T quadratic_interpolate(const T& a, const T& b, const T& d, const T& fa,
     108             :                         const T& fb, const T& fd,
     109             :                         const simd::mask_type_t<T>& incomplete_mask,
     110             :                         const unsigned count) {
     111             :   // Performs quadratic interpolation to determine the next point,
     112             :   // takes count Newton steps to find the location of the
     113             :   // quadratic polynomial.
     114             :   //
     115             :   // Point d must lie outside of the interval [a,b], it is the third
     116             :   // best approximation to the root, after a and b.
     117             :   //
     118             :   // Note: this does not guarantee to find a root
     119             :   // inside [a, b], so we fall back to a secant step should
     120             :   // the result be out of range.
     121             :   //
     122             :   // Start by obtaining the coefficients of the quadratic polynomial:
     123             :   const T B = safe_div(fb - fa, b - a, std::numeric_limits<T>::max());
     124             :   T A = safe_div(fd - fb, d - b, std::numeric_limits<T>::max());
     125             :   A = safe_div(A - B, d - a, static_cast<T>(0));
     126             : 
     127             :   const auto secant_failure_mask = A == static_cast<T>(0) and incomplete_mask;
     128             :   T result_secant{};
     129             :   if (UNLIKELY(simd::any(secant_failure_mask))) {
     130             :     // failure to determine coefficients, try a secant step:
     131             :     result_secant = secant_interpolate(a, b, fa, fb);
     132             :     if (UNLIKELY(simd::all(secant_failure_mask or (not incomplete_mask)))) {
     133             :       return result_secant;
     134             :     }
     135             :   }
     136             : 
     137             :   // Determine the starting point of the Newton steps:
     138             :   //
     139             :   // Note: unlike Boost, we assume A*fa doesn't overflow. This speeds up the
     140             :   // code quite a bit.
     141             :   T c = AssumeFinite
     142             :             ? simd::select(A * fa > static_cast<T>(0), a, b)
     143             :             : simd::select(simd::sign(A) * simd::sign(fa) > static_cast<T>(0),
     144             :                            a, b);
     145             : 
     146             :   // Take the Newton steps:
     147             :   const T two_A = static_cast<T>(2) * A;
     148             :   const T half_a_plus_b = 0.5 * (a + b);
     149             :   const T one_minus_a = static_cast<T>(1) - a;
     150             :   const T B_minus_A_times_b = B - A * b;
     151             :   for (unsigned i = 1; i <= count; ++i) {
     152             :     c -= safe_div(simd::fma(simd::fma(A, c, B_minus_A_times_b), c - a, fa),
     153             :                   simd::fma(two_A, c - half_a_plus_b, B), one_minus_a + c);
     154             :   }
     155             :   if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
     156             :       simd::any(mask)) {
     157             :     // Failure, try a secant step:
     158             :     c = simd::select(mask, secant_interpolate(a, b, fa, fb), c);
     159             :   }
     160             :   return simd::select(secant_failure_mask, result_secant, c);
     161             : }
     162             : 
     163             : template <bool AssumeFinite, typename T>
     164             : T cubic_interpolate(const T& a, const T& b, const T& d, const T& e, const T& fa,
     165             :                     const T& fb, const T& fd, const T& fe,
     166             :                     const simd::mask_type_t<T>& incomplete_mask) {
     167             :   // Uses inverse cubic interpolation of f(x) at points
     168             :   // [a,b,d,e] to obtain an approximate root of f(x).
     169             :   // Points d and e lie outside the interval [a,b]
     170             :   // and are the third and forth best approximations
     171             :   // to the root that we have found so far.
     172             :   //
     173             :   // Note: this does not guarantee to find a root
     174             :   // inside [a, b], so we fall back to quadratic
     175             :   // interpolation in case of an erroneous result.
     176             :   //
     177             :   // This commented chunk is the original Boost implementation translated into
     178             :   // simd. The actual code below is a heavily optimized version.
     179             :   //
     180             :   // const T q11 = (d - e) * fd / (fe - fd);
     181             :   // const T q21 = (b - d) * fb / (fd - fb);
     182             :   // const T q31 = (a - b) * fa / (fb - fa);
     183             :   // const T d21 = (b - d) * fd / (fd - fb);
     184             :   // const T d31 = (a - b) * fb / (fb - fa);
     185             :   //
     186             :   // const T q22 = (d21 - q11) * fb / (fe - fb);
     187             :   // const T q32 = (d31 - q21) * fa / (fd - fa);
     188             :   // const T d32 = (d31 - q21) * fd / (fd - fa);
     189             :   // const T q33 = (d32 - q22) * fa / (fe - fa);
     190             :   //
     191             :   // T c = q31 + q32 + q33 + a;
     192             : 
     193             :   // The optimized implementation here is L1-cache bound. That is, we aren't
     194             :   // able to completely saturate the FP units because we are waiting on the L1
     195             :   // cache. While not ideal, that's okay and just part of the algorithm.
     196             :   const T denom_fb_fa = fb - fa;
     197             :   const T denom_fd_fb = fd - fb;
     198             :   const T denom_fd_fa = fd - fa;
     199             :   const T denom_fe_fd = fe - fd;
     200             :   const T denom_fe_fb = fe - fb;
     201             :   const T denom =
     202             :       denom_fe_fb * denom_fe_fd * denom_fd_fa * denom_fd_fb * (fe - fa);
     203             : 
     204             :   // Avoid division by zero with mask.
     205             :   const T fa_by_denom = fa / simd::select(incomplete_mask, denom_fb_fa * denom,
     206             :                                           static_cast<T>(1));
     207             : 
     208             :   const T d31 = (a - b);
     209             :   const T q21 = (b - d);
     210             :   const T q32 = simd::fms(denom_fd_fb, d31, denom_fb_fa * q21) * denom_fe_fb *
     211             :                 denom_fe_fd;
     212             :   const T q22 = simd::fms(denom_fe_fd, q21, (d - e) * denom_fd_fb) * fd *
     213             :                 denom_fb_fa * denom_fd_fa;
     214             : 
     215             :   // Note: the reduction in rounding error that comes from the improvement by
     216             :   // Stoer & Bulirsch to Neville's algorithm is adding `a` at the very end as
     217             :   // we do below. Alternative ways of evaluating polynomials do not delay this
     218             :   // inclusion of `a`, and so then when the correction to `a` is small,
     219             :   // floating point errors decrease the accuracy of the result.
     220             :   T c = simd::fma(
     221             :       fa_by_denom,
     222             :       simd::fma(fb, simd::fms(q32, (fe + denom_fd_fa), q22), d31 * denom), a);
     223             : 
     224             :   if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
     225             :       simd::any(mask)) {
     226             :     // Out of bounds step, fall back to quadratic interpolation:
     227             :     //
     228             :     // Note: we only apply quadratic interpolation at points where cubic
     229             :     // failed and that aren't already at a root.
     230             :     c = simd::select(
     231             :         mask, quadratic_interpolate<AssumeFinite>(a, b, d, fa, fb, fd, mask, 3),
     232             :         c);
     233             :   }
     234             : 
     235             :   return c;
     236             : }
     237             : 
     238             : template <bool AssumeFinite, typename F, typename T>
     239             : void bracket(F f, T& a, T& b, T c, T& fa, T& fb, T& d, T& fd,
     240             :              const simd::mask_type_t<T>& incomplete_mask) {
     241             :   // Given a point c inside the existing enclosing interval
     242             :   // [a, b] sets a = c if f(c) == 0, otherwise finds the new
     243             :   // enclosing interval: either [a, c] or [c, b] and sets
     244             :   // d and fd to the point that has just been removed from
     245             :   // the interval.  In other words d is the third best guess
     246             :   // to the root.
     247             :   //
     248             :   // Note: `bracket` will only modify slots marked as `true` in
     249             :   //       `incomplete_mask`
     250             :   const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(2);
     251             : 
     252             :   // If the interval [a,b] is very small, or if c is too close
     253             :   // to one end of the interval then we need to adjust the
     254             :   // location of c accordingly. This is:
     255             :   //
     256             :   //   if ((b - a) < 2 * tol * a) {
     257             :   //     c = a + (b - a) / 2;
     258             :   //   } else if (c <= a + fabs(a) * tol) {
     259             :   //     c = a + fabs(a) * tol;
     260             :   //   } else if (c >= b - fabs(b) * tol) {
     261             :   //     c = b - fabs(b) * tol;
     262             :   //   }
     263             :   const T a_filt = simd::fma(fabs(a), tol_batch, a);
     264             :   const T b_filt = simd::fnma(fabs(b), tol_batch, b);
     265             :   const T b_minus_a = b - a;
     266             :   c = simd::select(
     267             :       (static_cast<T>(2) * tol_batch * a > b_minus_a) and incomplete_mask,
     268             :       simd::fma(b_minus_a, static_cast<T>(0.5), a),
     269             :       simd::select(c <= a_filt, a_filt, simd::select(c >= b_filt, b_filt, c)));
     270             : 
     271             :   // Invoke f(c):
     272             :   T fc = f(c);
     273             : 
     274             :   // if we have a zero then we have an exact solution to the root:
     275             :   const auto fc_is_zero_mask = (fc == static_cast<T>(0));
     276             :   if (const auto mask = fc_is_zero_mask and incomplete_mask;
     277             :       UNLIKELY(simd::any(mask))) {
     278             :     a = simd::select(mask, c, a);
     279             :     fa = simd::select(mask, static_cast<T>(0), fa);
     280             :     d = simd::select(mask, static_cast<T>(0), d);
     281             :     fd = simd::select(mask, static_cast<T>(0), fd);
     282             :     if (UNLIKELY(simd::all(mask or not incomplete_mask))) {
     283             :       return;
     284             :     }
     285             :   }
     286             : 
     287             :   // Non-zero fc, update the interval:
     288             :   //
     289             :   // Note: unlike Boost, we assume fa*fc doesn't overflow. This speeds up the
     290             :   // code quite a bit.
     291             :   //
     292             :   // Boost code is:
     293             :   // if (boost::math::sign(fa) * boost::math::sign(fc) < 0) {...} else {...}
     294             :   using simd::sign;
     295             :   const auto sign_mask = AssumeFinite
     296             :                              ? (fa * fc < static_cast<T>(0))
     297             :                              : (sign(fa) * sign(fc) < static_cast<T>(0));
     298             :   const auto mask_if =
     299             :       (sign_mask and (not fc_is_zero_mask)) and incomplete_mask;
     300             :   d = simd::select(mask_if, b, d);
     301             :   fd = simd::select(mask_if, fb, fd);
     302             :   b = simd::select(mask_if, c, b);
     303             :   fb = simd::select(mask_if, fc, fb);
     304             : 
     305             :   const auto mask_else =
     306             :       ((not sign_mask) and (not fc_is_zero_mask)) and incomplete_mask;
     307             :   d = simd::select(mask_else, a, d);
     308             :   fd = simd::select(mask_else, fa, fd);
     309             :   a = simd::select(mask_else, c, a);
     310             :   fa = simd::select(mask_else, fc, fa);
     311             : }
     312             : 
     313             : template <bool AssumeFinite, class F, class T, class Tol>
     314             : std::pair<T, T> toms748_solve(F f, const T& ax, const T& bx, const T& fax,
     315             :                               const T& fbx, Tol tol,
     316             :                               const simd::mask_type_t<T>& ignore_filter,
     317             :                               size_t& max_iter) {
     318             :   // Main entry point and logic for Toms Algorithm 748
     319             :   // root finder.
     320             :   if (UNLIKELY(simd::any(ax > bx))) {
     321             :     ERROR_AS("Lower bound is larger than upper bound", std::domain_error);
     322             :   }
     323             : 
     324             :   // Sanity check - are we allowed to iterate at all?
     325             :   if (UNLIKELY(max_iter == 0)) {
     326             :     return std::pair{ax, bx};
     327             :   }
     328             : 
     329             :   size_t count = max_iter;
     330             :   // mu is a parameter in the algorithm that must be between (0, 1).
     331             :   static const T mu = 0.5f;
     332             : 
     333             :   // initialise a, b and fa, fb:
     334             :   T a = ax;
     335             :   T b = bx;
     336             :   T fa = fax;
     337             :   T fb = fbx;
     338             :   CAPTURE_FOR_ERROR(ax);
     339             :   CAPTURE_FOR_ERROR(bx);
     340             :   CAPTURE_FOR_ERROR(fax);
     341             :   CAPTURE_FOR_ERROR(fbx);
     342             : 
     343             :   const auto fa_is_zero_mask = (fa == static_cast<T>(0));
     344             :   const auto fb_is_zero_mask = (fb == static_cast<T>(0));
     345             :   auto completion_mask =
     346             :       tol(a, b) or fa_is_zero_mask or fb_is_zero_mask or ignore_filter;
     347             :   auto incomplete_mask = not completion_mask;
     348             :   if (UNLIKELY(simd::all(completion_mask))) {
     349             :     max_iter = 0;
     350             :     return std::pair{simd::select(fb_is_zero_mask, b, a),
     351             :                      simd::select(fa_is_zero_mask, a, b)};
     352             :   }
     353             : 
     354             :   // Note: unlike Boost, we can assume fa*fb doesn't overflow when possible.
     355             :   // This speeds up the code quite a bit.
     356             :   if (UNLIKELY(simd::any((AssumeFinite ? (fa * fb > static_cast<T>(0))
     357             :                                        : (simd::sign(fa) * simd::sign(fb) >
     358             :                                           static_cast<T>(0))) and
     359             :                          (not fa_is_zero_mask) and (not fb_is_zero_mask)))) {
     360             :     ERROR_AS("Parameters lower and upper bounds do not bracket a root.",
     361             :              std::domain_error);
     362             :   }
     363             :   // dummy value for fd, e and fe:
     364             :   T fe(static_cast<T>(1e5F));
     365             :   T e(static_cast<T>(1e5F));
     366             :   T fd(static_cast<T>(1e5F));
     367             : 
     368             :   T c(0.0);
     369             :   T d(0.0);
     370             : 
     371             :   // We can't use signaling_NaN() because we do a max comparison at the end to
     372             :   // 0, and signaling_NaN when compared to 0 raises an FPE.
     373             :   const T nan(std::numeric_limits<T>::quiet_NaN());
     374             :   auto completed_a = simd::select(completion_mask, a, nan);
     375             :   auto completed_b = simd::select(completion_mask, b, nan);
     376             :   auto completed_fa = simd::select(completion_mask, fa, nan);
     377             :   auto completed_fb = simd::select(completion_mask, fb, nan);
     378             :   const auto update_completed = [&fa, &fb, &completion_mask, &incomplete_mask,
     379             :                                  &completed_a, &completed_b, &a, &b,
     380             :                                  &completed_fa, &completed_fb, &tol]() {
     381             :     const auto new_completed =
     382             :         (fa == static_cast<T>(0) or tol(a, b)) and (not completion_mask);
     383             :     completed_a = simd::select(new_completed, a, completed_a);
     384             :     completed_b = simd::select(new_completed, b, completed_b);
     385             :     completed_fa = simd::select(new_completed, fa, completed_fa);
     386             :     completed_fb = simd::select(new_completed, fb, completed_fb);
     387             :     completion_mask = new_completed or completion_mask;
     388             :     incomplete_mask = not completion_mask;
     389             :     // returns true if _all_ simd registers have been completed
     390             :     return simd::all(completion_mask);
     391             :   };
     392             : 
     393             :   if (simd::any(fa != static_cast<T>(0))) {
     394             :     // On the first step we take a secant step:
     395             :     c = toms748_detail::secant_interpolate(a, b, fa, fb);
     396             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     397             :                                           incomplete_mask);
     398             :     --count;
     399             : 
     400             :     // Note: The Boost fa!=0 check is handled with the completion_mask.
     401             :     if (count and not update_completed()) {
     402             :       // On the second step we take a quadratic interpolation:
     403             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     404             :           a, b, d, fa, fb, fd, incomplete_mask, 2);
     405             :       e = d;
     406             :       fe = fd;
     407             :       toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     408             :                                             incomplete_mask);
     409             :       --count;
     410             :       update_completed();
     411             :     }
     412             :   }
     413             : 
     414             :   T u(std::numeric_limits<T>::signaling_NaN());
     415             :   T fu(std::numeric_limits<T>::signaling_NaN());
     416             :   T a0(std::numeric_limits<T>::signaling_NaN());
     417             :   T b0(std::numeric_limits<T>::signaling_NaN());
     418             : 
     419             :   // Note: The Boost fa!=0 check is handled with the completion_mask.
     420             :   while (count and not simd::all(completion_mask)) {
     421             :     // save our brackets:
     422             :     a0 = a;
     423             :     b0 = b;
     424             :     // Starting with the third step taken
     425             :     // we can use either quadratic or cubic interpolation.
     426             :     // Cubic interpolation requires that all four function values
     427             :     // fa, fb, fd, and fe are distinct, should that not be the case
     428             :     // then variable prof will get set to true, and we'll end up
     429             :     // taking a quadratic step instead.
     430             :     static const T min_diff = std::numeric_limits<T>::min() * 32;
     431             :     bool prof =
     432             :         simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
     433             :                    (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
     434             :                    (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
     435             :                   incomplete_mask);
     436             :     if (prof) {
     437             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     438             :           a, b, d, fa, fb, fd, incomplete_mask, 2);
     439             :     } else {
     440             :       c = toms748_detail::cubic_interpolate<AssumeFinite>(
     441             :           a, b, d, e, fa, fb, fd, fe, incomplete_mask);
     442             :     }
     443             :     // re-bracket, and check for termination:
     444             :     e = d;
     445             :     fe = fd;
     446             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     447             :                                           incomplete_mask);
     448             :     if ((0 == --count) or update_completed()) {
     449             :       break;
     450             :     }
     451             :     // Now another interpolated step:
     452             :     prof =
     453             :         simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
     454             :                    (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
     455             :                    (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
     456             :                   incomplete_mask);
     457             :     if (prof) {
     458             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     459             :           a, b, d, fa, fb, fd, incomplete_mask, 3);
     460             :     } else {
     461             :       c = toms748_detail::cubic_interpolate<AssumeFinite>(
     462             :           a, b, d, e, fa, fb, fd, fe, incomplete_mask);
     463             :     }
     464             :     // Bracket again, and check termination condition, update e:
     465             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     466             :                                           incomplete_mask);
     467             :     if ((0 == --count) or update_completed()) {
     468             :       break;
     469             :     }
     470             : 
     471             :     // Now we take a double-length secant step:
     472             :     const auto fabs_fa_less_fabs_fb_mask =
     473             :         (fabs(fa) < fabs(fb)) and incomplete_mask;
     474             :     u = simd::select(fabs_fa_less_fabs_fb_mask, a, b);
     475             :     fu = simd::select(fabs_fa_less_fabs_fb_mask, fa, fb);
     476             :     const T b_minus_a = b - a;
     477             :     // Assumes that bounds a & b are not so close that fa == fb. Boost makes
     478             :     // this assumption too. If this is violated then the algorithm doesn't
     479             :     // work since the function at least appears constant.
     480             :     c = simd::fnma(static_cast<T>(2) * (fu / (fb - fa)), b_minus_a, u);
     481             :     c = simd::select(static_cast<T>(2) * fabs(c - u) > b_minus_a,
     482             :                      simd::fma(static_cast<T>(0.5), b_minus_a, a), c);
     483             : 
     484             :     // Bracket again, and check termination condition:
     485             :     e = d;
     486             :     fe = fd;
     487             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     488             :                                           incomplete_mask);
     489             :     if ((0 == --count) or update_completed()) {
     490             :       break;
     491             :     }
     492             : 
     493             :     // And finally... check to see if an additional bisection step is
     494             :     // to be taken, we do this if we're not converging fast enough:
     495             :     const auto bisection_mask = (b - a) >= mu * (b0 - a0) and incomplete_mask;
     496             :     if (LIKELY(simd::none(bisection_mask))) {
     497             :       continue;
     498             :     }
     499             :     // bracket again on a bisection:
     500             :     //
     501             :     // Note: the mask ensures we only ever modify the slots that the mask has
     502             :     // identified as needing to be modified.
     503             :     e = simd::select(bisection_mask, d, e);
     504             :     fe = simd::select(bisection_mask, fd, fe);
     505             :     toms748_detail::bracket<AssumeFinite>(
     506             :         f, a, b, simd::fma((b - a), static_cast<T>(0.5), a), fa, fb, d, fd,
     507             :         bisection_mask);
     508             :     --count;
     509             :     if (update_completed()) {
     510             :       break;
     511             :     }
     512             :   }  // while loop
     513             : 
     514             :   max_iter -= count;
     515             :   completed_b =
     516             :       simd::select(completed_fa == static_cast<T>(0), completed_a, completed_b);
     517             :   completed_a =
     518             :       simd::select(completed_fb == static_cast<T>(0), completed_b, completed_a);
     519             :   return std::pair{completed_a, completed_b};
     520             : }
     521             : }  // namespace toms748_detail
     522             : 
     523             : /*!
     524             :  * \ingroup NumericalAlgorithmsGroup
     525             :  * \brief Finds the root of the function `f` with the TOMS_748 method.
     526             :  *
     527             :  * `f` is a unary invokable that takes a `double` which is the current value at
     528             :  * which to evaluate `f`. An example is below.
     529             :  *
     530             :  * \snippet Test_TOMS748.cpp double_root_find
     531             :  *
     532             :  * The TOMS_748 algorithm searches for a root in the interval [`lower_bound`,
     533             :  * `upper_bound`], and will throw if this interval does not bracket a root,
     534             :  * i.e. if `f(lower_bound) * f(upper_bound) > 0`.
     535             :  *
     536             :  * The arguments `f_at_lower_bound` and `f_at_upper_bound` are optional, and
     537             :  * are the function values at `lower_bound` and `upper_bound`. These function
     538             :  * values are often known because the user typically checks if a root is
     539             :  * bracketed before calling `toms748`; passing the function values here saves
     540             :  * two function evaluations.
     541             :  *
     542             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     543             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     544             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     545             :  * that products like `fa * fb` are also finite.
     546             :  *
     547             :  * \requires Function `f` is invokable with a `double`
     548             :  *
     549             :  * \throws `convergence_error` if the requested tolerance is not met after
     550             :  *                            `max_iterations` iterations.
     551             :  */
     552             : template <bool AssumeFinite = false, typename Function, typename T>
     553           1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
     554             :           const T f_at_lower_bound, const T f_at_upper_bound,
     555             :           const simd::scalar_type_t<T> absolute_tolerance,
     556             :           const simd::scalar_type_t<T> relative_tolerance,
     557             :           const size_t max_iterations = 100,
     558             :           const simd::mask_type_t<T> ignore_filter =
     559             :               static_cast<simd::mask_type_t<T>>(0)) {
     560             :   ASSERT(relative_tolerance >
     561             :              std::numeric_limits<simd::scalar_type_t<T>>::epsilon(),
     562             :          "The relative tolerance is too small. Got "
     563             :              << relative_tolerance << " but must be at least "
     564             :              << std::numeric_limits<simd::scalar_type_t<T>>::epsilon());
     565             :   if (simd::any(f_at_lower_bound * f_at_upper_bound > 0.0)) {
     566             :     ERROR("Root not bracketed: f(" << lower_bound << ") = " << f_at_lower_bound
     567             :                                    << ", f(" << upper_bound
     568             :                                    << ") = " << f_at_upper_bound);
     569             :   }
     570             : 
     571             :   std::size_t max_iters = max_iterations;
     572             : 
     573             :   // This solver requires tol to be passed as a termination condition. This
     574             :   // termination condition is equivalent to the convergence criteria used by the
     575             :   // GSL
     576             :   const auto tol = [absolute_tolerance, relative_tolerance](const T& lhs,
     577             :                                                             const T& rhs) {
     578             :     return simd::abs(lhs - rhs) <=
     579             :            simd::fma(T(relative_tolerance),
     580             :                      simd::min(simd::abs(lhs), simd::abs(rhs)),
     581             :                      T(absolute_tolerance));
     582             :   };
     583             :   auto result = toms748_detail::toms748_solve<AssumeFinite>(
     584             :       f, lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, tol,
     585             :       ignore_filter, max_iters);
     586             :   if (max_iters >= max_iterations) {
     587             :     ERROR_AS(
     588             :         "toms748 reached max iterations without converging.\nAbsolute "
     589             :         "tolerance: "
     590             :             << absolute_tolerance << "\nRelative tolerance: "
     591             :             << relative_tolerance << "\nResult: " << get_output(result.first)
     592             :             << " " << get_output(result.second),
     593             :         convergence_error);
     594             :   }
     595             :   return simd::fma(static_cast<T>(0.5), (result.second - result.first),
     596             :                    result.first);
     597             : }
     598             : 
     599             : /*!
     600             :  * \ingroup NumericalAlgorithmsGroup
     601             :  * \brief Finds the root of the function `f` with the TOMS_748 method, where
     602             :  * function values are not supplied at the lower and upper bounds.
     603             :  *
     604             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     605             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     606             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     607             :  * that products like `fa * fb` are also finite.
     608             :  */
     609             : template <bool AssumeFinite = false, typename Function, typename T>
     610           1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
     611             :           const simd::scalar_type_t<T> absolute_tolerance,
     612             :           const simd::scalar_type_t<T> relative_tolerance,
     613             :           const size_t max_iterations = 100,
     614             :           const simd::mask_type_t<T> ignore_filter =
     615             :               static_cast<simd::mask_type_t<T>>(0)) {
     616             :   return toms748<AssumeFinite>(
     617             :       f, lower_bound, upper_bound, f(lower_bound), f(upper_bound),
     618             :       absolute_tolerance, relative_tolerance, max_iterations, ignore_filter);
     619             : }
     620             : 
     621             : /*!
     622             :  * \ingroup NumericalAlgorithmsGroup
     623             :  * \brief Finds the root of the function `f` with the TOMS_748 method on each
     624             :  * element in a `DataVector`.
     625             :  *
     626             :  * `f` is a binary invokable that takes a `double` as its first argument and a
     627             :  * `size_t` as its second. The `double` is the current value at which to
     628             :  * evaluate `f`, and the `size_t` is the current index into the `DataVector`s.
     629             :  * Below is an example of how to root find different functions by indexing into
     630             :  * a lambda-captured `DataVector` using the `size_t` passed to `f`.
     631             :  *
     632             :  * \snippet Test_TOMS748.cpp datavector_root_find
     633             :  *
     634             :  * For each index `i` into the DataVector, the TOMS_748 algorithm searches for a
     635             :  * root in the interval [`lower_bound[i]`, `upper_bound[i]`], and will throw if
     636             :  * this interval does not bracket a root,
     637             :  * i.e. if `f(lower_bound[i], i) * f(upper_bound[i], i) > 0`.
     638             :  *
     639             :  * See the [Boost](http://www.boost.org/) documentation for more details.
     640             :  *
     641             :  * \requires Function `f` be callable with a `double` and a `size_t`
     642             :  *
     643             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     644             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     645             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     646             :  * that products like `fa * fb` are also finite.
     647             :  *
     648             :  * \throws `convergence_error` if, for any index, the requested tolerance is not
     649             :  * met after `max_iterations` iterations.
     650             :  */
     651             : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
     652           1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
     653             :                    const DataVector& upper_bound,
     654             :                    const double absolute_tolerance,
     655             :                    const double relative_tolerance,
     656             :                    const size_t max_iterations = 100) {
     657             :   DataVector result_vector{lower_bound.size()};
     658             :   if constexpr (UseSimd) {
     659             :     constexpr size_t simd_width{simd::size<
     660             :         std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
     661             :     const size_t vectorized_size =
     662             :         lower_bound.size() - lower_bound.size() % simd_width;
     663             :     for (size_t i = 0; i < vectorized_size; i += simd_width) {
     664             :       simd::store_unaligned(
     665             :           &result_vector[i],
     666             :           toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
     667             :                                 simd::load_unaligned(&lower_bound[i]),
     668             :                                 simd::load_unaligned(&upper_bound[i]),
     669             :                                 absolute_tolerance, relative_tolerance,
     670             :                                 max_iterations));
     671             :     }
     672             :     for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
     673             :       result_vector[i] = toms748<AssumeFinite>(
     674             :           [&f, i](const auto x) { return f(x, i); }, lower_bound[i],
     675             :           upper_bound[i], absolute_tolerance, relative_tolerance,
     676             :           max_iterations);
     677             :     }
     678             :   } else {
     679             :     for (size_t i = 0; i < result_vector.size(); ++i) {
     680             :       result_vector[i] = toms748<AssumeFinite>(
     681             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     682             :           absolute_tolerance, relative_tolerance, max_iterations);
     683             :     }
     684             :   }
     685             :   return result_vector;
     686             : }
     687             : 
     688             : /*!
     689             :  * \ingroup NumericalAlgorithmsGroup
     690             :  * \brief Finds the root of the function `f` with the TOMS_748 method on each
     691             :  * element in a `DataVector`, where function values are supplied at the lower
     692             :  * and upper bounds.
     693             :  *
     694             :  * Supplying function values is an optimization that saves two
     695             :  * function calls per point.  The function values are often available
     696             :  * because one often checks if the root is bracketed before calling `toms748`.
     697             :  *
     698             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     699             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     700             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     701             :  * that products like `fa * fb` are also finite.
     702             :  */
     703             : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
     704           1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
     705             :                    const DataVector& upper_bound,
     706             :                    const DataVector& f_at_lower_bound,
     707             :                    const DataVector& f_at_upper_bound,
     708             :                    const double absolute_tolerance,
     709             :                    const double relative_tolerance,
     710             :                    const size_t max_iterations = 100) {
     711             :   DataVector result_vector{lower_bound.size()};
     712             :   if constexpr (UseSimd) {
     713             :     constexpr size_t simd_width{simd::size<
     714             :         std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
     715             :     const size_t vectorized_size =
     716             :         lower_bound.size() - lower_bound.size() % simd_width;
     717             :     for (size_t i = 0; i < vectorized_size; i += simd_width) {
     718             :       simd::store_unaligned(
     719             :           &result_vector[i],
     720             :           toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
     721             :                                 simd::load_unaligned(&lower_bound[i]),
     722             :                                 simd::load_unaligned(&upper_bound[i]),
     723             :                                 simd::load_unaligned(&f_at_lower_bound[i]),
     724             :                                 simd::load_unaligned(&f_at_upper_bound[i]),
     725             :                                 absolute_tolerance, relative_tolerance,
     726             :                                 max_iterations));
     727             :     }
     728             :     for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
     729             :       result_vector[i] = toms748<AssumeFinite>(
     730             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     731             :           f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
     732             :           relative_tolerance, max_iterations);
     733             :     }
     734             :   } else {
     735             :     for (size_t i = 0; i < lower_bound.size(); ++i) {
     736             :       result_vector[i] = toms748<AssumeFinite>(
     737             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     738             :           f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
     739             :           relative_tolerance, max_iterations);
     740             :     }
     741             :   }
     742             :   return result_vector;
     743             : }
     744             : }  // namespace RootFinder

Generated by: LCOV version 1.14