SpECTRE Documentation Coverage Report
Current view: top level - Utilities - Blas.hpp Hit Total Coverage
Commit: 817e13c5144619b701c7cd870655d8dbf94ab8ce Lines: 8 8 100.0 %
Date: 2024-07-19 22:17:05
Legend: Lines: hit not hit

          Line data    Source code
       1           1 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : /// \file
       5             : /// Declares the interfaces for the BLAS used.
       6             : ///
       7             : /// Wrappers are defined to perform casts from different integer types
       8             : /// when the natural type in C++ differs from the BLAS argument.
       9             : 
      10             : #pragma once
      11             : 
      12             : #include <complex>
      13             : 
      14             : #ifndef SPECTRE_DEBUG
      15             : #include <libxsmm.h>
      16             : #endif  // ifndef SPECTRE_DEBUG
      17             : #include <gsl/gsl_cblas.h>
      18             : 
      19             : #include "Utilities/ErrorHandling/Assert.hpp"
      20             : #include "Utilities/Gsl.hpp"
      21             : 
      22             : namespace blas_detail {
      23             : extern "C" {
      24             : double ddot_(const int& N, const double* X, const int& INCX, const double* Y,
      25             :              const int& INCY);
      26             : 
      27             : // The final two arguments are the "hidden" lengths of the first two.
      28             : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
      29             : void dgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
      30             :             const int& K, const double& ALPHA, const double* A, const int& LDA,
      31             :             const double* B, const int& LDB, const double& BETA,
      32             :             const double* C, const int& LDC, size_t, size_t);
      33             : 
      34             : // The final argument is the "hidden" length of the first one.
      35             : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
      36             : void dgemv_(const char& TRANS, const int& M, const int& N, const double& ALPHA,
      37             :             const double* A, const int& LDA, const double* X, const int& INCX,
      38             :             const double& BETA, double* Y, const int& INCY, size_t);
      39             : }  // extern "C"
      40             : }  // namespace blas_detail
      41             : 
      42             : /*!
      43             :  * \brief Disable OpenBLAS multithreading since it conflicts with Charm++
      44             :  * parallelism
      45             :  *
      46             :  * Add this function to the `charm_init_node_funcs` of any executable that uses
      47             :  * BLAS routines.
      48             :  *
      49             :  * Details: https://github.com/xianyi/OpenBLAS/wiki/Faq#multi-threaded
      50             :  */
      51           1 : void disable_openblas_multithreading();
      52             : 
      53             : /// @{
      54             : /*!
      55             :  * \ingroup UtilitiesGroup
      56             :  * The dot product of two vectors.
      57             :  *
      58             :  * \param N the length of the vectors.
      59             :  * \param X a pointer to the first element of the first vector.
      60             :  * \param INCX the stride for the elements of the first vector.
      61             :  * \param Y a pointer to the first element of the second vector.
      62             :  * \param INCY the stride for the elements of the second vector.
      63             :  * \return the dot product of the given vectors.
      64             :  */
      65           1 : inline double ddot_(const size_t& N, const double* X, const size_t& INCX,
      66             :                     const double* Y, const size_t& INCY) {
      67             :   // INCX and INCY are allowed to be negative by BLAS, but we never
      68             :   // use them that way.  If needed, they can be changed here, but then
      69             :   // code providing values will also have to be changed to int to
      70             :   // avoid warnings.
      71             :   return blas_detail::ddot_(gsl::narrow_cast<int>(N), X,
      72             :                             gsl::narrow_cast<int>(INCX), Y,
      73             :                             gsl::narrow_cast<int>(INCY));
      74             : }
      75             : /// The unconjugated complex dot product $x \cdot y$. See `zdotc_` for the
      76             : /// conjugated complex dot product, which is the standard dot product on the
      77             : /// vector space of complex numbers.
      78           1 : inline std::complex<double> zdotu_(const size_t& N,
      79             :                                    const std::complex<double>* X,
      80             :                                    const size_t& INCX,
      81             :                                    const std::complex<double>* Y,
      82             :                                    const size_t& INCY) {
      83             :   // The complex result of the BLAS zdot* functions is sometimes returned by
      84             :   // value and sometimes returned by reference, depending on the Fortran
      85             :   // compiler settings. By using the cblas interface we ensure a consistent
      86             :   // behavior.
      87             :   std::complex<double> result;
      88             :   cblas_zdotu_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
      89             :                   gsl::narrow_cast<int>(INCY), &result);
      90             :   return result;
      91             : }
      92             : /// The conjugated complex dot product $\bar{x} \cdot y$. This is the standard
      93             : /// dot product on the vector space of complex numbers.
      94           1 : inline std::complex<double> zdotc_(const size_t& N,
      95             :                                    const std::complex<double>* X,
      96             :                                    const size_t& INCX,
      97             :                                    const std::complex<double>* Y,
      98             :                                    const size_t& INCY) {
      99             :   std::complex<double> result;
     100             :   cblas_zdotc_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
     101             :                   gsl::narrow_cast<int>(INCY), &result);
     102             :   return result;
     103             : }
     104             : /// @}
     105             : 
     106             : /// @{
     107             : /*!
     108             :  * \ingroup UtilitiesGroup
     109             :  * \brief Perform a matrix-matrix multiplication
     110             :  *
     111             :  * Perform the matrix-matrix multiplication
     112             :  * \f[
     113             :  * C = \alpha \mathrm{op}(A) \mathrm{op}(B) + \beta \mathrm{op}(C)
     114             :  * \f]
     115             :  *
     116             :  * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
     117             :  * (transpose of \f$A\f$).
     118             :  *
     119             :  * LIBXSMM, which is much faster than BLAS for small matrices, can be called
     120             :  * instead of BLAS by passing the template parameter `true`.
     121             :  *
     122             :  * \param TRANSA either 'N', 'T' or 'C', transposition of matrix A
     123             :  * \param TRANSB either 'N', 'T' or 'C', transposition of matrix B
     124             :  * \param M Number of rows in \f$\mathrm{op}(A)\f$
     125             :  * \param N Number of columns in \f$\mathrm{op}(B)\f$ and \f$\mathrm{op}(C)\f$
     126             :  * \param K Number of columns in \f$\mathrm{op}(A)\f$
     127             :  * \param ALPHA specifies \f$\alpha\f$
     128             :  * \param A Matrix \f$A\f$
     129             :  * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
     130             :  * \param B Matrix \f$B\f$
     131             :  * \param LDB Specifies first dimension of \f$\mathrm{op}(B)\f$
     132             :  * \param BETA specifies \f$\beta\f$
     133             :  * \param C Matrix \f$C\f$
     134             :  * \param LDC Specifies first dimension of \f$\mathrm{op}(C)\f$
     135             :  * \tparam UseLibXsmm if `true` then use LIBXSMM
     136             :  */
     137             : template <bool UseLibXsmm = false>
     138           1 : inline void dgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
     139             :                    const size_t& N, const size_t& K, const double& ALPHA,
     140             :                    const double* A, const size_t& LDA, const double* B,
     141             :                    const size_t& LDB, const double& BETA, double* C,
     142             :                    const size_t& LDC) {
     143             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     144             :              'C' == TRANSA or 'c' == TRANSA,
     145             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     146             :          "documentation for help.");
     147             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     148             :              'C' == TRANSB or 'c' == TRANSB,
     149             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     150             :          "documentation for help.");
     151             :   blas_detail::dgemm_(
     152             :       TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     153             :       gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
     154             :       gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
     155             : }
     156             : 
     157             : // libxsmm is disabled in DEBUG builds because backtraces (from, for
     158             : // example, FPEs) do not work when the error occurs in libxsmm code.
     159             : #ifndef SPECTRE_DEBUG
     160             : template <>
     161           1 : inline void dgemm_<true>(const char& TRANSA, const char& TRANSB,
     162             :                          const size_t& M, const size_t& N, const size_t& K,
     163             :                          const double& ALPHA, const double* A,
     164             :                          const size_t& LDA, const double* B, const size_t& LDB,
     165             :                          const double& BETA, double* C, const size_t& LDC) {
     166             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     167             :              'C' == TRANSA or 'c' == TRANSA,
     168             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     169             :          "documentation for help.");
     170             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     171             :              'C' == TRANSB or 'c' == TRANSB,
     172             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     173             :          "documentation for help.");
     174             :   const auto m = gsl::narrow_cast<int>(M);
     175             :   const auto n = gsl::narrow_cast<int>(N);
     176             :   const auto k = gsl::narrow_cast<int>(K);
     177             :   const auto lda = gsl::narrow_cast<int>(LDA);
     178             :   const auto ldb = gsl::narrow_cast<int>(LDB);
     179             :   const auto ldc = gsl::narrow_cast<int>(LDC);
     180             :   libxsmm_dgemm(&TRANSA, &TRANSB, &m, &n, &k, &ALPHA, A, &lda, B, &ldb, &BETA,
     181             :                 C, &ldc);
     182             : }
     183             : #endif  // ifndef SPECTRE_DEBUG
     184             : /// @}
     185             : 
     186             : /// @{
     187             : /*!
     188             :  * \ingroup UtilitiesGroup
     189             :  * \brief Perform a matrix-vector multiplication
     190             :  *
     191             :  * \f[
     192             :  * y = \alpha \mathrm{op}(A) x + \beta y
     193             :  * \f]
     194             :  *
     195             :  * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
     196             :  * (transpose of \f$A\f$).
     197             :  *
     198             :  * \param TRANS either 'N', 'T' or 'C', transposition of matrix A
     199             :  * \param M Number of rows in \f$\mathrm{op}(A)\f$
     200             :  * \param N Number of columns in \f$\mathrm{op}(A)\f$
     201             :  * \param ALPHA specifies \f$\alpha\f$
     202             :  * \param A Matrix \f$A\f$
     203             :  * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
     204             :  * \param X Vector \f$x\f$
     205             :  * \param INCX Specifies the increment for the elements of \f$x\f$
     206             :  * \param BETA Specifies \f$\beta\f$
     207             :  * \param Y Vector \f$y\f$
     208             :  * \param INCY Specifies the increment for the elements of \f$y\f$
     209             :  */
     210           1 : inline void dgemv_(const char& TRANS, const size_t& M, const size_t& N,
     211             :                    const double& ALPHA, const double* A, const size_t& LDA,
     212             :                    const double* X, const size_t& INCX, const double& BETA,
     213             :                    double* Y, const size_t& INCY) {
     214             :   ASSERT('N' == TRANS or 'n' == TRANS or 'T' == TRANS or 't' == TRANS or
     215             :              'C' == TRANS or 'c' == TRANS,
     216             :          "TRANS must be upper or lower case N, T, or C. See the BLAS "
     217             :          "documentation for help.");
     218             :   // INCX and INCY are allowed to be negative by BLAS, but we never
     219             :   // use them that way.  If needed, they can be changed here, but then
     220             :   // code providing values will also have to be changed to int to
     221             :   // avoid warnings.
     222             :   blas_detail::dgemv_(TRANS, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     223             :                       ALPHA, A, gsl::narrow_cast<int>(LDA), X,
     224             :                       gsl::narrow_cast<int>(INCX), BETA, Y,
     225             :                       gsl::narrow_cast<int>(INCY), 1);
     226             : }
     227             : /// @}

Generated by: LCOV version 1.14