1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #include <algorithm>
7 : #include <limits>
8 : #include <type_traits>
9 :
10 : #include "Utilities/Algorithm.hpp"
11 : #include "Utilities/ConstantExpressions.hpp"
12 : #include "Utilities/ErrorHandling/Assert.hpp"
13 : #include "Utilities/ForceInline.hpp"
14 : #include "Utilities/TypeTraits/IsIterable.hpp"
15 : #include "Utilities/TypeTraits/IsMaplike.hpp"
16 :
17 : namespace EqualWithinRoundoffImpls {
18 : /*!
19 : * \brief Specialize this class to add support for the `equal_within_roundoff`
20 : * function.
21 : *
22 : * Ensure the `Lhs` and `Rhs` are symmetric. A specialization must implement a
23 : * static `apply` function with this signature:
24 : *
25 : * ```cpp
26 : * static bool apply(const Lhs& lhs, const Rhs& rhs, const double eps,
27 : * const double scale);
28 : * ```
29 : *
30 : * It can be helpful to invoke the `equal_within_roundoff` function for floating
31 : * points from within your specialization.
32 : */
33 : template <typename Lhs, typename Rhs, typename = std::nullptr_t>
34 1 : struct EqualWithinRoundoffImpl;
35 : } // namespace EqualWithinRoundoffImpls
36 :
37 : /*!
38 : * \ingroup UtilitiesGroup
39 : * \brief Checks if two values `lhs` and `rhs` are equal within roundoff, by
40 : * comparing `abs(lhs - rhs) < (max(abs(lhs), abs(rhs)) + scale) * eps`.
41 : *
42 : * The two values can be floating-point numbers, or any types for which
43 : * `EqualWithinRoundoffImpls::EqualWithinRoundoffImpl` has been specialized. For
44 : * example, a default implementation exists for the case where `lhs`, `rhs`, or
45 : * both, are iterable, and compares the values point-wise.
46 : */
47 : template <typename Lhs, typename Rhs>
48 1 : constexpr SPECTRE_ALWAYS_INLINE bool equal_within_roundoff(
49 : const Lhs& lhs, const Rhs& rhs,
50 : const double eps = std::numeric_limits<double>::epsilon() * 100.0,
51 : const double scale = 1.0) {
52 : return EqualWithinRoundoffImpls::EqualWithinRoundoffImpl<Lhs, Rhs>::apply(
53 : lhs, rhs, eps, scale);
54 : }
55 :
56 : /// Specializations of `EqualWithinRoundoffImpl` for custom types, to add
57 : /// support for the `equal_within_roundoff` function.
58 : namespace EqualWithinRoundoffImpls {
59 :
60 : // Compare two floating points
61 : template <typename Floating>
62 : struct EqualWithinRoundoffImpl<Floating, Floating,
63 : Requires<std::is_floating_point_v<Floating>>> {
64 : static constexpr SPECTRE_ALWAYS_INLINE bool apply(const Floating& lhs,
65 : const Floating& rhs,
66 : const double eps,
67 : const double scale) {
68 : return ce_fabs(lhs - rhs) <=
69 : (std::max(ce_fabs(lhs), ce_fabs(rhs)) + scale) * eps;
70 : }
71 : };
72 :
73 : // Compare a complex number to a floating point, interpreting the latter as a
74 : // real number
75 : template <typename Floating>
76 : struct EqualWithinRoundoffImpl<std::complex<Floating>, Floating,
77 : Requires<std::is_floating_point_v<Floating>>> {
78 : static SPECTRE_ALWAYS_INLINE bool apply(const std::complex<Floating>& lhs,
79 : const Floating& rhs, const double eps,
80 : const double scale) {
81 : return equal_within_roundoff(lhs.real(), rhs, eps, scale) and
82 : equal_within_roundoff(lhs.imag(), 0., eps, scale);
83 : }
84 : };
85 :
86 : // Compare a floating point to a complex number, interpreting the former as a
87 : // real number
88 : template <typename Floating>
89 : struct EqualWithinRoundoffImpl<Floating, std::complex<Floating>,
90 : Requires<std::is_floating_point_v<Floating>>> {
91 : static SPECTRE_ALWAYS_INLINE bool apply(const std::complex<Floating>& lhs,
92 : const Floating& rhs, const double eps,
93 : const double scale) {
94 : return equal_within_roundoff(rhs, lhs, eps, scale);
95 : }
96 : };
97 :
98 : // Compare two complex numbers
99 : template <typename Floating>
100 : struct EqualWithinRoundoffImpl<std::complex<Floating>, std::complex<Floating>,
101 : Requires<std::is_floating_point_v<Floating>>> {
102 : static SPECTRE_ALWAYS_INLINE bool apply(const std::complex<Floating>& lhs,
103 : const std::complex<Floating>& rhs,
104 : const double eps,
105 : const double scale) {
106 : return equal_within_roundoff(lhs.real(), rhs.real(), eps, scale) and
107 : equal_within_roundoff(lhs.imag(), rhs.imag(), eps, scale);
108 : }
109 : };
110 :
111 : // Compare an iterable to a floating point
112 : template <typename Lhs, typename Rhs>
113 : struct EqualWithinRoundoffImpl<
114 : Lhs, Rhs,
115 : Requires<tt::is_iterable_v<Lhs> and not tt::is_maplike_v<Lhs> and
116 : std::is_floating_point_v<Rhs>>> {
117 : static SPECTRE_ALWAYS_INLINE bool apply(const Lhs& lhs, const Rhs& rhs,
118 : const double eps,
119 : const double scale) {
120 : return alg::all_of(lhs, [&rhs, &eps, &scale](const auto& lhs_element) {
121 : return equal_within_roundoff(lhs_element, rhs, eps, scale);
122 : });
123 : }
124 : };
125 :
126 : // Compare a floating point to an iterable
127 : template <typename Lhs, typename Rhs>
128 : struct EqualWithinRoundoffImpl<
129 : Lhs, Rhs,
130 : Requires<tt::is_iterable_v<Rhs> and not tt::is_maplike_v<Rhs> and
131 : std::is_floating_point_v<Lhs>>> {
132 : static SPECTRE_ALWAYS_INLINE bool apply(const Lhs& lhs, const Rhs& rhs,
133 : const double eps,
134 : const double scale) {
135 : return equal_within_roundoff(rhs, lhs, eps, scale);
136 : }
137 : };
138 :
139 : // Compare two iterables
140 : template <typename Lhs, typename Rhs>
141 0 : struct EqualWithinRoundoffImpl<
142 : Lhs, Rhs,
143 : Requires<tt::is_iterable_v<Lhs> and not tt::is_maplike_v<Lhs> and
144 : tt::is_iterable_v<Rhs> and not tt::is_maplike_v<Rhs>>> {
145 0 : static bool apply(const Lhs& lhs, const Rhs& rhs, const double eps,
146 : const double scale) {
147 : auto lhs_it = lhs.begin();
148 : auto rhs_it = rhs.begin();
149 : while (lhs_it != lhs.end() and rhs_it != rhs.end()) {
150 : if (not equal_within_roundoff(*lhs_it, *rhs_it, eps, scale)) {
151 : return false;
152 : }
153 : ++lhs_it;
154 : ++rhs_it;
155 : }
156 : ASSERT(lhs_it == lhs.end() and rhs_it == rhs.end(),
157 : "Can't compare lhs and rhs because they have different lengths.");
158 : return true;
159 : }
160 : };
161 :
162 : } // namespace EqualWithinRoundoffImpls