SpECTRE Documentation Coverage Report
Current view: top level - Utilities - Blas.hpp Hit Total Coverage
Commit: 5b6dac11263b5fb9107cb6ea064c64c61b65a417 Lines: 6 6 100.0 %
Date: 2024-04-19 22:56:45
Legend: Lines: hit not hit

          Line data    Source code
       1           1 : // Distributed under the MIT License.
       2             : // See LICENSE.txt for details.
       3             : 
       4             : /// \file
       5             : /// Declares the interfaces for the BLAS used.
       6             : ///
       7             : /// Wrappers are defined to perform casts from different integer types
       8             : /// when the natural type in C++ differs from the BLAS argument.
       9             : 
      10             : #pragma once
      11             : 
      12             : #ifndef SPECTRE_DEBUG
      13             : #include <libxsmm.h>
      14             : #endif  // ifndef SPECTRE_DEBUG
      15             : 
      16             : #include "Utilities/ErrorHandling/Assert.hpp"
      17             : #include "Utilities/Gsl.hpp"
      18             : 
      19             : namespace blas_detail {
      20             : extern "C" {
      21             : double ddot_(const int& N, const double* X, const int& INCX, const double* Y,
      22             :              const int& INCY);
      23             : 
      24             : // The final two arguments are the "hidden" lengths of the first two.
      25             : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
      26             : void dgemm_(const char& TRANSA, const char& TRANSB, const int& M, const int& N,
      27             :             const int& K, const double& ALPHA, const double* A, const int& LDA,
      28             :             const double* B, const int& LDB, const double& BETA,
      29             :             const double* C, const int& LDC, size_t, size_t);
      30             : 
      31             : // The final argument is the "hidden" length of the first one.
      32             : // https://gcc.gnu.org/onlinedocs/gfortran/Argument-passing-conventions.html
      33             : void dgemv_(const char& TRANS, const int& M, const int& N, const double& ALPHA,
      34             :             const double* A, const int& LDA, const double* X, const int& INCX,
      35             :             const double& BETA, double* Y, const int& INCY, size_t);
      36             : }  // extern "C"
      37             : }  // namespace blas_detail
      38             : 
      39             : /*!
      40             :  * \brief Disable OpenBLAS multithreading since it conflicts with Charm++
      41             :  * parallelism
      42             :  *
      43             :  * Add this function to the `charm_init_node_funcs` of any executable that uses
      44             :  * BLAS routines.
      45             :  *
      46             :  * Details: https://github.com/xianyi/OpenBLAS/wiki/Faq#multi-threaded
      47             :  */
      48           1 : void disable_openblas_multithreading();
      49             : 
      50             : /*!
      51             :  * \ingroup UtilitiesGroup
      52             :  * The dot product of two vectors.
      53             :  *
      54             :  * \param N the length of the vectors.
      55             :  * \param X a pointer to the first element of the first vector.
      56             :  * \param INCX the stride for the elements of the first vector.
      57             :  * \param Y a pointer to the first element of the second vector.
      58             :  * \param INCY the stride for the elements of the second vector.
      59             :  * \return the dot product of the given vectors.
      60             :  */
      61           1 : inline double ddot_(const size_t& N, const double* X, const size_t& INCX,
      62             :                     const double* Y, const size_t& INCY) {
      63             :   // INCX and INCY are allowed to be negative by BLAS, but we never
      64             :   // use them that way.  If needed, they can be changed here, but then
      65             :   // code providing values will also have to be changed to int to
      66             :   // avoid warnings.
      67             :   return blas_detail::ddot_(gsl::narrow_cast<int>(N), X,
      68             :                             gsl::narrow_cast<int>(INCX), Y,
      69             :                             gsl::narrow_cast<int>(INCY));
      70             : }
      71             : 
      72             : /// @{
      73             : /*!
      74             :  * \ingroup UtilitiesGroup
      75             :  * \brief Perform a matrix-matrix multiplication
      76             :  *
      77             :  * Perform the matrix-matrix multiplication
      78             :  * \f[
      79             :  * C = \alpha \mathrm{op}(A) \mathrm{op}(B) + \beta \mathrm{op}(C)
      80             :  * \f]
      81             :  *
      82             :  * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
      83             :  * (transpose of \f$A\f$).
      84             :  *
      85             :  * LIBXSMM, which is much faster than BLAS for small matrices, can be called
      86             :  * instead of BLAS by passing the template parameter `true`.
      87             :  *
      88             :  * \param TRANSA either 'N', 'T' or 'C', transposition of matrix A
      89             :  * \param TRANSB either 'N', 'T' or 'C', transposition of matrix B
      90             :  * \param M Number of rows in \f$\mathrm{op}(A)\f$
      91             :  * \param N Number of columns in \f$\mathrm{op}(B)\f$ and \f$\mathrm{op}(C)\f$
      92             :  * \param K Number of columns in \f$\mathrm{op}(A)\f$
      93             :  * \param ALPHA specifies \f$\alpha\f$
      94             :  * \param A Matrix \f$A\f$
      95             :  * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
      96             :  * \param B Matrix \f$B\f$
      97             :  * \param LDB Specifies first dimension of \f$\mathrm{op}(B)\f$
      98             :  * \param BETA specifies \f$\beta\f$
      99             :  * \param C Matrix \f$C\f$
     100             :  * \param LDC Specifies first dimension of \f$\mathrm{op}(C)\f$
     101             :  * \tparam UseLibXsmm if `true` then use LIBXSMM
     102             :  */
     103             : template <bool UseLibXsmm = false>
     104           1 : inline void dgemm_(const char& TRANSA, const char& TRANSB, const size_t& M,
     105             :                    const size_t& N, const size_t& K, const double& ALPHA,
     106             :                    const double* A, const size_t& LDA, const double* B,
     107             :                    const size_t& LDB, const double& BETA, double* C,
     108             :                    const size_t& LDC) {
     109             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     110             :              'C' == TRANSA or 'c' == TRANSA,
     111             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     112             :          "documentation for help.");
     113             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     114             :              'C' == TRANSB or 'c' == TRANSB,
     115             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     116             :          "documentation for help.");
     117             :   blas_detail::dgemm_(
     118             :       TRANSA, TRANSB, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     119             :       gsl::narrow_cast<int>(K), ALPHA, A, gsl::narrow_cast<int>(LDA), B,
     120             :       gsl::narrow_cast<int>(LDB), BETA, C, gsl::narrow_cast<int>(LDC), 1, 1);
     121             : }
     122             : 
     123             : // libxsmm is disabled in DEBUG builds because backtraces (from, for
     124             : // example, FPEs) do not work when the error occurs in libxsmm code.
     125             : #ifndef SPECTRE_DEBUG
     126             : template <>
     127           1 : inline void dgemm_<true>(const char& TRANSA, const char& TRANSB,
     128             :                          const size_t& M, const size_t& N, const size_t& K,
     129             :                          const double& ALPHA, const double* A,
     130             :                          const size_t& LDA, const double* B, const size_t& LDB,
     131             :                          const double& BETA, double* C, const size_t& LDC) {
     132             :   ASSERT('N' == TRANSA or 'n' == TRANSA or 'T' == TRANSA or 't' == TRANSA or
     133             :              'C' == TRANSA or 'c' == TRANSA,
     134             :          "TRANSA must be upper or lower case N, T, or C. See the BLAS "
     135             :          "documentation for help.");
     136             :   ASSERT('N' == TRANSB or 'n' == TRANSB or 'T' == TRANSB or 't' == TRANSB or
     137             :              'C' == TRANSB or 'c' == TRANSB,
     138             :          "TRANSB must be upper or lower case N, T, or C. See the BLAS "
     139             :          "documentation for help.");
     140             :   const auto m = gsl::narrow_cast<int>(M);
     141             :   const auto n = gsl::narrow_cast<int>(N);
     142             :   const auto k = gsl::narrow_cast<int>(K);
     143             :   const auto lda = gsl::narrow_cast<int>(LDA);
     144             :   const auto ldb = gsl::narrow_cast<int>(LDB);
     145             :   const auto ldc = gsl::narrow_cast<int>(LDC);
     146             :   libxsmm_dgemm(&TRANSA, &TRANSB, &m, &n, &k, &ALPHA, A, &lda, B, &ldb, &BETA,
     147             :                 C, &ldc);
     148             : }
     149             : #endif  // ifndef SPECTRE_DEBUG
     150             : /// @}
     151             : 
     152             : /// @{
     153             : /*!
     154             :  * \ingroup UtilitiesGroup
     155             :  * \brief Perform a matrix-vector multiplication
     156             :  *
     157             :  * \f[
     158             :  * y = \alpha \mathrm{op}(A) x + \beta y
     159             :  * \f]
     160             :  *
     161             :  * where \f$\mathrm{op}(A)\f$ represents either \f$A\f$ or \f$A^{T}\f$
     162             :  * (transpose of \f$A\f$).
     163             :  *
     164             :  * \param TRANS either 'N', 'T' or 'C', transposition of matrix A
     165             :  * \param M Number of rows in \f$\mathrm{op}(A)\f$
     166             :  * \param N Number of columns in \f$\mathrm{op}(A)\f$
     167             :  * \param ALPHA specifies \f$\alpha\f$
     168             :  * \param A Matrix \f$A\f$
     169             :  * \param LDA Specifies first dimension of \f$\mathrm{op}(A)\f$
     170             :  * \param X Vector \f$x\f$
     171             :  * \param INCX Specifies the increment for the elements of \f$x\f$
     172             :  * \param BETA Specifies \f$\beta\f$
     173             :  * \param Y Vector \f$y\f$
     174             :  * \param INCY Specifies the increment for the elements of \f$y\f$
     175             :  */
     176           1 : inline void dgemv_(const char& TRANS, const size_t& M, const size_t& N,
     177             :                    const double& ALPHA, const double* A, const size_t& LDA,
     178             :                    const double* X, const size_t& INCX, const double& BETA,
     179             :                    double* Y, const size_t& INCY) {
     180             :   ASSERT('N' == TRANS or 'n' == TRANS or 'T' == TRANS or 't' == TRANS or
     181             :              'C' == TRANS or 'c' == TRANS,
     182             :          "TRANS must be upper or lower case N, T, or C. See the BLAS "
     183             :          "documentation for help.");
     184             :   // INCX and INCY are allowed to be negative by BLAS, but we never
     185             :   // use them that way.  If needed, they can be changed here, but then
     186             :   // code providing values will also have to be changed to int to
     187             :   // avoid warnings.
     188             :   blas_detail::dgemv_(TRANS, gsl::narrow_cast<int>(M), gsl::narrow_cast<int>(N),
     189             :                       ALPHA, A, gsl::narrow_cast<int>(LDA), X,
     190             :                       gsl::narrow_cast<int>(INCX), BETA, Y,
     191             :                       gsl::narrow_cast<int>(INCY), 1);
     192             : }
     193             : /// @}

Generated by: LCOV version 1.14