Line data Source code
1 0 : // Distributed under the MIT License. 2 : // See LICENSE.txt for details. 3 : 4 : #pragma once 5 : 6 : #include <blaze/math/DenseVector.h> 7 : #include <blaze/math/constraints/SIMDPack.h> 8 : #include <blaze/math/simd/BasicTypes.h> 9 : #include <blaze/system/Inline.h> 10 : #include <blaze/system/Vectorization.h> 11 : 12 : namespace blaze { 13 : // This vectorized implementation of the step function is necessary because 14 : // blaze does not offer its own version of a vectorized step function. 15 : template <typename T> 16 0 : BLAZE_ALWAYS_INLINE SIMDdouble step_function(const SIMDf64<T>& v) 17 : #if BLAZE_AVX512F_MODE || BLAZE_MIC_MODE 18 : { 19 : return _mm512_set_pd((*v).eval().value[7] < 0.0 ? 0.0 : 1.0, 20 : (*v).eval().value[6] < 0.0 ? 0.0 : 1.0, 21 : (*v).eval().value[5] < 0.0 ? 0.0 : 1.0, 22 : (*v).eval().value[4] < 0.0 ? 0.0 : 1.0, 23 : (*v).eval().value[3] < 0.0 ? 0.0 : 1.0, 24 : (*v).eval().value[2] < 0.0 ? 0.0 : 1.0, 25 : (*v).eval().value[1] < 0.0 ? 0.0 : 1.0, 26 : (*v).eval().value[0] < 0.0 ? 0.0 : 1.0); 27 : } 28 : #elif BLAZE_AVX_MODE 29 : { 30 : return _mm256_set_pd((*v).eval().value[3] < 0.0 ? 0.0 : 1.0, 31 : (*v).eval().value[2] < 0.0 ? 0.0 : 1.0, 32 : (*v).eval().value[1] < 0.0 ? 0.0 : 1.0, 33 : (*v).eval().value[0] < 0.0 ? 0.0 : 1.0); 34 : } 35 : #elif BLAZE_SSE2_MODE 36 : { 37 : return _mm_set_pd((*v).eval().value[1] < 0.0 ? 0.0 : 1.0, 38 : (*v).eval().value[0] < 0.0 ? 0.0 : 1.0); 39 : } 40 : #else 41 : { 42 : return SIMDdouble{(*v).value < 0.0 ? 0.0 : 1.0}; 43 : } 44 : #endif 45 : 46 0 : BLAZE_ALWAYS_INLINE double step_function(const double v) { 47 : return v < 0.0 ? 0.0 : 1.0; 48 : } 49 : 50 0 : struct StepFunction { 51 0 : explicit inline StepFunction() = default; 52 : 53 : template <typename T> 54 0 : BLAZE_ALWAYS_INLINE decltype(auto) operator()(const T& a) const { 55 : return step_function(a); 56 : } 57 : 58 : template <typename T> 59 0 : BLAZE_ALWAYS_INLINE decltype(auto) load(const T& a) const { 60 : BLAZE_CONSTRAINT_MUST_BE_SIMD_PACK(T); 61 : return step_function(a); 62 : } 63 : }; 64 : } // namespace blaze 65 : 66 : template <typename VT, bool TF> 67 0 : BLAZE_ALWAYS_INLINE decltype(auto) step_function( 68 : const blaze::DenseVector<VT, TF>& vec) { 69 : return map(*vec, blaze::StepFunction{}); 70 : } 71 : 72 : template <typename VT, bool TF> 73 0 : BLAZE_ALWAYS_INLINE decltype(auto) StepFunction( 74 : const blaze::DenseVector<VT, TF>& vec) { 75 : return map(*vec, blaze::StepFunction{}); 76 : }