SpECTRE Documentation Coverage Report
Current view: top level - Utilities - Blas.hpp Hit Total Coverage
Commit: e88482f6bea8064ca40dbebfd81596847f6f1cd9 Lines: 9 9 100.0 %
Date: 2024-10-21 20:40:03
Legend: Lines: hit not hit

          Line data    Source code
       1           1 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : /// \file
       5             : /// Declares the interfaces for the BLAS used.
       6             : ///
       7             : /// Wrappers are defined to perform casts from different integer types
       8             : /// when the natural type in C++ differs from the BLAS argument.
       9             : 
      10             : #pragma once
      11             : 
      12             : #include <complex>
      13             : 
      14             : #ifndef SPECTRE_DEBUG
      15             : #include <libxsmm.h>
      16             : #endif  // ifndef SPECTRE_DEBUG
      17             : #include <gsl/gsl_cblas.h>
      18             : 
      19             : #include "Utilities/ErrorHandling/Assert.hpp"
      20             : #include "Utilities/Gsl.hpp"
      21             : 
      22             : namespace blas_detail {
      23             : extern "C" {
      24             : double ddot_(const int& N, const double* X, const int& INCX, const double* Y,
      25             :              const int& INCY);
      26             : 
      27             : // The final two arguments are the "hidden" lengths of the first two.
      28             : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
      29             : void dgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
      30             :             const int& K, const double& ALPHA, const double* A, const int& LDA,
      31             :             const double* B, const int& LDB, const double& BETA,
      32             :             const double* C, const int& LDC, size_t, size_t);
      33             : void zgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
      34             :             const int& K, const std::complex<double>& ALPHA,
      35             :             const std::complex<double>* A, const int& LDA,
      36             :             const std::complex<double>* B, const int& LDB,
      37             :             const std::complex<double>& BETA, const std::complex<double>* C,
      38             :             const int& LDC, size_t, size_t);
      39             : 
      40             : // The final argument is the "hidden" length of the first one.
      41             : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
      42             : void dgemv_(const char& TRANS, const int& M, const int& N, const double& ALPHA,
      43             :             const double* A, const int& LDA, const double* X, const int& INCX,
      44             :             const double& BETA, double* Y, const int& INCY, size_t);
      45             : }  // extern "C"
      46             : }  // namespace blas_detail
      47             : 
      48             : /*!
      49             :  * \brief Disable OpenBLAS multithreading since it conflicts with Charm++
      50             :  * parallelism
      51             :  *
      52             :  * Add this function to the `charm_init_node_funcs` of any executable that uses
      53             :  * BLAS routines.
      54             :  *
      55             :  * Details: https://github.com/xianyi/OpenBLAS/wiki/Faq#multi-threaded
      56             :  */
      57           1 : void disable_openblas_multithreading();
      58             : 
      59             : /// @{
      60             : /*!
      61             :  * \ingroup UtilitiesGroup
      62             :  * The dot product of two vectors.
      63             :  *
      64             :  * \param N the length of the vectors.
      65             :  * \param X a pointer to the first element of the first vector.
      66             :  * \param INCX the stride for the elements of the first vector.
      67             :  * \param Y a pointer to the first element of the second vector.
      68             :  * \param INCY the stride for the elements of the second vector.
      69             :  * \return the dot product of the given vectors.
      70             :  */
      71           1 : inline double ddot_(const size_t& N, const double* X, const size_t& INCX,
      72             :                     const double* Y, const size_t& INCY) {
      73             :   // INCX and INCY are allowed to be negative by BLAS, but we never
      74             :   // use them that way.  If needed, they can be changed here, but then
      75             :   // code providing values will also have to be changed to int to
      76             :   // avoid warnings.
      77             :   return blas_detail::ddot_(gsl::narrow_cast<int>(N), X,
      78             :                             gsl::narrow_cast<int>(INCX), Y,
      79             :                             gsl::narrow_cast<int>(INCY));
      80             : }
      81             : /// The unconjugated complex dot product $x \cdot y$. See `zdotc_` for the
      82             : /// conjugated complex dot product, which is the standard dot product on the
      83             : /// vector space of complex numbers.
      84           1 : inline std::complex<double> zdotu_(const size_t& N,
      85             :                                    const std::complex<double>* X,
      86             :                                    const size_t& INCX,
      87             :                                    const std::complex<double>* Y,
      88             :                                    const size_t& INCY) {
      89             :   // The complex result of the BLAS zdot* functions is sometimes returned by
      90             :   // value and sometimes returned by reference, depending on the Fortran
      91             :   // compiler settings. By using the cblas interface we ensure a consistent
      92             :   // behavior.
      93             :   std::complex<double> result;
      94             :   cblas_zdotu_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
      95             :                   gsl::narrow_cast<int>(INCY), &result);
      96             :   return result;
      97             : }
      98             : /// The conjugated complex dot product $\bar{x} \cdot y$. This is the standard
      99             : /// dot product on the vector space of complex numbers.
     100           1 : inline std::complex<double> zdotc_(const size_t& N,
     101             :                                    const std::complex<double>* X,
     102             :                                    const size_t& INCX,
     103             :                                    const std::complex<double>* Y,
     104             :                                    const size_t& INCY) {
     105             :   std::complex<double> result;
     106             :   cblas_zdotc_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
     107             :                   gsl::narrow_cast<int>(INCY), &result);
     108             :   return result;
     109             : }
     110             : /// @}
     111             : 
     112             : /// @{
     113             : /*!
     114             :  * \ingroup UtilitiesGroup
     115             :  * \brief Perform a matrix-matrix multiplication
     116             :  *
     117             :  * Perform the matrix-matrix multiplication
     118             :  * \f[
     119             :  * C = \alpha \mathrm{op}(A) \mathrm{op}(B) + \beta \mathrm{op}(C)
     120             :  * \f]
     121             :  *
     122             :  * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
     123             :  * (transpose of \f$A\f$).
     124             :  *
     125             :  * LIBXSMM, which is much faster than BLAS for small matrices, can be called
     126             :  * instead of BLAS by passing the template parameter `true`.
     127             :  *
     128             :  * \param TRANSA either 'N', 'T' or 'C', transposition of matrix A
     129             :  * \param TRANSB either 'N', 'T' or 'C', transposition of matrix B
     130             :  * \param M Number of rows in \f$\mathrm{op}(A)\f$
     131             :  * \param N Number of columns in \f$\mathrm{op}(B)\f$ and \f$\mathrm{op}(C)\f$
     132             :  * \param K Number of columns in \f$\mathrm{op}(A)\f$
     133             :  * \param ALPHA specifies \f$\alpha\f$
     134             :  * \param A Matrix \f$A\f$
     135             :  * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
     136             :  * \param B Matrix \f$B\f$
     137             :  * \param LDB Specifies first dimension of \f$\mathrm{op}(B)\f$
     138             :  * \param BETA specifies \f$\beta\f$
     139             :  * \param C Matrix \f$C\f$
     140             :  * \param LDC Specifies first dimension of \f$\mathrm{op}(C)\f$
     141             :  * \tparam UseLibXsmm if `true` then use LIBXSMM
     142             :  */
     143             : template <bool UseLibXsmm = false>
     144           1 : inline void dgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
     145             :                    const size_t& N, const size_t& K, const double& ALPHA,
     146             :                    const double* A, const size_t& LDA, const double* B,
     147             :                    const size_t& LDB, const double& BETA, double* C,
     148             :                    const size_t& LDC) {
     149             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     150             :              'C' == TRANSA or 'c' == TRANSA,
     151             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     152             :          "documentation for help.");
     153             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     154             :              'C' == TRANSB or 'c' == TRANSB,
     155             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     156             :          "documentation for help.");
     157             :   blas_detail::dgemm_(
     158             :       TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     159             :       gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
     160             :       gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
     161             : }
     162             : template <bool UseLibXsmm = false>
     163           1 : inline void zgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
     164             :                    const size_t& N, const size_t& K,
     165             :                    const std::complex<double>& ALPHA,
     166             :                    const std::complex<double>* A, const size_t& LDA,
     167             :                    const std::complex<double>* B, const size_t& LDB,
     168             :                    const std::complex<double>& BETA, std::complex<double>* C,
     169             :                    const size_t& LDC) {
     170             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     171             :              'C' == TRANSA or 'c' == TRANSA,
     172             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     173             :          "documentation for help.");
     174             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     175             :              'C' == TRANSB or 'c' == TRANSB,
     176             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     177             :          "documentation for help.");
     178             :   blas_detail::zgemm_(
     179             :       TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     180             :       gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
     181             :       gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
     182             : }
     183             : 
     184             : // libxsmm is disabled in DEBUG builds because backtraces (from, for
     185             : // example, FPEs) do not work when the error occurs in libxsmm code.
     186             : #ifndef SPECTRE_DEBUG
     187             : template <>
     188           1 : inline void dgemm_<true>(const char& TRANSA, const char& TRANSB,
     189             :                          const size_t& M, const size_t& N, const size_t& K,
     190             :                          const double& ALPHA, const double* A,
     191             :                          const size_t& LDA, const double* B, const size_t& LDB,
     192             :                          const double& BETA, double* C, const size_t& LDC) {
     193             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     194             :              'C' == TRANSA or 'c' == TRANSA,
     195             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     196             :          "documentation for help.");
     197             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     198             :              'C' == TRANSB or 'c' == TRANSB,
     199             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     200             :          "documentation for help.");
     201             :   const auto m = gsl::narrow_cast<int>(M);
     202             :   const auto n = gsl::narrow_cast<int>(N);
     203             :   const auto k = gsl::narrow_cast<int>(K);
     204             :   const auto lda = gsl::narrow_cast<int>(LDA);
     205             :   const auto ldb = gsl::narrow_cast<int>(LDB);
     206             :   const auto ldc = gsl::narrow_cast<int>(LDC);
     207             :   libxsmm_dgemm(&TRANSA, &TRANSB, &m, &n, &k, &ALPHA, A, &lda, B, &ldb, &BETA,
     208             :                 C, &ldc);
     209             : }
     210             : #endif  // ifndef SPECTRE_DEBUG
     211             : /// @}
     212             : 
     213             : /// @{
     214             : /*!
     215             :  * \ingroup UtilitiesGroup
     216             :  * \brief Perform a matrix-vector multiplication
     217             :  *
     218             :  * \f[
     219             :  * y = \alpha \mathrm{op}(A) x + \beta y
     220             :  * \f]
     221             :  *
     222             :  * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
     223             :  * (transpose of \f$A\f$).
     224             :  *
     225             :  * \param TRANS either 'N', 'T' or 'C', transposition of matrix A
     226             :  * \param M Number of rows in \f$\mathrm{op}(A)\f$
     227             :  * \param N Number of columns in \f$\mathrm{op}(A)\f$
     228             :  * \param ALPHA specifies \f$\alpha\f$
     229             :  * \param A Matrix \f$A\f$
     230             :  * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
     231             :  * \param X Vector \f$x\f$
     232             :  * \param INCX Specifies the increment for the elements of \f$x\f$
     233             :  * \param BETA Specifies \f$\beta\f$
     234             :  * \param Y Vector \f$y\f$
     235             :  * \param INCY Specifies the increment for the elements of \f$y\f$
     236             :  */
     237           1 : inline void dgemv_(const char& TRANS, const size_t& M, const size_t& N,
     238             :                    const double& ALPHA, const double* A, const size_t& LDA,
     239             :                    const double* X, const size_t& INCX, const double& BETA,
     240             :                    double* Y, const size_t& INCY) {
     241             :   ASSERT('N' == TRANS or 'n' == TRANS or 'T' == TRANS or 't' == TRANS or
     242             :              'C' == TRANS or 'c' == TRANS,
     243             :          "TRANS must be upper or lower case N, T, or C. See the BLAS "
     244             :          "documentation for help.");
     245             :   // INCX and INCY are allowed to be negative by BLAS, but we never
     246             :   // use them that way.  If needed, they can be changed here, but then
     247             :   // code providing values will also have to be changed to int to
     248             :   // avoid warnings.
     249             :   blas_detail::dgemv_(TRANS, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     250             :                       ALPHA, A, gsl::narrow_cast<int>(LDA), X,
     251             :                       gsl::narrow_cast<int>(INCX), BETA, Y,
     252             :                       gsl::narrow_cast<int>(INCY), 1);
     253             : }
     254             : /// @}

Generated by: LCOV version 1.14