Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #ifdef SPECTRE_USE_XSIMD
7 : #include <limits>
8 : #include <xsimd/xsimd.hpp>
9 :
10 : #include "Utilities/MakeWithValue.hpp"
11 : #include "Utilities/TypeTraits/CreateGetTypeAliasOrDefault.hpp"
12 :
13 : /// Namespace containing SIMD functions based on XSIMD.
14 : namespace simd = xsimd;
15 :
16 : namespace MakeWithValueImpls {
17 : template <typename U, typename T, typename Arch>
18 : struct MakeWithValueImpl<xsimd::batch<U, Arch>, T> {
19 : static SPECTRE_ALWAYS_INLINE xsimd::batch<U, Arch> apply(const T& /* input */,
20 : const U value) {
21 : return xsimd::batch<U, Arch>(value);
22 : }
23 : };
24 : } // namespace MakeWithValueImpls
25 :
26 : namespace xsimd {
27 : CREATE_GET_TYPE_ALIAS_OR_DEFAULT(value_type)
28 :
29 : namespace detail {
30 : template <typename T>
31 : struct size_impl : std::integral_constant<size_t, 1> {};
32 :
33 : template <typename T, typename A>
34 : struct size_impl<batch<T, A>>
35 : : std::integral_constant<size_t, batch<T, A>::size> {};
36 : } // namespace detail
37 :
38 : template <typename T>
39 : constexpr size_t size() {
40 : return detail::size_impl<T>::value;
41 : }
42 :
43 : namespace detail {
44 : template <typename T, size_t... Is>
45 : T make_sequence_impl(std::index_sequence<Is...> /*meta*/) {
46 : return T{static_cast<typename T::value_type>(Is)...};
47 : }
48 : } // namespace detail
49 :
50 : template <typename T>
51 : T make_sequence() {
52 : return detail::make_sequence_impl<T>(std::make_index_sequence<size<T>()>{});
53 : }
54 : } // namespace xsimd
55 :
56 : namespace std {
57 : template <typename T, typename Arch>
58 : class numeric_limits<::xsimd::batch<T, Arch>> : public numeric_limits<T> {
59 : public:
60 : static constexpr bool is_iec559 = false;
61 :
62 : static ::xsimd::batch<T, Arch> min() {
63 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::min());
64 : }
65 : static ::xsimd::batch<T, Arch> lowest() {
66 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::lowest());
67 : }
68 : static ::xsimd::batch<T, Arch> max() {
69 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::max());
70 : }
71 : static ::xsimd::batch<T, Arch> epsilon() {
72 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::epsilon());
73 : }
74 : static ::xsimd::batch<T, Arch> round_error() {
75 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::round_error());
76 : }
77 : static ::xsimd::batch<T, Arch> infinity() {
78 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::infinity());
79 : }
80 : static ::xsimd::batch<T, Arch> quiet_NaN() {
81 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::quiet_NaN());
82 : }
83 : static ::xsimd::batch<T, Arch> signaling_NaN() {
84 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::signaling_NaN());
85 : }
86 : static ::xsimd::batch<T, Arch> denorm_min() {
87 : return ::xsimd::batch<T, Arch>(numeric_limits<T>::denorm_min());
88 : }
89 : };
90 : } // namespace std
91 : #else // no xsimd
92 :
93 : #include <cmath>
94 : #include <complex>
95 : #include <type_traits>
96 :
97 : #include "Utilities/Requires.hpp"
98 :
99 0 : namespace simd {
100 : template <typename T, typename A = void>
101 0 : class batch;
102 : template <typename T, typename A = void>
103 0 : class batch_bool;
104 :
105 : template <typename T>
106 0 : struct scalar_type {
107 0 : using type = T;
108 : };
109 : template <typename T>
110 0 : using scalar_type_t = typename scalar_type<T>::type;
111 :
112 : template <typename T>
113 0 : struct mask_type {
114 0 : using type = bool;
115 : };
116 : template <typename T>
117 0 : using mask_type_t = typename mask_type<T>::type;
118 :
119 : template <typename T>
120 0 : struct is_batch : std::false_type {};
121 :
122 : template <typename T, typename A>
123 0 : struct is_batch<batch<T, A>> : std::true_type {};
124 :
125 : namespace detail {
126 : template <typename T>
127 : struct size_impl : std::integral_constant<size_t, 1> {};
128 : } // namespace detail
129 :
130 : template <typename T>
131 0 : constexpr size_t size() {
132 : return detail::size_impl<T>::value;
133 : }
134 :
135 : namespace detail {
136 : template <typename T, size_t... Is>
137 : T make_sequence_impl(std::index_sequence<Is...> /*meta*/) {
138 : return T{static_cast<typename T::value_type>(Is)...};
139 : }
140 : } // namespace detail
141 :
142 : template <typename T>
143 0 : T make_sequence() {
144 : return detail::make_sequence_impl<T>(std::make_index_sequence<size<T>()>{});
145 : }
146 :
147 : // NOLINTBEGIN(misc-unused-using-decls)
148 : using std::abs;
149 : using std::acos;
150 : using std::acosh;
151 : using std::arg;
152 : using std::asin;
153 : using std::asinh;
154 : using std::atan;
155 : using std::atan2;
156 : using std::atanh;
157 : using std::cbrt;
158 : using std::ceil;
159 : using std::conj;
160 : using std::copysign;
161 : using std::cos;
162 : using std::cosh;
163 : using std::erf;
164 : using std::erfc;
165 : using std::exp;
166 : using std::exp2;
167 : using std::expm1;
168 : using std::fabs;
169 : using std::fdim;
170 : using std::floor;
171 : using std::fmax;
172 : using std::fmin;
173 : using std::fmod;
174 : using std::hypot;
175 : using std::isfinite;
176 : using std::isinf;
177 : using std::isnan;
178 : using std::ldexp;
179 : using std::lgamma;
180 : using std::log;
181 : using std::log10;
182 : using std::log1p;
183 : using std::log2;
184 : using std::max;
185 : using std::min;
186 : using std::modf;
187 : using std::nearbyint;
188 : using std::nextafter;
189 : using std::norm;
190 : using std::polar;
191 : using std::proj;
192 : using std::remainder;
193 : using std::rint;
194 : using std::round;
195 : using std::sin;
196 : using std::sinh;
197 : using std::sqrt;
198 : using std::tan;
199 : using std::tanh;
200 : using std::tgamma;
201 : using std::trunc;
202 : // NOLINTEND(misc-unused-using-decls)
203 :
204 0 : inline bool all(const bool mask) { return mask; }
205 :
206 0 : inline bool any(const bool mask) { return mask; }
207 :
208 0 : inline bool none(const bool mask) { return not mask; }
209 :
210 : template <typename T, Requires<std::is_scalar_v<T>> = nullptr>
211 0 : T clip(const T& val, const T& low, const T& hi) {
212 : assert(low <= hi && "ordered clipping bounds");
213 : return low > val ? low : (hi < val ? hi : val);
214 : }
215 :
216 : #if defined(__GLIBC__)
217 : inline float exp10(const float& x) { return ::exp10f(x); }
218 : inline double exp10(const double& x) { return ::exp10(x); }
219 : #else
220 0 : inline float exp10(const float& x) {
221 : const float ln10 = std::log(10.f);
222 : return std::exp(ln10 * x);
223 : }
224 0 : inline double exp10(const double& x) {
225 : const double ln10 = std::log(10.);
226 : return std::exp(ln10 * x);
227 : }
228 : #endif
229 :
230 0 : inline double sign(const bool& v) { return static_cast<double>(v); }
231 :
232 : template <typename T>
233 0 : T sign(const T& v) {
234 : return v < static_cast<T>(0) ? static_cast<T>(-1.)
235 : : v == static_cast<T>(0) ? static_cast<T>(0.)
236 : : static_cast<T>(1.);
237 : }
238 :
239 : template <typename T>
240 0 : T select(const bool cond, const T true_branch, const T false_branch) {
241 : return cond ? true_branch : false_branch;
242 : }
243 :
244 0 : inline std::pair<float, float> sincos(const float val) {
245 : // The nvcc compiler's built-in __sincos is for GPU code, not CPU code. In
246 : // the case that we are running on a GPU (__CUDA_ARCH__ is defined) or we
247 : // are not using nvcc then use the builtin, otherwise call sin and cos
248 : // separately.
249 : #if (defined(__CUDACC__) && defined(__CUDA_ARCH__)) or (not defined(__CUDACC__))
250 : float result_sin{};
251 : float result_cos{};
252 : __sincosf(val, &result_sin, &result_cos);
253 : return std::pair{result_sin, result_cos};
254 : #else
255 : return std::pair{sin(val), cos(val)};
256 : #endif
257 : }
258 :
259 0 : inline std::pair<double, double> sincos(const double val) {
260 : // The nvcc compiler's built-in __sincos is for GPU code, not CPU code. In
261 : // the case that we are running on a GPU (__CUDA_ARCH__ is defined) or we
262 : // are not using nvcc then use the builtin, otherwise call sin and cos
263 : // separately.
264 : #if (defined(__CUDACC__) && defined(__CUDA_ARCH__)) or (not defined(__CUDACC__))
265 : double result_sin{};
266 : double result_cos{};
267 : __sincos(val, &result_sin, &result_cos);
268 : return std::pair{result_sin, result_cos};
269 : #else
270 : return std::pair{sin(val), cos(val)};
271 : #endif
272 : }
273 :
274 : template <typename T, Requires<std::is_integral_v<T>> = nullptr>
275 0 : T fma(const T a, const T b, const T c) {
276 : return a * b + c;
277 : }
278 :
279 : template <typename T, Requires<std::is_floating_point_v<T>> = nullptr>
280 : T fma(const T a, const T b, const T c) {
281 : return std::fma(a, b, c);
282 : }
283 :
284 : template <typename T, Requires<std::is_integral_v<T>> = nullptr>
285 0 : T fms(const T& a, const T& b, const T& c) {
286 : return a * b - c;
287 : }
288 :
289 : template <typename T, Requires<std::is_floating_point_v<T>> = nullptr>
290 0 : T fms(const T a, const T b, const T c) {
291 : return std::fma(a, b, -c);
292 : }
293 :
294 : template <typename T, Requires<std::is_integral_v<T>> = nullptr>
295 0 : T fnma(const T a, const T b, const T c) {
296 : return -(a * b) + c;
297 : }
298 :
299 : template <typename T, Requires<std::is_floating_point_v<T>> = nullptr>
300 : T fnma(const T a, const T b, const T c) {
301 : return std::fma(-a, b, c);
302 : }
303 :
304 : template <typename T, Requires<std::is_integral_v<T>> = nullptr>
305 0 : T fnms(const T a, const T b, const T c) {
306 : return -(a * b) - c;
307 : }
308 :
309 : template <typename T, Requires<std::is_floating_point_v<T>> = nullptr>
310 : T fnms(const T a, const T b, const T c) {
311 : return -std::fma(a, b, c);
312 : }
313 :
314 : template <typename Arch = void, typename From>
315 0 : From load(From* mem) {
316 : static_assert(std::is_arithmetic_v<From>);
317 : return *mem;
318 : }
319 :
320 : template <typename Arch = void, typename T>
321 0 : void store(T* mem, const T& val) {
322 : static_assert(std::is_arithmetic_v<T>);
323 : *mem = val;
324 : }
325 :
326 : template <typename Arch = void, typename From>
327 0 : From load_unaligned(From* mem) {
328 : static_assert(std::is_arithmetic_v<From>);
329 : return *mem;
330 : }
331 :
332 : template <typename Arch = void, typename T>
333 0 : void store_unaligned(T* mem, const T& val) {
334 : static_assert(std::is_arithmetic_v<T>);
335 : *mem = val;
336 : }
337 : } // namespace simd
338 : #endif
|