Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Declares the interfaces for the BLAS used.
6 : ///
7 : /// Wrappers are defined to perform casts from different integer types
8 : /// when the natural type in C++ differs from the BLAS argument.
9 :
10 : #pragma once
11 :
12 : #include <complex>
13 :
14 : #ifndef SPECTRE_DEBUG
15 : #include <libxsmm.h>
16 : #endif // ifndef SPECTRE_DEBUG
17 : #include <gsl/gsl_cblas.h>
18 :
19 : #include "Utilities/ErrorHandling/Assert.hpp"
20 : #include "Utilities/Gsl.hpp"
21 :
22 : namespace blas_detail {
23 : extern "C" {
24 : double ddot_(const int& N, const double* X, const int& INCX, const double* Y,
25 : const int& INCY);
26 :
27 : // The final two arguments are the "hidden" lengths of the first two.
28 : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
29 : void dgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
30 : const int& K, const double& ALPHA, const double* A, const int& LDA,
31 : const double* B, const int& LDB, const double& BETA,
32 : const double* C, const int& LDC, size_t, size_t);
33 : void zgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
34 : const int& K, const std::complex<double>& ALPHA,
35 : const std::complex<double>* A, const int& LDA,
36 : const std::complex<double>* B, const int& LDB,
37 : const std::complex<double>& BETA, const std::complex<double>* C,
38 : const int& LDC, size_t, size_t);
39 :
40 : // The final argument is the "hidden" length of the first one.
41 : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
42 : void dgemv_(const char& TRANS, const int& M, const int& N, const double& ALPHA,
43 : const double* A, const int& LDA, const double* X, const int& INCX,
44 : const double& BETA, double* Y, const int& INCY, size_t);
45 : } // extern "C"
46 : } // namespace blas_detail
47 :
48 : /*!
49 : * \brief Disable OpenBLAS multithreading since it conflicts with Charm++
50 : * parallelism
51 : *
52 : * Add this function to the `charm_init_node_funcs` of any executable that uses
53 : * BLAS routines.
54 : *
55 : * Details: https://github.com/xianyi/OpenBLAS/wiki/Faq#multi-threaded
56 : */
57 1 : void disable_openblas_multithreading();
58 :
59 : /// @{
60 : /*!
61 : * \ingroup UtilitiesGroup
62 : * The dot product of two vectors.
63 : *
64 : * \param N the length of the vectors.
65 : * \param X a pointer to the first element of the first vector.
66 : * \param INCX the stride for the elements of the first vector.
67 : * \param Y a pointer to the first element of the second vector.
68 : * \param INCY the stride for the elements of the second vector.
69 : * \return the dot product of the given vectors.
70 : */
71 1 : inline double ddot_(const size_t& N, const double* X, const size_t& INCX,
72 : const double* Y, const size_t& INCY) {
73 : // INCX and INCY are allowed to be negative by BLAS, but we never
74 : // use them that way. If needed, they can be changed here, but then
75 : // code providing values will also have to be changed to int to
76 : // avoid warnings.
77 : return blas_detail::ddot_(gsl::narrow_cast<int>(N), X,
78 : gsl::narrow_cast<int>(INCX), Y,
79 : gsl::narrow_cast<int>(INCY));
80 : }
81 : /// The unconjugated complex dot product $x \cdot y$. See `zdotc_` for the
82 : /// conjugated complex dot product, which is the standard dot product on the
83 : /// vector space of complex numbers.
84 1 : inline std::complex<double> zdotu_(const size_t& N,
85 : const std::complex<double>* X,
86 : const size_t& INCX,
87 : const std::complex<double>* Y,
88 : const size_t& INCY) {
89 : // The complex result of the BLAS zdot* functions is sometimes returned by
90 : // value and sometimes returned by reference, depending on the Fortran
91 : // compiler settings. By using the cblas interface we ensure a consistent
92 : // behavior.
93 : std::complex<double> result;
94 : cblas_zdotu_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
95 : gsl::narrow_cast<int>(INCY), &result);
96 : return result;
97 : }
98 : /// The conjugated complex dot product $\bar{x} \cdot y$. This is the standard
99 : /// dot product on the vector space of complex numbers.
100 1 : inline std::complex<double> zdotc_(const size_t& N,
101 : const std::complex<double>* X,
102 : const size_t& INCX,
103 : const std::complex<double>* Y,
104 : const size_t& INCY) {
105 : std::complex<double> result;
106 : cblas_zdotc_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
107 : gsl::narrow_cast<int>(INCY), &result);
108 : return result;
109 : }
110 : /// @}
111 :
112 : /// @{
113 : /*!
114 : * \ingroup UtilitiesGroup
115 : * \brief Perform a matrix-matrix multiplication
116 : *
117 : * Perform the matrix-matrix multiplication
118 : * \f[
119 : * C = \alpha \mathrm{op}(A) \mathrm{op}(B) + \beta \mathrm{op}(C)
120 : * \f]
121 : *
122 : * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
123 : * (transpose of \f$A\f$).
124 : *
125 : * LIBXSMM, which is much faster than BLAS for small matrices, can be called
126 : * instead of BLAS by passing the template parameter `true`.
127 : *
128 : * \param TRANSA either 'N', 'T' or 'C', transposition of matrix A
129 : * \param TRANSB either 'N', 'T' or 'C', transposition of matrix B
130 : * \param M Number of rows in \f$\mathrm{op}(A)\f$
131 : * \param N Number of columns in \f$\mathrm{op}(B)\f$ and \f$\mathrm{op}(C)\f$
132 : * \param K Number of columns in \f$\mathrm{op}(A)\f$
133 : * \param ALPHA specifies \f$\alpha\f$
134 : * \param A Matrix \f$A\f$
135 : * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
136 : * \param B Matrix \f$B\f$
137 : * \param LDB Specifies first dimension of \f$\mathrm{op}(B)\f$
138 : * \param BETA specifies \f$\beta\f$
139 : * \param C Matrix \f$C\f$
140 : * \param LDC Specifies first dimension of \f$\mathrm{op}(C)\f$
141 : * \tparam UseLibXsmm if `true` then use LIBXSMM
142 : */
143 : template <bool UseLibXsmm = false>
144 1 : inline void dgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
145 : const size_t& N, const size_t& K, const double& ALPHA,
146 : const double* A, const size_t& LDA, const double* B,
147 : const size_t& LDB, const double& BETA, double* C,
148 : const size_t& LDC) {
149 : ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
150 : 'C' == TRANSA or 'c' == TRANSA,
151 : "TRANSA must be upper or lower case N, T, or C. See the BLAS "
152 : "documentation for help.");
153 : ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
154 : 'C' == TRANSB or 'c' == TRANSB,
155 : "TRANSB must be upper or lower case N, T, or C. See the BLAS "
156 : "documentation for help.");
157 :
158 : // On some BLAS implementations (e.g. Accelerate.framework on macOS) a call
159 : // with a zero-sized dimension aborts instead of acting as a no-op. Treat
160 : // these cases explicitly so we behave consistently across platforms.
161 : if (M == 0 or N == 0 or K == 0) {
162 : return;
163 : }
164 :
165 : blas_detail::dgemm_(
166 : TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
167 : gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
168 : gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
169 : }
170 : template <bool UseLibXsmm = false>
171 1 : inline void zgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
172 : const size_t& N, const size_t& K,
173 : const std::complex<double>& ALPHA,
174 : const std::complex<double>* A, const size_t& LDA,
175 : const std::complex<double>* B, const size_t& LDB,
176 : const std::complex<double>& BETA, std::complex<double>* C,
177 : const size_t& LDC) {
178 : ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
179 : 'C' == TRANSA or 'c' == TRANSA,
180 : "TRANSA must be upper or lower case N, T, or C. See the BLAS "
181 : "documentation for help.");
182 : ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
183 : 'C' == TRANSB or 'c' == TRANSB,
184 : "TRANSB must be upper or lower case N, T, or C. See the BLAS "
185 : "documentation for help.");
186 :
187 : // On some BLAS implementations (e.g. Accelerate.framework on macOS) a call
188 : // with a zero-sized dimension aborts instead of acting as a no-op. Treat
189 : // these cases explicitly so we behave consistently across platforms.
190 : if (M == 0 or N == 0 or K == 0) {
191 : return;
192 : }
193 :
194 : blas_detail::zgemm_(
195 : TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
196 : gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
197 : gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
198 : }
199 :
200 : // libxsmm is disabled in DEBUG builds because backtraces (from, for
201 : // example, FPEs) do not work when the error occurs in libxsmm code.
202 : #ifndef SPECTRE_DEBUG
203 : template <>
204 1 : inline void dgemm_<true>(const char& TRANSA, const char& TRANSB,
205 : const size_t& M, const size_t& N, const size_t& K,
206 : const double& ALPHA, const double* A,
207 : const size_t& LDA, const double* B, const size_t& LDB,
208 : const double& BETA, double* C, const size_t& LDC) {
209 : ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
210 : 'C' == TRANSA or 'c' == TRANSA,
211 : "TRANSA must be upper or lower case N, T, or C. See the BLAS "
212 : "documentation for help.");
213 : ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
214 : 'C' == TRANSB or 'c' == TRANSB,
215 : "TRANSB must be upper or lower case N, T, or C. See the BLAS "
216 : "documentation for help.");
217 :
218 : // On some BLAS implementations (e.g. Accelerate.framework on macOS) a call
219 : // with a zero-sized dimension aborts instead of acting as a no-op. Treat
220 : // these cases explicitly so we behave consistently across platforms.
221 : if (M == 0 or N == 0 or K == 0) {
222 : return;
223 : }
224 :
225 : const auto m = gsl::narrow_cast<int>(M);
226 : const auto n = gsl::narrow_cast<int>(N);
227 : const auto k = gsl::narrow_cast<int>(K);
228 : const auto lda = gsl::narrow_cast<int>(LDA);
229 : const auto ldb = gsl::narrow_cast<int>(LDB);
230 : const auto ldc = gsl::narrow_cast<int>(LDC);
231 : libxsmm_dgemm(&TRANSA, &TRANSB, &m, &n, &k, &ALPHA, A, &lda, B, &ldb, &BETA,
232 : C, &ldc);
233 : }
234 : #endif // ifndef SPECTRE_DEBUG
235 : /// @}
236 :
237 : /// @{
238 : /*!
239 : * \ingroup UtilitiesGroup
240 : * \brief Perform a matrix-vector multiplication
241 : *
242 : * \f[
243 : * y = \alpha \mathrm{op}(A) x + \beta y
244 : * \f]
245 : *
246 : * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
247 : * (transpose of \f$A\f$).
248 : *
249 : * \param TRANS either 'N', 'T' or 'C', transposition of matrix A
250 : * \param M Number of rows in \f$\mathrm{op}(A)\f$
251 : * \param N Number of columns in \f$\mathrm{op}(A)\f$
252 : * \param ALPHA specifies \f$\alpha\f$
253 : * \param A Matrix \f$A\f$
254 : * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
255 : * \param X Vector \f$x\f$
256 : * \param INCX Specifies the increment for the elements of \f$x\f$
257 : * \param BETA Specifies \f$\beta\f$
258 : * \param Y Vector \f$y\f$
259 : * \param INCY Specifies the increment for the elements of \f$y\f$
260 : */
261 1 : inline void dgemv_(const char& TRANS, const size_t& M, const size_t& N,
262 : const double& ALPHA, const double* A, const size_t& LDA,
263 : const double* X, const size_t& INCX, const double& BETA,
264 : double* Y, const size_t& INCY) {
265 : ASSERT('N' == TRANS or 'n' == TRANS or 'T' == TRANS or 't' == TRANS or
266 : 'C' == TRANS or 'c' == TRANS,
267 : "TRANS must be upper or lower case N, T, or C. See the BLAS "
268 : "documentation for help.");
269 :
270 : // On some BLAS implementations (e.g. Accelerate.framework on macOS) a call
271 : // with a zero-sized dimension aborts instead of acting as a no-op. Treat
272 : // these cases explicitly so we behave consistently across platforms.
273 : if (M == 0 or N == 0) {
274 : return;
275 : }
276 :
277 : // INCX and INCY are allowed to be negative by BLAS, but we never
278 : // use them that way. If needed, they can be changed here, but then
279 : // code providing values will also have to be changed to int to
280 : // avoid warnings.
281 : blas_detail::dgemv_(TRANS, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
282 : ALPHA, A, gsl::narrow_cast<int>(LDA), X,
283 : gsl::narrow_cast<int>(INCX), BETA, Y,
284 : gsl::narrow_cast<int>(INCY), 1);
285 : }
286 : /// @}
|