Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Declares function RootFinder::toms748
6 :
7 : #pragma once
8 :
9 : #include <functional>
10 : #include <iomanip>
11 : #include <ios>
12 : #include <limits>
13 : #include <string>
14 : #include <type_traits>
15 :
16 : #include "DataStructures/DataVector.hpp"
17 : #include "Utilities/ErrorHandling/Error.hpp"
18 : #include "Utilities/ErrorHandling/Exceptions.hpp"
19 : #include "Utilities/GetOutput.hpp"
20 : #include "Utilities/MakeString.hpp"
21 : #include "Utilities/Simd/Simd.hpp"
22 :
23 : namespace RootFinder {
24 : namespace toms748_detail {
25 : // Original implementation of TOMS748 is from Boost:
26 : // (C) Copyright John Maddock 2006.
27 : // Use, modification and distribution are subject to the
28 : // Boost Software License, Version 1.0. (See accompanying file
29 : // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
30 : //
31 : // Significant changes made to pretty much all of it to support SIMD by Nils
32 : // Deppe. Changes are copyrighted by SXS Collaboration under the MIT License.
33 :
34 : template <typename T>
35 : T safe_div(const T& num, const T& denom, const T& r) {
36 : if constexpr (std::is_floating_point_v<T>) {
37 : using std::abs;
38 : if (abs(denom) < static_cast<T>(1)) {
39 : if (abs(denom * std::numeric_limits<T>::max()) <= abs(num)) {
40 : return r;
41 : }
42 : }
43 : return num / denom;
44 : } else {
45 : // return num / denom without overflow, return r if overflow would occur.
46 : //
47 : // To do this we need to handle the following cases:
48 : // 1. fabs(denom) < 1 AND fabs(denom * max) <= fabs(num):
49 : // return r
50 : // 2. return num / denom
51 : //
52 : // We do this by creating 2 masks.
53 : // 1. `mask0` selects where fabs(denom) < 1. This is where we _may_ have
54 : // issues with division by zero or overflow.
55 : // 2. `mask` selects where fabs(denom) < 1 AND fabs(denom * max)<=fabs(num)
56 : // * Note: the edge case of fabs(num)==max could be problematic, but if
57 : // you're dealing with numbers like that you are likely in
58 : // trouble anyway.
59 : // The second mask is where we must avoid division by zero and instead
60 : // return `r`.
61 : const auto mask0 = fabs(denom) < static_cast<T>(1);
62 : // Note: if denom >= 1.0 you get an FPE because of overflow from
63 : // `max() * (1. + value)`, which is why `mask0` is necessary.
64 : const auto mask = fabs(simd::select(mask0, denom, static_cast<T>(1.0)) *
65 : std::numeric_limits<T>::max()) <= fabs(num);
66 : const auto new_denom = simd::select(mask, static_cast<T>(1), denom);
67 : return simd::select(mask, r, num / new_denom);
68 : }
69 : }
70 :
71 : template <typename T>
72 : T secant_interpolate(const T& a, const T& b, const T& fa, const T& fb) {
73 : //
74 : // Performs standard secant interpolation of [a,b] given
75 : // function evaluations f(a) and f(b). Performs a bisection
76 : // if secant interpolation would leave us very close to either
77 : // a or b. Rationale: we only call this function when at least
78 : // one other form of interpolation has already failed, so we know
79 : // that the function is unlikely to be smooth with a root very
80 : // close to a or b.
81 : //
82 : const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(5);
83 : // WARNING: There are several different ways to implement the interpolation
84 : // that all have different rounding properties. Unfortunately this means that
85 : // tolerances even at 1e-14 can be difficult to achieve generically.
86 : //
87 : // `g` below is:
88 : // const T g = (fa / (fb - fa));
89 : //
90 : // const T c = simd::fma((fa / (fb - fa)), (a - b), a); // fails
91 : //
92 : // const T c = a * (static_cast<T>(1) - g) + g * b; // works
93 : //
94 : // const T c = simd::fma(a, (static_cast<T>(1) - g), g * b); // works
95 : //
96 : // Original Boost code:
97 : // const T c = a - (fa / (fb - fa)) * (b - a); // works
98 :
99 : const T c = a - (fa / (fb - fa)) * (b - a);
100 : return simd::select((c <= simd::fma(fabs(a), tol_batch, a)) or
101 : (c >= simd::fnma(fabs(b), tol_batch, b)),
102 : static_cast<T>(0.5) * (a + b), c);
103 : }
104 :
105 : template <bool AssumeFinite, typename T>
106 : T quadratic_interpolate(const T& a, const T& b, const T& d, const T& fa,
107 : const T& fb, const T& fd,
108 : const simd::mask_type_t<T>& incomplete_mask,
109 : const unsigned count) {
110 : // Performs quadratic interpolation to determine the next point,
111 : // takes count Newton steps to find the location of the
112 : // quadratic polynomial.
113 : //
114 : // Point d must lie outside of the interval [a,b], it is the third
115 : // best approximation to the root, after a and b.
116 : //
117 : // Note: this does not guarantee to find a root
118 : // inside [a, b], so we fall back to a secant step should
119 : // the result be out of range.
120 : //
121 : // Start by obtaining the coefficients of the quadratic polynomial:
122 : const T B = safe_div(fb - fa, b - a, std::numeric_limits<T>::max());
123 : T A = safe_div(fd - fb, d - b, std::numeric_limits<T>::max());
124 : A = safe_div(A - B, d - a, static_cast<T>(0));
125 :
126 : const auto secant_failure_mask = A == static_cast<T>(0) and incomplete_mask;
127 : T result_secant{};
128 : if (UNLIKELY(simd::any(secant_failure_mask))) {
129 : // failure to determine coefficients, try a secant step:
130 : result_secant = secant_interpolate(a, b, fa, fb);
131 : if (UNLIKELY(simd::all(secant_failure_mask or (not incomplete_mask)))) {
132 : return result_secant;
133 : }
134 : }
135 :
136 : // Determine the starting point of the Newton steps:
137 : //
138 : // Note: unlike Boost, we assume A*fa doesn't overflow. This speeds up the
139 : // code quite a bit.
140 : T c = AssumeFinite
141 : ? simd::select(A * fa > static_cast<T>(0), a, b)
142 : : simd::select(simd::sign(A) * simd::sign(fa) > static_cast<T>(0),
143 : a, b);
144 :
145 : // Take the Newton steps:
146 : const T two_A = static_cast<T>(2) * A;
147 : const T half_a_plus_b = 0.5 * (a + b);
148 : const T one_minus_a = static_cast<T>(1) - a;
149 : const T B_minus_A_times_b = B - A * b;
150 : for (unsigned i = 1; i <= count; ++i) {
151 : c -= safe_div(simd::fma(simd::fma(A, c, B_minus_A_times_b), c - a, fa),
152 : simd::fma(two_A, c - half_a_plus_b, B), one_minus_a + c);
153 : }
154 : if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
155 : simd::any(mask)) {
156 : // Failure, try a secant step:
157 : c = simd::select(mask, secant_interpolate(a, b, fa, fb), c);
158 : }
159 : return simd::select(secant_failure_mask, result_secant, c);
160 : }
161 :
162 : template <bool AssumeFinite, typename T>
163 : T cubic_interpolate(const T& a, const T& b, const T& d, const T& e, const T& fa,
164 : const T& fb, const T& fd, const T& fe,
165 : const simd::mask_type_t<T>& incomplete_mask) {
166 : // Uses inverse cubic interpolation of f(x) at points
167 : // [a,b,d,e] to obtain an approximate root of f(x).
168 : // Points d and e lie outside the interval [a,b]
169 : // and are the third and forth best approximations
170 : // to the root that we have found so far.
171 : //
172 : // Note: this does not guarantee to find a root
173 : // inside [a, b], so we fall back to quadratic
174 : // interpolation in case of an erroneous result.
175 : //
176 : // This commented chunk is the original Boost implementation translated into
177 : // simd. The actual code below is a heavily optimized version.
178 : //
179 : // const T q11 = (d - e) * fd / (fe - fd);
180 : // const T q21 = (b - d) * fb / (fd - fb);
181 : // const T q31 = (a - b) * fa / (fb - fa);
182 : // const T d21 = (b - d) * fd / (fd - fb);
183 : // const T d31 = (a - b) * fb / (fb - fa);
184 : //
185 : // const T q22 = (d21 - q11) * fb / (fe - fb);
186 : // const T q32 = (d31 - q21) * fa / (fd - fa);
187 : // const T d32 = (d31 - q21) * fd / (fd - fa);
188 : // const T q33 = (d32 - q22) * fa / (fe - fa);
189 : //
190 : // T c = q31 + q32 + q33 + a;
191 :
192 : // The optimized implementation here is L1-cache bound. That is, we aren't
193 : // able to completely saturate the FP units because we are waiting on the L1
194 : // cache. While not ideal, that's okay and just part of the algorithm.
195 : const T denom_fb_fa = fb - fa;
196 : const T denom_fd_fb = fd - fb;
197 : const T denom_fd_fa = fd - fa;
198 : const T denom_fe_fd = fe - fd;
199 : const T denom_fe_fb = fe - fb;
200 : const T denom =
201 : denom_fe_fb * denom_fe_fd * denom_fd_fa * denom_fd_fb * (fe - fa);
202 :
203 : // Avoid division by zero with mask.
204 : const T fa_by_denom = fa / simd::select(incomplete_mask, denom_fb_fa * denom,
205 : static_cast<T>(1));
206 :
207 : const T d31 = (a - b);
208 : const T q21 = (b - d);
209 : const T q32 = simd::fms(denom_fd_fb, d31, denom_fb_fa * q21) * denom_fe_fb *
210 : denom_fe_fd;
211 : const T q22 = simd::fms(denom_fe_fd, q21, (d - e) * denom_fd_fb) * fd *
212 : denom_fb_fa * denom_fd_fa;
213 :
214 : // Note: the reduction in rounding error that comes from the improvement by
215 : // Stoer & Bulirsch to Neville's algorithm is adding `a` at the very end as
216 : // we do below. Alternative ways of evaluating polynomials do not delay this
217 : // inclusion of `a`, and so then when the correction to `a` is small,
218 : // floating point errors decrease the accuracy of the result.
219 : T c = simd::fma(
220 : fa_by_denom,
221 : simd::fma(fb, simd::fms(q32, (fe + denom_fd_fa), q22), d31 * denom), a);
222 :
223 : if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
224 : simd::any(mask)) {
225 : // Out of bounds step, fall back to quadratic interpolation:
226 : //
227 : // Note: we only apply quadratic interpolation at points where cubic
228 : // failed and that aren't already at a root.
229 : c = simd::select(
230 : mask, quadratic_interpolate<AssumeFinite>(a, b, d, fa, fb, fd, mask, 3),
231 : c);
232 : }
233 :
234 : return c;
235 : }
236 :
237 : template <bool AssumeFinite, typename F, typename T>
238 : void bracket(F f, T& a, T& b, T c, T& fa, T& fb, T& d, T& fd,
239 : const simd::mask_type_t<T>& incomplete_mask) {
240 : // Given a point c inside the existing enclosing interval
241 : // [a, b] sets a = c if f(c) == 0, otherwise finds the new
242 : // enclosing interval: either [a, c] or [c, b] and sets
243 : // d and fd to the point that has just been removed from
244 : // the interval. In other words d is the third best guess
245 : // to the root.
246 : //
247 : // Note: `bracket` will only modify slots marked as `true` in
248 : // `incomplete_mask`
249 : const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(2);
250 :
251 : // If the interval [a,b] is very small, or if c is too close
252 : // to one end of the interval then we need to adjust the
253 : // location of c accordingly. This is:
254 : //
255 : // if ((b - a) < 2 * tol * a) {
256 : // c = a + (b - a) / 2;
257 : // } else if (c <= a + fabs(a) * tol) {
258 : // c = a + fabs(a) * tol;
259 : // } else if (c >= b - fabs(b) * tol) {
260 : // c = b - fabs(b) * tol;
261 : // }
262 : const T a_filt = simd::fma(fabs(a), tol_batch, a);
263 : const T b_filt = simd::fnma(fabs(b), tol_batch, b);
264 : const T b_minus_a = b - a;
265 : c = simd::select(
266 : (static_cast<T>(2) * tol_batch * a > b_minus_a) and incomplete_mask,
267 : simd::fma(b_minus_a, static_cast<T>(0.5), a),
268 : simd::select(c <= a_filt, a_filt, simd::select(c >= b_filt, b_filt, c)));
269 :
270 : // Invoke f(c):
271 : T fc = f(c);
272 :
273 : // if we have a zero then we have an exact solution to the root:
274 : const auto fc_is_zero_mask = (fc == static_cast<T>(0));
275 : if (const auto mask = fc_is_zero_mask and incomplete_mask;
276 : UNLIKELY(simd::any(mask))) {
277 : a = simd::select(mask, c, a);
278 : fa = simd::select(mask, static_cast<T>(0), fa);
279 : d = simd::select(mask, static_cast<T>(0), d);
280 : fd = simd::select(mask, static_cast<T>(0), fd);
281 : if (UNLIKELY(simd::all(mask or not incomplete_mask))) {
282 : return;
283 : }
284 : }
285 :
286 : // Non-zero fc, update the interval:
287 : //
288 : // Note: unlike Boost, we assume fa*fc doesn't overflow. This speeds up the
289 : // code quite a bit.
290 : //
291 : // Boost code is:
292 : // if (boost::math::sign(fa) * boost::math::sign(fc) < 0) {...} else {...}
293 : using simd::sign;
294 : const auto sign_mask = AssumeFinite
295 : ? (fa * fc < static_cast<T>(0))
296 : : (sign(fa) * sign(fc) < static_cast<T>(0));
297 : const auto mask_if =
298 : (sign_mask and (not fc_is_zero_mask)) and incomplete_mask;
299 : d = simd::select(mask_if, b, d);
300 : fd = simd::select(mask_if, fb, fd);
301 : b = simd::select(mask_if, c, b);
302 : fb = simd::select(mask_if, fc, fb);
303 :
304 : const auto mask_else =
305 : ((not sign_mask) and (not fc_is_zero_mask)) and incomplete_mask;
306 : d = simd::select(mask_else, a, d);
307 : fd = simd::select(mask_else, fa, fd);
308 : a = simd::select(mask_else, c, a);
309 : fa = simd::select(mask_else, fc, fa);
310 : }
311 :
312 : template <bool AssumeFinite, class F, class T, class Tol>
313 : std::pair<T, T> toms748_solve(F f, const T& ax, const T& bx, const T& fax,
314 : const T& fbx, Tol tol,
315 : const simd::mask_type_t<T>& ignore_filter,
316 : size_t& max_iter) {
317 : // Main entry point and logic for Toms Algorithm 748
318 : // root finder.
319 : if (UNLIKELY(simd::any(ax >= bx))) {
320 : throw std::domain_error("Lower bound is larger than upper bound");
321 : }
322 :
323 : // Sanity check - are we allowed to iterate at all?
324 : if (UNLIKELY(max_iter == 0)) {
325 : return std::pair{ax, bx};
326 : }
327 :
328 : size_t count = max_iter;
329 : // mu is a parameter in the algorithm that must be between (0, 1).
330 : static const T mu = 0.5f;
331 :
332 : // initialise a, b and fa, fb:
333 : T a = ax;
334 : T b = bx;
335 : T fa = fax;
336 : T fb = fbx;
337 :
338 : const auto fa_is_zero_mask = (fa == static_cast<T>(0));
339 : const auto fb_is_zero_mask = (fb == static_cast<T>(0));
340 : auto completion_mask =
341 : tol(a, b) or fa_is_zero_mask or fb_is_zero_mask or ignore_filter;
342 : auto incomplete_mask = not completion_mask;
343 : if (UNLIKELY(simd::all(completion_mask))) {
344 : max_iter = 0;
345 : return std::pair{simd::select(fb_is_zero_mask, b, a),
346 : simd::select(fa_is_zero_mask, a, b)};
347 : }
348 :
349 : // Note: unlike Boost, we can assume fa*fb doesn't overflow when possible.
350 : // This speeds up the code quite a bit.
351 : if (UNLIKELY(simd::any((AssumeFinite ? (fa * fb > static_cast<T>(0))
352 : : (simd::sign(fa) * simd::sign(fb) >
353 : static_cast<T>(0))) and
354 : (not fa_is_zero_mask) and (not fb_is_zero_mask)))) {
355 : throw std::domain_error(
356 : "Parameters lower and upper bounds do not bracket a root");
357 : }
358 : // dummy value for fd, e and fe:
359 : T fe(static_cast<T>(1e5F));
360 : T e(static_cast<T>(1e5F));
361 : T fd(static_cast<T>(1e5F));
362 :
363 : T c(0.0);
364 : T d(0.0);
365 :
366 : // We can't use signaling_NaN() because we do a max comparison at the end to
367 : // 0, and signaling_NaN when compared to 0 raises an FPE.
368 : const T nan(std::numeric_limits<T>::quiet_NaN());
369 : auto completed_a = simd::select(completion_mask, a, nan);
370 : auto completed_b = simd::select(completion_mask, b, nan);
371 : auto completed_fa = simd::select(completion_mask, fa, nan);
372 : auto completed_fb = simd::select(completion_mask, fb, nan);
373 : const auto update_completed = [&fa, &fb, &completion_mask, &incomplete_mask,
374 : &completed_a, &completed_b, &a, &b,
375 : &completed_fa, &completed_fb, &tol]() {
376 : const auto new_completed =
377 : (fa == static_cast<T>(0) or tol(a, b)) and (not completion_mask);
378 : completed_a = simd::select(new_completed, a, completed_a);
379 : completed_b = simd::select(new_completed, b, completed_b);
380 : completed_fa = simd::select(new_completed, fa, completed_fa);
381 : completed_fb = simd::select(new_completed, fb, completed_fb);
382 : completion_mask = new_completed or completion_mask;
383 : incomplete_mask = not completion_mask;
384 : // returns true if _all_ simd registers have been completed
385 : return simd::all(completion_mask);
386 : };
387 :
388 : if (simd::any(fa != static_cast<T>(0))) {
389 : // On the first step we take a secant step:
390 : c = toms748_detail::secant_interpolate(a, b, fa, fb);
391 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
392 : incomplete_mask);
393 : --count;
394 :
395 : // Note: The Boost fa!=0 check is handled with the completion_mask.
396 : if (count and not update_completed()) {
397 : // On the second step we take a quadratic interpolation:
398 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
399 : a, b, d, fa, fb, fd, incomplete_mask, 2);
400 : e = d;
401 : fe = fd;
402 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
403 : incomplete_mask);
404 : --count;
405 : update_completed();
406 : }
407 : }
408 :
409 : T u(std::numeric_limits<T>::signaling_NaN());
410 : T fu(std::numeric_limits<T>::signaling_NaN());
411 : T a0(std::numeric_limits<T>::signaling_NaN());
412 : T b0(std::numeric_limits<T>::signaling_NaN());
413 :
414 : // Note: The Boost fa!=0 check is handled with the completion_mask.
415 : while (count and not simd::all(completion_mask)) {
416 : // save our brackets:
417 : a0 = a;
418 : b0 = b;
419 : // Starting with the third step taken
420 : // we can use either quadratic or cubic interpolation.
421 : // Cubic interpolation requires that all four function values
422 : // fa, fb, fd, and fe are distinct, should that not be the case
423 : // then variable prof will get set to true, and we'll end up
424 : // taking a quadratic step instead.
425 : static const T min_diff = std::numeric_limits<T>::min() * 32;
426 : bool prof =
427 : simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
428 : (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
429 : (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
430 : incomplete_mask);
431 : if (prof) {
432 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
433 : a, b, d, fa, fb, fd, incomplete_mask, 2);
434 : } else {
435 : c = toms748_detail::cubic_interpolate<AssumeFinite>(
436 : a, b, d, e, fa, fb, fd, fe, incomplete_mask);
437 : }
438 : // re-bracket, and check for termination:
439 : e = d;
440 : fe = fd;
441 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
442 : incomplete_mask);
443 : if ((0 == --count) or update_completed()) {
444 : break;
445 : }
446 : // Now another interpolated step:
447 : prof =
448 : simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
449 : (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
450 : (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
451 : incomplete_mask);
452 : if (prof) {
453 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
454 : a, b, d, fa, fb, fd, incomplete_mask, 3);
455 : } else {
456 : c = toms748_detail::cubic_interpolate<AssumeFinite>(
457 : a, b, d, e, fa, fb, fd, fe, incomplete_mask);
458 : }
459 : // Bracket again, and check termination condition, update e:
460 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
461 : incomplete_mask);
462 : if ((0 == --count) or update_completed()) {
463 : break;
464 : }
465 :
466 : // Now we take a double-length secant step:
467 : const auto fabs_fa_less_fabs_fb_mask =
468 : (fabs(fa) < fabs(fb)) and incomplete_mask;
469 : u = simd::select(fabs_fa_less_fabs_fb_mask, a, b);
470 : fu = simd::select(fabs_fa_less_fabs_fb_mask, fa, fb);
471 : const T b_minus_a = b - a;
472 : // Assumes that bounds a & b are not so close that fa == fb. Boost makes
473 : // this assumption too. If this is violated then the algorithm doesn't
474 : // work since the function at least appears constant.
475 : c = simd::fnma(static_cast<T>(2) * (fu / (fb - fa)), b_minus_a, u);
476 : c = simd::select(static_cast<T>(2) * fabs(c - u) > b_minus_a,
477 : simd::fma(static_cast<T>(0.5), b_minus_a, a), c);
478 :
479 : // Bracket again, and check termination condition:
480 : e = d;
481 : fe = fd;
482 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
483 : incomplete_mask);
484 : if ((0 == --count) or update_completed()) {
485 : break;
486 : }
487 :
488 : // And finally... check to see if an additional bisection step is
489 : // to be taken, we do this if we're not converging fast enough:
490 : const auto bisection_mask = (b - a) >= mu * (b0 - a0) and incomplete_mask;
491 : if (LIKELY(simd::none(bisection_mask))) {
492 : continue;
493 : }
494 : // bracket again on a bisection:
495 : //
496 : // Note: the mask ensures we only ever modify the slots that the mask has
497 : // identified as needing to be modified.
498 : e = simd::select(bisection_mask, d, e);
499 : fe = simd::select(bisection_mask, fd, fe);
500 : toms748_detail::bracket<AssumeFinite>(
501 : f, a, b, simd::fma((b - a), static_cast<T>(0.5), a), fa, fb, d, fd,
502 : bisection_mask);
503 : --count;
504 : if (update_completed()) {
505 : break;
506 : }
507 : } // while loop
508 :
509 : max_iter -= count;
510 : completed_b =
511 : simd::select(completed_fa == static_cast<T>(0), completed_a, completed_b);
512 : completed_a =
513 : simd::select(completed_fb == static_cast<T>(0), completed_b, completed_a);
514 : return std::pair{completed_a, completed_b};
515 : }
516 : } // namespace toms748_detail
517 :
518 : /*!
519 : * \ingroup NumericalAlgorithmsGroup
520 : * \brief Finds the root of the function `f` with the TOMS_748 method.
521 : *
522 : * `f` is a unary invokable that takes a `double` which is the current value at
523 : * which to evaluate `f`. An example is below.
524 : *
525 : * \snippet Test_TOMS748.cpp double_root_find
526 : *
527 : * The TOMS_748 algorithm searches for a root in the interval [`lower_bound`,
528 : * `upper_bound`], and will throw if this interval does not bracket a root,
529 : * i.e. if `f(lower_bound) * f(upper_bound) > 0`.
530 : *
531 : * The arguments `f_at_lower_bound` and `f_at_upper_bound` are optional, and
532 : * are the function values at `lower_bound` and `upper_bound`. These function
533 : * values are often known because the user typically checks if a root is
534 : * bracketed before calling `toms748`; passing the function values here saves
535 : * two function evaluations.
536 : *
537 : * \note if `AssumeFinite` is true than the code assumes all numbers are
538 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
539 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
540 : * that products like `fa * fb` are also finite.
541 : *
542 : * \requires Function `f` is invokable with a `double`
543 : *
544 : * \throws `convergence_error` if the requested tolerance is not met after
545 : * `max_iterations` iterations.
546 : */
547 : template <bool AssumeFinite = false, typename Function, typename T>
548 1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
549 : const T f_at_lower_bound, const T f_at_upper_bound,
550 : const simd::scalar_type_t<T> absolute_tolerance,
551 : const simd::scalar_type_t<T> relative_tolerance,
552 : const size_t max_iterations = 100,
553 : const simd::mask_type_t<T> ignore_filter =
554 : static_cast<simd::mask_type_t<T>>(0)) {
555 : ASSERT(relative_tolerance >
556 : std::numeric_limits<simd::scalar_type_t<T>>::epsilon(),
557 : "The relative tolerance is too small. Got "
558 : << relative_tolerance << " but must be at least "
559 : << std::numeric_limits<simd::scalar_type_t<T>>::epsilon());
560 : if (simd::any(f_at_lower_bound * f_at_upper_bound > 0.0)) {
561 : ERROR("Root not bracketed: f(" << lower_bound << ") = " << f_at_lower_bound
562 : << ", f(" << upper_bound
563 : << ") = " << f_at_upper_bound);
564 : }
565 :
566 : std::size_t max_iters = max_iterations;
567 :
568 : // This solver requires tol to be passed as a termination condition. This
569 : // termination condition is equivalent to the convergence criteria used by the
570 : // GSL
571 : const auto tol = [absolute_tolerance, relative_tolerance](const T& lhs,
572 : const T& rhs) {
573 : return simd::abs(lhs - rhs) <=
574 : simd::fma(T(relative_tolerance),
575 : simd::min(simd::abs(lhs), simd::abs(rhs)),
576 : T(absolute_tolerance));
577 : };
578 : auto result = toms748_detail::toms748_solve<AssumeFinite>(
579 : f, lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, tol,
580 : ignore_filter, max_iters);
581 : if (max_iters >= max_iterations) {
582 : throw convergence_error(
583 : MakeString{}
584 : << std::setprecision(8) << std::scientific
585 : << "toms748 reached max iterations without converging.\nAbsolute "
586 : "tolerance: "
587 : << absolute_tolerance << "\nRelative tolerance: " << relative_tolerance
588 : << "\nResult: " << get_output(result.first) << " "
589 : << get_output(result.second));
590 : }
591 : return simd::fma(static_cast<T>(0.5), (result.second - result.first),
592 : result.first);
593 : }
594 :
595 : /*!
596 : * \ingroup NumericalAlgorithmsGroup
597 : * \brief Finds the root of the function `f` with the TOMS_748 method, where
598 : * function values are not supplied at the lower and upper bounds.
599 : *
600 : * \note if `AssumeFinite` is true than the code assumes all numbers are
601 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
602 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
603 : * that products like `fa * fb` are also finite.
604 : */
605 : template <bool AssumeFinite = false, typename Function, typename T>
606 1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
607 : const simd::scalar_type_t<T> absolute_tolerance,
608 : const simd::scalar_type_t<T> relative_tolerance,
609 : const size_t max_iterations = 100,
610 : const simd::mask_type_t<T> ignore_filter =
611 : static_cast<simd::mask_type_t<T>>(0)) {
612 : return toms748<AssumeFinite>(
613 : f, lower_bound, upper_bound, f(lower_bound), f(upper_bound),
614 : absolute_tolerance, relative_tolerance, max_iterations, ignore_filter);
615 : }
616 :
617 : /*!
618 : * \ingroup NumericalAlgorithmsGroup
619 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
620 : * element in a `DataVector`.
621 : *
622 : * `f` is a binary invokable that takes a `double` as its first argument and a
623 : * `size_t` as its second. The `double` is the current value at which to
624 : * evaluate `f`, and the `size_t` is the current index into the `DataVector`s.
625 : * Below is an example of how to root find different functions by indexing into
626 : * a lambda-captured `DataVector` using the `size_t` passed to `f`.
627 : *
628 : * \snippet Test_TOMS748.cpp datavector_root_find
629 : *
630 : * For each index `i` into the DataVector, the TOMS_748 algorithm searches for a
631 : * root in the interval [`lower_bound[i]`, `upper_bound[i]`], and will throw if
632 : * this interval does not bracket a root,
633 : * i.e. if `f(lower_bound[i], i) * f(upper_bound[i], i) > 0`.
634 : *
635 : * See the [Boost](http://www.boost.org/) documentation for more details.
636 : *
637 : * \requires Function `f` be callable with a `double` and a `size_t`
638 : *
639 : * \note if `AssumeFinite` is true than the code assumes all numbers are
640 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
641 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
642 : * that products like `fa * fb` are also finite.
643 : *
644 : * \throws `convergence_error` if, for any index, the requested tolerance is not
645 : * met after `max_iterations` iterations.
646 : */
647 : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
648 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
649 : const DataVector& upper_bound,
650 : const double absolute_tolerance,
651 : const double relative_tolerance,
652 : const size_t max_iterations = 100) {
653 : DataVector result_vector{lower_bound.size()};
654 : if constexpr (UseSimd) {
655 : constexpr size_t simd_width{simd::size<
656 : std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
657 : const size_t vectorized_size =
658 : lower_bound.size() - lower_bound.size() % simd_width;
659 : for (size_t i = 0; i < vectorized_size; i += simd_width) {
660 : simd::store_unaligned(
661 : &result_vector[i],
662 : toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
663 : simd::load_unaligned(&lower_bound[i]),
664 : simd::load_unaligned(&upper_bound[i]),
665 : absolute_tolerance, relative_tolerance,
666 : max_iterations));
667 : }
668 : for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
669 : result_vector[i] = toms748<AssumeFinite>(
670 : [&f, i](const auto x) { return f(x, i); }, lower_bound[i],
671 : upper_bound[i], absolute_tolerance, relative_tolerance,
672 : max_iterations);
673 : }
674 : } else {
675 : for (size_t i = 0; i < result_vector.size(); ++i) {
676 : result_vector[i] = toms748<AssumeFinite>(
677 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
678 : absolute_tolerance, relative_tolerance, max_iterations);
679 : }
680 : }
681 : return result_vector;
682 : }
683 :
684 : /*!
685 : * \ingroup NumericalAlgorithmsGroup
686 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
687 : * element in a `DataVector`, where function values are supplied at the lower
688 : * and upper bounds.
689 : *
690 : * Supplying function values is an optimization that saves two
691 : * function calls per point. The function values are often available
692 : * because one often checks if the root is bracketed before calling `toms748`.
693 : *
694 : * \note if `AssumeFinite` is true than the code assumes all numbers are
695 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
696 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
697 : * that products like `fa * fb` are also finite.
698 : */
699 : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
700 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
701 : const DataVector& upper_bound,
702 : const DataVector& f_at_lower_bound,
703 : const DataVector& f_at_upper_bound,
704 : const double absolute_tolerance,
705 : const double relative_tolerance,
706 : const size_t max_iterations = 100) {
707 : DataVector result_vector{lower_bound.size()};
708 : if constexpr (UseSimd) {
709 : constexpr size_t simd_width{simd::size<
710 : std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
711 : const size_t vectorized_size =
712 : lower_bound.size() - lower_bound.size() % simd_width;
713 : for (size_t i = 0; i < vectorized_size; i += simd_width) {
714 : simd::store_unaligned(
715 : &result_vector[i],
716 : toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
717 : simd::load_unaligned(&lower_bound[i]),
718 : simd::load_unaligned(&upper_bound[i]),
719 : simd::load_unaligned(&f_at_lower_bound[i]),
720 : simd::load_unaligned(&f_at_upper_bound[i]),
721 : absolute_tolerance, relative_tolerance,
722 : max_iterations));
723 : }
724 : for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
725 : result_vector[i] = toms748<AssumeFinite>(
726 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
727 : f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
728 : relative_tolerance, max_iterations);
729 : }
730 : } else {
731 : for (size_t i = 0; i < lower_bound.size(); ++i) {
732 : result_vector[i] = toms748<AssumeFinite>(
733 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
734 : f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
735 : relative_tolerance, max_iterations);
736 : }
737 : }
738 : return result_vector;
739 : }
740 : } // namespace RootFinder
|