Line data Source code
1 1 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : /// \file
5 : /// Declares the interfaces for the BLAS used.
6 : ///
7 : /// Wrappers are defined to perform casts from different integer types
8 : /// when the natural type in C++ differs from the BLAS argument.
9 :
10 : #pragma once
11 :
12 : #ifndef SPECTRE_DEBUG
13 : #include <libxsmm.h>
14 : #endif // ifndef SPECTRE_DEBUG
15 :
16 : #include "Utilities/ErrorHandling/Assert.hpp"
17 : #include "Utilities/Gsl.hpp"
18 :
19 : namespace blas_detail {
20 : extern "C" {
21 : double ddot_(const int& N, const double* X, const int& INCX, const double* Y,
22 : const int& INCY);
23 :
24 : // The final two arguments are the "hidden" lengths of the first two.
25 : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
26 : void dgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
27 : const int& K, const double& ALPHA, const double* A, const int& LDA,
28 : const double* B, const int& LDB, const double& BETA,
29 : const double* C, const int& LDC, size_t, size_t);
30 :
31 : // The final argument is the "hidden" length of the first one.
32 : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
33 : void dgemv_(const char& TRANS, const int& M, const int& N, const double& ALPHA,
34 : const double* A, const int& LDA, const double* X, const int& INCX,
35 : const double& BETA, double* Y, const int& INCY, size_t);
36 : } // extern "C"
37 : } // namespace blas_detail
38 :
39 : /*!
40 : * \brief Disable OpenBLAS multithreading since it conflicts with Charm++
41 : * parallelism
42 : *
43 : * Add this function to the `charm_init_node_funcs` of any executable that uses
44 : * BLAS routines.
45 : *
46 : * Details: https://github.com/xianyi/OpenBLAS/wiki/Faq#multi-threaded
47 : */
48 1 : void disable_openblas_multithreading();
49 :
50 : /*!
51 : * \ingroup UtilitiesGroup
52 : * The dot product of two vectors.
53 : *
54 : * \param N the length of the vectors.
55 : * \param X a pointer to the first element of the first vector.
56 : * \param INCX the stride for the elements of the first vector.
57 : * \param Y a pointer to the first element of the second vector.
58 : * \param INCY the stride for the elements of the second vector.
59 : * \return the dot product of the given vectors.
60 : */
61 1 : inline double ddot_(const size_t& N, const double* X, const size_t& INCX,
62 : const double* Y, const size_t& INCY) {
63 : // INCX and INCY are allowed to be negative by BLAS, but we never
64 : // use them that way. If needed, they can be changed here, but then
65 : // code providing values will also have to be changed to int to
66 : // avoid warnings.
67 : return blas_detail::ddot_(gsl::narrow_cast<int>(N), X,
68 : gsl::narrow_cast<int>(INCX), Y,
69 : gsl::narrow_cast<int>(INCY));
70 : }
71 :
72 : /// @{
73 : /*!
74 : * \ingroup UtilitiesGroup
75 : * \brief Perform a matrix-matrix multiplication
76 : *
77 : * Perform the matrix-matrix multiplication
78 : * \f[
79 : * C = \alpha \mathrm{op}(A) \mathrm{op}(B) + \beta \mathrm{op}(C)
80 : * \f]
81 : *
82 : * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
83 : * (transpose of \f$A\f$).
84 : *
85 : * LIBXSMM, which is much faster than BLAS for small matrices, can be called
86 : * instead of BLAS by passing the template parameter `true`.
87 : *
88 : * \param TRANSA either 'N', 'T' or 'C', transposition of matrix A
89 : * \param TRANSB either 'N', 'T' or 'C', transposition of matrix B
90 : * \param M Number of rows in \f$\mathrm{op}(A)\f$
91 : * \param N Number of columns in \f$\mathrm{op}(B)\f$ and \f$\mathrm{op}(C)\f$
92 : * \param K Number of columns in \f$\mathrm{op}(A)\f$
93 : * \param ALPHA specifies \f$\alpha\f$
94 : * \param A Matrix \f$A\f$
95 : * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
96 : * \param B Matrix \f$B\f$
97 : * \param LDB Specifies first dimension of \f$\mathrm{op}(B)\f$
98 : * \param BETA specifies \f$\beta\f$
99 : * \param C Matrix \f$C\f$
100 : * \param LDC Specifies first dimension of \f$\mathrm{op}(C)\f$
101 : * \tparam UseLibXsmm if `true` then use LIBXSMM
102 : */
103 : template <bool UseLibXsmm = false>
104 1 : inline void dgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
105 : const size_t& N, const size_t& K, const double& ALPHA,
106 : const double* A, const size_t& LDA, const double* B,
107 : const size_t& LDB, const double& BETA, double* C,
108 : const size_t& LDC) {
109 : ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
110 : 'C' == TRANSA or 'c' == TRANSA,
111 : "TRANSA must be upper or lower case N, T, or C. See the BLAS "
112 : "documentation for help.");
113 : ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
114 : 'C' == TRANSB or 'c' == TRANSB,
115 : "TRANSB must be upper or lower case N, T, or C. See the BLAS "
116 : "documentation for help.");
117 : blas_detail::dgemm_(
118 : TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
119 : gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
120 : gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
121 : }
122 :
123 : // libxsmm is disabled in DEBUG builds because backtraces (from, for
124 : // example, FPEs) do not work when the error occurs in libxsmm code.
125 : #ifndef SPECTRE_DEBUG
126 : template <>
127 1 : inline void dgemm_<true>(const char& TRANSA, const char& TRANSB,
128 : const size_t& M, const size_t& N, const size_t& K,
129 : const double& ALPHA, const double* A,
130 : const size_t& LDA, const double* B, const size_t& LDB,
131 : const double& BETA, double* C, const size_t& LDC) {
132 : ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
133 : 'C' == TRANSA or 'c' == TRANSA,
134 : "TRANSA must be upper or lower case N, T, or C. See the BLAS "
135 : "documentation for help.");
136 : ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
137 : 'C' == TRANSB or 'c' == TRANSB,
138 : "TRANSB must be upper or lower case N, T, or C. See the BLAS "
139 : "documentation for help.");
140 : const auto m = gsl::narrow_cast<int>(M);
141 : const auto n = gsl::narrow_cast<int>(N);
142 : const auto k = gsl::narrow_cast<int>(K);
143 : const auto lda = gsl::narrow_cast<int>(LDA);
144 : const auto ldb = gsl::narrow_cast<int>(LDB);
145 : const auto ldc = gsl::narrow_cast<int>(LDC);
146 : libxsmm_dgemm(&TRANSA, &TRANSB, &m, &n, &k, &ALPHA, A, &lda, B, &ldb, &BETA,
147 : C, &ldc);
148 : }
149 : #endif // ifndef SPECTRE_DEBUG
150 : /// @}
151 :
152 : /// @{
153 : /*!
154 : * \ingroup UtilitiesGroup
155 : * \brief Perform a matrix-vector multiplication
156 : *
157 : * \f[
158 : * y = \alpha \mathrm{op}(A) x + \beta y
159 : * \f]
160 : *
161 : * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
162 : * (transpose of \f$A\f$).
163 : *
164 : * \param TRANS either 'N', 'T' or 'C', transposition of matrix A
165 : * \param M Number of rows in \f$\mathrm{op}(A)\f$
166 : * \param N Number of columns in \f$\mathrm{op}(A)\f$
167 : * \param ALPHA specifies \f$\alpha\f$
168 : * \param A Matrix \f$A\f$
169 : * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
170 : * \param X Vector \f$x\f$
171 : * \param INCX Specifies the increment for the elements of \f$x\f$
172 : * \param BETA Specifies \f$\beta\f$
173 : * \param Y Vector \f$y\f$
174 : * \param INCY Specifies the increment for the elements of \f$y\f$
175 : */
176 1 : inline void dgemv_(const char& TRANS, const size_t& M, const size_t& N,
177 : const double& ALPHA, const double* A, const size_t& LDA,
178 : const double* X, const size_t& INCX, const double& BETA,
179 : double* Y, const size_t& INCY) {
180 : ASSERT('N' == TRANS or 'n' == TRANS or 'T' == TRANS or 't' == TRANS or
181 : 'C' == TRANS or 'c' == TRANS,
182 : "TRANS must be upper or lower case N, T, or C. See the BLAS "
183 : "documentation for help.");
184 : // INCX and INCY are allowed to be negative by BLAS, but we never
185 : // use them that way. If needed, they can be changed here, but then
186 : // code providing values will also have to be changed to int to
187 : // avoid warnings.
188 : blas_detail::dgemv_(TRANS, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
189 : ALPHA, A, gsl::narrow_cast<int>(LDA), X,
190 : gsl::narrow_cast<int>(INCX), BETA, Y,
191 : gsl::narrow_cast<int>(INCY), 1);
192 : }
193 : /// @}
|