Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Declares function RootFinder::toms748
6 :
7 : #pragma once
8 :
9 : #include <functional>
10 : #include <iomanip>
11 : #include <ios>
12 : #include <limits>
13 : #include <string>
14 : #include <type_traits>
15 :
16 : #include "DataStructures/DataVector.hpp"
17 : #include "Utilities/ErrorHandling/Error.hpp"
18 : #include "Utilities/ErrorHandling/Exceptions.hpp"
19 : #include "Utilities/GetOutput.hpp"
20 : #include "Utilities/MakeString.hpp"
21 : #include "Utilities/Simd/Simd.hpp"
22 :
23 : namespace RootFinder {
24 : namespace toms748_detail {
25 : // Original implementation of TOMS748 is from Boost:
26 : // (C) Copyright John Maddock 2006.
27 : // Use, modification and distribution are subject to the
28 : // Boost Software License, Version 1.0. (See accompanying file
29 : // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
30 : //
31 : // Significant changes made to pretty much all of it to support SIMD by Nils
32 : // Deppe. Changes are copyrighted by SXS Collaboration under the MIT License.
33 :
34 : template <typename T>
35 : T safe_div(const T& num, const T& denom, const T& r) {
36 : if constexpr (std::is_floating_point_v<T>) {
37 : using std::abs;
38 : if (abs(denom) < static_cast<T>(1)) {
39 : if (abs(denom * std::numeric_limits<T>::max()) <= abs(num)) {
40 : return r;
41 : }
42 : }
43 : return num / denom;
44 : } else {
45 : // return num / denom without overflow, return r if overflow would occur.
46 : const auto mask0 = fabs(denom) < static_cast<T>(1);
47 : // Note: if denom >= 1.0 you get an FPE because of overflow from
48 : // `max() * (1. + value)`
49 : const auto mask = fabs(simd::select(mask0, denom, static_cast<T>(0)) *
50 : std::numeric_limits<T>::max()) <= fabs(num);
51 : return simd::select(mask0 and mask, r,
52 : num / simd::select(mask, denom, static_cast<T>(1)));
53 : }
54 : }
55 :
56 : template <typename T>
57 : T secant_interpolate(const T& a, const T& b, const T& fa, const T& fb) {
58 : //
59 : // Performs standard secant interpolation of [a,b] given
60 : // function evaluations f(a) and f(b). Performs a bisection
61 : // if secant interpolation would leave us very close to either
62 : // a or b. Rationale: we only call this function when at least
63 : // one other form of interpolation has already failed, so we know
64 : // that the function is unlikely to be smooth with a root very
65 : // close to a or b.
66 : //
67 : const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(5);
68 : // WARNING: There are several different ways to implement the interpolation
69 : // that all have different rounding properties. Unfortunately this means that
70 : // tolerances even at 1e-14 can be difficult to achieve generically.
71 : //
72 : // `g` below is:
73 : // const T g = (fa / (fb - fa));
74 : //
75 : // const T c = simd::fma((fa / (fb - fa)), (a - b), a); // fails
76 : //
77 : // const T c = a * (static_cast<T>(1) - g) + g * b; // works
78 : //
79 : // const T c = simd::fma(a, (static_cast<T>(1) - g), g * b); // works
80 : //
81 : // Original Boost code:
82 : // const T c = a - (fa / (fb - fa)) * (b - a); // works
83 :
84 : const T c = a - (fa / (fb - fa)) * (b - a);
85 : return simd::select((c <= simd::fma(fabs(a), tol_batch, a)) or
86 : (c >= simd::fnma(fabs(b), tol_batch, b)),
87 : static_cast<T>(0.5) * (a + b), c);
88 : }
89 :
90 : template <bool AssumeFinite, typename T>
91 : T quadratic_interpolate(const T& a, const T& b, const T& d, const T& fa,
92 : const T& fb, const T& fd,
93 : const simd::mask_type_t<T>& incomplete_mask,
94 : const unsigned count) {
95 : // Performs quadratic interpolation to determine the next point,
96 : // takes count Newton steps to find the location of the
97 : // quadratic polynomial.
98 : //
99 : // Point d must lie outside of the interval [a,b], it is the third
100 : // best approximation to the root, after a and b.
101 : //
102 : // Note: this does not guarantee to find a root
103 : // inside [a, b], so we fall back to a secant step should
104 : // the result be out of range.
105 : //
106 : // Start by obtaining the coefficients of the quadratic polynomial:
107 : const T B = safe_div(fb - fa, b - a, std::numeric_limits<T>::max());
108 : T A = safe_div(fd - fb, d - b, std::numeric_limits<T>::max());
109 : A = safe_div(A - B, d - a, static_cast<T>(0));
110 :
111 : const auto secant_failure_mask = A == static_cast<T>(0) and incomplete_mask;
112 : T result_secant{};
113 : if (UNLIKELY(simd::any(secant_failure_mask))) {
114 : // failure to determine coefficients, try a secant step:
115 : result_secant = secant_interpolate(a, b, fa, fb);
116 : if (UNLIKELY(simd::all(secant_failure_mask or (not incomplete_mask)))) {
117 : return result_secant;
118 : }
119 : }
120 :
121 : // Determine the starting point of the Newton steps:
122 : //
123 : // Note: unlike Boost, we assume A*fa doesn't overflow. This speeds up the
124 : // code quite a bit.
125 : T c = AssumeFinite
126 : ? simd::select(A * fa > static_cast<T>(0), a, b)
127 : : simd::select(simd::sign(A) * simd::sign(fa) > static_cast<T>(0),
128 : a, b);
129 :
130 : // Take the Newton steps:
131 : const T two_A = static_cast<T>(2) * A;
132 : const T half_a_plus_b = 0.5 * (a + b);
133 : const T one_minus_a = static_cast<T>(1) - a;
134 : const T B_minus_A_times_b = B - A * b;
135 : for (unsigned i = 1; i <= count; ++i) {
136 : c -= safe_div(simd::fma(simd::fma(A, c, B_minus_A_times_b), c - a, fa),
137 : simd::fma(two_A, c - half_a_plus_b, B), one_minus_a + c);
138 : }
139 : if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
140 : simd::any(mask)) {
141 : // Failure, try a secant step:
142 : c = simd::select(mask, secant_interpolate(a, b, fa, fb), c);
143 : }
144 : return simd::select(secant_failure_mask, result_secant, c);
145 : }
146 :
147 : template <bool AssumeFinite, typename T>
148 : T cubic_interpolate(const T& a, const T& b, const T& d, const T& e, const T& fa,
149 : const T& fb, const T& fd, const T& fe,
150 : const simd::mask_type_t<T>& incomplete_mask) {
151 : // Uses inverse cubic interpolation of f(x) at points
152 : // [a,b,d,e] to obtain an approximate root of f(x).
153 : // Points d and e lie outside the interval [a,b]
154 : // and are the third and forth best approximations
155 : // to the root that we have found so far.
156 : //
157 : // Note: this does not guarantee to find a root
158 : // inside [a, b], so we fall back to quadratic
159 : // interpolation in case of an erroneous result.
160 : //
161 : // This commented chunk is the original Boost implementation translated into
162 : // simd. The actual code below is a heavily optimized version.
163 : //
164 : // const T q11 = (d - e) * fd / (fe - fd);
165 : // const T q21 = (b - d) * fb / (fd - fb);
166 : // const T q31 = (a - b) * fa / (fb - fa);
167 : // const T d21 = (b - d) * fd / (fd - fb);
168 : // const T d31 = (a - b) * fb / (fb - fa);
169 : //
170 : // const T q22 = (d21 - q11) * fb / (fe - fb);
171 : // const T q32 = (d31 - q21) * fa / (fd - fa);
172 : // const T d32 = (d31 - q21) * fd / (fd - fa);
173 : // const T q33 = (d32 - q22) * fa / (fe - fa);
174 : //
175 : // T c = q31 + q32 + q33 + a;
176 :
177 : // The optimized implementation here is L1-cache bound. That is, we aren't
178 : // able to completely saturate the FP units because we are waiting on the L1
179 : // cache. While not ideal, that's okay and just part of the algorithm.
180 : const T denom_fb_fa = fb - fa;
181 : const T denom_fd_fb = fd - fb;
182 : const T denom_fd_fa = fd - fa;
183 : const T denom_fe_fd = fe - fd;
184 : const T denom_fe_fb = fe - fb;
185 : const T denom =
186 : denom_fe_fb * denom_fe_fd * denom_fd_fa * denom_fd_fb * (fe - fa);
187 :
188 : // Avoid division by zero with mask.
189 : const T fa_by_denom = fa / simd::select(incomplete_mask, denom_fb_fa * denom,
190 : static_cast<T>(1));
191 :
192 : const T d31 = (a - b);
193 : const T q21 = (b - d);
194 : const T q32 = simd::fms(denom_fd_fb, d31, denom_fb_fa * q21) * denom_fe_fb *
195 : denom_fe_fd;
196 : const T q22 = simd::fms(denom_fe_fd, q21, (d - e) * denom_fd_fb) * fd *
197 : denom_fb_fa * denom_fd_fa;
198 :
199 : // Note: the reduction in rounding error that comes from the improvement by
200 : // Stoer & Bulirsch to Neville's algorithm is adding `a` at the very end as
201 : // we do below. Alternative ways of evaluating polynomials do not delay this
202 : // inclusion of `a`, and so then when the correction to `a` is small,
203 : // floating point errors decrease the accuracy of the result.
204 : T c = simd::fma(
205 : fa_by_denom,
206 : simd::fma(fb, simd::fms(q32, (fe + denom_fd_fa), q22), d31 * denom), a);
207 :
208 : if (const auto mask = ((c <= a) or (c >= b)) and incomplete_mask;
209 : simd::any(mask)) {
210 : // Out of bounds step, fall back to quadratic interpolation:
211 : //
212 : // Note: we only apply quadratic interpolation at points where cubic
213 : // failed and that aren't already at a root.
214 : c = simd::select(
215 : mask, quadratic_interpolate<AssumeFinite>(a, b, d, fa, fb, fd, mask, 3),
216 : c);
217 : }
218 :
219 : return c;
220 : }
221 :
222 : template <bool AssumeFinite, typename F, typename T>
223 : void bracket(F f, T& a, T& b, T c, T& fa, T& fb, T& d, T& fd,
224 : const simd::mask_type_t<T>& incomplete_mask) {
225 : // Given a point c inside the existing enclosing interval
226 : // [a, b] sets a = c if f(c) == 0, otherwise finds the new
227 : // enclosing interval: either [a, c] or [c, b] and sets
228 : // d and fd to the point that has just been removed from
229 : // the interval. In other words d is the third best guess
230 : // to the root.
231 : //
232 : // Note: `bracket` will only modify slots marked as `true` in
233 : // `incomplete_mask`
234 : const T tol_batch = std::numeric_limits<T>::epsilon() * static_cast<T>(2);
235 :
236 : // If the interval [a,b] is very small, or if c is too close
237 : // to one end of the interval then we need to adjust the
238 : // location of c accordingly. This is:
239 : //
240 : // if ((b - a) < 2 * tol * a) {
241 : // c = a + (b - a) / 2;
242 : // } else if (c <= a + fabs(a) * tol) {
243 : // c = a + fabs(a) * tol;
244 : // } else if (c >= b - fabs(b) * tol) {
245 : // c = b - fabs(b) * tol;
246 : // }
247 : const T a_filt = simd::fma(fabs(a), tol_batch, a);
248 : const T b_filt = simd::fnma(fabs(b), tol_batch, b);
249 : const T b_minus_a = b - a;
250 : c = simd::select(
251 : (static_cast<T>(2) * tol_batch * a > b_minus_a) and incomplete_mask,
252 : simd::fma(b_minus_a, static_cast<T>(0.5), a),
253 : simd::select(c <= a_filt, a_filt, simd::select(c >= b_filt, b_filt, c)));
254 :
255 : // Invoke f(c):
256 : T fc = f(c);
257 :
258 : // if we have a zero then we have an exact solution to the root:
259 : const auto fc_is_zero_mask = (fc == static_cast<T>(0));
260 : if (const auto mask = fc_is_zero_mask and incomplete_mask;
261 : UNLIKELY(simd::any(mask))) {
262 : a = simd::select(mask, c, a);
263 : fa = simd::select(mask, static_cast<T>(0), fa);
264 : d = simd::select(mask, static_cast<T>(0), d);
265 : fd = simd::select(mask, static_cast<T>(0), fd);
266 : if (UNLIKELY(simd::all(mask or not incomplete_mask))) {
267 : return;
268 : }
269 : }
270 :
271 : // Non-zero fc, update the interval:
272 : //
273 : // Note: unlike Boost, we assume fa*fc doesn't overflow. This speeds up the
274 : // code quite a bit.
275 : //
276 : // Boost code is:
277 : // if (boost::math::sign(fa) * boost::math::sign(fc) < 0) {...} else {...}
278 : using simd::sign;
279 : const auto sign_mask = AssumeFinite
280 : ? (fa * fc < static_cast<T>(0))
281 : : (sign(fa) * sign(fc) < static_cast<T>(0));
282 : const auto mask_if =
283 : (sign_mask and (not fc_is_zero_mask)) and incomplete_mask;
284 : d = simd::select(mask_if, b, d);
285 : fd = simd::select(mask_if, fb, fd);
286 : b = simd::select(mask_if, c, b);
287 : fb = simd::select(mask_if, fc, fb);
288 :
289 : const auto mask_else =
290 : ((not sign_mask) and (not fc_is_zero_mask)) and incomplete_mask;
291 : d = simd::select(mask_else, a, d);
292 : fd = simd::select(mask_else, fa, fd);
293 : a = simd::select(mask_else, c, a);
294 : fa = simd::select(mask_else, fc, fa);
295 : }
296 :
297 : template <bool AssumeFinite, class F, class T, class Tol>
298 : std::pair<T, T> toms748_solve(F f, const T& ax, const T& bx, const T& fax,
299 : const T& fbx, Tol tol,
300 : const simd::mask_type_t<T>& ignore_filter,
301 : size_t& max_iter) {
302 : // Main entry point and logic for Toms Algorithm 748
303 : // root finder.
304 : if (UNLIKELY(simd::any(ax >= bx))) {
305 : throw std::domain_error("Lower bound is larger than upper bound");
306 : }
307 :
308 : // Sanity check - are we allowed to iterate at all?
309 : if (UNLIKELY(max_iter == 0)) {
310 : return std::pair{ax, bx};
311 : }
312 :
313 : size_t count = max_iter;
314 : // mu is a parameter in the algorithm that must be between (0, 1).
315 : static const T mu = 0.5f;
316 :
317 : // initialise a, b and fa, fb:
318 : T a = ax;
319 : T b = bx;
320 : T fa = fax;
321 : T fb = fbx;
322 :
323 : const auto fa_is_zero_mask = (fa == static_cast<T>(0));
324 : const auto fb_is_zero_mask = (fb == static_cast<T>(0));
325 : auto completion_mask =
326 : tol(a, b) or fa_is_zero_mask or fb_is_zero_mask or ignore_filter;
327 : auto incomplete_mask = not completion_mask;
328 : if (UNLIKELY(simd::all(completion_mask))) {
329 : max_iter = 0;
330 : return std::pair{simd::select(fb_is_zero_mask, b, a),
331 : simd::select(fa_is_zero_mask, a, b)};
332 : }
333 :
334 : // Note: unlike Boost, we can assume fa*fb doesn't overflow when possible.
335 : // This speeds up the code quite a bit.
336 : if (UNLIKELY(simd::any((AssumeFinite ? (fa * fb > static_cast<T>(0))
337 : : (simd::sign(fa) * simd::sign(fb) >
338 : static_cast<T>(0))) and
339 : (not fa_is_zero_mask) and (not fb_is_zero_mask)))) {
340 : throw std::domain_error(
341 : "Parameters lower and upper bounds do not bracket a root");
342 : }
343 : // dummy value for fd, e and fe:
344 : T fe(static_cast<T>(1e5F));
345 : T e(static_cast<T>(1e5F));
346 : T fd(static_cast<T>(1e5F));
347 :
348 : T c(std::numeric_limits<T>::signaling_NaN());
349 : T d(std::numeric_limits<T>::signaling_NaN());
350 :
351 : const T nan(std::numeric_limits<T>::signaling_NaN());
352 : auto completed_a = simd::select(completion_mask, a, nan);
353 : auto completed_b = simd::select(completion_mask, b, nan);
354 : auto completed_fa = simd::select(completion_mask, fa, nan);
355 : auto completed_fb = simd::select(completion_mask, fb, nan);
356 : const auto update_completed = [&fa, &fb, &completion_mask, &incomplete_mask,
357 : &completed_a, &completed_b, &a, &b,
358 : &completed_fa, &completed_fb, &tol]() {
359 : const auto new_completed =
360 : (fa == static_cast<T>(0) or tol(a, b)) and (not completion_mask);
361 : completed_a = simd::select(new_completed, a, completed_a);
362 : completed_b = simd::select(new_completed, b, completed_b);
363 : completed_fa = simd::select(new_completed, fa, completed_fa);
364 : completed_fb = simd::select(new_completed, fb, completed_fb);
365 : completion_mask = new_completed or completion_mask;
366 : incomplete_mask = not completion_mask;
367 : // returns true if _all_ simd registers have been completed
368 : return simd::all(completion_mask);
369 : };
370 :
371 : if (simd::any(fa != static_cast<T>(0))) {
372 : // On the first step we take a secant step:
373 : c = toms748_detail::secant_interpolate(a, b, fa, fb);
374 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
375 : incomplete_mask);
376 : --count;
377 :
378 : // Note: The Boost fa!=0 check is handled with the completion_mask.
379 : if (count and not update_completed()) {
380 : // On the second step we take a quadratic interpolation:
381 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
382 : a, b, d, fa, fb, fd, incomplete_mask, 2);
383 : e = d;
384 : fe = fd;
385 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
386 : incomplete_mask);
387 : --count;
388 : update_completed();
389 : }
390 : }
391 :
392 : T u(std::numeric_limits<T>::signaling_NaN());
393 : T fu(std::numeric_limits<T>::signaling_NaN());
394 : T a0(std::numeric_limits<T>::signaling_NaN());
395 : T b0(std::numeric_limits<T>::signaling_NaN());
396 :
397 : // Note: The Boost fa!=0 check is handled with the completion_mask.
398 : while (count and not simd::all(completion_mask)) {
399 : // save our brackets:
400 : a0 = a;
401 : b0 = b;
402 : // Starting with the third step taken
403 : // we can use either quadratic or cubic interpolation.
404 : // Cubic interpolation requires that all four function values
405 : // fa, fb, fd, and fe are distinct, should that not be the case
406 : // then variable prof will get set to true, and we'll end up
407 : // taking a quadratic step instead.
408 : static const T min_diff = std::numeric_limits<T>::min() * 32;
409 : bool prof =
410 : simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
411 : (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
412 : (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
413 : incomplete_mask);
414 : if (prof) {
415 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
416 : a, b, d, fa, fb, fd, incomplete_mask, 2);
417 : } else {
418 : c = toms748_detail::cubic_interpolate<AssumeFinite>(
419 : a, b, d, e, fa, fb, fd, fe, incomplete_mask);
420 : }
421 : // re-bracket, and check for termination:
422 : e = d;
423 : fe = fd;
424 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
425 : incomplete_mask);
426 : if ((0 == --count) or update_completed()) {
427 : break;
428 : }
429 : // Now another interpolated step:
430 : prof =
431 : simd::any(((fabs(fa - fb) < min_diff) or (fabs(fa - fd) < min_diff) or
432 : (fabs(fa - fe) < min_diff) or (fabs(fb - fd) < min_diff) or
433 : (fabs(fb - fe) < min_diff) or (fabs(fd - fe) < min_diff)) and
434 : incomplete_mask);
435 : if (prof) {
436 : c = toms748_detail::quadratic_interpolate<AssumeFinite>(
437 : a, b, d, fa, fb, fd, incomplete_mask, 3);
438 : } else {
439 : c = toms748_detail::cubic_interpolate<AssumeFinite>(
440 : a, b, d, e, fa, fb, fd, fe, incomplete_mask);
441 : }
442 : // Bracket again, and check termination condition, update e:
443 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
444 : incomplete_mask);
445 : if ((0 == --count) or update_completed()) {
446 : break;
447 : }
448 :
449 : // Now we take a double-length secant step:
450 : const auto fabs_fa_less_fabs_fb_mask =
451 : (fabs(fa) < fabs(fb)) and incomplete_mask;
452 : u = simd::select(fabs_fa_less_fabs_fb_mask, a, b);
453 : fu = simd::select(fabs_fa_less_fabs_fb_mask, fa, fb);
454 : const T b_minus_a = b - a;
455 : // Assumes that bounds a & b are not so close that fa == fb. Boost makes
456 : // this assumption too. If this is violated then the algorithm doesn't
457 : // work since the function at least appears constant.
458 : c = simd::fnma(static_cast<T>(2) * (fu / (fb - fa)), b_minus_a, u);
459 : c = simd::select(static_cast<T>(2) * fabs(c - u) > b_minus_a,
460 : simd::fma(static_cast<T>(0.5), b_minus_a, a), c);
461 :
462 : // Bracket again, and check termination condition:
463 : e = d;
464 : fe = fd;
465 : toms748_detail::bracket<AssumeFinite>(f, a, b, c, fa, fb, d, fd,
466 : incomplete_mask);
467 : if ((0 == --count) or update_completed()) {
468 : break;
469 : }
470 :
471 : // And finally... check to see if an additional bisection step is
472 : // to be taken, we do this if we're not converging fast enough:
473 : const auto bisection_mask = (b - a) >= mu * (b0 - a0) and incomplete_mask;
474 : if (LIKELY(simd::none(bisection_mask))) {
475 : continue;
476 : }
477 : // bracket again on a bisection:
478 : //
479 : // Note: the mask ensures we only ever modify the slots that the mask has
480 : // identified as needing to be modified.
481 : e = simd::select(bisection_mask, d, e);
482 : fe = simd::select(bisection_mask, fd, fe);
483 : toms748_detail::bracket<AssumeFinite>(
484 : f, a, b, simd::fma((b - a), static_cast<T>(0.5), a), fa, fb, d, fd,
485 : bisection_mask);
486 : --count;
487 : if (update_completed()) {
488 : break;
489 : }
490 : } // while loop
491 :
492 : max_iter -= count;
493 : completed_b =
494 : simd::select(completed_fa == static_cast<T>(0), completed_a, completed_b);
495 : completed_a =
496 : simd::select(completed_fb == static_cast<T>(0), completed_b, completed_a);
497 : return std::pair{completed_a, completed_b};
498 : }
499 : } // namespace toms748_detail
500 :
501 : /*!
502 : * \ingroup NumericalAlgorithmsGroup
503 : * \brief Finds the root of the function `f` with the TOMS_748 method.
504 : *
505 : * `f` is a unary invokable that takes a `double` which is the current value at
506 : * which to evaluate `f`. An example is below.
507 : *
508 : * \snippet Test_TOMS748.cpp double_root_find
509 : *
510 : * The TOMS_748 algorithm searches for a root in the interval [`lower_bound`,
511 : * `upper_bound`], and will throw if this interval does not bracket a root,
512 : * i.e. if `f(lower_bound) * f(upper_bound) > 0`.
513 : *
514 : * The arguments `f_at_lower_bound` and `f_at_upper_bound` are optional, and
515 : * are the function values at `lower_bound` and `upper_bound`. These function
516 : * values are often known because the user typically checks if a root is
517 : * bracketed before calling `toms748`; passing the function values here saves
518 : * two function evaluations.
519 : *
520 : * \note if `AssumeFinite` is true than the code assumes all numbers are
521 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
522 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
523 : * that products like `fa * fb` are also finite.
524 : *
525 : * \requires Function `f` is invokable with a `double`
526 : *
527 : * \throws `convergence_error` if the requested tolerance is not met after
528 : * `max_iterations` iterations.
529 : */
530 : template <bool AssumeFinite = false, typename Function, typename T>
531 1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
532 : const T f_at_lower_bound, const T f_at_upper_bound,
533 : const simd::scalar_type_t<T> absolute_tolerance,
534 : const simd::scalar_type_t<T> relative_tolerance,
535 : const size_t max_iterations = 100,
536 : const simd::mask_type_t<T> ignore_filter =
537 : static_cast<simd::mask_type_t<T>>(0)) {
538 : ASSERT(relative_tolerance >
539 : std::numeric_limits<simd::scalar_type_t<T>>::epsilon(),
540 : "The relative tolerance is too small. Got "
541 : << relative_tolerance << " but must be at least "
542 : << std::numeric_limits<simd::scalar_type_t<T>>::epsilon());
543 : if (simd::any(f_at_lower_bound * f_at_upper_bound > 0.0)) {
544 : ERROR("Root not bracketed: f(" << lower_bound << ") = " << f_at_lower_bound
545 : << ", f(" << upper_bound
546 : << ") = " << f_at_upper_bound);
547 : }
548 :
549 : std::size_t max_iters = max_iterations;
550 :
551 : // This solver requires tol to be passed as a termination condition. This
552 : // termination condition is equivalent to the convergence criteria used by the
553 : // GSL
554 : const auto tol = [absolute_tolerance, relative_tolerance](const T& lhs,
555 : const T& rhs) {
556 : return simd::abs(lhs - rhs) <=
557 : simd::fma(T(relative_tolerance),
558 : simd::min(simd::abs(lhs), simd::abs(rhs)),
559 : T(absolute_tolerance));
560 : };
561 : auto result = toms748_detail::toms748_solve<AssumeFinite>(
562 : f, lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, tol,
563 : ignore_filter, max_iters);
564 : if (max_iters >= max_iterations) {
565 : throw convergence_error(
566 : MakeString{}
567 : << std::setprecision(8) << std::scientific
568 : << "toms748 reached max iterations without converging.\nAbsolute "
569 : "tolerance: "
570 : << absolute_tolerance << "\nRelative tolerance: " << relative_tolerance
571 : << "\nResult: " << get_output(result.first) << " "
572 : << get_output(result.second));
573 : }
574 : return simd::fma(static_cast<T>(0.5), (result.second - result.first),
575 : result.first);
576 : }
577 :
578 : /*!
579 : * \ingroup NumericalAlgorithmsGroup
580 : * \brief Finds the root of the function `f` with the TOMS_748 method, where
581 : * function values are not supplied at the lower and upper bounds.
582 : *
583 : * \note if `AssumeFinite` is true than the code assumes all numbers are
584 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
585 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
586 : * that products like `fa * fb` are also finite.
587 : */
588 : template <bool AssumeFinite = false, typename Function, typename T>
589 1 : T toms748(const Function& f, const T lower_bound, const T upper_bound,
590 : const simd::scalar_type_t<T> absolute_tolerance,
591 : const simd::scalar_type_t<T> relative_tolerance,
592 : const size_t max_iterations = 100,
593 : const simd::mask_type_t<T> ignore_filter =
594 : static_cast<simd::mask_type_t<T>>(0)) {
595 : return toms748<AssumeFinite>(
596 : f, lower_bound, upper_bound, f(lower_bound), f(upper_bound),
597 : absolute_tolerance, relative_tolerance, max_iterations, ignore_filter);
598 : }
599 :
600 : /*!
601 : * \ingroup NumericalAlgorithmsGroup
602 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
603 : * element in a `DataVector`.
604 : *
605 : * `f` is a binary invokable that takes a `double` as its first argument and a
606 : * `size_t` as its second. The `double` is the current value at which to
607 : * evaluate `f`, and the `size_t` is the current index into the `DataVector`s.
608 : * Below is an example of how to root find different functions by indexing into
609 : * a lambda-captured `DataVector` using the `size_t` passed to `f`.
610 : *
611 : * \snippet Test_TOMS748.cpp datavector_root_find
612 : *
613 : * For each index `i` into the DataVector, the TOMS_748 algorithm searches for a
614 : * root in the interval [`lower_bound[i]`, `upper_bound[i]`], and will throw if
615 : * this interval does not bracket a root,
616 : * i.e. if `f(lower_bound[i], i) * f(upper_bound[i], i) > 0`.
617 : *
618 : * See the [Boost](http://www.boost.org/) documentation for more details.
619 : *
620 : * \requires Function `f` be callable with a `double` and a `size_t`
621 : *
622 : * \note if `AssumeFinite` is true than the code assumes all numbers are
623 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
624 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
625 : * that products like `fa * fb` are also finite.
626 : *
627 : * \throws `convergence_error` if, for any index, the requested tolerance is not
628 : * met after `max_iterations` iterations.
629 : */
630 : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
631 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
632 : const DataVector& upper_bound,
633 : const double absolute_tolerance,
634 : const double relative_tolerance,
635 : const size_t max_iterations = 100) {
636 : DataVector result_vector{lower_bound.size()};
637 : if constexpr (UseSimd) {
638 : constexpr size_t simd_width{simd::size<
639 : std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
640 : const size_t vectorized_size =
641 : lower_bound.size() - lower_bound.size() % simd_width;
642 : for (size_t i = 0; i < vectorized_size; i += simd_width) {
643 : simd::store_unaligned(
644 : &result_vector[i],
645 : toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
646 : simd::load_unaligned(&lower_bound[i]),
647 : simd::load_unaligned(&upper_bound[i]),
648 : absolute_tolerance, relative_tolerance,
649 : max_iterations));
650 : }
651 : for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
652 : result_vector[i] = toms748<AssumeFinite>(
653 : [&f, i](const auto x) { return f(x, i); }, lower_bound[i],
654 : upper_bound[i], absolute_tolerance, relative_tolerance,
655 : max_iterations);
656 : }
657 : } else {
658 : for (size_t i = 0; i < result_vector.size(); ++i) {
659 : result_vector[i] = toms748<AssumeFinite>(
660 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
661 : absolute_tolerance, relative_tolerance, max_iterations);
662 : }
663 : }
664 : return result_vector;
665 : }
666 :
667 : /*!
668 : * \ingroup NumericalAlgorithmsGroup
669 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
670 : * element in a `DataVector`, where function values are supplied at the lower
671 : * and upper bounds.
672 : *
673 : * Supplying function values is an optimization that saves two
674 : * function calls per point. The function values are often available
675 : * because one often checks if the root is bracketed before calling `toms748`.
676 : *
677 : * \note if `AssumeFinite` is true than the code assumes all numbers are
678 : * finite and that `a > 0` is equivalent to `sign(a) > 0`. This reduces
679 : * runtime but will cause bugs if the numbers aren't finite. It also assumes
680 : * that products like `fa * fb` are also finite.
681 : */
682 : template <bool UseSimd = true, bool AssumeFinite = false, typename Function>
683 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
684 : const DataVector& upper_bound,
685 : const DataVector& f_at_lower_bound,
686 : const DataVector& f_at_upper_bound,
687 : const double absolute_tolerance,
688 : const double relative_tolerance,
689 : const size_t max_iterations = 100) {
690 : DataVector result_vector{lower_bound.size()};
691 : if constexpr (UseSimd) {
692 : constexpr size_t simd_width{simd::size<
693 : std::decay_t<decltype(simd::load_unaligned(lower_bound.data()))>>()};
694 : const size_t vectorized_size =
695 : lower_bound.size() - lower_bound.size() % simd_width;
696 : for (size_t i = 0; i < vectorized_size; i += simd_width) {
697 : simd::store_unaligned(
698 : &result_vector[i],
699 : toms748<AssumeFinite>([&f, i](const auto x) { return f(x, i); },
700 : simd::load_unaligned(&lower_bound[i]),
701 : simd::load_unaligned(&upper_bound[i]),
702 : simd::load_unaligned(&f_at_lower_bound[i]),
703 : simd::load_unaligned(&f_at_upper_bound[i]),
704 : absolute_tolerance, relative_tolerance,
705 : max_iterations));
706 : }
707 : for (size_t i = vectorized_size; i < lower_bound.size(); ++i) {
708 : result_vector[i] = toms748<AssumeFinite>(
709 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
710 : f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
711 : relative_tolerance, max_iterations);
712 : }
713 : } else {
714 : for (size_t i = 0; i < lower_bound.size(); ++i) {
715 : result_vector[i] = toms748<AssumeFinite>(
716 : [&f, i](double x) { return f(x, i); }, lower_bound[i], upper_bound[i],
717 : f_at_lower_bound[i], f_at_upper_bound[i], absolute_tolerance,
718 : relative_tolerance, max_iterations);
719 : }
720 : }
721 : return result_vector;
722 : }
723 : } // namespace RootFinder
|