SpECTRE Documentation Coverage Report
Current view: top level - Utilities - Blas.hpp Hit Total Coverage
Commit: 1f2210958b4f38fdc0400907ee7c6d5af5111418 Lines: 9 9 100.0 %
Date: 2025-12-05 05:03:31
Legend: Lines: hit not hit

          Line data    Source code
       1           1 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : /// \file
       5             : /// Declares the interfaces for the BLAS used.
       6             : ///
       7             : /// Wrappers are defined to perform casts from different integer types
       8             : /// when the natural type in C++ differs from the BLAS argument.
       9             : 
      10             : #pragma once
      11             : 
      12             : #include <complex>
      13             : 
      14             : #ifndef SPECTRE_DEBUG
      15             : #include <libxsmm.h>
      16             : #endif  // ifndef SPECTRE_DEBUG
      17             : #include <gsl/gsl_cblas.h>
      18             : 
      19             : #include "Utilities/ErrorHandling/Assert.hpp"
      20             : #include "Utilities/Gsl.hpp"
      21             : 
      22             : namespace blas_detail {
      23             : extern "C" {
      24             : double ddot_(const int& N, const double* X, const int& INCX, const double* Y,
      25             :              const int& INCY);
      26             : 
      27             : // The final two arguments are the "hidden" lengths of the first two.
      28             : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
      29             : void dgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
      30             :             const int& K, const double& ALPHA, const double* A, const int& LDA,
      31             :             const double* B, const int& LDB, const double& BETA,
      32             :             const double* C, const int& LDC, size_t, size_t);
      33             : void zgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
      34             :             const int& K, const std::complex<double>& ALPHA,
      35             :             const std::complex<double>* A, const int& LDA,
      36             :             const std::complex<double>* B, const int& LDB,
      37             :             const std::complex<double>& BETA, const std::complex<double>* C,
      38             :             const int& LDC, size_t, size_t);
      39             : 
      40             : // The final argument is the "hidden" length of the first one.
      41             : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
      42             : void dgemv_(const char& TRANS, const int& M, const int& N, const double& ALPHA,
      43             :             const double* A, const int& LDA, const double* X, const int& INCX,
      44             :             const double& BETA, double* Y, const int& INCY, size_t);
      45             : }  // extern "C"
      46             : }  // namespace blas_detail
      47             : 
      48             : /*!
      49             :  * \brief Disable OpenBLAS multithreading since it conflicts with Charm++
      50             :  * parallelism
      51             :  *
      52             :  * Add this function to the `charm_init_node_funcs` of any executable that uses
      53             :  * BLAS routines.
      54             :  *
      55             :  * Details: https://github.com/xianyi/OpenBLAS/wiki/Faq#multi-threaded
      56             :  */
      57           1 : void disable_openblas_multithreading();
      58             : 
      59             : /// @{
      60             : /*!
      61             :  * \ingroup UtilitiesGroup
      62             :  * The dot product of two vectors.
      63             :  *
      64             :  * \param N the length of the vectors.
      65             :  * \param X a pointer to the first element of the first vector.
      66             :  * \param INCX the stride for the elements of the first vector.
      67             :  * \param Y a pointer to the first element of the second vector.
      68             :  * \param INCY the stride for the elements of the second vector.
      69             :  * \return the dot product of the given vectors.
      70             :  */
      71           1 : inline double ddot_(const size_t& N, const double* X, const size_t& INCX,
      72             :                     const double* Y, const size_t& INCY) {
      73             :   // INCX and INCY are allowed to be negative by BLAS, but we never
      74             :   // use them that way.  If needed, they can be changed here, but then
      75             :   // code providing values will also have to be changed to int to
      76             :   // avoid warnings.
      77             :   return blas_detail::ddot_(gsl::narrow_cast<int>(N), X,
      78             :                             gsl::narrow_cast<int>(INCX), Y,
      79             :                             gsl::narrow_cast<int>(INCY));
      80             : }
      81             : /// The unconjugated complex dot product $x \cdot y$. See `zdotc_` for the
      82             : /// conjugated complex dot product, which is the standard dot product on the
      83             : /// vector space of complex numbers.
      84           1 : inline std::complex<double> zdotu_(const size_t& N,
      85             :                                    const std::complex<double>* X,
      86             :                                    const size_t& INCX,
      87             :                                    const std::complex<double>* Y,
      88             :                                    const size_t& INCY) {
      89             :   // The complex result of the BLAS zdot* functions is sometimes returned by
      90             :   // value and sometimes returned by reference, depending on the Fortran
      91             :   // compiler settings. By using the cblas interface we ensure a consistent
      92             :   // behavior.
      93             :   std::complex<double> result;
      94             :   cblas_zdotu_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
      95             :                   gsl::narrow_cast<int>(INCY), &result);
      96             :   return result;
      97             : }
      98             : /// The conjugated complex dot product $\bar{x} \cdot y$. This is the standard
      99             : /// dot product on the vector space of complex numbers.
     100           1 : inline std::complex<double> zdotc_(const size_t& N,
     101             :                                    const std::complex<double>* X,
     102             :                                    const size_t& INCX,
     103             :                                    const std::complex<double>* Y,
     104             :                                    const size_t& INCY) {
     105             :   std::complex<double> result;
     106             :   cblas_zdotc_sub(gsl::narrow_cast<int>(N), X, gsl::narrow_cast<int>(INCX), Y,
     107             :                   gsl::narrow_cast<int>(INCY), &result);
     108             :   return result;
     109             : }
     110             : /// @}
     111             : 
     112             : /// @{
     113             : /*!
     114             :  * \ingroup UtilitiesGroup
     115             :  * \brief Perform a matrix-matrix multiplication
     116             :  *
     117             :  * Perform the matrix-matrix multiplication
     118             :  * \f[
     119             :  * C = \alpha \mathrm{op}(A) \mathrm{op}(B) + \beta \mathrm{op}(C)
     120             :  * \f]
     121             :  *
     122             :  * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
     123             :  * (transpose of \f$A\f$).
     124             :  *
     125             :  * LIBXSMM, which is much faster than BLAS for small matrices, can be called
     126             :  * instead of BLAS by passing the template parameter `true`.
     127             :  *
     128             :  * \param TRANSA either 'N', 'T' or 'C', transposition of matrix A
     129             :  * \param TRANSB either 'N', 'T' or 'C', transposition of matrix B
     130             :  * \param M Number of rows in \f$\mathrm{op}(A)\f$
     131             :  * \param N Number of columns in \f$\mathrm{op}(B)\f$ and \f$\mathrm{op}(C)\f$
     132             :  * \param K Number of columns in \f$\mathrm{op}(A)\f$
     133             :  * \param ALPHA specifies \f$\alpha\f$
     134             :  * \param A Matrix \f$A\f$
     135             :  * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
     136             :  * \param B Matrix \f$B\f$
     137             :  * \param LDB Specifies first dimension of \f$\mathrm{op}(B)\f$
     138             :  * \param BETA specifies \f$\beta\f$
     139             :  * \param C Matrix \f$C\f$
     140             :  * \param LDC Specifies first dimension of \f$\mathrm{op}(C)\f$
     141             :  * \tparam UseLibXsmm if `true` then use LIBXSMM
     142             :  */
     143             : template <bool UseLibXsmm = false>
     144           1 : inline void dgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
     145             :                    const size_t& N, const size_t& K, const double& ALPHA,
     146             :                    const double* A, const size_t& LDA, const double* B,
     147             :                    const size_t& LDB, const double& BETA, double* C,
     148             :                    const size_t& LDC) {
     149             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     150             :              'C' == TRANSA or 'c' == TRANSA,
     151             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     152             :          "documentation for help.");
     153             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     154             :              'C' == TRANSB or 'c' == TRANSB,
     155             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     156             :          "documentation for help.");
     157             : 
     158             :   // On some BLAS implementations (e.g. Accelerate.framework on macOS) a call
     159             :   // with a zero-sized dimension aborts instead of acting as a no-op.  Treat
     160             :   // these cases explicitly so we behave consistently across platforms.
     161             :   if (M == 0 or N == 0 or K == 0) {
     162             :     return;
     163             :   }
     164             : 
     165             :   blas_detail::dgemm_(
     166             :       TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     167             :       gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
     168             :       gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
     169             : }
     170             : template <bool UseLibXsmm = false>
     171           1 : inline void zgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
     172             :                    const size_t& N, const size_t& K,
     173             :                    const std::complex<double>& ALPHA,
     174             :                    const std::complex<double>* A, const size_t& LDA,
     175             :                    const std::complex<double>* B, const size_t& LDB,
     176             :                    const std::complex<double>& BETA, std::complex<double>* C,
     177             :                    const size_t& LDC) {
     178             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     179             :              'C' == TRANSA or 'c' == TRANSA,
     180             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     181             :          "documentation for help.");
     182             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     183             :              'C' == TRANSB or 'c' == TRANSB,
     184             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     185             :          "documentation for help.");
     186             : 
     187             :   // On some BLAS implementations (e.g. Accelerate.framework on macOS) a call
     188             :   // with a zero-sized dimension aborts instead of acting as a no-op.  Treat
     189             :   // these cases explicitly so we behave consistently across platforms.
     190             :   if (M == 0 or N == 0 or K == 0) {
     191             :     return;
     192             :   }
     193             : 
     194             :   blas_detail::zgemm_(
     195             :       TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     196             :       gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
     197             :       gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
     198             : }
     199             : 
     200             : // libxsmm is disabled in DEBUG builds because backtraces (from, for
     201             : // example, FPEs) do not work when the error occurs in libxsmm code.
     202             : #ifndef SPECTRE_DEBUG
     203             : template <>
     204           1 : inline void dgemm_<true>(const char& TRANSA, const char& TRANSB,
     205             :                          const size_t& M, const size_t& N, const size_t& K,
     206             :                          const double& ALPHA, const double* A,
     207             :                          const size_t& LDA, const double* B, const size_t& LDB,
     208             :                          const double& BETA, double* C, const size_t& LDC) {
     209             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     210             :              'C' == TRANSA or 'c' == TRANSA,
     211             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     212             :          "documentation for help.");
     213             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     214             :              'C' == TRANSB or 'c' == TRANSB,
     215             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     216             :          "documentation for help.");
     217             : 
     218             :   // On some BLAS implementations (e.g. Accelerate.framework on macOS) a call
     219             :   // with a zero-sized dimension aborts instead of acting as a no-op.  Treat
     220             :   // these cases explicitly so we behave consistently across platforms.
     221             :   if (M == 0 or N == 0 or K == 0) {
     222             :     return;
     223             :   }
     224             : 
     225             :   const auto m = gsl::narrow_cast<int>(M);
     226             :   const auto n = gsl::narrow_cast<int>(N);
     227             :   const auto k = gsl::narrow_cast<int>(K);
     228             :   const auto lda = gsl::narrow_cast<int>(LDA);
     229             :   const auto ldb = gsl::narrow_cast<int>(LDB);
     230             :   const auto ldc = gsl::narrow_cast<int>(LDC);
     231             :   libxsmm_dgemm(&TRANSA, &TRANSB, &m, &n, &k, &ALPHA, A, &lda, B, &ldb, &BETA,
     232             :                 C, &ldc);
     233             : }
     234             : #endif  // ifndef SPECTRE_DEBUG
     235             : /// @}
     236             : 
     237             : /// @{
     238             : /*!
     239             :  * \ingroup UtilitiesGroup
     240             :  * \brief Perform a matrix-vector multiplication
     241             :  *
     242             :  * \f[
     243             :  * y = \alpha \mathrm{op}(A) x + \beta y
     244             :  * \f]
     245             :  *
     246             :  * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
     247             :  * (transpose of \f$A\f$).
     248             :  *
     249             :  * \param TRANS either 'N', 'T' or 'C', transposition of matrix A
     250             :  * \param M Number of rows in \f$\mathrm{op}(A)\f$
     251             :  * \param N Number of columns in \f$\mathrm{op}(A)\f$
     252             :  * \param ALPHA specifies \f$\alpha\f$
     253             :  * \param A Matrix \f$A\f$
     254             :  * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
     255             :  * \param X Vector \f$x\f$
     256             :  * \param INCX Specifies the increment for the elements of \f$x\f$
     257             :  * \param BETA Specifies \f$\beta\f$
     258             :  * \param Y Vector \f$y\f$
     259             :  * \param INCY Specifies the increment for the elements of \f$y\f$
     260             :  */
     261           1 : inline void dgemv_(const char& TRANS, const size_t& M, const size_t& N,
     262             :                    const double& ALPHA, const double* A, const size_t& LDA,
     263             :                    const double* X, const size_t& INCX, const double& BETA,
     264             :                    double* Y, const size_t& INCY) {
     265             :   ASSERT('N' == TRANS or 'n' == TRANS or 'T' == TRANS or 't' == TRANS or
     266             :              'C' == TRANS or 'c' == TRANS,
     267             :          "TRANS must be upper or lower case N, T, or C. See the BLAS "
     268             :          "documentation for help.");
     269             : 
     270             :   // On some BLAS implementations (e.g. Accelerate.framework on macOS) a call
     271             :   // with a zero-sized dimension aborts instead of acting as a no-op.  Treat
     272             :   // these cases explicitly so we behave consistently across platforms.
     273             :   if (M == 0 or N == 0) {
     274             :     return;
     275             :   }
     276             : 
     277             :   // INCX and INCY are allowed to be negative by BLAS, but we never
     278             :   // use them that way.  If needed, they can be changed here, but then
     279             :   // code providing values will also have to be changed to int to
     280             :   // avoid warnings.
     281             :   blas_detail::dgemv_(TRANS, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     282             :                       ALPHA, A, gsl::narrow_cast<int>(LDA), X,
     283             :                       gsl::narrow_cast<int>(INCX), BETA, Y,
     284             :                       gsl::narrow_cast<int>(INCY), 1);
     285             : }
     286             : /// @}

Generated by: LCOV version 1.14