Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Declares function RootFinder::toms748
6 :
7 : #pragma once
8 :
9 : #include <functional>
10 : #include <iomanip>
11 : #include <ios>
12 : #include <limits>
13 : #include <string>
14 : #include <type_traits>
15 :
16 : #include "DataStructures/DataVector.hpp"
17 : #include "Utilities/ErrorHandling/CaptureForError.hpp"
18 : #include "Utilities/ErrorHandling/Error.hpp"
19 : #include "Utilities/ErrorHandling/Exceptions.hpp"
20 : #include "Utilities/GetOutput.hpp"
21 : #include "Utilities/MakeString.hpp"
22 : #include "Utilities/Simd/Simd.hpp"
23 :
24 : namespace RootFinder {
25 : namespace toms748_detail {
26 : // Original implementation of TOMS748 is from Boost:
27 : // (C) Copyright John Maddock 2006.
28 : // Use, modification and distribution are subject to the
29 : // Boost Software License, Version 1.0. (See accompanying file
30 : // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
31 : //
32 : // Significant changes made to pretty much all of it to support SIMD by Nils
33 : // Deppe. Changes are copyrighted by SXS Collaboration under the MIT License.
34 :
35 : template <typename T>
36 : T safe_div(const T& num, const T& denom, const T& r) {
37 : if constexpr (std::is_floating_point_v<T>) {
38 : using std::abs;
39 : if (abs(denom) < static_cast<T>(1)) {
40 : if (abs(denom * std::numeric_limits<T>::max()) <= abs(num)) {
41 : return r;
42 : }
43 : }
44 : return num / denom;
45 : } else {
46 : // return num / denom without overflow, return r if overflow would occur.
47 : //
48 : // To do this we need to handle the following cases:
49 : // 1. fabs(denom) < 1 AND fabs(denom * max) <= fabs(num):
50 : // return r
51 : // 2. return num / denom
52 : //
53 : // We do this by creating 2 masks.
54 : // 1. `mask0` selects where fabs(denom) < 1. This is where we _may_ have
55 : // issues with division by zero or overflow.
56 : // 2. `mask` selects where fabs(denom) < 1 AND fabs(denom * max)<=fabs(num)
57 : // * Note: the edge case of fabs(num)==max could be problematic, but if
58 : // you're dealing with numbers like that you are likely in
59 : // trouble anyway.
60 : // The second mask is where we must avoid division by zero and instead
61 : // return `r`.
62 : const auto mask0 = fabs(denom) < static_cast<T>(1);
63 : // Note: if denom >= 1.0 you get an FPE because of overflow from
64 : // `max() * (1. + value)`, which is why `mask0` is necessary.
65 : const auto mask = fabs(simd::select(mask0, denom, static_cast<T>(1.0)) *
66 : std::numeric_limits<T>::max()) <= fabs(num);
67 : const auto new_denom = simd::select(mask, static_cast<T>(1), denom);
68 : return simd::select(mask, r, num / new_denom);
69 : }
70 : }
71 :
72 : template <typename T>
73 : T secant_interpolate(const T& a, const T& b, const T& fa, const T& fb,
74 : const simd::mask_type_t<T>& incomplete_mask) {
75 : //
76 : // Performs standard secant interpolation of [a,b] given
77 : // function evaluations f(a) and f(b). Performs a bisection
78 : // if secant interpolation would leave us very close to either
79 : // a or b. Rationale: we only call this function when at least
80 : // one other form of interpolation has already failed, so we know
81 : // that the function is unlikely to be smooth with a root very
82 : // close to a or b.
83 : //
84 : // For lanes where `incomplete_mask` is false (already converged), the
85 : // denominator `fb - fa` may be stale or zero. We substitute 1 to avoid
86 : // undefined behavior (e.g. division by zero); the returned value for those
87 : // lanes is meaningless and should not be used by the caller. See
88 : // `cubic_interpolate` for the same pattern.
89 : //
90 : const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(5);
91 : // WARNING: There are several different ways to implement the interpolation
92 : // that all have different rounding properties. Unfortunately this means that
93 : // tolerances even at 1e-14 can be difficult to achieve generically.
94 : //
95 : // `g` below is:
96 : // const T g = (fa / (fb - fa));
97 : //
98 : // const T c = simd::fma((fa / (fb - fa)), (a - b), a); // fails
99 : //
100 : // const T c = a * (static_cast<T>(1) - g) + g * b; // works
101 : //
102 : // const T c = simd::fma(a, (static_cast<T>(1) - g), g * b); // works
103 : //
104 : // Original Boost code:
105 : // const T c = a - (fa / (fb - fa)) * (b - a); // works
106 :
107 : const T c =
108 : a - (fa / simd::select(incomplete_mask, fb - fa, static_cast<T>(1))) *
109 : (b - a);
110 : return simd::select((c <= simd::fma(fabs(a), tol_batch, a)) or
111 : (c >= simd::fnma(fabs(b), tol_batch, b)),
112 : static_cast<T>(0.5) * (a + b), c);
113 : }
114 :
115 : template <bool AssumeFinite, typename T>
116 : T quadratic_interpolate(const T& a, const T& b, const T& d, const T& fa,
117 : const T& fb, const T& fd,
118 : const simd::mask_type_t<T>& incomplete_mask,
119 : const unsigned count) {
120 : // Performs quadratic interpolation to determine the next point,
121 : // takes count Newton steps to find the location of the
122 : // quadratic polynomial.
123 : //
124 : // Point d must lie outside of the interval [a,b], it is the third
125 : // best approximation to the root, after a and b.
126 : //
127 : // Note: this does not guarantee to find a root
128 : // inside [a, b], so we fall back to a secant step should
129 : // the result be out of range.
130 : //
131 : // Start by obtaining the coefficients of the quadratic polynomial:
132 : const T B = safe_div(fb - fa, b - a, std::numeric_limits<T>::max());
133 : T A = safe_div(fd - fb, d - b, std::numeric_limits<T>::max());
134 : A = safe_div(A - B, d - a, static_cast<T>(0));
135 :
136 : const auto secant_failure_mask = A == static_cast<T>(0) and incomplete_mask;
137 : T result_secant{};
138 : if (UNLIKELY(simd::any(secant_failure_mask))) {
139 : // failure to determine coefficients, try a secant step:
140 : result_secant = secant_interpolate(a, b, fa, fb, secant_failure_mask);
141 : if (UNLIKELY(simd::all(secant_failure_mask or (not incomplete_mask)))) {
142 : return result_secant;
143 : }
144 : }
145 :
146 : // Determine the starting point of the Newton steps:
147 : //
148 : // Note: unlike Boost, we assume A*fa doesn't overflow. This speeds up the
149 : // code quite a bit.
150 : T c = AssumeFinite
151 : ? simd::select(A * fa > static_cast<T>(0), a, b)
152 : : simd::select(simd::sign(A) * simd::sign(fa) > static_cast<T>(0),
153 : a, b);
154 :
155 : // Take the Newton steps:
156 : const T two_A = static_cast<T>(2) * A;
157 : const T half_a_plus_b = 0.5 * (a + b);
158 : const T one_minus_a = static_cast<T>(1) - a;
159 : const T B_minus_A_times_b = B - A * b;
160 : for (unsigned i = 1; i <= count; ++i) {
161 : c -= safe_div(simd::fma(simd::fma(A, c, B_minus_A_times_b), c - a, fa),
162 : simd::fma(two_A, c - half_a_plus_b, B), one_minus_a + c);
163 : }
164 : if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
165 : simd::any(mask)) {
166 : // Failure, try a secant step:
167 : c = simd::select(mask, secant_interpolate(a, b, fa, fb, mask), c);
168 : }
169 : return simd::select(secant_failure_mask, result_secant, c);
170 : }
171 :
172 : template <bool AssumeFinite, typename T>
173 : T cubic_interpolate(const T& a, const T& b, const T& d, const T& e, const T& fa,
174 : const T& fb, const T& fd, const T& fe,
175 : const simd::mask_type_t<T>& incomplete_mask) {
176 : // Uses inverse cubic interpolation of f(x) at points
177 : // [a,b,d,e] to obtain an approximate root of f(x).
178 : // Points d and e lie outside the interval [a,b]
179 : // and are the third and forth best approximations
180 : // to the root that we have found so far.
181 : //
182 : // Note: this does not guarantee to find a root
183 : // inside [a, b], so we fall back to quadratic
184 : // interpolation in case of an erroneous result.
185 : //
186 : // This commented chunk is the original Boost implementation translated into
187 : // simd. The actual code below is a heavily optimized version.
188 : //
189 : // const T q11 = (d - e) * fd / (fe - fd);
190 : // const T q21 = (b - d) * fb / (fd - fb);
191 : // const T q31 = (a - b) * fa / (fb - fa);
192 : // const T d21 = (b - d) * fd / (fd - fb);
193 : // const T d31 = (a - b) * fb / (fb - fa);
194 : //
195 : // const T q22 = (d21 - q11) * fb / (fe - fb);
196 : // const T q32 = (d31 - q21) * fa / (fd - fa);
197 : // const T d32 = (d31 - q21) * fd / (fd - fa);
198 : // const T q33 = (d32 - q22) * fa / (fe - fa);
199 : //
200 : // T c = q31 + q32 + q33 + a;
201 :
202 : // The optimized implementation here is L1-cache bound. That is, we aren't
203 : // able to completely saturate the FP units because we are waiting on the L1
204 : // cache. While not ideal, that's okay and just part of the algorithm.
205 : const T denom_fb_fa = fb - fa;
206 : const T denom_fd_fb = fd - fb;
207 : const T denom_fd_fa = fd - fa;
208 : const T denom_fe_fd = fe - fd;
209 : const T denom_fe_fb = fe - fb;
210 : const T denom =
211 : denom_fe_fb * denom_fe_fd * denom_fd_fa * denom_fd_fb * (fe - fa);
212 :
213 : // Avoid division by zero with mask.
214 : const T fa_by_denom = fa / simd::select(incomplete_mask, denom_fb_fa * denom,
215 : static_cast<T>(1));
216 :
217 : const T d31 = (a - b);
218 : const T q21 = (b - d);
219 : const T q32 = simd::fms(denom_fd_fb, d31, denom_fb_fa * q21) * denom_fe_fb *
220 : denom_fe_fd;
221 : const T q22 = simd::fms(denom_fe_fd, q21, (d - e) * denom_fd_fb) * fd *
222 : denom_fb_fa * denom_fd_fa;
223 :
224 : // Note: the reduction in rounding error that comes from the improvement by
225 : // Stoer & Bulirsch to Neville's algorithm is adding `a` at the very end as
226 : // we do below. Alternative ways of evaluating polynomials do not delay this
227 : // inclusion of `a`, and so then when the correction to `a` is small,
228 : // floating point errors decrease the accuracy of the result.
229 : T c = simd::fma(
230 : fa_by_denom,
231 : simd::fma(fb, simd::fms(q32, (fe + denom_fd_fa), q22), d31 * denom), a);
232 :
233 : if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
234 : simd::any(mask)) {
235 : // Out of bounds step, fall back to quadratic interpolation:
236 : //
237 : // Note: we only apply quadratic interpolation at points where cubic
238 : // failed and that aren't already at a root.
239 : c = simd::select(
240 : mask, quadratic_interpolate<AssumeFinite>(a, b, d, fa, fb, fd, mask, 3),
241 : c);
242 : }
243 :
244 : return c;
245 : }
246 :
247 : template <bool AssumeFinite, typename F, typename T>
248 : void bracket(F f, T& a, T& b, T c, T& fa, T& fb, T& d, T& fd,
249 : const simd::mask_type_t<T>& incomplete_mask) {
250 : // Given a point c inside the existing enclosing interval
251 : // [a, b] sets a = c if f(c) == 0, otherwise finds the new
252 : // enclosing interval: either [a, c] or [c, b] and sets
253 : // d and fd to the point that has just been removed from
254 : // the interval. In other words d is the third best guess
255 : // to the root.
256 : //
257 : // Note: `bracket` will only modify slots marked as `true` in
258 : // `incomplete_mask`
259 : const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(2);
260 :
261 : // If the interval [a,b] is very small, or if c is too close
262 : // to one end of the interval then we need to adjust the
263 : // location of c accordingly. This is:
264 : //
265 : // if ((b - a) < 2 * tol * a) {
266 : // c = a + (b - a) / 2;
267 : // } else if (c <= a + fabs(a) * tol) {
268 : // c = a + fabs(a) * tol;
269 : // } else if (c >= b - fabs(b) * tol) {
270 : // c = b - fabs(b) * tol;
271 : // }
272 : const T a_filt = simd::fma(fabs(a), tol_batch, a);
273 : const T b_filt = simd::fnma(fabs(b), tol_batch, b);
274 : const T b_minus_a = b - a;
275 : c = simd::select(
276 : (static_cast<T>(2) * tol_batch * a > b_minus_a) and incomplete_mask,
277 : simd::fma(b_minus_a, static_cast<T>(0.5), a),
278 : simd::select(c <= a_filt, a_filt, simd::select(c >= b_filt, b_filt, c)));
279 :
280 : // Invoke f(c):
281 : T fc = f(c);
282 :
283 : // if we have a zero then we have an exact solution to the root:
284 : const auto fc_is_zero_mask = (fc == static_cast<T>(0));
285 : if (const auto mask = fc_is_zero_mask and incomplete_mask;
286 : UNLIKELY(simd::any(mask))) {
287 : a = simd::select(mask, c, a);
288 : fa = simd::select(mask, static_cast<T>(0), fa);
289 : d = simd::select(mask, static_cast<T>(0), d);
290 : fd = simd::select(mask, static_cast<T>(0), fd);
291 : if (UNLIKELY(simd::all(mask or not incomplete_mask))) {
292 : return;
293 : }
294 : }
295 :
296 : // Non-zero fc, update the interval:
297 : //
298 : // Note: unlike Boost, we assume fa*fc doesn't overflow. This speeds up the
299 : // code quite a bit.
300 : //
301 : // Boost code is:
302 : // if (boost::math::sign(fa) * boost::math::sign(fc) < 0) {...} else {...}
303 : using simd::sign;
304 : const auto sign_mask = AssumeFinite
305 : ? (fa * fc < static_cast<T>(0))
306 : : (sign(fa) * sign(fc) < static_cast<T>(0));
307 : const auto mask_if =
308 : (sign_mask and (not fc_is_zero_mask)) and incomplete_mask;
309 : d = simd::select(mask_if, b, d);
310 : fd = simd::select(mask_if, fb, fd);
311 : b = simd::select(mask_if, c, b);
312 : fb = simd::select(mask_if, fc, fb);
313 :
314 : const auto mask_else =
315 : ((not sign_mask) and (not fc_is_zero_mask)) and incomplete_mask;
316 : d = simd::select(mask_else, a, d);
317 : fd = simd::select(mask_else, fa, fd);
318 : a = simd::select(mask_else, c, a);
319 : fa = simd::select(mask_else, fc, fa);
320 : }
321 :
322 : template <bool AssumeFinite, class F, class T, class Tol>
323 : std::pair<T, T> toms748_solve(F f, const T& ax, const T& bx, const T& fax,
324 : const T& fbx, Tol tol,
325 : const simd::mask_type_t<T>& ignore_filter,
326 : size_t& max_iter) {
327 : // Main entry point and logic for Toms Algorithm 748
328 : // root finder.
329 : if (UNLIKELY(simd::any(ax > bx))) {
330 : ERROR_AS("Lower bound is larger than upper bound", std::domain_error);
331 : }
332 :
333 : // Sanity check - are we allowed to iterate at all?
334 : if (UNLIKELY(max_iter == 0)) {
335 : return std::pair{ax, bx};
336 : }
337 :
338 : size_t count = max_iter;
339 : // mu is a parameter in the algorithm that must be between (0, 1).
340 : static const T mu = 0.5f;
341 :
342 : // initialise a, b and fa, fb:
343 : T a = ax;
344 : T b = bx;
345 : T fa = fax;
346 : T fb = fbx;
347 : CAPTURE_FOR_ERROR(ax);
348 : CAPTURE_FOR_ERROR(bx);
349 : CAPTURE_FOR_ERROR(fax);
350 : CAPTURE_FOR_ERROR(fbx);
351 :
352 : const auto fa_is_zero_mask = (fa == static_cast<T>(0));
353 : const auto fb_is_zero_mask = (fb == static_cast<T>(0));
354 : auto completion_mask =
355 : tol(a, b) or fa_is_zero_mask or fb_is_zero_mask or ignore_filter;
356 : auto incomplete_mask = not completion_mask;
357 : if (UNLIKELY(simd::all(completion_mask))) {
358 : max_iter = 0;
359 : return std::pair{simd::select(fb_is_zero_mask, b, a),
360 : simd::select(fa_is_zero_mask, a, b)};
361 : }
362 :
363 : // Note: unlike Boost, we can assume fa*fb doesn't overflow when possible.
364 : // This speeds up the code quite a bit.
365 : if (UNLIKELY(simd::any((AssumeFinite ? (fa * fb > static_cast<T>(0))
366 : : (simd::sign(fa) * simd::sign(fb) >
367 : static_cast<T>(0))) and
368 : (not fa_is_zero_mask) and (not fb_is_zero_mask)))) {
369 : ERROR_AS("Parameters lower and upper bounds do not bracket a root.",
370 : std::domain_error);
371 : }
372 : // dummy value for fd, e and fe:
373 : T fe(static_cast<T>(1e5F));
374 : T e(static_cast<T>(1e5F));
375 : T fd(static_cast<T>(1e5F));
376 :
377 : T c(0.0);
378 : T d(0.0);
379 :
380 : // We can't use signaling_NaN() because we do a max comparison at the end to
381 : // 0, and signaling_NaN when compared to 0 raises an FPE.
382 : const T nan(std::numeric_limits<T>::quiet_NaN());
383 : auto completed_a = simd::select(completion_mask, a, nan);
384 : auto completed_b = simd::select(completion_mask, b, nan);
385 : auto completed_fa = simd::select(completion_mask, fa, nan);
386 : auto completed_fb = simd::select(completion_mask, fb, nan);
387 : const auto update_completed = [&fa, &fb, &completion_mask, &incomplete_mask,
388 : &completed_a, &completed_b, &a, &b,
389 : &completed_fa, &completed_fb, &tol]() {
390 : const auto new_completed =
391 : (fa == static_cast<T>(0) or tol(a, b)) and (not completion_mask);
392 : completed_a = simd::select(new_completed, a, completed_a);
393 : completed_b = simd::select(new_completed, b, completed_b);
394 : completed_fa = simd::select(new_completed, fa, completed_fa);
395 : completed_fb = simd::select(new_completed, fb, completed_fb);
396 : completion_mask = new_completed or completion_mask;
397 : incomplete_mask = not completion_mask;
398 : // returns true if _all_ simd registers have been completed
399 : return simd::all(completion_mask);
400 : };
401 :
402 : if (simd::any(fa != static_cast<T>(0))) {
403 : // On the first step we take a secant step:
404 : c = toms748_detail::secant_interpolate(a, b, fa, fb, incomplete_mask);
405 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
406 : incomplete_mask);
407 : --count;
408 :
409 : // Note: The Boost fa!=0 check is handled with the completion_mask.
410 : if (count and not update_completed()) {
411 : // On the second step we take a quadratic interpolation:
412 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
413 : a, b, d, fa, fb, fd, incomplete_mask, 2);
414 : e = d;
415 : fe = fd;
416 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
417 : incomplete_mask);
418 : --count;
419 : update_completed();
420 : }
421 : }
422 :
423 : T u(std::numeric_limits<T>::signaling_NaN());
424 : T fu(std::numeric_limits<T>::signaling_NaN());
425 : T a0(std::numeric_limits<T>::signaling_NaN());
426 : T b0(std::numeric_limits<T>::signaling_NaN());
427 :
428 : // Note: The Boost fa!=0 check is handled with the completion_mask.
429 : while (count and not simd::all(completion_mask)) {
430 : // save our brackets:
431 : a0 = a;
432 : b0 = b;
433 : // Starting with the third step taken
434 : // we can use either quadratic or cubic interpolation.
435 : // Cubic interpolation requires that all four function values
436 : // fa, fb, fd, and fe are distinct, should that not be the case
437 : // then variable prof will get set to true, and we'll end up
438 : // taking a quadratic step instead.
439 : static const T min_diff = std::numeric_limits<T>::min() * 32;
440 : bool prof =
441 : simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
442 : (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
443 : (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
444 : incomplete_mask);
445 : if (prof) {
446 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
447 : a, b, d, fa, fb, fd, incomplete_mask, 2);
448 : } else {
449 : c = toms748_detail::cubic_interpolate<AssumeFinite>(
450 : a, b, d, e, fa, fb, fd, fe, incomplete_mask);
451 : }
452 : // re-bracket, and check for termination:
453 : e = d;
454 : fe = fd;
455 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
456 : incomplete_mask);
457 : if ((0 == --count) or update_completed()) {
458 : break;
459 : }
460 : // Now another interpolated step:
461 : prof =
462 : simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
463 : (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
464 : (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
465 : incomplete_mask);
466 : if (prof) {
467 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
468 : a, b, d, fa, fb, fd, incomplete_mask, 3);
469 : } else {
470 : c = toms748_detail::cubic_interpolate<AssumeFinite>(
471 : a, b, d, e, fa, fb, fd, fe, incomplete_mask);
472 : }
473 : // Bracket again, and check termination condition, update e:
474 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
475 : incomplete_mask);
476 : if ((0 == --count) or update_completed()) {
477 : break;
478 : }
479 :
480 : // Now we take a double-length secant step:
481 : const auto fabs_fa_less_fabs_fb_mask =
482 : (fabs(fa) < fabs(fb)) and incomplete_mask;
483 : u = simd::select(fabs_fa_less_fabs_fb_mask, a, b);
484 : fu = simd::select(fabs_fa_less_fabs_fb_mask, fa, fb);
485 : const T b_minus_a = b - a;
486 : // Assumes that bounds a & b are not so close that fa == fb. Boost makes
487 : // this assumption too. If this is violated then the algorithm doesn't
488 : // work since the function at least appears constant.
489 :
490 : // Safeguard complete lanes: `fb - fa` may be stale or zero for lanes
491 : // that have already converged. Use 1 as a harmless dummy denominator;
492 : // the result for complete lanes is discarded downstream.
493 : c = simd::fnma(
494 : static_cast<T>(2) *
495 : (fu / simd::select(incomplete_mask, fb - fa, static_cast<T>(1))),
496 : b_minus_a, u);
497 : c = simd::select(static_cast<T>(2) * fabs(c - u) > b_minus_a,
498 : simd::fma(static_cast<T>(0.5), b_minus_a, a), c);
499 :
500 : // Bracket again, and check termination condition:
501 : e = d;
502 : fe = fd;
503 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
504 : incomplete_mask);
505 : if ((0 == --count) or update_completed()) {
506 : break;
507 : }
508 :
509 : // And finally... check to see if an additional bisection step is
510 : // to be taken, we do this if we're not converging fast enough:
511 : const auto bisection_mask = (b - a) >= mu * (b0 - a0) and incomplete_mask;
512 : if (LIKELY(simd::none(bisection_mask))) {
513 : continue;
514 : }
515 : // bracket again on a bisection:
516 : //
517 : // Note: the mask ensures we only ever modify the slots that the mask has
518 : // identified as needing to be modified.
519 : e = simd::select(bisection_mask, d, e);
520 : fe = simd::select(bisection_mask, fd, fe);
521 : toms748_detail::bracket<AssumeFinite>(
522 : f, a, b, simd::fma((b - a), static_cast<T>(0.5), a), fa, fb, d, fd,
523 : bisection_mask);
524 : --count;
525 : if (update_completed()) {
526 : break;
527 : }
528 : } // while loop
529 :
530 : max_iter -= count;
531 : completed_b =
532 : simd::select(completed_fa == static_cast<T>(0), completed_a, completed_b);
533 : completed_a =
534 : simd::select(completed_fb == static_cast<T>(0), completed_b, completed_a);
535 : return std::pair{completed_a, completed_b};
536 : }
537 : } // namespace toms748_detail
538 :
539 : /*!
540 : * \ingroup NumericalAlgorithmsGroup
541 : * \brief Finds the root of the function `f` with the TOMS_748 method.
542 : *
543 : * `f` is a unary invokable that takes a `double` which is the current value at
544 : * which to evaluate `f`. An example is below.
545 : *
546 : * \snippet Test_TOMS748.cpp double_root_find
547 : *
548 : * The TOMS_748 algorithm searches for a root in the interval [`lower_bound`,
549 : * `upper_bound`], and will throw if this interval does not bracket a root,
550 : * i.e. if `f(lower_bound) * f(upper_bound) > 0`.
551 : *
552 : * The arguments `f_at_lower_bound` and `f_at_upper_bound` are optional, and
553 : * are the function values at `lower_bound` and `upper_bound`. These function
554 : * values are often known because the user typically checks if a root is
555 : * bracketed before calling `toms748`; passing the function values here saves
556 : * two function evaluations.
557 : *
558 : * \note if `AssumeFinite` is true than the code assumes all numbers are
559 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
560 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
561 : * that products like `fa * fb` are also finite.
562 : *
563 : * \requires Function `f` is invokable with a `double`
564 : *
565 : * \throws `convergence_error` if the requested tolerance is not met after
566 : * `max_iterations` iterations.
567 : */
568 : template <bool AssumeFinite = false, typename Function, typename T>
569 1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
570 : const T f_at_lower_bound, const T f_at_upper_bound,
571 : const simd::scalar_type_t<T> absolute_tolerance,
572 : const simd::scalar_type_t<T> relative_tolerance,
573 : const size_t max_iterations = 100,
574 : const simd::mask_type_t<T> ignore_filter =
575 : static_cast<simd::mask_type_t<T>>(0)) {
576 : ASSERT(relative_tolerance >
577 : std::numeric_limits<simd::scalar_type_t<T>>::epsilon(),
578 : "The relative tolerance is too small. Got "
579 : << relative_tolerance << " but must be at least "
580 : << std::numeric_limits<simd::scalar_type_t<T>>::epsilon());
581 : if (simd::any(f_at_lower_bound * f_at_upper_bound > 0.0)) {
582 : ERROR("Root not bracketed: f(" << lower_bound << ") = " << f_at_lower_bound
583 : << ", f(" << upper_bound
584 : << ") = " << f_at_upper_bound);
585 : }
586 :
587 : std::size_t max_iters = max_iterations;
588 :
589 : // This solver requires tol to be passed as a termination condition. This
590 : // termination condition is equivalent to the convergence criteria used by the
591 : // GSL
592 : const auto tol = [absolute_tolerance, relative_tolerance](const T& lhs,
593 : const T& rhs) {
594 : return simd::abs(lhs - rhs) <=
595 : simd::fma(T(relative_tolerance),
596 : simd::min(simd::abs(lhs), simd::abs(rhs)),
597 : T(absolute_tolerance));
598 : };
599 : auto result = toms748_detail::toms748_solve<AssumeFinite>(
600 : f, lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, tol,
601 : ignore_filter, max_iters);
602 : if (max_iters >= max_iterations) {
603 : ERROR_AS(
604 : "toms748 reached max iterations without converging.\nAbsolute "
605 : "tolerance: "
606 : << absolute_tolerance << "\nRelative tolerance: "
607 : << relative_tolerance << "\nResult: " << get_output(result.first)
608 : << " " << get_output(result.second),
609 : convergence_error);
610 : }
611 : return simd::fma(static_cast<T>(0.5), (result.second - result.first),
612 : result.first);
613 : }
614 :
615 : /*!
616 : * \ingroup NumericalAlgorithmsGroup
617 : * \brief Finds the root of the function `f` with the TOMS_748 method, where
618 : * function values are not supplied at the lower and upper bounds.
619 : *
620 : * \note if `AssumeFinite` is true than the code assumes all numbers are
621 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
622 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
623 : * that products like `fa * fb` are also finite.
624 : */
625 : template <bool AssumeFinite = false, typename Function, typename T>
626 1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
627 : const simd::scalar_type_t<T> absolute_tolerance,
628 : const simd::scalar_type_t<T> relative_tolerance,
629 : const size_t max_iterations = 100,
630 : const simd::mask_type_t<T> ignore_filter =
631 : static_cast<simd::mask_type_t<T>>(0)) {
632 : return toms748<AssumeFinite>(
633 : f, lower_bound, upper_bound, f(lower_bound), f(upper_bound),
634 : absolute_tolerance, relative_tolerance, max_iterations, ignore_filter);
635 : }
636 :
637 : /*!
638 : * \ingroup NumericalAlgorithmsGroup
639 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
640 : * element in a `DataVector`.
641 : *
642 : * `f` is a binary invokable that takes a `double` as its first argument and a
643 : * `size_t` as its second. The `double` is the current value at which to
644 : * evaluate `f`, and the `size_t` is the current index into the `DataVector`s.
645 : * Below is an example of how to root find different functions by indexing into
646 : * a lambda-captured `DataVector` using the `size_t` passed to `f`.
647 : *
648 : * \snippet Test_TOMS748.cpp datavector_root_find
649 : *
650 : * For each index `i` into the DataVector, the TOMS_748 algorithm searches for a
651 : * root in the interval [`lower_bound[i]`, `upper_bound[i]`], and will throw if
652 : * this interval does not bracket a root,
653 : * i.e. if `f(lower_bound[i], i) * f(upper_bound[i], i) > 0`.
654 : *
655 : * See the [Boost](http://www.boost.org/) documentation for more details.
656 : *
657 : * \requires Function `f` be callable with a `double` and a `size_t`
658 : *
659 : * \note if `AssumeFinite` is true than the code assumes all numbers are
660 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
661 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
662 : * that products like `fa * fb` are also finite.
663 : *
664 : * \throws `convergence_error` if, for any index, the requested tolerance is not
665 : * met after `max_iterations` iterations.
666 : */
667 : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
668 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
669 : const DataVector& upper_bound,
670 : const double absolute_tolerance,
671 : const double relative_tolerance,
672 : const size_t max_iterations = 100) {
673 : DataVector result_vector{lower_bound.size()};
674 : if constexpr (UseSimd) {
675 : constexpr size_t simd_width{simd::size<
676 : std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
677 : const size_t vectorized_size =
678 : lower_bound.size() - lower_bound.size() % simd_width;
679 : for (size_t i = 0; i < vectorized_size; i += simd_width) {
680 : simd::store_unaligned(
681 : &result_vector[i],
682 : toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
683 : simd::load_unaligned(&lower_bound[i]),
684 : simd::load_unaligned(&upper_bound[i]),
685 : absolute_tolerance, relative_tolerance,
686 : max_iterations));
687 : }
688 : for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
689 : result_vector[i] = toms748<AssumeFinite>(
690 : [&f, i](const auto x) { return f(x, i); }, lower_bound[i],
691 : upper_bound[i], absolute_tolerance, relative_tolerance,
692 : max_iterations);
693 : }
694 : } else {
695 : for (size_t i = 0; i < result_vector.size(); ++i) {
696 : result_vector[i] = toms748<AssumeFinite>(
697 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
698 : absolute_tolerance, relative_tolerance, max_iterations);
699 : }
700 : }
701 : return result_vector;
702 : }
703 :
704 : /*!
705 : * \ingroup NumericalAlgorithmsGroup
706 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
707 : * element in a `DataVector`, where function values are supplied at the lower
708 : * and upper bounds.
709 : *
710 : * Supplying function values is an optimization that saves two
711 : * function calls per point. The function values are often available
712 : * because one often checks if the root is bracketed before calling `toms748`.
713 : *
714 : * \note if `AssumeFinite` is true than the code assumes all numbers are
715 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
716 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
717 : * that products like `fa * fb` are also finite.
718 : */
719 : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
720 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
721 : const DataVector& upper_bound,
722 : const DataVector& f_at_lower_bound,
723 : const DataVector& f_at_upper_bound,
724 : const double absolute_tolerance,
725 : const double relative_tolerance,
726 : const size_t max_iterations = 100) {
727 : DataVector result_vector{lower_bound.size()};
728 : if constexpr (UseSimd) {
729 : constexpr size_t simd_width{simd::size<
730 : std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
731 : const size_t vectorized_size =
732 : lower_bound.size() - lower_bound.size() % simd_width;
733 : for (size_t i = 0; i < vectorized_size; i += simd_width) {
734 : simd::store_unaligned(
735 : &result_vector[i],
736 : toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
737 : simd::load_unaligned(&lower_bound[i]),
738 : simd::load_unaligned(&upper_bound[i]),
739 : simd::load_unaligned(&f_at_lower_bound[i]),
740 : simd::load_unaligned(&f_at_upper_bound[i]),
741 : absolute_tolerance, relative_tolerance,
742 : max_iterations));
743 : }
744 : for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
745 : result_vector[i] = toms748<AssumeFinite>(
746 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
747 : f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
748 : relative_tolerance, max_iterations);
749 : }
750 : } else {
751 : for (size_t i = 0; i < lower_bound.size(); ++i) {
752 : result_vector[i] = toms748<AssumeFinite>(
753 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
754 : f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
755 : relative_tolerance, max_iterations);
756 : }
757 : }
758 : return result_vector;
759 : }
760 : } // namespace RootFinder
|