Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Declares function RootFinder::toms748
6 :
7 : #pragma once
8 :
9 : #include <boost/math/tools/roots.hpp>
10 : #include <functional>
11 : #include <limits>
12 :
13 : #include "DataStructures/DataVector.hpp"
14 : #include "Utilities/ErrorHandling/Exceptions.hpp"
15 :
16 : namespace RootFinder {
17 :
18 : /*!
19 : * \ingroup NumericalAlgorithmsGroup
20 : * \brief Finds the root of the function `f` with the TOMS_748 method.
21 : *
22 : * `f` is a unary invokable that takes a `double` which is the current value at
23 : * which to evaluate `f`. An example is below.
24 : *
25 : * \snippet Test_TOMS748.cpp double_root_find
26 : *
27 : * The TOMS_748 algorithm searches for a root in the interval [`lower_bound`,
28 : * `upper_bound`], and will throw if this interval does not bracket a root,
29 : * i.e. if `f(lower_bound) * f(upper_bound) > 0`.
30 : *
31 : * The arguments `f_at_lower_bound` and `f_at_upper_bound` are optional, and
32 : * are the function values at `lower_bound` and `upper_bound`. These function
33 : * values are often known because the user typically checks if a root is
34 : * bracketed before calling `toms748`; passing the function values here saves
35 : * two function evaluations.
36 : *
37 : * See the [Boost](http://www.boost.org/) documentation for more details.
38 : *
39 : * \requires Function `f` is invokable with a `double`
40 : *
41 : * \throws `std::domain_error` if the bounds do not bracket a root.
42 : * \throws `convergence_error` if the requested tolerance is not met after
43 : * `max_iterations` iterations.
44 : */
45 : template <typename Function>
46 1 : double toms748(const Function& f, const double lower_bound,
47 : const double upper_bound, const double f_at_lower_bound,
48 : const double f_at_upper_bound, const double absolute_tolerance,
49 : const double relative_tolerance,
50 : const size_t max_iterations = 100) {
51 : ASSERT(relative_tolerance > std::numeric_limits<double>::epsilon(),
52 : "The relative tolerance is too small.");
53 :
54 : boost::uintmax_t max_iters = max_iterations;
55 :
56 : // This solver requires tol to be passed as a termination condition. This
57 : // termination condition is equivalent to the convergence criteria used by the
58 : // GSL
59 : auto tol = [absolute_tolerance, relative_tolerance](double lhs, double rhs) {
60 : return (fabs(lhs - rhs) <=
61 : absolute_tolerance +
62 : relative_tolerance * fmin(fabs(lhs), fabs(rhs)));
63 : };
64 : // clang-tidy: internal boost warning, can't fix it.
65 : auto result = boost::math::tools::toms748_solve( // NOLINT
66 : f, lower_bound, upper_bound, f_at_lower_bound, f_at_upper_bound, tol,
67 : max_iters);
68 : if (max_iters >= max_iterations) {
69 : throw convergence_error(
70 : "toms748 reached max iterations without converging");
71 : }
72 : return result.first + 0.5 * (result.second - result.first);
73 : }
74 :
75 : /*!
76 : * \ingroup NumericalAlgorithmsGroup
77 : * \brief Finds the root of the function `f` with the TOMS_748 method,
78 : * where function values are not supplied at the lower and upper
79 : * bounds.
80 : */
81 : template <typename Function>
82 1 : double toms748(const Function& f, const double lower_bound,
83 : const double upper_bound, const double absolute_tolerance,
84 : const double relative_tolerance,
85 : const size_t max_iterations = 100) {
86 : return toms748(f, lower_bound, upper_bound, f(lower_bound), f(upper_bound),
87 : absolute_tolerance, relative_tolerance, max_iterations);
88 : }
89 :
90 : namespace detail {
91 : template <typename Function>
92 : DataVector toms748_impl(const Function& f, const DataVector& lower_bound,
93 : const DataVector& upper_bound,
94 : const DataVector& f_at_lower_bound,
95 : const DataVector& f_at_upper_bound,
96 : const double absolute_tolerance,
97 : const double relative_tolerance,
98 : const size_t max_iterations,
99 : const bool function_values_are_supplied) {
100 : ASSERT(relative_tolerance > std::numeric_limits<double>::epsilon(),
101 : "The relative tolerance is too small.");
102 : // This solver requires tol to be passed as a termination condition. This
103 : // termination condition is equivalent to the convergence criteria used by the
104 : // GSL
105 : auto tol = [absolute_tolerance, relative_tolerance](const double lhs,
106 : const double rhs) {
107 : return (fabs(lhs - rhs) <=
108 : absolute_tolerance +
109 : relative_tolerance * fmin(fabs(lhs), fabs(rhs)));
110 : };
111 : DataVector result_vector{lower_bound.size()};
112 : for (size_t i = 0; i < result_vector.size(); ++i) {
113 : // toms748_solver modifies the max_iters after the root is found to the
114 : // number of iterations that it took to find the root, so we reset it to
115 : // max_iterations after each root find.
116 : boost::uintmax_t max_iters = max_iterations;
117 : auto result = function_values_are_supplied
118 : ?
119 : // clang-tidy: internal boost warning, can't fix it.
120 : boost::math::tools::toms748_solve( // NOLINT
121 : [&f, i](double x) { return f(x, i); }, lower_bound[i],
122 : upper_bound[i], f_at_lower_bound[i],
123 : f_at_upper_bound[i], tol, max_iters)
124 : :
125 : // clang-tidy: internal boost warning, can't fix it.
126 : boost::math::tools::toms748_solve( // NOLINT
127 : [&f, i](double x) { return f(x, i); }, lower_bound[i],
128 : upper_bound[i], tol, max_iters);
129 : if (max_iters >= max_iterations) {
130 : throw convergence_error(
131 : "toms748 reached max iterations without converging");
132 : }
133 : result_vector[i] = result.first + 0.5 * (result.second - result.first);
134 : }
135 : return result_vector;
136 : }
137 : } // namespace detail
138 :
139 : /*!
140 : * \ingroup NumericalAlgorithmsGroup
141 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
142 : * element in a `DataVector`.
143 : *
144 : * `f` is a binary invokable that takes a `double` as its first argument and a
145 : * `size_t` as its second. The `double` is the current value at which to
146 : * evaluate `f`, and the `size_t` is the current index into the `DataVector`s.
147 : * Below is an example of how to root find different functions by indexing into
148 : * a lambda-captured `DataVector` using the `size_t` passed to `f`.
149 : *
150 : * \snippet Test_TOMS748.cpp datavector_root_find
151 : *
152 : * For each index `i` into the DataVector, the TOMS_748 algorithm searches for a
153 : * root in the interval [`lower_bound[i]`, `upper_bound[i]`], and will throw if
154 : * this interval does not bracket a root,
155 : * i.e. if `f(lower_bound[i], i) * f(upper_bound[i], i) > 0`.
156 : *
157 : * See the [Boost](http://www.boost.org/) documentation for more details.
158 : *
159 : * \requires Function `f` be callable with a `double` and a `size_t`
160 : *
161 : * \throws `std::domain_error` if, for any index, the bounds do not bracket a
162 : * root.
163 : * \throws `convergence_error` if, for any index, the requested tolerance is not
164 : * met after `max_iterations` iterations.
165 : */
166 : template <typename Function>
167 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
168 : const DataVector& upper_bound,
169 : const double absolute_tolerance,
170 : const double relative_tolerance,
171 : const size_t max_iterations = 100) {
172 : return detail::toms748_impl(f, lower_bound, upper_bound, DataVector{},
173 : DataVector{}, absolute_tolerance,
174 : relative_tolerance, max_iterations, false);
175 : }
176 :
177 : /*!
178 : * \ingroup NumericalAlgorithmsGroup
179 : * \brief Finds the root of the function `f` with the TOMS_748 method on each
180 : * element in a `DataVector`, where function values are supplied at the lower
181 : * and upper bounds.
182 : *
183 : * Supplying function values is an optimization that saves two
184 : * function calls per point. The function values are often available
185 : * because one often checks if the root is bracketed before calling `toms748`.
186 : */
187 : template <typename Function>
188 1 : DataVector toms748(const Function& f, const DataVector& lower_bound,
189 : const DataVector& upper_bound,
190 : const DataVector& f_at_lower_bound,
191 : const DataVector& f_at_upper_bound,
192 : const double absolute_tolerance,
193 : const double relative_tolerance,
194 : const size_t max_iterations = 100) {
195 : return detail::toms748_impl(f, lower_bound, upper_bound, f_at_lower_bound,
196 : f_at_upper_bound, absolute_tolerance,
197 : relative_tolerance, max_iterations, true);
198 : }
199 :
200 : } // namespace RootFinder
|