SpECTRE Documentation Coverage Report
Current view: top level - NumericalAlgorithms/RootFinding - TOMS748.hpp Hit Total Coverage
Commit: 6e1258ccd353220e12442198913007fb6c170b6b Lines: 5 5 100.0 %
Date: 2024-10-23 19:54:13
Legend: Lines: hit not hit

          Line data    Source code
       1           1 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : /// \file
       5             : /// Declares function RootFinder::toms748
       6             : 
       7             : #pragma once
       8             : 
       9             : #include <functional>
      10             : #include <iomanip>
      11             : #include <ios>
      12             : #include <limits>
      13             : #include <string>
      14             : #include <type_traits>
      15             : 
      16             : #include "DataStructures/DataVector.hpp"
      17             : #include "Utilities/ErrorHandling/Error.hpp"
      18             : #include "Utilities/ErrorHandling/Exceptions.hpp"
      19             : #include "Utilities/GetOutput.hpp"
      20             : #include "Utilities/MakeString.hpp"
      21             : #include "Utilities/Simd/Simd.hpp"
      22             : 
      23             : namespace RootFinder {
      24             : namespace toms748_detail {
      25             : // Original implementation of TOMS748 is from Boost:
      26             : //  (C) Copyright John Maddock 2006.
      27             : //  Use, modification and distribution are subject to the
      28             : //  Boost Software License, Version 1.0. (See accompanying file
      29             : //  LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
      30             : //
      31             : // Significant changes made to pretty much all of it to support SIMD by Nils
      32             : // Deppe. Changes are copyrighted by SXS Collaboration under the MIT License.
      33             : 
      34             : template <typename T>
      35             : T safe_div(const T& num, const T& denom, const T& r) {
      36             :   if constexpr (std::is_floating_point_v<T>) {
      37             :     using std::abs;
      38             :     if (abs(denom) < static_cast<T>(1)) {
      39             :       if (abs(denom * std::numeric_limits<T>::max()) <= abs(num)) {
      40             :         return r;
      41             :       }
      42             :     }
      43             :     return num / denom;
      44             :   } else {
      45             :     // return num / denom without overflow, return r if overflow would occur.
      46             :     //
      47             :     // To do this we need to handle the following cases:
      48             :     // 1. fabs(denom) < 1 AND fabs(denom * max) <= fabs(num):
      49             :     //    return r
      50             :     // 2. return num / denom
      51             :     //
      52             :     // We do this by creating 2 masks.
      53             :     // 1. `mask0` selects where fabs(denom) < 1. This is where we _may_ have
      54             :     //    issues with division by zero or overflow.
      55             :     // 2. `mask` selects where fabs(denom) < 1 AND fabs(denom * max)<=fabs(num)
      56             :     //    * Note: the edge case of fabs(num)==max could be problematic, but if
      57             :     //            you're dealing with numbers like that you are likely in
      58             :     //            trouble anyway.
      59             :     // The second mask is where we must avoid division by zero and instead
      60             :     // return `r`.
      61             :     const auto mask0 = fabs(denom) < static_cast<T>(1);
      62             :     // Note: if denom >= 1.0 you get an FPE because of overflow from
      63             :     // `max() * (1. + value)`, which is why `mask0` is necessary.
      64             :     const auto mask = fabs(simd::select(mask0, denom, static_cast<T>(1.0)) *
      65             :                            std::numeric_limits<T>::max()) <= fabs(num);
      66             :     const auto new_denom = simd::select(mask, static_cast<T>(1), denom);
      67             :     return simd::select(mask, r, num / new_denom);
      68             :   }
      69             : }
      70             : 
      71             : template <typename T>
      72             : T secant_interpolate(const T& a, const T& b, const T& fa, const T& fb) {
      73             :   //
      74             :   // Performs standard secant interpolation of [a,b] given
      75             :   // function evaluations f(a) and f(b).  Performs a bisection
      76             :   // if secant interpolation would leave us very close to either
      77             :   // a or b.  Rationale: we only call this function when at least
      78             :   // one other form of interpolation has already failed, so we know
      79             :   // that the function is unlikely to be smooth with a root very
      80             :   // close to a or b.
      81             :   //
      82             :   const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(5);
      83             :   // WARNING: There are several different ways to implement the interpolation
      84             :   // that all have different rounding properties. Unfortunately this means that
      85             :   // tolerances even at 1e-14 can be difficult to achieve generically.
      86             :   //
      87             :   // `g` below is:
      88             :   // const T g = (fa / (fb - fa));
      89             :   //
      90             :   // const T c = simd::fma((fa / (fb - fa)), (a - b), a); // fails
      91             :   //
      92             :   // const T c = a * (static_cast<T>(1) - g) + g * b; // works
      93             :   //
      94             :   // const T c = simd::fma(a, (static_cast<T>(1) - g), g * b); // works
      95             :   //
      96             :   // Original Boost code:
      97             :   // const T c = a - (fa / (fb - fa)) * (b - a); // works
      98             : 
      99             :   const T c = a - (fa / (fb - fa)) * (b - a);
     100             :   return simd::select((c <= simd::fma(fabs(a), tol_batch, a)) or
     101             :                           (c >= simd::fnma(fabs(b), tol_batch, b)),
     102             :                       static_cast<T>(0.5) * (a + b), c);
     103             : }
     104             : 
     105             : template <bool AssumeFinite, typename T>
     106             : T quadratic_interpolate(const T& a, const T& b, const T& d, const T& fa,
     107             :                         const T& fb, const T& fd,
     108             :                         const simd::mask_type_t<T>& incomplete_mask,
     109             :                         const unsigned count) {
     110             :   // Performs quadratic interpolation to determine the next point,
     111             :   // takes count Newton steps to find the location of the
     112             :   // quadratic polynomial.
     113             :   //
     114             :   // Point d must lie outside of the interval [a,b], it is the third
     115             :   // best approximation to the root, after a and b.
     116             :   //
     117             :   // Note: this does not guarantee to find a root
     118             :   // inside [a, b], so we fall back to a secant step should
     119             :   // the result be out of range.
     120             :   //
     121             :   // Start by obtaining the coefficients of the quadratic polynomial:
     122             :   const T B = safe_div(fb - fa, b - a, std::numeric_limits<T>::max());
     123             :   T A = safe_div(fd - fb, d - b, std::numeric_limits<T>::max());
     124             :   A = safe_div(A - B, d - a, static_cast<T>(0));
     125             : 
     126             :   const auto secant_failure_mask = A == static_cast<T>(0) and incomplete_mask;
     127             :   T result_secant{};
     128             :   if (UNLIKELY(simd::any(secant_failure_mask))) {
     129             :     // failure to determine coefficients, try a secant step:
     130             :     result_secant = secant_interpolate(a, b, fa, fb);
     131             :     if (UNLIKELY(simd::all(secant_failure_mask or (not incomplete_mask)))) {
     132             :       return result_secant;
     133             :     }
     134             :   }
     135             : 
     136             :   // Determine the starting point of the Newton steps:
     137             :   //
     138             :   // Note: unlike Boost, we assume A*fa doesn't overflow. This speeds up the
     139             :   // code quite a bit.
     140             :   T c = AssumeFinite
     141             :             ? simd::select(A * fa > static_cast<T>(0), a, b)
     142             :             : simd::select(simd::sign(A) * simd::sign(fa) > static_cast<T>(0),
     143             :                            a, b);
     144             : 
     145             :   // Take the Newton steps:
     146             :   const T two_A = static_cast<T>(2) * A;
     147             :   const T half_a_plus_b = 0.5 * (a + b);
     148             :   const T one_minus_a = static_cast<T>(1) - a;
     149             :   const T B_minus_A_times_b = B - A * b;
     150             :   for (unsigned i = 1; i <= count; ++i) {
     151             :     c -= safe_div(simd::fma(simd::fma(A, c, B_minus_A_times_b), c - a, fa),
     152             :                   simd::fma(two_A, c - half_a_plus_b, B), one_minus_a + c);
     153             :   }
     154             :   if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
     155             :       simd::any(mask)) {
     156             :     // Failure, try a secant step:
     157             :     c = simd::select(mask, secant_interpolate(a, b, fa, fb), c);
     158             :   }
     159             :   return simd::select(secant_failure_mask, result_secant, c);
     160             : }
     161             : 
     162             : template <bool AssumeFinite, typename T>
     163             : T cubic_interpolate(const T& a, const T& b, const T& d, const T& e, const T& fa,
     164             :                     const T& fb, const T& fd, const T& fe,
     165             :                     const simd::mask_type_t<T>& incomplete_mask) {
     166             :   // Uses inverse cubic interpolation of f(x) at points
     167             :   // [a,b,d,e] to obtain an approximate root of f(x).
     168             :   // Points d and e lie outside the interval [a,b]
     169             :   // and are the third and forth best approximations
     170             :   // to the root that we have found so far.
     171             :   //
     172             :   // Note: this does not guarantee to find a root
     173             :   // inside [a, b], so we fall back to quadratic
     174             :   // interpolation in case of an erroneous result.
     175             :   //
     176             :   // This commented chunk is the original Boost implementation translated into
     177             :   // simd. The actual code below is a heavily optimized version.
     178             :   //
     179             :   // const T q11 = (d - e) * fd / (fe - fd);
     180             :   // const T q21 = (b - d) * fb / (fd - fb);
     181             :   // const T q31 = (a - b) * fa / (fb - fa);
     182             :   // const T d21 = (b - d) * fd / (fd - fb);
     183             :   // const T d31 = (a - b) * fb / (fb - fa);
     184             :   //
     185             :   // const T q22 = (d21 - q11) * fb / (fe - fb);
     186             :   // const T q32 = (d31 - q21) * fa / (fd - fa);
     187             :   // const T d32 = (d31 - q21) * fd / (fd - fa);
     188             :   // const T q33 = (d32 - q22) * fa / (fe - fa);
     189             :   //
     190             :   // T c = q31 + q32 + q33 + a;
     191             : 
     192             :   // The optimized implementation here is L1-cache bound. That is, we aren't
     193             :   // able to completely saturate the FP units because we are waiting on the L1
     194             :   // cache. While not ideal, that's okay and just part of the algorithm.
     195             :   const T denom_fb_fa = fb - fa;
     196             :   const T denom_fd_fb = fd - fb;
     197             :   const T denom_fd_fa = fd - fa;
     198             :   const T denom_fe_fd = fe - fd;
     199             :   const T denom_fe_fb = fe - fb;
     200             :   const T denom =
     201             :       denom_fe_fb * denom_fe_fd * denom_fd_fa * denom_fd_fb * (fe - fa);
     202             : 
     203             :   // Avoid division by zero with mask.
     204             :   const T fa_by_denom = fa / simd::select(incomplete_mask, denom_fb_fa * denom,
     205             :                                           static_cast<T>(1));
     206             : 
     207             :   const T d31 = (a - b);
     208             :   const T q21 = (b - d);
     209             :   const T q32 = simd::fms(denom_fd_fb, d31, denom_fb_fa * q21) * denom_fe_fb *
     210             :                 denom_fe_fd;
     211             :   const T q22 = simd::fms(denom_fe_fd, q21, (d - e) * denom_fd_fb) * fd *
     212             :                 denom_fb_fa * denom_fd_fa;
     213             : 
     214             :   // Note: the reduction in rounding error that comes from the improvement by
     215             :   // Stoer & Bulirsch to Neville's algorithm is adding `a` at the very end as
     216             :   // we do below. Alternative ways of evaluating polynomials do not delay this
     217             :   // inclusion of `a`, and so then when the correction to `a` is small,
     218             :   // floating point errors decrease the accuracy of the result.
     219             :   T c = simd::fma(
     220             :       fa_by_denom,
     221             :       simd::fma(fb, simd::fms(q32, (fe + denom_fd_fa), q22), d31 * denom), a);
     222             : 
     223             :   if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
     224             :       simd::any(mask)) {
     225             :     // Out of bounds step, fall back to quadratic interpolation:
     226             :     //
     227             :     // Note: we only apply quadratic interpolation at points where cubic
     228             :     // failed and that aren't already at a root.
     229             :     c = simd::select(
     230             :         mask, quadratic_interpolate<AssumeFinite>(a, b, d, fa, fb, fd, mask, 3),
     231             :         c);
     232             :   }
     233             : 
     234             :   return c;
     235             : }
     236             : 
     237             : template <bool AssumeFinite, typename F, typename T>
     238             : void bracket(F f, T& a, T& b, T c, T& fa, T& fb, T& d, T& fd,
     239             :              const simd::mask_type_t<T>& incomplete_mask) {
     240             :   // Given a point c inside the existing enclosing interval
     241             :   // [a, b] sets a = c if f(c) == 0, otherwise finds the new
     242             :   // enclosing interval: either [a, c] or [c, b] and sets
     243             :   // d and fd to the point that has just been removed from
     244             :   // the interval.  In other words d is the third best guess
     245             :   // to the root.
     246             :   //
     247             :   // Note: `bracket` will only modify slots marked as `true` in
     248             :   //       `incomplete_mask`
     249             :   const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(2);
     250             : 
     251             :   // If the interval [a,b] is very small, or if c is too close
     252             :   // to one end of the interval then we need to adjust the
     253             :   // location of c accordingly. This is:
     254             :   //
     255             :   //   if ((b - a) < 2 * tol * a) {
     256             :   //     c = a + (b - a) / 2;
     257             :   //   } else if (c <= a + fabs(a) * tol) {
     258             :   //     c = a + fabs(a) * tol;
     259             :   //   } else if (c >= b - fabs(b) * tol) {
     260             :   //     c = b - fabs(b) * tol;
     261             :   //   }
     262             :   const T a_filt = simd::fma(fabs(a), tol_batch, a);
     263             :   const T b_filt = simd::fnma(fabs(b), tol_batch, b);
     264             :   const T b_minus_a = b - a;
     265             :   c = simd::select(
     266             :       (static_cast<T>(2) * tol_batch * a > b_minus_a) and incomplete_mask,
     267             :       simd::fma(b_minus_a, static_cast<T>(0.5), a),
     268             :       simd::select(c <= a_filt, a_filt, simd::select(c >= b_filt, b_filt, c)));
     269             : 
     270             :   // Invoke f(c):
     271             :   T fc = f(c);
     272             : 
     273             :   // if we have a zero then we have an exact solution to the root:
     274             :   const auto fc_is_zero_mask = (fc == static_cast<T>(0));
     275             :   if (const auto mask = fc_is_zero_mask and incomplete_mask;
     276             :       UNLIKELY(simd::any(mask))) {
     277             :     a = simd::select(mask, c, a);
     278             :     fa = simd::select(mask, static_cast<T>(0), fa);
     279             :     d = simd::select(mask, static_cast<T>(0), d);
     280             :     fd = simd::select(mask, static_cast<T>(0), fd);
     281             :     if (UNLIKELY(simd::all(mask or not incomplete_mask))) {
     282             :       return;
     283             :     }
     284             :   }
     285             : 
     286             :   // Non-zero fc, update the interval:
     287             :   //
     288             :   // Note: unlike Boost, we assume fa*fc doesn't overflow. This speeds up the
     289             :   // code quite a bit.
     290             :   //
     291             :   // Boost code is:
     292             :   // if (boost::math::sign(fa) * boost::math::sign(fc) < 0) {...} else {...}
     293             :   using simd::sign;
     294             :   const auto sign_mask = AssumeFinite
     295             :                              ? (fa * fc < static_cast<T>(0))
     296             :                              : (sign(fa) * sign(fc) < static_cast<T>(0));
     297             :   const auto mask_if =
     298             :       (sign_mask and (not fc_is_zero_mask)) and incomplete_mask;
     299             :   d = simd::select(mask_if, b, d);
     300             :   fd = simd::select(mask_if, fb, fd);
     301             :   b = simd::select(mask_if, c, b);
     302             :   fb = simd::select(mask_if, fc, fb);
     303             : 
     304             :   const auto mask_else =
     305             :       ((not sign_mask) and (not fc_is_zero_mask)) and incomplete_mask;
     306             :   d = simd::select(mask_else, a, d);
     307             :   fd = simd::select(mask_else, fa, fd);
     308             :   a = simd::select(mask_else, c, a);
     309             :   fa = simd::select(mask_else, fc, fa);
     310             : }
     311             : 
     312             : template <bool AssumeFinite, class F, class T, class Tol>
     313             : std::pair<T, T> toms748_solve(F f, const T& ax, const T& bx, const T& fax,
     314             :                               const T& fbx, Tol tol,
     315             :                               const simd::mask_type_t<T>& ignore_filter,
     316             :                               size_t& max_iter) {
     317             :   // Main entry point and logic for Toms Algorithm 748
     318             :   // root finder.
     319             :   if (UNLIKELY(simd::any(ax >= bx))) {
     320             :     throw std::domain_error("Lower bound is larger than upper bound");
     321             :   }
     322             : 
     323             :   // Sanity check - are we allowed to iterate at all?
     324             :   if (UNLIKELY(max_iter == 0)) {
     325             :     return std::pair{ax, bx};
     326             :   }
     327             : 
     328             :   size_t count = max_iter;
     329             :   // mu is a parameter in the algorithm that must be between (0, 1).
     330             :   static const T mu = 0.5f;
     331             : 
     332             :   // initialise a, b and fa, fb:
     333             :   T a = ax;
     334             :   T b = bx;
     335             :   T fa = fax;
     336             :   T fb = fbx;
     337             : 
     338             :   const auto fa_is_zero_mask = (fa == static_cast<T>(0));
     339             :   const auto fb_is_zero_mask = (fb == static_cast<T>(0));
     340             :   auto completion_mask =
     341             :       tol(a, b) or fa_is_zero_mask or fb_is_zero_mask or ignore_filter;
     342             :   auto incomplete_mask = not completion_mask;
     343             :   if (UNLIKELY(simd::all(completion_mask))) {
     344             :     max_iter = 0;
     345             :     return std::pair{simd::select(fb_is_zero_mask, b, a),
     346             :                      simd::select(fa_is_zero_mask, a, b)};
     347             :   }
     348             : 
     349             :   // Note: unlike Boost, we can assume fa*fb doesn't overflow when possible.
     350             :   // This speeds up the code quite a bit.
     351             :   if (UNLIKELY(simd::any((AssumeFinite ? (fa * fb > static_cast<T>(0))
     352             :                                        : (simd::sign(fa) * simd::sign(fb) >
     353             :                                           static_cast<T>(0))) and
     354             :                          (not fa_is_zero_mask) and (not fb_is_zero_mask)))) {
     355             :     throw std::domain_error(
     356             :         "Parameters lower and upper bounds do not bracket a root");
     357             :   }
     358             :   // dummy value for fd, e and fe:
     359             :   T fe(static_cast<T>(1e5F));
     360             :   T e(static_cast<T>(1e5F));
     361             :   T fd(static_cast<T>(1e5F));
     362             : 
     363             :   T c(0.0);
     364             :   T d(0.0);
     365             : 
     366             :   // We can't use signaling_NaN() because we do a max comparison at the end to
     367             :   // 0, and signaling_NaN when compared to 0 raises an FPE.
     368             :   const T nan(std::numeric_limits<T>::quiet_NaN());
     369             :   auto completed_a = simd::select(completion_mask, a, nan);
     370             :   auto completed_b = simd::select(completion_mask, b, nan);
     371             :   auto completed_fa = simd::select(completion_mask, fa, nan);
     372             :   auto completed_fb = simd::select(completion_mask, fb, nan);
     373             :   const auto update_completed = [&fa, &fb, &completion_mask, &incomplete_mask,
     374             :                                  &completed_a, &completed_b, &a, &b,
     375             :                                  &completed_fa, &completed_fb, &tol]() {
     376             :     const auto new_completed =
     377             :         (fa == static_cast<T>(0) or tol(a, b)) and (not completion_mask);
     378             :     completed_a = simd::select(new_completed, a, completed_a);
     379             :     completed_b = simd::select(new_completed, b, completed_b);
     380             :     completed_fa = simd::select(new_completed, fa, completed_fa);
     381             :     completed_fb = simd::select(new_completed, fb, completed_fb);
     382             :     completion_mask = new_completed or completion_mask;
     383             :     incomplete_mask = not completion_mask;
     384             :     // returns true if _all_ simd registers have been completed
     385             :     return simd::all(completion_mask);
     386             :   };
     387             : 
     388             :   if (simd::any(fa != static_cast<T>(0))) {
     389             :     // On the first step we take a secant step:
     390             :     c = toms748_detail::secant_interpolate(a, b, fa, fb);
     391             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     392             :                                           incomplete_mask);
     393             :     --count;
     394             : 
     395             :     // Note: The Boost fa!=0 check is handled with the completion_mask.
     396             :     if (count and not update_completed()) {
     397             :       // On the second step we take a quadratic interpolation:
     398             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     399             :           a, b, d, fa, fb, fd, incomplete_mask, 2);
     400             :       e = d;
     401             :       fe = fd;
     402             :       toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     403             :                                             incomplete_mask);
     404             :       --count;
     405             :       update_completed();
     406             :     }
     407             :   }
     408             : 
     409             :   T u(std::numeric_limits<T>::signaling_NaN());
     410             :   T fu(std::numeric_limits<T>::signaling_NaN());
     411             :   T a0(std::numeric_limits<T>::signaling_NaN());
     412             :   T b0(std::numeric_limits<T>::signaling_NaN());
     413             : 
     414             :   // Note: The Boost fa!=0 check is handled with the completion_mask.
     415             :   while (count and not simd::all(completion_mask)) {
     416             :     // save our brackets:
     417             :     a0 = a;
     418             :     b0 = b;
     419             :     // Starting with the third step taken
     420             :     // we can use either quadratic or cubic interpolation.
     421             :     // Cubic interpolation requires that all four function values
     422             :     // fa, fb, fd, and fe are distinct, should that not be the case
     423             :     // then variable prof will get set to true, and we'll end up
     424             :     // taking a quadratic step instead.
     425             :     static const T min_diff = std::numeric_limits<T>::min() * 32;
     426             :     bool prof =
     427             :         simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
     428             :                    (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
     429             :                    (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
     430             :                   incomplete_mask);
     431             :     if (prof) {
     432             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     433             :           a, b, d, fa, fb, fd, incomplete_mask, 2);
     434             :     } else {
     435             :       c = toms748_detail::cubic_interpolate<AssumeFinite>(
     436             :           a, b, d, e, fa, fb, fd, fe, incomplete_mask);
     437             :     }
     438             :     // re-bracket, and check for termination:
     439             :     e = d;
     440             :     fe = fd;
     441             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     442             :                                           incomplete_mask);
     443             :     if ((0 == --count) or update_completed()) {
     444             :       break;
     445             :     }
     446             :     // Now another interpolated step:
     447             :     prof =
     448             :         simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
     449             :                    (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
     450             :                    (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
     451             :                   incomplete_mask);
     452             :     if (prof) {
     453             :       c = toms748_detail::quadratic_interpolate<AssumeFinite>(
     454             :           a, b, d, fa, fb, fd, incomplete_mask, 3);
     455             :     } else {
     456             :       c = toms748_detail::cubic_interpolate<AssumeFinite>(
     457             :           a, b, d, e, fa, fb, fd, fe, incomplete_mask);
     458             :     }
     459             :     // Bracket again, and check termination condition, update e:
     460             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     461             :                                           incomplete_mask);
     462             :     if ((0 == --count) or update_completed()) {
     463             :       break;
     464             :     }
     465             : 
     466             :     // Now we take a double-length secant step:
     467             :     const auto fabs_fa_less_fabs_fb_mask =
     468             :         (fabs(fa) < fabs(fb)) and incomplete_mask;
     469             :     u = simd::select(fabs_fa_less_fabs_fb_mask, a, b);
     470             :     fu = simd::select(fabs_fa_less_fabs_fb_mask, fa, fb);
     471             :     const T b_minus_a = b - a;
     472             :     // Assumes that bounds a & b are not so close that fa == fb. Boost makes
     473             :     // this assumption too. If this is violated then the algorithm doesn't
     474             :     // work since the function at least appears constant.
     475             :     c = simd::fnma(static_cast<T>(2) * (fu / (fb - fa)), b_minus_a, u);
     476             :     c = simd::select(static_cast<T>(2) * fabs(c - u) > b_minus_a,
     477             :                      simd::fma(static_cast<T>(0.5), b_minus_a, a), c);
     478             : 
     479             :     // Bracket again, and check termination condition:
     480             :     e = d;
     481             :     fe = fd;
     482             :     toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
     483             :                                           incomplete_mask);
     484             :     if ((0 == --count) or update_completed()) {
     485             :       break;
     486             :     }
     487             : 
     488             :     // And finally... check to see if an additional bisection step is
     489             :     // to be taken, we do this if we're not converging fast enough:
     490             :     const auto bisection_mask = (b - a) >= mu * (b0 - a0) and incomplete_mask;
     491             :     if (LIKELY(simd::none(bisection_mask))) {
     492             :       continue;
     493             :     }
     494             :     // bracket again on a bisection:
     495             :     //
     496             :     // Note: the mask ensures we only ever modify the slots that the mask has
     497             :     // identified as needing to be modified.
     498             :     e = simd::select(bisection_mask, d, e);
     499             :     fe = simd::select(bisection_mask, fd, fe);
     500             :     toms748_detail::bracket<AssumeFinite>(
     501             :         f, a, b, simd::fma((b - a), static_cast<T>(0.5), a), fa, fb, d, fd,
     502             :         bisection_mask);
     503             :     --count;
     504             :     if (update_completed()) {
     505             :       break;
     506             :     }
     507             :   }  // while loop
     508             : 
     509             :   max_iter -= count;
     510             :   completed_b =
     511             :       simd::select(completed_fa == static_cast<T>(0), completed_a, completed_b);
     512             :   completed_a =
     513             :       simd::select(completed_fb == static_cast<T>(0), completed_b, completed_a);
     514             :   return std::pair{completed_a, completed_b};
     515             : }
     516             : }  // namespace toms748_detail
     517             : 
     518             : /*!
     519             :  * \ingroup NumericalAlgorithmsGroup
     520             :  * \brief Finds the root of the function `f` with the TOMS_748 method.
     521             :  *
     522             :  * `f` is a unary invokable that takes a `double` which is the current value at
     523             :  * which to evaluate `f`. An example is below.
     524             :  *
     525             :  * \snippet Test_TOMS748.cpp double_root_find
     526             :  *
     527             :  * The TOMS_748 algorithm searches for a root in the interval [`lower_bound`,
     528             :  * `upper_bound`], and will throw if this interval does not bracket a root,
     529             :  * i.e. if `f(lower_bound) * f(upper_bound) > 0`.
     530             :  *
     531             :  * The arguments `f_at_lower_bound` and `f_at_upper_bound` are optional, and
     532             :  * are the function values at `lower_bound` and `upper_bound`. These function
     533             :  * values are often known because the user typically checks if a root is
     534             :  * bracketed before calling `toms748`; passing the function values here saves
     535             :  * two function evaluations.
     536             :  *
     537             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     538             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     539             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     540             :  * that products like `fa * fb` are also finite.
     541             :  *
     542             :  * \requires Function `f` is invokable with a `double`
     543             :  *
     544             :  * \throws `convergence_error` if the requested tolerance is not met after
     545             :  *                            `max_iterations` iterations.
     546             :  */
     547             : template <bool AssumeFinite = false, typename Function, typename T>
     548           1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
     549             :           const T f_at_lower_bound, const T f_at_upper_bound,
     550             :           const simd::scalar_type_t<T> absolute_tolerance,
     551             :           const simd::scalar_type_t<T> relative_tolerance,
     552             :           const size_t max_iterations = 100,
     553             :           const simd::mask_type_t<T> ignore_filter =
     554             :               static_cast<simd::mask_type_t<T>>(0)) {
     555             :   ASSERT(relative_tolerance >
     556             :              std::numeric_limits<simd::scalar_type_t<T>>::epsilon(),
     557             :          "The relative tolerance is too small. Got "
     558             :              << relative_tolerance << " but must be at least "
     559             :              << std::numeric_limits<simd::scalar_type_t<T>>::epsilon());
     560             :   if (simd::any(f_at_lower_bound * f_at_upper_bound > 0.0)) {
     561             :     ERROR("Root not bracketed: f(" << lower_bound << ") = " << f_at_lower_bound
     562             :                                    << ", f(" << upper_bound
     563             :                                    << ") = " << f_at_upper_bound);
     564             :   }
     565             : 
     566             :   std::size_t max_iters = max_iterations;
     567             : 
     568             :   // This solver requires tol to be passed as a termination condition. This
     569             :   // termination condition is equivalent to the convergence criteria used by the
     570             :   // GSL
     571             :   const auto tol = [absolute_tolerance, relative_tolerance](const T& lhs,
     572             :                                                             const T& rhs) {
     573             :     return simd::abs(lhs - rhs) <=
     574             :            simd::fma(T(relative_tolerance),
     575             :                      simd::min(simd::abs(lhs), simd::abs(rhs)),
     576             :                      T(absolute_tolerance));
     577             :   };
     578             :   auto result = toms748_detail::toms748_solve<AssumeFinite>(
     579             :       f, lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, tol,
     580             :       ignore_filter, max_iters);
     581             :   if (max_iters >= max_iterations) {
     582             :     throw convergence_error(
     583             :         MakeString{}
     584             :         << std::setprecision(8) << std::scientific
     585             :         << "toms748 reached max iterations without converging.\nAbsolute "
     586             :            "tolerance: "
     587             :         << absolute_tolerance << "\nRelative tolerance: " << relative_tolerance
     588             :         << "\nResult: " << get_output(result.first) << " "
     589             :         << get_output(result.second));
     590             :   }
     591             :   return simd::fma(static_cast<T>(0.5), (result.second - result.first),
     592             :                    result.first);
     593             : }
     594             : 
     595             : /*!
     596             :  * \ingroup NumericalAlgorithmsGroup
     597             :  * \brief Finds the root of the function `f` with the TOMS_748 method, where
     598             :  * function values are not supplied at the lower and upper bounds.
     599             :  *
     600             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     601             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     602             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     603             :  * that products like `fa * fb` are also finite.
     604             :  */
     605             : template <bool AssumeFinite = false, typename Function, typename T>
     606           1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
     607             :           const simd::scalar_type_t<T> absolute_tolerance,
     608             :           const simd::scalar_type_t<T> relative_tolerance,
     609             :           const size_t max_iterations = 100,
     610             :           const simd::mask_type_t<T> ignore_filter =
     611             :               static_cast<simd::mask_type_t<T>>(0)) {
     612             :   return toms748<AssumeFinite>(
     613             :       f, lower_bound, upper_bound, f(lower_bound), f(upper_bound),
     614             :       absolute_tolerance, relative_tolerance, max_iterations, ignore_filter);
     615             : }
     616             : 
     617             : /*!
     618             :  * \ingroup NumericalAlgorithmsGroup
     619             :  * \brief Finds the root of the function `f` with the TOMS_748 method on each
     620             :  * element in a `DataVector`.
     621             :  *
     622             :  * `f` is a binary invokable that takes a `double` as its first argument and a
     623             :  * `size_t` as its second. The `double` is the current value at which to
     624             :  * evaluate `f`, and the `size_t` is the current index into the `DataVector`s.
     625             :  * Below is an example of how to root find different functions by indexing into
     626             :  * a lambda-captured `DataVector` using the `size_t` passed to `f`.
     627             :  *
     628             :  * \snippet Test_TOMS748.cpp datavector_root_find
     629             :  *
     630             :  * For each index `i` into the DataVector, the TOMS_748 algorithm searches for a
     631             :  * root in the interval [`lower_bound[i]`, `upper_bound[i]`], and will throw if
     632             :  * this interval does not bracket a root,
     633             :  * i.e. if `f(lower_bound[i], i) * f(upper_bound[i], i) > 0`.
     634             :  *
     635             :  * See the [Boost](http://www.boost.org/) documentation for more details.
     636             :  *
     637             :  * \requires Function `f` be callable with a `double` and a `size_t`
     638             :  *
     639             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     640             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     641             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     642             :  * that products like `fa * fb` are also finite.
     643             :  *
     644             :  * \throws `convergence_error` if, for any index, the requested tolerance is not
     645             :  * met after `max_iterations` iterations.
     646             :  */
     647             : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
     648           1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
     649             :                    const DataVector& upper_bound,
     650             :                    const double absolute_tolerance,
     651             :                    const double relative_tolerance,
     652             :                    const size_t max_iterations = 100) {
     653             :   DataVector result_vector{lower_bound.size()};
     654             :   if constexpr (UseSimd) {
     655             :     constexpr size_t simd_width{simd::size<
     656             :         std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
     657             :     const size_t vectorized_size =
     658             :         lower_bound.size() - lower_bound.size() % simd_width;
     659             :     for (size_t i = 0; i < vectorized_size; i += simd_width) {
     660             :       simd::store_unaligned(
     661             :           &result_vector[i],
     662             :           toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
     663             :                                 simd::load_unaligned(&lower_bound[i]),
     664             :                                 simd::load_unaligned(&upper_bound[i]),
     665             :                                 absolute_tolerance, relative_tolerance,
     666             :                                 max_iterations));
     667             :     }
     668             :     for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
     669             :       result_vector[i] = toms748<AssumeFinite>(
     670             :           [&f, i](const auto x) { return f(x, i); }, lower_bound[i],
     671             :           upper_bound[i], absolute_tolerance, relative_tolerance,
     672             :           max_iterations);
     673             :     }
     674             :   } else {
     675             :     for (size_t i = 0; i < result_vector.size(); ++i) {
     676             :       result_vector[i] = toms748<AssumeFinite>(
     677             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     678             :           absolute_tolerance, relative_tolerance, max_iterations);
     679             :     }
     680             :   }
     681             :   return result_vector;
     682             : }
     683             : 
     684             : /*!
     685             :  * \ingroup NumericalAlgorithmsGroup
     686             :  * \brief Finds the root of the function `f` with the TOMS_748 method on each
     687             :  * element in a `DataVector`, where function values are supplied at the lower
     688             :  * and upper bounds.
     689             :  *
     690             :  * Supplying function values is an optimization that saves two
     691             :  * function calls per point.  The function values are often available
     692             :  * because one often checks if the root is bracketed before calling `toms748`.
     693             :  *
     694             :  * \note if `AssumeFinite` is true than the code assumes all numbers are
     695             :  * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
     696             :  * runtime but will cause bugs if the numbers aren't finite. It also assumes
     697             :  * that products like `fa * fb` are also finite.
     698             :  */
     699             : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
     700           1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
     701             :                    const DataVector& upper_bound,
     702             :                    const DataVector& f_at_lower_bound,
     703             :                    const DataVector& f_at_upper_bound,
     704             :                    const double absolute_tolerance,
     705             :                    const double relative_tolerance,
     706             :                    const size_t max_iterations = 100) {
     707             :   DataVector result_vector{lower_bound.size()};
     708             :   if constexpr (UseSimd) {
     709             :     constexpr size_t simd_width{simd::size<
     710             :         std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
     711             :     const size_t vectorized_size =
     712             :         lower_bound.size() - lower_bound.size() % simd_width;
     713             :     for (size_t i = 0; i < vectorized_size; i += simd_width) {
     714             :       simd::store_unaligned(
     715             :           &result_vector[i],
     716             :           toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
     717             :                                 simd::load_unaligned(&lower_bound[i]),
     718             :                                 simd::load_unaligned(&upper_bound[i]),
     719             :                                 simd::load_unaligned(&f_at_lower_bound[i]),
     720             :                                 simd::load_unaligned(&f_at_upper_bound[i]),
     721             :                                 absolute_tolerance, relative_tolerance,
     722             :                                 max_iterations));
     723             :     }
     724             :     for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
     725             :       result_vector[i] = toms748<AssumeFinite>(
     726             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     727             :           f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
     728             :           relative_tolerance, max_iterations);
     729             :     }
     730             :   } else {
     731             :     for (size_t i = 0; i < lower_bound.size(); ++i) {
     732             :       result_vector[i] = toms748<AssumeFinite>(
     733             :           [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
     734             :           f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
     735             :           relative_tolerance, max_iterations);
     736             :     }
     737             :   }
     738             :   return result_vector;
     739             : }
     740             : }  // namespace RootFinder

Generated by: LCOV version 1.14