Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Declares function RootFinder::toms748
6 :
7 : #pragma once
8 :
9 : #include <functional>
10 : #include <iomanip>
11 : #include <ios>
12 : #include <limits>
13 : #include <string>
14 : #include <type_traits>
15 :
16 : #include "DataStructures/DataVector.hpp"
17 : #include "Utilities/ErrorHandling/CaptureForError.hpp"
18 : #include "Utilities/ErrorHandling/Error.hpp"
19 : #include "Utilities/ErrorHandling/Exceptions.hpp"
20 : #include "Utilities/GetOutput.hpp"
21 : #include "Utilities/MakeString.hpp"
22 : #include "Utilities/Simd/Simd.hpp"
23 :
24 : namespace RootFinder {
25 : namespace toms748_detail {
26 : // Original implementation of TOMS748 is from Boost:
27 : // (C) Copyright John Maddock 2006.
28 : // Use, modification and distribution are subject to the
29 : // Boost Software License, Version 1.0. (See accompanying file
30 : // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
31 : //
32 : // Significant changes made to pretty much all of it to support SIMD by Nils
33 : // Deppe. Changes are copyrighted by SXS Collaboration under the MIT License.
34 :
35 : template <typename T>
36 : T safe_div(const T& num, const T& denom, const T& r) {
37 : if constexpr (std::is_floating_point_v<T>) {
38 : using std::abs;
39 : if (abs(denom) < static_cast<T>(1)) {
40 : if (abs(denom * std::numeric_limits<T>::max()) <= abs(num)) {
41 : return r;
42 : }
43 : }
44 : return num / denom;
45 : } else {
46 : // return num / denom without overflow, return r if overflow would occur.
47 : //
48 : // To do this we need to handle the following cases:
49 : // 1. fabs(denom) < 1 AND fabs(denom * max) <= fabs(num):
50 : // return r
51 : // 2. return num / denom
52 : //
53 : // We do this by creating 2 masks.
54 : // 1. `mask0` selects where fabs(denom) < 1. This is where we _may_ have
55 : // issues with division by zero or overflow.
56 : // 2. `mask` selects where fabs(denom) < 1 AND fabs(denom * max)<=fabs(num)
57 : // * Note: the edge case of fabs(num)==max could be problematic, but if
58 : // you're dealing with numbers like that you are likely in
59 : // trouble anyway.
60 : // The second mask is where we must avoid division by zero and instead
61 : // return `r`.
62 : const auto mask0 = fabs(denom) < static_cast<T>(1);
63 : // Note: if denom >= 1.0 you get an FPE because of overflow from
64 : // `max() * (1. + value)`, which is why `mask0` is necessary.
65 : const auto mask = fabs(simd::select(mask0, denom, static_cast<T>(1.0)) *
66 : std::numeric_limits<T>::max()) <= fabs(num);
67 : const auto new_denom = simd::select(mask, static_cast<T>(1), denom);
68 : return simd::select(mask, r, num / new_denom);
69 : }
70 : }
71 :
72 : template <typename T>
73 : T secant_interpolate(const T& a, const T& b, const T& fa, const T& fb) {
74 : //
75 : // Performs standard secant interpolation of [a,b] given
76 : // function evaluations f(a) and f(b). Performs a bisection
77 : // if secant interpolation would leave us very close to either
78 : // a or b. Rationale: we only call this function when at least
79 : // one other form of interpolation has already failed, so we know
80 : // that the function is unlikely to be smooth with a root very
81 : // close to a or b.
82 : //
83 : const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(5);
84 : // WARNING: There are several different ways to implement the interpolation
85 : // that all have different rounding properties. Unfortunately this means that
86 : // tolerances even at 1e-14 can be difficult to achieve generically.
87 : //
88 : // `g` below is:
89 : // const T g = (fa / (fb - fa));
90 : //
91 : // const T c = simd::fma((fa / (fb - fa)), (a - b), a); // fails
92 : //
93 : // const T c = a * (static_cast<T>(1) - g) + g * b; // works
94 : //
95 : // const T c = simd::fma(a, (static_cast<T>(1) - g), g * b); // works
96 : //
97 : // Original Boost code:
98 : // const T c = a - (fa / (fb - fa)) * (b - a); // works
99 :
100 : const T c = a - (fa / (fb - fa)) * (b - a);
101 : return simd::select((c <= simd::fma(fabs(a), tol_batch, a)) or
102 : (c >= simd::fnma(fabs(b), tol_batch, b)),
103 : static_cast<T>(0.5) * (a + b), c);
104 : }
105 :
106 : template <bool AssumeFinite, typename T>
107 : T quadratic_interpolate(const T& a, const T& b, const T& d, const T& fa,
108 : const T& fb, const T& fd,
109 : const simd::mask_type_t<T>& incomplete_mask,
110 : const unsigned count) {
111 : // Performs quadratic interpolation to determine the next point,
112 : // takes count Newton steps to find the location of the
113 : // quadratic polynomial.
114 : //
115 : // Point d must lie outside of the interval [a,b], it is the third
116 : // best approximation to the root, after a and b.
117 : //
118 : // Note: this does not guarantee to find a root
119 : // inside [a, b], so we fall back to a secant step should
120 : // the result be out of range.
121 : //
122 : // Start by obtaining the coefficients of the quadratic polynomial:
123 : const T B = safe_div(fb - fa, b - a, std::numeric_limits<T>::max());
124 : T A = safe_div(fd - fb, d - b, std::numeric_limits<T>::max());
125 : A = safe_div(A - B, d - a, static_cast<T>(0));
126 :
127 : const auto secant_failure_mask = A == static_cast<T>(0) and incomplete_mask;
128 : T result_secant{};
129 : if (UNLIKELY(simd::any(secant_failure_mask))) {
130 : // failure to determine coefficients, try a secant step:
131 : result_secant = secant_interpolate(a, b, fa, fb);
132 : if (UNLIKELY(simd::all(secant_failure_mask or (not incomplete_mask)))) {
133 : return result_secant;
134 : }
135 : }
136 :
137 : // Determine the starting point of the Newton steps:
138 : //
139 : // Note: unlike Boost, we assume A*fa doesn't overflow. This speeds up the
140 : // code quite a bit.
141 : T c = AssumeFinite
142 : ? simd::select(A * fa > static_cast<T>(0), a, b)
143 : : simd::select(simd::sign(A) * simd::sign(fa) > static_cast<T>(0),
144 : a, b);
145 :
146 : // Take the Newton steps:
147 : const T two_A = static_cast<T>(2) * A;
148 : const T half_a_plus_b = 0.5 * (a + b);
149 : const T one_minus_a = static_cast<T>(1) - a;
150 : const T B_minus_A_times_b = B - A * b;
151 : for (unsigned i = 1; i <= count; ++i) {
152 : c -= safe_div(simd::fma(simd::fma(A, c, B_minus_A_times_b), c - a, fa),
153 : simd::fma(two_A, c - half_a_plus_b, B), one_minus_a + c);
154 : }
155 : if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
156 : simd::any(mask)) {
157 : // Failure, try a secant step:
158 : c = simd::select(mask, secant_interpolate(a, b, fa, fb), c);
159 : }
160 : return simd::select(secant_failure_mask, result_secant, c);
161 : }
162 :
163 : template <bool AssumeFinite, typename T>
164 : T cubic_interpolate(const T& a, const T& b, const T& d, const T& e, const T& fa,
165 : const T& fb, const T& fd, const T& fe,
166 : const simd::mask_type_t<T>& incomplete_mask) {
167 : // Uses inverse cubic interpolation of f(x) at points
168 : // [a,b,d,e] to obtain an approximate root of f(x).
169 : // Points d and e lie outside the interval [a,b]
170 : // and are the third and forth best approximations
171 : // to the root that we have found so far.
172 : //
173 : // Note: this does not guarantee to find a root
174 : // inside [a, b], so we fall back to quadratic
175 : // interpolation in case of an erroneous result.
176 : //
177 : // This commented chunk is the original Boost implementation translated into
178 : // simd. The actual code below is a heavily optimized version.
179 : //
180 : // const T q11 = (d - e) * fd / (fe - fd);
181 : // const T q21 = (b - d) * fb / (fd - fb);
182 : // const T q31 = (a - b) * fa / (fb - fa);
183 : // const T d21 = (b - d) * fd / (fd - fb);
184 : // const T d31 = (a - b) * fb / (fb - fa);
185 : //
186 : // const T q22 = (d21 - q11) * fb / (fe - fb);
187 : // const T q32 = (d31 - q21) * fa / (fd - fa);
188 : // const T d32 = (d31 - q21) * fd / (fd - fa);
189 : // const T q33 = (d32 - q22) * fa / (fe - fa);
190 : //
191 : // T c = q31 + q32 + q33 + a;
192 :
193 : // The optimized implementation here is L1-cache bound. That is, we aren't
194 : // able to completely saturate the FP units because we are waiting on the L1
195 : // cache. While not ideal, that's okay and just part of the algorithm.
196 : const T denom_fb_fa = fb - fa;
197 : const T denom_fd_fb = fd - fb;
198 : const T denom_fd_fa = fd - fa;
199 : const T denom_fe_fd = fe - fd;
200 : const T denom_fe_fb = fe - fb;
201 : const T denom =
202 : denom_fe_fb * denom_fe_fd * denom_fd_fa * denom_fd_fb * (fe - fa);
203 :
204 : // Avoid division by zero with mask.
205 : const T fa_by_denom = fa / simd::select(incomplete_mask, denom_fb_fa * denom,
206 : static_cast<T>(1));
207 :
208 : const T d31 = (a - b);
209 : const T q21 = (b - d);
210 : const T q32 = simd::fms(denom_fd_fb, d31, denom_fb_fa * q21) * denom_fe_fb *
211 : denom_fe_fd;
212 : const T q22 = simd::fms(denom_fe_fd, q21, (d - e) * denom_fd_fb) * fd *
213 : denom_fb_fa * denom_fd_fa;
214 :
215 : // Note: the reduction in rounding error that comes from the improvement by
216 : // Stoer & Bulirsch to Neville's algorithm is adding `a` at the very end as
217 : // we do below. Alternative ways of evaluating polynomials do not delay this
218 : // inclusion of `a`, and so then when the correction to `a` is small,
219 : // floating point errors decrease the accuracy of the result.
220 : T c = simd::fma(
221 : fa_by_denom,
222 : simd::fma(fb, simd::fms(q32, (fe + denom_fd_fa), q22), d31 * denom), a);
223 :
224 : if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
225 : simd::any(mask)) {
226 : // Out of bounds step, fall back to quadratic interpolation:
227 : //
228 : // Note: we only apply quadratic interpolation at points where cubic
229 : // failed and that aren't already at a root.
230 : c = simd::select(
231 : mask, quadratic_interpolate<AssumeFinite>(a, b, d, fa, fb, fd, mask, 3),
232 : c);
233 : }
234 :
235 : return c;
236 : }
237 :
238 : template <bool AssumeFinite, typename F, typename T>
239 : void bracket(F f, T& a, T& b, T c, T& fa, T& fb, T& d, T& fd,
240 : const simd::mask_type_t<T>& incomplete_mask) {
241 : // Given a point c inside the existing enclosing interval
242 : // [a, b] sets a = c if f(c) == 0, otherwise finds the new
243 : // enclosing interval: either [a, c] or [c, b] and sets
244 : // d and fd to the point that has just been removed from
245 : // the interval. In other words d is the third best guess
246 : // to the root.
247 : //
248 : // Note: `bracket` will only modify slots marked as `true` in
249 : // `incomplete_mask`
250 : const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(2);
251 :
252 : // If the interval [a,b] is very small, or if c is too close
253 : // to one end of the interval then we need to adjust the
254 : // location of c accordingly. This is:
255 : //
256 : // if ((b - a) < 2 * tol * a) {
257 : // c = a + (b - a) / 2;
258 : // } else if (c <= a + fabs(a) * tol) {
259 : // c = a + fabs(a) * tol;
260 : // } else if (c >= b - fabs(b) * tol) {
261 : // c = b - fabs(b) * tol;
262 : // }
263 : const T a_filt = simd::fma(fabs(a), tol_batch, a);
264 : const T b_filt = simd::fnma(fabs(b), tol_batch, b);
265 : const T b_minus_a = b - a;
266 : c = simd::select(
267 : (static_cast<T>(2) * tol_batch * a > b_minus_a) and incomplete_mask,
268 : simd::fma(b_minus_a, static_cast<T>(0.5), a),
269 : simd::select(c <= a_filt, a_filt, simd::select(c >= b_filt, b_filt, c)));
270 :
271 : // Invoke f(c):
272 : T fc = f(c);
273 :
274 : // if we have a zero then we have an exact solution to the root:
275 : const auto fc_is_zero_mask = (fc == static_cast<T>(0));
276 : if (const auto mask = fc_is_zero_mask and incomplete_mask;
277 : UNLIKELY(simd::any(mask))) {
278 : a = simd::select(mask, c, a);
279 : fa = simd::select(mask, static_cast<T>(0), fa);
280 : d = simd::select(mask, static_cast<T>(0), d);
281 : fd = simd::select(mask, static_cast<T>(0), fd);
282 : if (UNLIKELY(simd::all(mask or not incomplete_mask))) {
283 : return;
284 : }
285 : }
286 :
287 : // Non-zero fc, update the interval:
288 : //
289 : // Note: unlike Boost, we assume fa*fc doesn't overflow. This speeds up the
290 : // code quite a bit.
291 : //
292 : // Boost code is:
293 : // if (boost::math::sign(fa) * boost::math::sign(fc) < 0) {...} else {...}
294 : using simd::sign;
295 : const auto sign_mask = AssumeFinite
296 : ? (fa * fc < static_cast<T>(0))
297 : : (sign(fa) * sign(fc) < static_cast<T>(0));
298 : const auto mask_if =
299 : (sign_mask and (not fc_is_zero_mask)) and incomplete_mask;
300 : d = simd::select(mask_if, b, d);
301 : fd = simd::select(mask_if, fb, fd);
302 : b = simd::select(mask_if, c, b);
303 : fb = simd::select(mask_if, fc, fb);
304 :
305 : const auto mask_else =
306 : ((not sign_mask) and (not fc_is_zero_mask)) and incomplete_mask;
307 : d = simd::select(mask_else, a, d);
308 : fd = simd::select(mask_else, fa, fd);
309 : a = simd::select(mask_else, c, a);
310 : fa = simd::select(mask_else, fc, fa);
311 : }
312 :
313 : template <bool AssumeFinite, class F, class T, class Tol>
314 : std::pair<T, T> toms748_solve(F f, const T& ax, const T& bx, const T& fax,
315 : const T& fbx, Tol tol,
316 : const simd::mask_type_t<T>& ignore_filter,
317 : size_t& max_iter) {
318 : // Main entry point and logic for Toms Algorithm 748
319 : // root finder.
320 : if (UNLIKELY(simd::any(ax > bx))) {
321 : ERROR_AS("Lower bound is larger than upper bound", std::domain_error);
322 : }
323 :
324 : // Sanity check - are we allowed to iterate at all?
325 : if (UNLIKELY(max_iter == 0)) {
326 : return std::pair{ax, bx};
327 : }
328 :
329 : size_t count = max_iter;
330 : // mu is a parameter in the algorithm that must be between (0, 1).
331 : static const T mu = 0.5f;
332 :
333 : // initialise a, b and fa, fb:
334 : T a = ax;
335 : T b = bx;
336 : T fa = fax;
337 : T fb = fbx;
338 : CAPTURE_FOR_ERROR(ax);
339 : CAPTURE_FOR_ERROR(bx);
340 : CAPTURE_FOR_ERROR(fax);
341 : CAPTURE_FOR_ERROR(fbx);
342 :
343 : const auto fa_is_zero_mask = (fa == static_cast<T>(0));
344 : const auto fb_is_zero_mask = (fb == static_cast<T>(0));
345 : auto completion_mask =
346 : tol(a, b) or fa_is_zero_mask or fb_is_zero_mask or ignore_filter;
347 : auto incomplete_mask = not completion_mask;
348 : if (UNLIKELY(simd::all(completion_mask))) {
349 : max_iter = 0;
350 : return std::pair{simd::select(fb_is_zero_mask, b, a),
351 : simd::select(fa_is_zero_mask, a, b)};
352 : }
353 :
354 : // Note: unlike Boost, we can assume fa*fb doesn't overflow when possible.
355 : // This speeds up the code quite a bit.
356 : if (UNLIKELY(simd::any((AssumeFinite ? (fa * fb > static_cast<T>(0))
357 : : (simd::sign(fa) * simd::sign(fb) >
358 : static_cast<T>(0))) and
359 : (not fa_is_zero_mask) and (not fb_is_zero_mask)))) {
360 : ERROR_AS("Parameters lower and upper bounds do not bracket a root.",
361 : std::domain_error);
362 : }
363 : // dummy value for fd, e and fe:
364 : T fe(static_cast<T>(1e5F));
365 : T e(static_cast<T>(1e5F));
366 : T fd(static_cast<T>(1e5F));
367 :
368 : T c(0.0);
369 : T d(0.0);
370 :
371 : // We can't use signaling_NaN() because we do a max comparison at the end to
372 : // 0, and signaling_NaN when compared to 0 raises an FPE.
373 : const T nan(std::numeric_limits<T>::quiet_NaN());
374 : auto completed_a = simd::select(completion_mask, a, nan);
375 : auto completed_b = simd::select(completion_mask, b, nan);
376 : auto completed_fa = simd::select(completion_mask, fa, nan);
377 : auto completed_fb = simd::select(completion_mask, fb, nan);
378 : const auto update_completed = [&fa, &fb, &completion_mask, &incomplete_mask,
379 : &completed_a, &completed_b, &a, &b,
380 : &completed_fa, &completed_fb, &tol]() {
381 : const auto new_completed =
382 : (fa == static_cast<T>(0) or tol(a, b)) and (not completion_mask);
383 : completed_a = simd::select(new_completed, a, completed_a);
384 : completed_b = simd::select(new_completed, b, completed_b);
385 : completed_fa = simd::select(new_completed, fa, completed_fa);
386 : completed_fb = simd::select(new_completed, fb, completed_fb);
387 : completion_mask = new_completed or completion_mask;
388 : incomplete_mask = not completion_mask;
389 : // returns true if _all_ simd registers have been completed
390 : return simd::all(completion_mask);
391 : };
392 :
393 : if (simd::any(fa != static_cast<T>(0))) {
394 : // On the first step we take a secant step:
395 : c = toms748_detail::secant_interpolate(a, b, fa, fb);
396 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
397 : incomplete_mask);
398 : --count;
399 :
400 : // Note: The Boost fa!=0 check is handled with the completion_mask.
401 : if (count and not update_completed()) {
402 : // On the second step we take a quadratic interpolation:
403 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
404 : a, b, d, fa, fb, fd, incomplete_mask, 2);
405 : e = d;
406 : fe = fd;
407 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
408 : incomplete_mask);
409 : --count;
410 : update_completed();
411 : }
412 : }
413 :
414 : T u(std::numeric_limits<T>::signaling_NaN());
415 : T fu(std::numeric_limits<T>::signaling_NaN());
416 : T a0(std::numeric_limits<T>::signaling_NaN());
417 : T b0(std::numeric_limits<T>::signaling_NaN());
418 :
419 : // Note: The Boost fa!=0 check is handled with the completion_mask.
420 : while (count and not simd::all(completion_mask)) {
421 : // save our brackets:
422 : a0 = a;
423 : b0 = b;
424 : // Starting with the third step taken
425 : // we can use either quadratic or cubic interpolation.
426 : // Cubic interpolation requires that all four function values
427 : // fa, fb, fd, and fe are distinct, should that not be the case
428 : // then variable prof will get set to true, and we'll end up
429 : // taking a quadratic step instead.
430 : static const T min_diff = std::numeric_limits<T>::min() * 32;
431 : bool prof =
432 : simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
433 : (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
434 : (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
435 : incomplete_mask);
436 : if (prof) {
437 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
438 : a, b, d, fa, fb, fd, incomplete_mask, 2);
439 : } else {
440 : c = toms748_detail::cubic_interpolate<AssumeFinite>(
441 : a, b, d, e, fa, fb, fd, fe, incomplete_mask);
442 : }
443 : // re-bracket, and check for termination:
444 : e = d;
445 : fe = fd;
446 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
447 : incomplete_mask);
448 : if ((0 == --count) or update_completed()) {
449 : break;
450 : }
451 : // Now another interpolated step:
452 : prof =
453 : simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
454 : (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
455 : (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
456 : incomplete_mask);
457 : if (prof) {
458 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
459 : a, b, d, fa, fb, fd, incomplete_mask, 3);
460 : } else {
461 : c = toms748_detail::cubic_interpolate<AssumeFinite>(
462 : a, b, d, e, fa, fb, fd, fe, incomplete_mask);
463 : }
464 : // Bracket again, and check termination condition, update e:
465 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
466 : incomplete_mask);
467 : if ((0 == --count) or update_completed()) {
468 : break;
469 : }
470 :
471 : // Now we take a double-length secant step:
472 : const auto fabs_fa_less_fabs_fb_mask =
473 : (fabs(fa) < fabs(fb)) and incomplete_mask;
474 : u = simd::select(fabs_fa_less_fabs_fb_mask, a, b);
475 : fu = simd::select(fabs_fa_less_fabs_fb_mask, fa, fb);
476 : const T b_minus_a = b - a;
477 : // Assumes that bounds a & b are not so close that fa == fb. Boost makes
478 : // this assumption too. If this is violated then the algorithm doesn't
479 : // work since the function at least appears constant.
480 : c = simd::fnma(static_cast<T>(2) * (fu / (fb - fa)), b_minus_a, u);
481 : c = simd::select(static_cast<T>(2) * fabs(c - u) > b_minus_a,
482 : simd::fma(static_cast<T>(0.5), b_minus_a, a), c);
483 :
484 : // Bracket again, and check termination condition:
485 : e = d;
486 : fe = fd;
487 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
488 : incomplete_mask);
489 : if ((0 == --count) or update_completed()) {
490 : break;
491 : }
492 :
493 : // And finally... check to see if an additional bisection step is
494 : // to be taken, we do this if we're not converging fast enough:
495 : const auto bisection_mask = (b - a) >= mu * (b0 - a0) and incomplete_mask;
496 : if (LIKELY(simd::none(bisection_mask))) {
497 : continue;
498 : }
499 : // bracket again on a bisection:
500 : //
501 : // Note: the mask ensures we only ever modify the slots that the mask has
502 : // identified as needing to be modified.
503 : e = simd::select(bisection_mask, d, e);
504 : fe = simd::select(bisection_mask, fd, fe);
505 : toms748_detail::bracket<AssumeFinite>(
506 : f, a, b, simd::fma((b - a), static_cast<T>(0.5), a), fa, fb, d, fd,
507 : bisection_mask);
508 : --count;
509 : if (update_completed()) {
510 : break;
511 : }
512 : } // while loop
513 :
514 : max_iter -= count;
515 : completed_b =
516 : simd::select(completed_fa == static_cast<T>(0), completed_a, completed_b);
517 : completed_a =
518 : simd::select(completed_fb == static_cast<T>(0), completed_b, completed_a);
519 : return std::pair{completed_a, completed_b};
520 : }
521 : } // namespace toms748_detail
522 :
523 : /*!
524 : * \ingroup NumericalAlgorithmsGroup
525 : * \brief Finds the root of the function `f` with the TOMS_748 method.
526 : *
527 : * `f` is a unary invokable that takes a `double` which is the current value at
528 : * which to evaluate `f`. An example is below.
529 : *
530 : * \snippet Test_TOMS748.cpp double_root_find
531 : *
532 : * The TOMS_748 algorithm searches for a root in the interval [`lower_bound`,
533 : * `upper_bound`], and will throw if this interval does not bracket a root,
534 : * i.e. if `f(lower_bound) * f(upper_bound) > 0`.
535 : *
536 : * The arguments `f_at_lower_bound` and `f_at_upper_bound` are optional, and
537 : * are the function values at `lower_bound` and `upper_bound`. These function
538 : * values are often known because the user typically checks if a root is
539 : * bracketed before calling `toms748`; passing the function values here saves
540 : * two function evaluations.
541 : *
542 : * \note if `AssumeFinite` is true than the code assumes all numbers are
543 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
544 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
545 : * that products like `fa * fb` are also finite.
546 : *
547 : * \requires Function `f` is invokable with a `double`
548 : *
549 : * \throws `convergence_error` if the requested tolerance is not met after
550 : * `max_iterations` iterations.
551 : */
552 : template <bool AssumeFinite = false, typename Function, typename T>
553 1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
554 : const T f_at_lower_bound, const T f_at_upper_bound,
555 : const simd::scalar_type_t<T> absolute_tolerance,
556 : const simd::scalar_type_t<T> relative_tolerance,
557 : const size_t max_iterations = 100,
558 : const simd::mask_type_t<T> ignore_filter =
559 : static_cast<simd::mask_type_t<T>>(0)) {
560 : ASSERT(relative_tolerance >
561 : std::numeric_limits<simd::scalar_type_t<T>>::epsilon(),
562 : "The relative tolerance is too small. Got "
563 : << relative_tolerance << " but must be at least "
564 : << std::numeric_limits<simd::scalar_type_t<T>>::epsilon());
565 : if (simd::any(f_at_lower_bound * f_at_upper_bound > 0.0)) {
566 : ERROR("Root not bracketed: f(" << lower_bound << ") = " << f_at_lower_bound
567 : << ", f(" << upper_bound
568 : << ") = " << f_at_upper_bound);
569 : }
570 :
571 : std::size_t max_iters = max_iterations;
572 :
573 : // This solver requires tol to be passed as a termination condition. This
574 : // termination condition is equivalent to the convergence criteria used by the
575 : // GSL
576 : const auto tol = [absolute_tolerance, relative_tolerance](const T& lhs,
577 : const T& rhs) {
578 : return simd::abs(lhs - rhs) <=
579 : simd::fma(T(relative_tolerance),
580 : simd::min(simd::abs(lhs), simd::abs(rhs)),
581 : T(absolute_tolerance));
582 : };
583 : auto result = toms748_detail::toms748_solve<AssumeFinite>(
584 : f, lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, tol,
585 : ignore_filter, max_iters);
586 : if (max_iters >= max_iterations) {
587 : ERROR_AS(
588 : "toms748 reached max iterations without converging.\nAbsolute "
589 : "tolerance: "
590 : << absolute_tolerance << "\nRelative tolerance: "
591 : << relative_tolerance << "\nResult: " << get_output(result.first)
592 : << " " << get_output(result.second),
593 : convergence_error);
594 : }
595 : return simd::fma(static_cast<T>(0.5), (result.second - result.first),
596 : result.first);
597 : }
598 :
599 : /*!
600 : * \ingroup NumericalAlgorithmsGroup
601 : * \brief Finds the root of the function `f` with the TOMS_748 method, where
602 : * function values are not supplied at the lower and upper bounds.
603 : *
604 : * \note if `AssumeFinite` is true than the code assumes all numbers are
605 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
606 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
607 : * that products like `fa * fb` are also finite.
608 : */
609 : template <bool AssumeFinite = false, typename Function, typename T>
610 1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
611 : const simd::scalar_type_t<T> absolute_tolerance,
612 : const simd::scalar_type_t<T> relative_tolerance,
613 : const size_t max_iterations = 100,
614 : const simd::mask_type_t<T> ignore_filter =
615 : static_cast<simd::mask_type_t<T>>(0)) {
616 : return toms748<AssumeFinite>(
617 : f, lower_bound, upper_bound, f(lower_bound), f(upper_bound),
618 : absolute_tolerance, relative_tolerance, max_iterations, ignore_filter);
619 : }
620 :
621 : /*!
622 : * \ingroup NumericalAlgorithmsGroup
623 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
624 : * element in a `DataVector`.
625 : *
626 : * `f` is a binary invokable that takes a `double` as its first argument and a
627 : * `size_t` as its second. The `double` is the current value at which to
628 : * evaluate `f`, and the `size_t` is the current index into the `DataVector`s.
629 : * Below is an example of how to root find different functions by indexing into
630 : * a lambda-captured `DataVector` using the `size_t` passed to `f`.
631 : *
632 : * \snippet Test_TOMS748.cpp datavector_root_find
633 : *
634 : * For each index `i` into the DataVector, the TOMS_748 algorithm searches for a
635 : * root in the interval [`lower_bound[i]`, `upper_bound[i]`], and will throw if
636 : * this interval does not bracket a root,
637 : * i.e. if `f(lower_bound[i], i) * f(upper_bound[i], i) > 0`.
638 : *
639 : * See the [Boost](http://www.boost.org/) documentation for more details.
640 : *
641 : * \requires Function `f` be callable with a `double` and a `size_t`
642 : *
643 : * \note if `AssumeFinite` is true than the code assumes all numbers are
644 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
645 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
646 : * that products like `fa * fb` are also finite.
647 : *
648 : * \throws `convergence_error` if, for any index, the requested tolerance is not
649 : * met after `max_iterations` iterations.
650 : */
651 : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
652 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
653 : const DataVector& upper_bound,
654 : const double absolute_tolerance,
655 : const double relative_tolerance,
656 : const size_t max_iterations = 100) {
657 : DataVector result_vector{lower_bound.size()};
658 : if constexpr (UseSimd) {
659 : constexpr size_t simd_width{simd::size<
660 : std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
661 : const size_t vectorized_size =
662 : lower_bound.size() - lower_bound.size() % simd_width;
663 : for (size_t i = 0; i < vectorized_size; i += simd_width) {
664 : simd::store_unaligned(
665 : &result_vector[i],
666 : toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
667 : simd::load_unaligned(&lower_bound[i]),
668 : simd::load_unaligned(&upper_bound[i]),
669 : absolute_tolerance, relative_tolerance,
670 : max_iterations));
671 : }
672 : for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
673 : result_vector[i] = toms748<AssumeFinite>(
674 : [&f, i](const auto x) { return f(x, i); }, lower_bound[i],
675 : upper_bound[i], absolute_tolerance, relative_tolerance,
676 : max_iterations);
677 : }
678 : } else {
679 : for (size_t i = 0; i < result_vector.size(); ++i) {
680 : result_vector[i] = toms748<AssumeFinite>(
681 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
682 : absolute_tolerance, relative_tolerance, max_iterations);
683 : }
684 : }
685 : return result_vector;
686 : }
687 :
688 : /*!
689 : * \ingroup NumericalAlgorithmsGroup
690 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
691 : * element in a `DataVector`, where function values are supplied at the lower
692 : * and upper bounds.
693 : *
694 : * Supplying function values is an optimization that saves two
695 : * function calls per point. The function values are often available
696 : * because one often checks if the root is bracketed before calling `toms748`.
697 : *
698 : * \note if `AssumeFinite` is true than the code assumes all numbers are
699 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
700 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
701 : * that products like `fa * fb` are also finite.
702 : */
703 : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
704 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
705 : const DataVector& upper_bound,
706 : const DataVector& f_at_lower_bound,
707 : const DataVector& f_at_upper_bound,
708 : const double absolute_tolerance,
709 : const double relative_tolerance,
710 : const size_t max_iterations = 100) {
711 : DataVector result_vector{lower_bound.size()};
712 : if constexpr (UseSimd) {
713 : constexpr size_t simd_width{simd::size<
714 : std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
715 : const size_t vectorized_size =
716 : lower_bound.size() - lower_bound.size() % simd_width;
717 : for (size_t i = 0; i < vectorized_size; i += simd_width) {
718 : simd::store_unaligned(
719 : &result_vector[i],
720 : toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
721 : simd::load_unaligned(&lower_bound[i]),
722 : simd::load_unaligned(&upper_bound[i]),
723 : simd::load_unaligned(&f_at_lower_bound[i]),
724 : simd::load_unaligned(&f_at_upper_bound[i]),
725 : absolute_tolerance, relative_tolerance,
726 : max_iterations));
727 : }
728 : for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
729 : result_vector[i] = toms748<AssumeFinite>(
730 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
731 : f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
732 : relative_tolerance, max_iterations);
733 : }
734 : } else {
735 : for (size_t i = 0; i < lower_bound.size(); ++i) {
736 : result_vector[i] = toms748<AssumeFinite>(
737 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
738 : f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
739 : relative_tolerance, max_iterations);
740 : }
741 : }
742 : return result_vector;
743 : }
744 : } // namespace RootFinder
|