Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Declares the interfaces for the BLAS used.
6 : ///
7 : /// Wrappers are defined to perform casts from different integer types
8 : /// when the natural type in C++ differs from the BLAS argument.
9 :
10 : #pragma once
11 :
12 : #include <complex>
13 :
14 : #ifndef SPECTRE_DEBUG
15 : #include <libxsmm.h>
16 : #endif // ifndef SPECTRE_DEBUG
17 : #include <gsl/gsl_cblas.h>
18 :
19 : #include "Utilities/ErrorHandling/Assert.hpp"
20 : #include "Utilities/Gsl.hpp"
21 :
22 : namespace blas_detail {
23 : extern "C" {
24 : double ddot_(const int& N, const double* X, const int& INCX, const double* Y,
25 : const int& INCY);
26 :
27 : // The final two arguments are the "hidden" lengths of the first two.
28 : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
29 : void dgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
30 : const int& K, const double& ALPHA, const double* A, const int& LDA,
31 : const double* B, const int& LDB, const double& BETA,
32 : const double* C, const int& LDC, size_t, size_t);
33 : void zgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
34 : const int& K, const std::complex<double>& ALPHA,
35 : const std::complex<double>* A, const int& LDA,
36 : const std::complex<double>* B, const int& LDB,
37 : const std::complex<double>& BETA, const std::complex<double>* C,
38 : const int& LDC, size_t, size_t);
39 :
40 : // The final argument is the "hidden" length of the first one.
41 : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
42 : void dgemv_(const char& TRANS, const int& M, const int& N, const double& ALPHA,
43 : const double* A, const int& LDA, const double* X, const int& INCX,
44 : const double& BETA, double* Y, const int& INCY, size_t);
45 : } // extern "C"
46 : } // namespace blas_detail
47 :
48 : /*!
49 : * \brief Disable OpenBLAS multithreading since it conflicts with Charm++
50 : * parallelism
51 : *
52 : * Add this function to the `charm_init_node_funcs` of any executable that uses
53 : * BLAS routines.
54 : *
55 : * Details: https://github.com/xianyi/OpenBLAS/wiki/Faq#multi-threaded
56 : */
57 1 : void disable_openblas_multithreading();
58 :
59 : /// @{
60 : /*!
61 : * \ingroup UtilitiesGroup
62 : * The dot product of two vectors.
63 : *
64 : * \param N the length of the vectors.
65 : * \param X a pointer to the first element of the first vector.
66 : * \param INCX the stride for the elements of the first vector.
67 : * \param Y a pointer to the first element of the second vector.
68 : * \param INCY the stride for the elements of the second vector.
69 : * \return the dot product of the given vectors.
70 : */
71 1 : inline double ddot_(const size_t& N, const double* X, const size_t& INCX,
72 : const double* Y, const size_t& INCY) {
73 : // INCX and INCY are allowed to be negative by BLAS, but we never
74 : // use them that way. If needed, they can be changed here, but then
75 : // code providing values will also have to be changed to int to
76 : // avoid warnings.
77 : return blas_detail::ddot_(gsl::narrow_cast<int>(N), X,
78 : gsl::narrow_cast<int>(INCX), Y,
79 : gsl::narrow_cast<int>(INCY));
80 : }
81 : /// The unconjugated complex dot product $x \cdot y$. See `zdotc_` for the
82 : /// conjugated complex dot product, which is the standard dot product on the
83 : /// vector space of complex numbers.
84 1 : inline std::complex<double> zdotu_(const size_t& N,
85 : const std::complex<double>* X,
86 : const size_t& INCX,
87 : const std::complex<double>* Y,
88 : const size_t& INCY) {
89 : // The complex result of the BLAS zdot* functions is sometimes returned by
90 : // value and sometimes returned by reference, depending on the Fortran
91 : // compiler settings. By using the cblas interface we ensure a consistent
92 : // behavior.
93 : std::complex<double> result;
94 : cblas_zdotu_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
95 : gsl::narrow_cast<int>(INCY), &result);
96 : return result;
97 : }
98 : /// The conjugated complex dot product $\bar{x} \cdot y$. This is the standard
99 : /// dot product on the vector space of complex numbers.
100 1 : inline std::complex<double> zdotc_(const size_t& N,
101 : const std::complex<double>* X,
102 : const size_t& INCX,
103 : const std::complex<double>* Y,
104 : const size_t& INCY) {
105 : std::complex<double> result;
106 : cblas_zdotc_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
107 : gsl::narrow_cast<int>(INCY), &result);
108 : return result;
109 : }
110 : /// @}
111 :
112 : /// @{
113 : /*!
114 : * \ingroup UtilitiesGroup
115 : * \brief Perform a matrix-matrix multiplication
116 : *
117 : * Perform the matrix-matrix multiplication
118 : * \f[
119 : * C = \alpha \mathrm{op}(A) \mathrm{op}(B) + \beta \mathrm{op}(C)
120 : * \f]
121 : *
122 : * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
123 : * (transpose of \f$A\f$).
124 : *
125 : * LIBXSMM, which is much faster than BLAS for small matrices, can be called
126 : * instead of BLAS by passing the template parameter `true`.
127 : *
128 : * \param TRANSA either 'N', 'T' or 'C', transposition of matrix A
129 : * \param TRANSB either 'N', 'T' or 'C', transposition of matrix B
130 : * \param M Number of rows in \f$\mathrm{op}(A)\f$
131 : * \param N Number of columns in \f$\mathrm{op}(B)\f$ and \f$\mathrm{op}(C)\f$
132 : * \param K Number of columns in \f$\mathrm{op}(A)\f$
133 : * \param ALPHA specifies \f$\alpha\f$
134 : * \param A Matrix \f$A\f$
135 : * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
136 : * \param B Matrix \f$B\f$
137 : * \param LDB Specifies first dimension of \f$\mathrm{op}(B)\f$
138 : * \param BETA specifies \f$\beta\f$
139 : * \param C Matrix \f$C\f$
140 : * \param LDC Specifies first dimension of \f$\mathrm{op}(C)\f$
141 : * \tparam UseLibXsmm if `true` then use LIBXSMM
142 : */
143 : template <bool UseLibXsmm = false>
144 1 : inline void dgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
145 : const size_t& N, const size_t& K, const double& ALPHA,
146 : const double* A, const size_t& LDA, const double* B,
147 : const size_t& LDB, const double& BETA, double* C,
148 : const size_t& LDC) {
149 : ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
150 : 'C' == TRANSA or 'c' == TRANSA,
151 : "TRANSA must be upper or lower case N, T, or C. See the BLAS "
152 : "documentation for help.");
153 : ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
154 : 'C' == TRANSB or 'c' == TRANSB,
155 : "TRANSB must be upper or lower case N, T, or C. See the BLAS "
156 : "documentation for help.");
157 : blas_detail::dgemm_(
158 : TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
159 : gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
160 : gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
161 : }
162 : template <bool UseLibXsmm = false>
163 1 : inline void zgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
164 : const size_t& N, const size_t& K,
165 : const std::complex<double>& ALPHA,
166 : const std::complex<double>* A, const size_t& LDA,
167 : const std::complex<double>* B, const size_t& LDB,
168 : const std::complex<double>& BETA, std::complex<double>* C,
169 : const size_t& LDC) {
170 : ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
171 : 'C' == TRANSA or 'c' == TRANSA,
172 : "TRANSA must be upper or lower case N, T, or C. See the BLAS "
173 : "documentation for help.");
174 : ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
175 : 'C' == TRANSB or 'c' == TRANSB,
176 : "TRANSB must be upper or lower case N, T, or C. See the BLAS "
177 : "documentation for help.");
178 : blas_detail::zgemm_(
179 : TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
180 : gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
181 : gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
182 : }
183 :
184 : // libxsmm is disabled in DEBUG builds because backtraces (from, for
185 : // example, FPEs) do not work when the error occurs in libxsmm code.
186 : #ifndef SPECTRE_DEBUG
187 : template <>
188 1 : inline void dgemm_<true>(const char& TRANSA, const char& TRANSB,
189 : const size_t& M, const size_t& N, const size_t& K,
190 : const double& ALPHA, const double* A,
191 : const size_t& LDA, const double* B, const size_t& LDB,
192 : const double& BETA, double* C, const size_t& LDC) {
193 : ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
194 : 'C' == TRANSA or 'c' == TRANSA,
195 : "TRANSA must be upper or lower case N, T, or C. See the BLAS "
196 : "documentation for help.");
197 : ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
198 : 'C' == TRANSB or 'c' == TRANSB,
199 : "TRANSB must be upper or lower case N, T, or C. See the BLAS "
200 : "documentation for help.");
201 : const auto m = gsl::narrow_cast<int>(M);
202 : const auto n = gsl::narrow_cast<int>(N);
203 : const auto k = gsl::narrow_cast<int>(K);
204 : const auto lda = gsl::narrow_cast<int>(LDA);
205 : const auto ldb = gsl::narrow_cast<int>(LDB);
206 : const auto ldc = gsl::narrow_cast<int>(LDC);
207 : libxsmm_dgemm(&TRANSA, &TRANSB, &m, &n, &k, &ALPHA, A, &lda, B, &ldb, &BETA,
208 : C, &ldc);
209 : }
210 : #endif // ifndef SPECTRE_DEBUG
211 : /// @}
212 :
213 : /// @{
214 : /*!
215 : * \ingroup UtilitiesGroup
216 : * \brief Perform a matrix-vector multiplication
217 : *
218 : * \f[
219 : * y = \alpha \mathrm{op}(A) x + \beta y
220 : * \f]
221 : *
222 : * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
223 : * (transpose of \f$A\f$).
224 : *
225 : * \param TRANS either 'N', 'T' or 'C', transposition of matrix A
226 : * \param M Number of rows in \f$\mathrm{op}(A)\f$
227 : * \param N Number of columns in \f$\mathrm{op}(A)\f$
228 : * \param ALPHA specifies \f$\alpha\f$
229 : * \param A Matrix \f$A\f$
230 : * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
231 : * \param X Vector \f$x\f$
232 : * \param INCX Specifies the increment for the elements of \f$x\f$
233 : * \param BETA Specifies \f$\beta\f$
234 : * \param Y Vector \f$y\f$
235 : * \param INCY Specifies the increment for the elements of \f$y\f$
236 : */
237 1 : inline void dgemv_(const char& TRANS, const size_t& M, const size_t& N,
238 : const double& ALPHA, const double* A, const size_t& LDA,
239 : const double* X, const size_t& INCX, const double& BETA,
240 : double* Y, const size_t& INCY) {
241 : ASSERT('N' == TRANS or 'n' == TRANS or 'T' == TRANS or 't' == TRANS or
242 : 'C' == TRANS or 'c' == TRANS,
243 : "TRANS must be upper or lower case N, T, or C. See the BLAS "
244 : "documentation for help.");
245 : // INCX and INCY are allowed to be negative by BLAS, but we never
246 : // use them that way. If needed, they can be changed here, but then
247 : // code providing values will also have to be changed to int to
248 : // avoid warnings.
249 : blas_detail::dgemv_(TRANS, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
250 : ALPHA, A, gsl::narrow_cast<int>(LDA), X,
251 : gsl::narrow_cast<int>(INCX), BETA, Y,
252 : gsl::narrow_cast<int>(INCY), 1);
253 : }
254 : /// @}
|