Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #ifdef SPECTRE_USE_XSIMD
7 : #include <limits>
8 : #include <xsimd/xsimd.hpp>
9 :
10 : #include "Utilities/MakeWithValue.hpp"
11 : #include "Utilities/TypeTraits/CreateGetTypeAliasOrDefault.hpp"
12 :
13 : /// Namespace containing SIMD functions based on XSIMD.
14 : namespace simd = xsimd;
15 :
16 : namespace MakeWithValueImpls {
17 : template <typename U, typename T, typename Arch>
18 : struct MakeWithValueImpl<xsimd::batch<U, Arch>, T> {
19 : static SPECTRE_ALWAYS_INLINE xsimd::batch<U, Arch> apply(const T& /* input */,
20 : const U value) {
21 : return xsimd::batch<U, Arch>(value);
22 : }
23 : };
24 : } // namespace MakeWithValueImpls
25 :
26 : namespace xsimd {
27 : CREATE_GET_TYPE_ALIAS_OR_DEFAULT(value_type)
28 :
29 : namespace detail {
30 : template <typename T>
31 : struct size_impl : std::integral_constant<size_t, 1> {};
32 :
33 : template <typename T, typename A>
34 : struct size_impl<batch<T, A>>
35 : : std::integral_constant<size_t, batch<T, A>::size> {};
36 : } // namespace detail
37 :
38 : template <typename T>
39 : constexpr size_t size() {
40 : return detail::size_impl<T>::value;
41 : }
42 :
43 : namespace detail {
44 : template <typename T, size_t... Is>
45 : T make_sequence_impl(std::index_sequence<Is...> /*meta*/) {
46 : return T{static_cast<typename T::value_type>(Is)...};
47 : }
48 : } // namespace detail
49 :
50 : template <typename T>
51 : T make_sequence() {
52 : return detail::make_sequence_impl<T>(std::make_index_sequence<size<T>()>{});
53 : }
54 : } // namespace xsimd
55 :
56 : namespace std {
57 : template <typename T, typename Arch>
58 : class numeric_limits<::xsimd::batch<T, Arch>> : public numeric_limits<T> {
59 : public:
60 : static constexpr bool is_iec559 = false;
61 :
62 : static ::xsimd::batch<T, Arch> min() {
63 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::min());
64 : }
65 : static ::xsimd::batch<T, Arch> lowest() {
66 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::lowest());
67 : }
68 : static ::xsimd::batch<T, Arch> max() {
69 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::max());
70 : }
71 : static ::xsimd::batch<T, Arch> epsilon() {
72 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::epsilon());
73 : }
74 : static ::xsimd::batch<T, Arch> round_error() {
75 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::round_error());
76 : }
77 : static ::xsimd::batch<T, Arch> infinity() {
78 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::infinity());
79 : }
80 : static ::xsimd::batch<T, Arch> quiet_NaN() {
81 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::quiet_NaN());
82 : }
83 : static ::xsimd::batch<T, Arch> signaling_NaN() {
84 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::signaling_NaN());
85 : }
86 : static ::xsimd::batch<T, Arch> denorm_min() {
87 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::denorm_min());
88 : }
89 : };
90 : } // namespace std
91 : #else // no xsimd
92 :
93 : #include <cmath>
94 : #include <complex>
95 : #include <type_traits>
96 :
97 : #include "Utilities/Requires.hpp"
98 :
99 0 : namespace simd {
100 : template <typename T, typename A = void>
101 0 : class batch;
102 : template <typename T, typename A = void>
103 0 : class batch_bool;
104 :
105 : template <typename T>
106 0 : struct scalar_type {
107 0 : using type = T;
108 : };
109 : template <typename T>
110 0 : using scalar_type_t = typename scalar_type<T>::type;
111 :
112 : template <typename T>
113 0 : struct mask_type {
114 0 : using type = bool;
115 : };
116 : template <typename T>
117 0 : using mask_type_t = typename mask_type<T>::type;
118 :
119 : template <typename T>
120 0 : struct is_batch : std::false_type {};
121 :
122 : template <typename T, typename A>
123 0 : struct is_batch<batch<T, A>> : std::true_type {};
124 :
125 : namespace detail {
126 : template <typename T>
127 : struct size_impl : std::integral_constant<size_t, 1> {};
128 : } // namespace detail
129 :
130 : template <typename T>
131 0 : constexpr size_t size() {
132 : return detail::size_impl<T>::value;
133 : }
134 :
135 : namespace detail {
136 : template <typename T, size_t... Is>
137 : T make_sequence_impl(std::index_sequence<Is...> /*meta*/) {
138 : return T{static_cast<typename T::value_type>(Is)...};
139 : }
140 : } // namespace detail
141 :
142 : template <typename T>
143 0 : T make_sequence() {
144 : return detail::make_sequence_impl<T>(std::make_index_sequence<size<T>()>{});
145 : }
146 :
147 : // NOLINTBEGIN(misc-unused-using-decls)
148 : using std::abs;
149 : using std::acos;
150 : using std::acosh;
151 : using std::arg;
152 : using std::asin;
153 : using std::asinh;
154 : using std::atan;
155 : using std::atan2;
156 : using std::atanh;
157 : using std::cbrt;
158 : using std::ceil;
159 : using std::conj;
160 : using std::copysign;
161 : using std::cos;
162 : using std::cosh;
163 : using std::erf;
164 : using std::erfc;
165 : using std::exp;
166 : using std::exp2;
167 : using std::expm1;
168 : using std::fabs;
169 : using std::fdim;
170 : using std::floor;
171 : using std::fmax;
172 : using std::fmin;
173 : using std::fmod;
174 : using std::hypot;
175 : using std::isfinite;
176 : using std::isinf;
177 : using std::isnan;
178 : using std::ldexp;
179 : using std::lgamma;
180 : using std::log;
181 : using std::log10;
182 : using std::log1p;
183 : using std::log2;
184 : using std::max;
185 : using std::min;
186 : using std::modf;
187 : using std::nearbyint;
188 : using std::nextafter;
189 : using std::norm;
190 : using std::polar;
191 : using std::proj;
192 : using std::remainder;
193 : using std::rint;
194 : using std::round;
195 : using std::sin;
196 : using std::sinh;
197 : using std::sqrt;
198 : using std::tan;
199 : using std::tanh;
200 : using std::tgamma;
201 : using std::trunc;
202 : // NOLINTEND(misc-unused-using-decls)
203 :
204 0 : inline bool all(const bool mask) { return mask; }
205 :
206 0 : inline bool any(const bool mask) { return mask; }
207 :
208 0 : inline bool none(const bool mask) { return not mask; }
209 :
210 : template <typename T, Requires<std::is_scalar_v<T>> = nullptr>
211 0 : T clip(const T& val, const T& low, const T& hi) {
212 : assert(low <= hi && "ordered clipping bounds");
213 : return low > val ? low : (hi < val ? hi : val);
214 : }
215 :
216 : #if defined(__GLIBC__)
217 : inline float exp10(const float& x) { return ::exp10f(x); }
218 : inline double exp10(const double& x) { return ::exp10(x); }
219 : #else
220 0 : inline float exp10(const float& x) {
221 : const float ln10 = std::log(10.f);
222 : return std::exp(ln10 * x);
223 : }
224 0 : inline double exp10(const double& x) {
225 : const double ln10 = std::log(10.);
226 : return std::exp(ln10 * x);
227 : }
228 : #endif
229 :
230 0 : inline double sign(const bool& v) { return static_cast<double>(v); }
231 :
232 : template <typename T>
233 0 : T sign(const T& v) {
234 : return v < static_cast<T>(0) ? static_cast<T>(-1.)
235 : : v == static_cast<T>(0) ? static_cast<T>(0.)
236 : : static_cast<T>(1.);
237 : }
238 :
239 : template <typename T>
240 0 : T select(const bool cond, const T true_branch, const T false_branch) {
241 : return cond ? true_branch : false_branch;
242 : }
243 :
244 0 : inline std::pair<float, float> sincos(const float val) {
245 : float result_sin{};
246 : float result_cos{};
247 : __sincosf(val, &result_sin, &result_cos);
248 : return std::make_pair(result_sin, result_cos);
249 : }
250 :
251 0 : inline std::pair<double, double> sincos(const double val) {
252 : double result_sin{};
253 : double result_cos{};
254 : __sincos(val, &result_sin, &result_cos);
255 : return std::make_pair(result_sin, result_cos);
256 : }
257 :
258 : template <typename T, Requires<std::is_integral_v<T>> = nullptr>
259 0 : T fma(const T a, const T b, const T c) {
260 : return a * b + c;
261 : }
262 :
263 : template <typename T, Requires<std::is_floating_point_v<T>> = nullptr>
264 : T fma(const T a, const T b, const T c) {
265 : return std::fma(a, b, c);
266 : }
267 :
268 : template <typename T, Requires<std::is_integral_v<T>> = nullptr>
269 0 : T fms(const T& a, const T& b, const T& c) {
270 : return a * b - c;
271 : }
272 :
273 : template <typename T, Requires<std::is_floating_point_v<T>> = nullptr>
274 0 : T fms(const T a, const T b, const T c) {
275 : return std::fma(a, b, -c);
276 : }
277 :
278 : template <typename T, Requires<std::is_integral_v<T>> = nullptr>
279 0 : T fnma(const T a, const T b, const T c) {
280 : return -(a * b) + c;
281 : }
282 :
283 : template <typename T, Requires<std::is_floating_point_v<T>> = nullptr>
284 : T fnma(const T a, const T b, const T c) {
285 : return std::fma(-a, b, c);
286 : }
287 :
288 : template <typename T, Requires<std::is_integral_v<T>> = nullptr>
289 0 : T fnms(const T a, const T b, const T c) {
290 : return -(a * b) - c;
291 : }
292 :
293 : template <typename T, Requires<std::is_floating_point_v<T>> = nullptr>
294 : T fnms(const T a, const T b, const T c) {
295 : return -std::fma(a, b, c);
296 : }
297 :
298 : template <typename Arch = void, typename From>
299 0 : From load(From* mem) {
300 : static_assert(std::is_arithmetic_v<From>);
301 : return *mem;
302 : }
303 :
304 : template <typename Arch = void, typename T>
305 0 : void store(T* mem, const T& val) {
306 : static_assert(std::is_arithmetic_v<T>);
307 : *mem = val;
308 : }
309 :
310 : template <typename Arch = void, typename From>
311 0 : From load_unaligned(From* mem) {
312 : static_assert(std::is_arithmetic_v<From>);
313 : return *mem;
314 : }
315 :
316 : template <typename Arch = void, typename T>
317 0 : void store_unaligned(T* mem, const T& val) {
318 : static_assert(std::is_arithmetic_v<T>);
319 : *mem = val;
320 : }
321 : } // namespace simd
322 : #endif
|