Line data Source code
1 0 : // Distributed under the MIT License. 2 : // See LICENSE.txt for details. 3 : 4 : #pragma once 5 : 6 : #include <algorithm> 7 : #include <blaze/math/Column.h> 8 : #include <blaze/math/Matrix.h> 9 : #include <blaze/math/typetraits/IsDenseMatrix.h> 10 : #include <blaze/math/typetraits/IsSparseMatrix.h> 11 : #include <cstddef> 12 : #include <tuple> 13 : 14 : #include "Utilities/EqualWithinRoundoff.hpp" 15 : #include "Utilities/Gsl.hpp" 16 : #include "Utilities/TypeTraits/CreateIsCallable.hpp" 17 : 18 0 : namespace LinearSolver::Serial { 19 : 20 : namespace detail { 21 : CREATE_IS_CALLABLE(reset) 22 : CREATE_IS_CALLABLE_V(reset) 23 : } // namespace detail 24 : 25 : /*! 26 : * \brief Construct an explicit matrix representation by "sniffing out" the 27 : * linear operator, i.e. feeding it unit vectors. 28 : * 29 : * \param matrix Output buffer for the operator matrix. Must be sized correctly 30 : * on entry. Can be any dense or sparse Blaze matrix. 31 : * \param operand_buffer Memory buffer that can hold operand data for the 32 : * `linear_operator`. Must be sized correctly on entry, and must be filled with 33 : * zeros. 34 : * \param result_buffer Memory buffer that can hold the result of the 35 : * `linear_operator` applied to the operand. Must be sized correctly on entry. 36 : * \param linear_operator The linear operator of which the matrix representation 37 : * should be constructed. See `::LinearSolver::Serial::LinearSolver::solve` for 38 : * requirements on the linear operator. 39 : * \param operator_args These arguments are passed along to the 40 : * `linear_operator` when it is applied to an operand. 41 : */ 42 : template <typename LinearOperator, typename OperandType, typename ResultType, 43 : typename MatrixType, typename... OperatorArgs> 44 1 : void build_matrix(const gsl::not_null<MatrixType*> matrix, 45 : const gsl::not_null<OperandType*> operand_buffer, 46 : const gsl::not_null<ResultType*> result_buffer, 47 : const LinearOperator& linear_operator, 48 : const std::tuple<OperatorArgs...>& operator_args = {}) { 49 : static_assert( 50 : blaze::IsSparseMatrix_v<MatrixType> or blaze::IsDenseMatrix_v<MatrixType>, 51 : "Unexpected matrix type"); 52 : if constexpr (blaze::IsSparseMatrix_v<MatrixType>) { 53 : matrix->reset(); 54 : } 55 : size_t i = 0; 56 : // Re-using the iterators for all operator invocations 57 : auto result_iterator_begin = result_buffer->begin(); 58 : auto result_iterator_end = result_buffer->end(); 59 : for (double& unit_vector_data : *operand_buffer) { 60 : // Set a 1 at the unit vector location i 61 : unit_vector_data = 1.; 62 : // Invoke the operator on the unit vector 63 : std::apply( 64 : linear_operator, 65 : std::tuple_cat(std::forward_as_tuple(result_buffer, *operand_buffer), 66 : operator_args)); 67 : // Set the unit vector back to zero 68 : unit_vector_data = 0.; 69 : // Reset the iterator by calling its `reset` member function or by 70 : // re-creating it 71 : if constexpr (detail::is_reset_callable_v< 72 : decltype(result_iterator_begin)>) { 73 : result_iterator_begin.reset(); 74 : } else { 75 : result_iterator_begin = result_buffer->begin(); 76 : result_iterator_end = result_buffer->end(); 77 : } 78 : // Store the result in column i of the matrix 79 : auto col = column(*matrix, i); 80 : if constexpr (blaze::IsSparseMatrix_v<MatrixType>) { 81 : size_t k = 0; 82 : while (result_iterator_begin != result_iterator_end) { 83 : if (not equal_within_roundoff(*result_iterator_begin, 0.)) { 84 : col[k] = *result_iterator_begin; 85 : } 86 : ++result_iterator_begin; 87 : ++k; 88 : } 89 : } else { 90 : std::copy(result_iterator_begin, result_iterator_end, col.begin()); 91 : } 92 : ++i; 93 : } 94 : } 95 : 96 : } // namespace LinearSolver::Serial