From c55e179e7702d87f08ba799ff6b9838c95238452 Mon Sep 17 00:00:00 2001 From: Spencer Patty Date: Fri, 27 Mar 2026 15:07:40 -0700 Subject: [PATCH 1/6] expand reference and onemkl_sycl backend to use operation_info_t, and introduce xyz_inspect for many cases, also add two new examples which uses them --- examples/CMakeLists.txt | 5 + examples/spmm_csr.cpp | 55 ++++++++ examples/sptrsv_csr.cpp | 64 +++++++++ include/spblas/algorithms/multiply.hpp | 27 ++++ include/spblas/algorithms/multiply_impl.hpp | 54 +++++++- .../spblas/algorithms/triangular_solve.hpp | 14 +- .../algorithms/triangular_solve_impl.hpp | 40 ++++++ .../detail/create_matrix_handle.hpp | 13 +- .../spblas/vendor/onemkl_sycl/spgemm_impl.hpp | 50 ++++++- .../spblas/vendor/onemkl_sycl/spmm_impl.hpp | 115 +++++++++++++++- .../spblas/vendor/onemkl_sycl/spmv_impl.hpp | 55 +++++++- .../onemkl_sycl/triangular_solve_impl.hpp | 126 +++++++++++++++--- 12 files changed, 585 insertions(+), 33 deletions(-) create mode 100644 examples/spmm_csr.cpp create mode 100644 examples/sptrsv_csr.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index fcf3a82..9d99ea6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -11,6 +11,11 @@ if (SPBLAS_CPU_BACKEND) add_example(simple_sptrsv) add_example(spmm_csc) add_example(matrix_opt_example) + if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND ) + # needs CPU + matrix_opt + operation_info_t to run + add_example(sptrsv_csr) # needs triangular_solve{_inspect} to run + add_example(spmm_csr) # needs multiply{_inspect} to run + endif() endif() # GPU examples diff --git a/examples/spmm_csr.cpp b/examples/spmm_csr.cpp new file mode 100644 index 0000000..697f42b --- /dev/null +++ b/examples/spmm_csr.cpp @@ -0,0 +1,55 @@ +#include + +#include +#include + +int main(int argc, char** argv) { + using namespace spblas; + namespace md = spblas::__mdspan; + + using T = float; + + spblas::index_t m = 10; + spblas::index_t n = 10; + spblas::index_t k = 10; + spblas::index_t nnz_in = 20; + + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n\t### Running Advanced SpMM Example:"); + fmt::print("\n\t###"); + fmt::print("\n\t### Y = alpha * A * X"); + fmt::print("\n\t###"); + fmt::print("\n\t### with "); + fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, k, + nnz_in); + fmt::print("\n\t### x, a dense matrix, of size ({}, {})", k, n); + fmt::print("\n\t### y, a dense vector, of size ({}, {})", m, n); + fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)", + sizeof(spblas::index_t)); + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n"); + + auto&& [values, rowptr, colind, shape, nnz] = generate_csr(m, k, nnz_in); + + csr_view a(values, rowptr, colind, shape, nnz); + matrix_opt a_opt(a); + + std::vector x_values(k * n, 1); + std::vector y_values(m * n, 0); + + md::mdspan x(x_values.data(), k, n); + md::mdspan y(y_values.data(), m, n); + + + // Y = A * X + auto state = multiply_inspect(a_opt, x, y); + multiply(state, a_opt, x, y); + + fmt::print("{}\n", spblas::__backend::values(y)); + + fmt::print("\tExample is completed!\n"); + + return 0; +} diff --git a/examples/sptrsv_csr.cpp b/examples/sptrsv_csr.cpp new file mode 100644 index 0000000..7abfe24 --- /dev/null +++ b/examples/sptrsv_csr.cpp @@ -0,0 +1,64 @@ +#include + +#include +#include + +int main(int argc, char** argv) { + using namespace spblas; + + using T = float; + + spblas::index_t m = 100; + spblas::index_t nnz_in = 20; + + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n\t### Running Full SpTRSV Example:"); + fmt::print("\n\t###"); + fmt::print("\n\t### solve for x: A * x = alpha * b"); + fmt::print("\n\t###"); + fmt::print("\n\t### with "); + fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, m, + nnz_in); + fmt::print("\n\t### x, a dense vector, of size ({}, {})", m, 1); + fmt::print("\n\t### b, a dense vector, of size ({}, {})", m, 1); + fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)", + sizeof(spblas::index_t)); + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n"); + + auto&& [values, rowptr, colind, shape, nnz] = + generate_csr(m, m, nnz_in); + + // scale values of matrix to make the implicit unit diagonal matrix + // be diagonally dominant, so it is solveable + T scale_factor = 1e-3f; + std::transform(values.begin(), values.end(), values.begin(), + [scale_factor](T val) { return scale_factor * val; }); + + csr_view a(values, rowptr, colind, shape, nnz); + + matrix_opt a_opt(a); + + // Scale every value of `a` by 5 in place. + // scale(5.f, a); + + std::vector x(m, 0); + std::vector b(m, 1); + + T alpha = 1.2f; + auto b_scaled = scaled(alpha, b); + + // solve for x: lower(A) * x = alpha * b + triangular_solve_inspect(a_opt, spblas::upper_triangle_t{}, + spblas::implicit_unit_diagonal_t{}, b_scaled, x); + + triangular_solve(a_opt, spblas::upper_triangle_t{}, + spblas::implicit_unit_diagonal_t{}, b_scaled, x); + + + fmt::print("\tExample is completed!\n"); + + return 0; +} diff --git a/include/spblas/algorithms/multiply.hpp b/include/spblas/algorithms/multiply.hpp index f15748e..1ce86d2 100644 --- a/include/spblas/algorithms/multiply.hpp +++ b/include/spblas/algorithms/multiply.hpp @@ -5,18 +5,45 @@ namespace spblas { +// SpMV variants +template +operation_info_t multiply_inspect(A&& a, B&& b, C&& c); + +template +void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c); + template void multiply(A&& a, B&& b, C&& c); +template +void multiply(operation_into_t& info, A&& a, B&& b, C&& c); + + +// SpMM variants template void multiply(A&& a, B&& b, C&& c); +template +void multiply(operation_info_t& info, A&& a, B&& b, C&& c); + +// SpMM and SpGEMM multiply_inspect variants template operation_info_t multiply_inspect(A&& a, B&& b, C&& c); template void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c); + +// SpGEMM variants +template +operation_info_t multiply_compute(ExecutionPolicy &&policy, A&& a, B&& b, C&& c); + +template +void multiply_compute(ExecutionPolicy &&policy, operation_info_t& info, A&& a, B&& b, C&& c); + +template +void multiply_fill(ExecutionPolicy &&policy, operation_info_t& info, A&& a, B&& b, C&& c); + template operation_info_t multiply_compute(A&& a, B&& b, C&& c); diff --git a/include/spblas/algorithms/multiply_impl.hpp b/include/spblas/algorithms/multiply_impl.hpp index d56da6b..eb02e40 100644 --- a/include/spblas/algorithms/multiply_impl.hpp +++ b/include/spblas/algorithms/multiply_impl.hpp @@ -15,6 +15,19 @@ namespace spblas { +// SpMV inspect +template +operation_info_t multiply_inspect(A&& a, B&& b, C&& c) { + log_trace(""); + return operation_info_t{}; +} + +// SpMV inspect +template +operation_info_t multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) { + log_trace(""); +} + // C = AB // SpMV template @@ -39,6 +52,16 @@ void multiply(A&& a, B&& b, C&& c) { }); } +// C = AB +// SpMV with info input +template + requires(__backend::lookupable && __backend::lookupable) +void multiply(operation_info_t &info, A&& a, B&& b, C&& c) { + log_trace(""); + multiply(std::forward(a), std::forward(b), std::forward(c)); +} + + // C = AB // SpMM template @@ -52,37 +75,63 @@ void multiply(A&& a, B&& b, C&& c) { "multiply: matrix dimensions are incompatible."); } + // initializes c to zero so we can use += everywhere __backend::for_each(c, [](auto&& e) { auto&& [_, v] = e; v = 0; }); + // traverses elements of a and performs appropriate + // multiplication with B rows __backend::for_each(a, [&](auto&& e) { auto&& [idx, a_v] = e; auto&& [i, k] = idx; - for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { + for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { // b_row __backend::lookup(c, i, j) += a_v * __backend::lookup(b, k, j); } }); } +// C = AB +// SpMM with info +template + requires(__backend::lookupable && __backend::lookupable) +void multiply(operation_info_t &info, A&& a, B&& b, C&& c) { + log_trace(""); + multiply(std::forward(a), std::forward(b), std::forward(c)); +} + + +// C = AB +// SpMM or SpGEMM multiply_inspect variants end up here template operation_info_t multiply_inspect(A&& a, B&& b, C&& c) { + log_trace(""); return operation_info_t{}; } +// C = AB +// SpMM or SpGEMM multiply_inspect variants end up here template -void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c){}; +void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c){ + log_trace(""); +}; + +// C = AB +// SpGEMM compute stage with CSR output template requires(__backend::row_iterable && __backend::row_iterable && __detail::is_csr_view_v) void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) { + log_trace(""); auto new_info = multiply_compute(std::forward(a), std::forward(b), std::forward(c)); info.update_impl_(new_info.result_shape(), new_info.result_nnz()); } +// C = AB +// SpGEMM compute stage with CSC output template requires(__backend::column_iterable && __backend::column_iterable && __detail::is_csc_view_v) @@ -93,6 +142,7 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) { } // C = AB +// SpGEMM fill stage with CSR or CSC output template void multiply_fill(operation_info_t info, A&& a, B&& b, C&& c) { log_trace(""); diff --git a/include/spblas/algorithms/triangular_solve.hpp b/include/spblas/algorithms/triangular_solve.hpp index 5bf1d88..948f0d0 100644 --- a/include/spblas/algorithms/triangular_solve.hpp +++ b/include/spblas/algorithms/triangular_solve.hpp @@ -3,13 +3,17 @@ #include #include -template -void triangular_matrix_vector_solve(ExecutionPolicy&& exec, InMat A, Triangle t, - DiagonalStorage d, InVec b, OutVec x); - namespace spblas { + +template +void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle uplo, DiagonalStorage diag, B&& b, X&& x); + + +template +operation_info_t triangular_solve_inspect(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, X&& x); + + template void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, X&& x); diff --git a/include/spblas/algorithms/triangular_solve_impl.hpp b/include/spblas/algorithms/triangular_solve_impl.hpp index 52be891..e207fee 100644 --- a/include/spblas/algorithms/triangular_solve_impl.hpp +++ b/include/spblas/algorithms/triangular_solve_impl.hpp @@ -8,10 +8,39 @@ namespace spblas { +// X = inv(A) B +// SpTRSV inspect stage +template + requires(__backend::row_iterable && __backend::lookupable && + __backend::lookupable) +operation_info_t triangular_solve_inspect(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { + log_trace(""); + static_assert(std::is_same_v || + std::is_same_v); + assert(__backend::shape(a)[0] == __backend::shape(a)[1]); + + return operation_info_t{}; +} + +// X = inv(A) B +// SpTRSV inspect stage +template + requires(__backend::row_iterable && __backend::lookupable && + __backend::lookupable) +void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { + log_trace(""); + static_assert(std::is_same_v || + std::is_same_v); + assert(__backend::shape(a)[0] == __backend::shape(a)[1]); +} + +// X = inv(A) B +// SpTRSV solve stage template requires(__backend::row_iterable && __backend::lookupable && __backend::lookupable) void triangular_solve(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { + log_trace(""); static_assert(std::is_same_v || std::is_same_v); assert(__backend::shape(a)[0] == __backend::shape(a)[1]); @@ -62,4 +91,15 @@ void triangular_solve(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { } } +// X = inv(A) B +// SpTRSV solve stage with info +template + requires(__backend::row_iterable && __backend::lookupable && + __backend::lookupable) +void triangular_solve(operation_info_t& info, A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { + log_trace(""); + triangular_solve(std::forward(a), std::forward(t), std::forward(d), std::forward(b), std::forward(x)); +} + + } // namespace spblas diff --git a/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp b/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp index 2413766..616020f 100644 --- a/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp +++ b/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp @@ -17,8 +17,13 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q, oneapi::mkl::sparse::matrix_handle_t handle = nullptr; oneapi::mkl::sparse::init_matrix_handle(&handle); + oneapi::mkl::sparse::set_csr_data( - q, handle, m.shape()[0], m.shape()[1], oneapi::mkl::index_base::zero, + q, handle, m.shape()[0], m.shape()[1], +#if defined(__INTEL_MKL__) && ( (__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || (__INTEL_MKL__ > 2025 ) ) + m.size(), // nnz added in 2025.3, and without deprecated +#endif + oneapi::mkl::index_base::zero, m.rowptr().data(), m.colind().data(), m.values().data()) .wait(); @@ -33,7 +38,11 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q, oneapi::mkl::sparse::init_matrix_handle(&handle); oneapi::mkl::sparse::set_csr_data( - q, handle, m.shape()[1], m.shape()[0], oneapi::mkl::index_base::zero, + q, handle, m.shape()[1], m.shape()[0], +#if defined(__INTEL_MKL__) && ( (__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || (__INTEL_MKL__ > 2025 ) ) + m.size(), // nnz added in 2025.3, and without deprecated +#endif + oneapi::mkl::index_base::zero, m.colptr().data(), m.rowind().data(), m.values().data()) .wait(); diff --git a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp index 4ee63c9..35afc89 100644 --- a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp @@ -29,6 +29,10 @@ namespace spblas { + +// +// multiply_compute -- csr/csc * csr/csc -> csr with ExecutionPolicy +// template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && @@ -68,6 +72,9 @@ operation_info_t oneapi::mkl::sparse::set_csr_data( q, c_handle, __backend::shape(c)[0], __backend::shape(c)[1], +#if defined(__INTEL_MKL__) && ( (__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || (__INTEL_MKL__ > 2025 ) ) + __backend::size(c), // nnz added in 2025.3, and without deprecated +#endif oneapi::mkl::index_base::zero, c_rowptr, (I*) nullptr, (T*) nullptr) .wait(); @@ -117,8 +124,36 @@ operation_info_t __mkl::operation_state_t{__detail::has_matrix_opt(a) ? nullptr : a_handle, __detail::has_matrix_opt(b) ? nullptr : b_handle, c_handle, nullptr, descr, (void*) c_rowptr, q}}; -} +} // multiply_compute +// +// multiply_compute -- csr/csc * csr/csc -> csr with ExecutionPolicy +// +template + requires(__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csr_view_v +void + multiply_compute(ExecutionPolicy&& policy, operation_info_t &info, A&& a, B&& b, C&& c) { + log_trace(""); + + auto tmp_info = multiply_compute(std::forward(policy), std::forward(a), std::forward(b), std::forward(c)); + + // fill the normal bucket of state stuf based on creating model for now. + info.update_impl_(tmp_info.result_shape(), tmp_info.result_nnz()); + info.state_.a_handle = tmp_info.state_.a_handle; + info.state_.b_handle = tmp_info.state_.b_handle; + info.state_.c_handle = tmp_info.state_.c_handle; + info.state_.descr = tmp_info.state_.descr; + info.state_.c_rowptr = tmp_info.state_.c_rowptr; + info.state_.q = tmp_info.state_.q; + +} // multiply_compute + + +// +// multiply_fill -- csr/csc * csr/csc -> csr with ExecutionPolicy +// template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && @@ -155,6 +190,9 @@ void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a, auto ev_setC = oneapi::mkl::sparse::set_csr_data( q, c_handle, __backend::shape(c)[0], __backend::shape(c)[1], +#if defined(__INTEL_MKL__) && ( (__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || (__INTEL_MKL__ > 2025 ) ) + __backend::size(c), // nnz added in 2025.3, and without deprecated +#endif oneapi::mkl::index_base::zero, c_rowptr, c.colind().data(), c.values().data()); @@ -186,6 +224,15 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) { std::forward(c)); } +template + requires(__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csr_view_v +void multiply_compute(operation_info_t & info, A&& a, B&& b, C&& c) { + return multiply_compute(mkl::par, std::forward(info), std::forward(a), std::forward(b), + std::forward(c)); +} + template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && @@ -203,6 +250,7 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) { return multiply_compute(transposed(b), transposed(a), transposed(c)); } + template requires((__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && diff --git a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp index 82d1415..1b6ee6a 100644 --- a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp @@ -37,9 +37,60 @@ template __mdspan::layout_right> && std::is_same_v::layout_type, __mdspan::layout_right>) -void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { +void multiply_inspect(ExecutionPolicy&& policy, operation_info_t& info, A&& a, X&& x, Y&& y) { + log_trace(""); + if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense matrices."); + } + + if (__detail::has_matrix_opt(a)) { + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); + + auto a_handle = __mkl::get_matrix_handle(q, a); + auto a_transpose = __mkl::get_transpose(a); + + auto x_base = __detail::get_ultimate_base(x); + + oneapi::mkl::sparse::optimize_gemm(q, oneapi::mkl::layout::row_major, a_transpose, + oneapi::mkl::transpose::nontrans, a_handle, static_cast(x_base.extent(1))) + .wait(); + } + else { + // do nothing, since it would be immediately discarded + log_info("No work done, since no matrix_opt to store optimized results into!"); + } +} // multiply_inspect + +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +operation_info_t multiply_inspect(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { + log_trace(""); + operation_info_t info{}; + + multiply_inspect(std::forward(policy), info, std::forward(a), std::forward(x), std::forward(y)); + + return info; +} + + +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +void multiply(ExecutionPolicy&& policy, operation_info_t& info, A&& a, X&& x, Y&& y) { log_trace(""); - auto x_base = __detail::get_ultimate_base(x); if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { throw std::runtime_error( @@ -55,6 +106,8 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { auto a_handle = __mkl::get_matrix_handle(q, a); auto a_transpose = __mkl::get_transpose(a); + auto x_base = __detail::get_ultimate_base(x); + oneapi::mkl::sparse::gemm(q, oneapi::mkl::layout::row_major, a_transpose, oneapi::mkl::transpose::nontrans, alpha, a_handle, x_base.data_handle(), x_base.extent(1), @@ -66,6 +119,60 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { } } + +// +// multiply_inspect - CSR/CSC with row major dense matrix rhs without execution policy +// +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +operation_info_t multiply_inspect(A&& a, X&& x, Y&& y) { + auto info = multiply_inspect(mkl::par, std::forward(a), + std::forward(x), std::forward(y)); + return info; +} + +// +// multiply_inspect - CSR/CSC with row major dense matrix rhs without execution policy +// +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +void multiply_inspect(operation_info_t& info, A&& a, X&& x, Y&& y) { + multiply_inspect(mkl::par, info, std::forward(a), std::forward(x), + std::forward(y)); +} + + +// +// multiply - CSR/CSC with row major dense matrix rhs without execution policy +// +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { + multiply(mkl::par, info, std::forward(a), std::forward(x), + std::forward(y)); +} + +// +// multiply - CSR/CSC with row major dense matrix rhs without execution policy or state object +// template requires( (__detail::has_csr_base || __detail::has_csc_base) && @@ -75,8 +182,10 @@ template std::is_same_v::layout_type, __mdspan::layout_right>) void multiply(A&& a, X&& x, Y&& y) { - multiply(mkl::par, std::forward(a), std::forward(x), + operation_info_t info{}; + multiply(mkl::par, info, std::forward(a), std::forward(x), std::forward(y)); } + } // namespace spblas diff --git a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp index c6b73c1..6a2e9d0 100644 --- a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp @@ -28,6 +28,43 @@ namespace spblas { + +// +// multiply_inspect with CSR/CSC and single rhs +// +template + requires((__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_contiguous_range_base && + __ranges::contiguous_range) +void multiply_inspect(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { + log_trace(""); + + if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense vectors."); + } + + if (__detail::has_matrix_opt(a)) { + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); + + auto a_handle = __mkl::get_matrix_handle(q, a); + auto a_transpose = __mkl::get_transpose(a); + + oneapi::mkl::sparse::optimize_gemv(q, a_transpose, a_handle).wait(); + } + else { + // do nothing, since it would be trashed immediately after + log_info("No work done, since no matrix_opt to store optimized results into!"); + + } + +} // multiply_inspect + + +// +// multiply with CSR/CSC and single rhs +// template requires((__detail::has_csr_base || __detail::has_csc_base) && __detail::has_contiguous_range_base && @@ -45,7 +82,6 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { tensor_scalar_t alpha = alpha_optional.value_or(1); auto a_data = __detail::get_ultimate_base(a).values().data(); - auto&& q = __mkl::get_queue(policy, a_data); auto a_handle = __mkl::get_matrix_handle(q, a); @@ -60,6 +96,23 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { } } + +// +// multiply_inspect -- CSR/CSC + single rhs vector +// with no ExecutionPolicy +// +template + requires((__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_contiguous_range_base && + __ranges::contiguous_range) +void multiply_inspect(A&& a, X&& x, Y&& y) { + multiply_inspect(mkl::par, std::forward(a), std::forward(x), + std::forward(y)); +} + +// +// multiply -- CSR/CSC + single rhs vector +// template requires((__detail::has_csr_base || __detail::has_csc_base) && __detail::has_contiguous_range_base && diff --git a/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp b/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp index 4d9bd05..4da55b5 100644 --- a/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp @@ -26,56 +26,144 @@ namespace spblas { // lower + conjtrans (D+U)^H -> conjtrans + upper (D+U)^H // -template +// +// CSR triangular solve inspection step +// +template requires __detail::has_csr_base && __detail::has_contiguous_range_base && __ranges::contiguous_range -void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, - X&& x) { +void triangular_solve_inspect(ExecutionPolicy&& policy, A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x) { log_trace(""); static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v || std::is_same_v); - auto a_base = __detail::get_ultimate_base(a); - auto b_base = __detail::get_ultimate_base(b); + if (__detail::is_conjugated(b) || __detail::is_conjugated(x)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense vectors."); + } using T = tensor_scalar_t; using I = tensor_index_t; using O = tensor_offset_t; - auto alpha_optional = __detail::get_scaling_factor(a, b); - T alpha = alpha_optional.value_or(1); + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); - sycl::queue q(sycl::cpu_selector_v); + auto a_handle = __mkl::get_matrix_handle(q, a); + auto a_op = __mkl::get_transpose(a); - oneapi::mkl::sparse::matrix_handle_t a_handle = nullptr; - oneapi::mkl::sparse::init_matrix_handle(&a_handle); + auto uplo_val = + std::is_same_v + ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; // someday apply mapping with op - oneapi::mkl::sparse::set_csr_data( - q, a_handle, __backend::shape(a_base)[0], __backend::shape(a_base)[1], - oneapi::mkl::index_base::zero, a_base.rowptr().data(), - a_base.colind().data(), a_base.values().data()) - .wait(); + auto diag_val = std::is_same_v + ? oneapi::mkl::diag::nonunit + : oneapi::mkl::diag::unit; + + oneapi::mkl::sparse::optimize_trsv(q, uplo_val, a_op, diag_val, a_handle).wait(); + + if (!__detail::has_matrix_opt(a)) { + oneapi::mkl::sparse::release_matrix_handle(q, &a_handle).wait(); + } + +} + + +// +// CSR triangular solve execution step +// +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve(ExecutionPolicy&& policy, A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x) { + log_trace(""); + static_assert(std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v || + std::is_same_v); + + if (__detail::is_conjugated(b) || __detail::is_conjugated(x)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense vectors."); + } - auto op = oneapi::mkl::transpose::nontrans; + using T = tensor_scalar_t; + using I = tensor_index_t; + using O = tensor_offset_t; + + auto alpha_optional = __detail::get_scaling_factor(a, b); + tensor_scalar_t alpha = alpha_optional.value_or(1); + + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); + + auto a_handle = __mkl::get_matrix_handle(q, a); + auto a_op = __mkl::get_transpose(a); auto uplo_val = std::is_same_v ? oneapi::mkl::uplo::upper - : oneapi::mkl::uplo::lower; // someday apply mapping with op + : oneapi::mkl::uplo::lower; auto diag_val = std::is_same_v ? oneapi::mkl::diag::nonunit : oneapi::mkl::diag::unit; - oneapi::mkl::sparse::trsv(q, uplo_val, op, diag_val, alpha, a_handle, + auto b_base = __detail::get_ultimate_base(b); + + oneapi::mkl::sparse::trsv(q, uplo_val, a_op, diag_val, alpha, a_handle, __ranges::data(b_base), __ranges::data(x)) .wait(); - oneapi::mkl::sparse::release_matrix_handle(q, &a_handle).wait(); + if (!__detail::has_matrix_opt(a)) { + oneapi::mkl::sparse::release_matrix_handle(q, &a_handle).wait(); + } } // triangular_solve + + +// +// CSR triangular_solve_inspect with no exception policy +// +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve_inspect(A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x) { + triangular_solve_inspect(mkl::par, + std::forward(a), + std::forward(uplo), + std::forward(diag), + std::forward(b), + std::forward(x) ); +} // triangular_solve_inspect + + +// +// CSR triangular_solve with no exception policy +// +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve(A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x) { + triangular_solve(mkl::par, + std::forward(a), + std::forward(uplo), + std::forward(diag), + std::forward(b), + std::forward(x) ); +} // triangular_solve + + } // namespace spblas From e4382f73c2b733828c168241d03cb918c70426d8 Mon Sep 17 00:00:00 2001 From: Spencer Patty Date: Fri, 27 Mar 2026 15:47:41 -0700 Subject: [PATCH 2/6] run pre-commit run --all-files to do clang formatting --- examples/CMakeLists.txt | 2 +- examples/spmm_csr.cpp | 1 - examples/sptrsv_csr.cpp | 1 - include/spblas/algorithms/multiply.hpp | 11 +++-- include/spblas/algorithms/multiply_impl.hpp | 11 ++--- .../spblas/algorithms/triangular_solve.hpp | 9 ++-- .../algorithms/triangular_solve_impl.hpp | 14 ++++-- .../detail/create_matrix_handle.hpp | 25 +++++----- .../spblas/vendor/onemkl_sycl/spgemm_impl.hpp | 34 +++++++------ .../spblas/vendor/onemkl_sycl/spmm_impl.hpp | 43 +++++++++------- .../spblas/vendor/onemkl_sycl/spmv_impl.hpp | 16 +++--- .../onemkl_sycl/triangular_solve_impl.hpp | 49 ++++++++----------- 12 files changed, 107 insertions(+), 109 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 9d99ea6..ead25c4 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -11,7 +11,7 @@ if (SPBLAS_CPU_BACKEND) add_example(simple_sptrsv) add_example(spmm_csc) add_example(matrix_opt_example) - if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND ) + if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND ) # needs CPU + matrix_opt + operation_info_t to run add_example(sptrsv_csr) # needs triangular_solve{_inspect} to run add_example(spmm_csr) # needs multiply{_inspect} to run diff --git a/examples/spmm_csr.cpp b/examples/spmm_csr.cpp index 697f42b..57ceeb9 100644 --- a/examples/spmm_csr.cpp +++ b/examples/spmm_csr.cpp @@ -42,7 +42,6 @@ int main(int argc, char** argv) { md::mdspan x(x_values.data(), k, n); md::mdspan y(y_values.data(), m, n); - // Y = A * X auto state = multiply_inspect(a_opt, x, y); multiply(state, a_opt, x, y); diff --git a/examples/sptrsv_csr.cpp b/examples/sptrsv_csr.cpp index 7abfe24..ab74c92 100644 --- a/examples/sptrsv_csr.cpp +++ b/examples/sptrsv_csr.cpp @@ -57,7 +57,6 @@ int main(int argc, char** argv) { triangular_solve(a_opt, spblas::upper_triangle_t{}, spblas::implicit_unit_diagonal_t{}, b_scaled, x); - fmt::print("\tExample is completed!\n"); return 0; diff --git a/include/spblas/algorithms/multiply.hpp b/include/spblas/algorithms/multiply.hpp index 1ce86d2..be4b255 100644 --- a/include/spblas/algorithms/multiply.hpp +++ b/include/spblas/algorithms/multiply.hpp @@ -18,7 +18,6 @@ void multiply(A&& a, B&& b, C&& c); template void multiply(operation_into_t& info, A&& a, B&& b, C&& c); - // SpMM variants template void multiply(A&& a, B&& b, C&& c); @@ -33,16 +32,18 @@ operation_info_t multiply_inspect(A&& a, B&& b, C&& c); template void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c); - // SpGEMM variants template -operation_info_t multiply_compute(ExecutionPolicy &&policy, A&& a, B&& b, C&& c); +operation_info_t multiply_compute(ExecutionPolicy&& policy, A&& a, B&& b, + C&& c); template -void multiply_compute(ExecutionPolicy &&policy, operation_info_t& info, A&& a, B&& b, C&& c); +void multiply_compute(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + B&& b, C&& c); template -void multiply_fill(ExecutionPolicy &&policy, operation_info_t& info, A&& a, B&& b, C&& c); +void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + B&& b, C&& c); template operation_info_t multiply_compute(A&& a, B&& b, C&& c); diff --git a/include/spblas/algorithms/multiply_impl.hpp b/include/spblas/algorithms/multiply_impl.hpp index eb02e40..856c97b 100644 --- a/include/spblas/algorithms/multiply_impl.hpp +++ b/include/spblas/algorithms/multiply_impl.hpp @@ -56,12 +56,11 @@ void multiply(A&& a, B&& b, C&& c) { // SpMV with info input template requires(__backend::lookupable && __backend::lookupable) -void multiply(operation_info_t &info, A&& a, B&& b, C&& c) { +void multiply(operation_info_t& info, A&& a, B&& b, C&& c) { log_trace(""); multiply(std::forward(a), std::forward(b), std::forward(c)); } - // C = AB // SpMM template @@ -81,7 +80,7 @@ void multiply(A&& a, B&& b, C&& c) { v = 0; }); - // traverses elements of a and performs appropriate + // traverses elements of a and performs appropriate // multiplication with B rows __backend::for_each(a, [&](auto&& e) { auto&& [idx, a_v] = e; @@ -96,12 +95,11 @@ void multiply(A&& a, B&& b, C&& c) { // SpMM with info template requires(__backend::lookupable && __backend::lookupable) -void multiply(operation_info_t &info, A&& a, B&& b, C&& c) { +void multiply(operation_info_t& info, A&& a, B&& b, C&& c) { log_trace(""); multiply(std::forward(a), std::forward(b), std::forward(c)); } - // C = AB // SpMM or SpGEMM multiply_inspect variants end up here template @@ -113,11 +111,10 @@ operation_info_t multiply_inspect(A&& a, B&& b, C&& c) { // C = AB // SpMM or SpGEMM multiply_inspect variants end up here template -void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c){ +void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) { log_trace(""); }; - // C = AB // SpGEMM compute stage with CSR output template diff --git a/include/spblas/algorithms/triangular_solve.hpp b/include/spblas/algorithms/triangular_solve.hpp index 948f0d0..88821df 100644 --- a/include/spblas/algorithms/triangular_solve.hpp +++ b/include/spblas/algorithms/triangular_solve.hpp @@ -5,14 +5,13 @@ namespace spblas { - template -void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle uplo, DiagonalStorage diag, B&& b, X&& x); - +void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x); template -operation_info_t triangular_solve_inspect(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, X&& x); - +operation_info_t triangular_solve_inspect(A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x); template void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, X&& x); diff --git a/include/spblas/algorithms/triangular_solve_impl.hpp b/include/spblas/algorithms/triangular_solve_impl.hpp index e207fee..4b4f0f3 100644 --- a/include/spblas/algorithms/triangular_solve_impl.hpp +++ b/include/spblas/algorithms/triangular_solve_impl.hpp @@ -13,7 +13,8 @@ namespace spblas { template requires(__backend::row_iterable && __backend::lookupable && __backend::lookupable) -operation_info_t triangular_solve_inspect(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { +operation_info_t triangular_solve_inspect(A&& a, Triangle t, DiagonalStorage d, + B&& b, X&& x) { log_trace(""); static_assert(std::is_same_v || std::is_same_v); @@ -27,7 +28,8 @@ operation_info_t triangular_solve_inspect(A&& a, Triangle t, DiagonalStorage d, template requires(__backend::row_iterable && __backend::lookupable && __backend::lookupable) -void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { +void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle t, + DiagonalStorage d, B&& b, X&& x) { log_trace(""); static_assert(std::is_same_v || std::is_same_v); @@ -96,10 +98,12 @@ void triangular_solve(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { template requires(__backend::row_iterable && __backend::lookupable && __backend::lookupable) -void triangular_solve(operation_info_t& info, A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { +void triangular_solve(operation_info_t& info, A&& a, Triangle t, + DiagonalStorage d, B&& b, X&& x) { log_trace(""); - triangular_solve(std::forward(a), std::forward(t), std::forward(d), std::forward(b), std::forward(x)); + triangular_solve(std::forward(a), std::forward(t), + std::forward(d), std::forward(b), + std::forward(x)); } - } // namespace spblas diff --git a/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp b/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp index 616020f..1e72b90 100644 --- a/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp +++ b/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp @@ -17,14 +17,15 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q, oneapi::mkl::sparse::matrix_handle_t handle = nullptr; oneapi::mkl::sparse::init_matrix_handle(&handle); - oneapi::mkl::sparse::set_csr_data( - q, handle, m.shape()[0], m.shape()[1], -#if defined(__INTEL_MKL__) && ( (__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || (__INTEL_MKL__ > 2025 ) ) + q, handle, m.shape()[0], m.shape()[1], +#if defined(__INTEL_MKL__) && \ + ((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \ + (__INTEL_MKL__ > 2025)) m.size(), // nnz added in 2025.3, and without deprecated -#endif - oneapi::mkl::index_base::zero, - m.rowptr().data(), m.colind().data(), m.values().data()) +#endif + oneapi::mkl::index_base::zero, m.rowptr().data(), m.colind().data(), + m.values().data()) .wait(); return handle; @@ -38,12 +39,14 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q, oneapi::mkl::sparse::init_matrix_handle(&handle); oneapi::mkl::sparse::set_csr_data( - q, handle, m.shape()[1], m.shape()[0], -#if defined(__INTEL_MKL__) && ( (__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || (__INTEL_MKL__ > 2025 ) ) + q, handle, m.shape()[1], m.shape()[0], +#if defined(__INTEL_MKL__) && \ + ((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \ + (__INTEL_MKL__ > 2025)) m.size(), // nnz added in 2025.3, and without deprecated -#endif - oneapi::mkl::index_base::zero, - m.colptr().data(), m.rowind().data(), m.values().data()) +#endif + oneapi::mkl::index_base::zero, m.colptr().data(), m.rowind().data(), + m.values().data()) .wait(); return handle; diff --git a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp index 35afc89..1b0b952 100644 --- a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp @@ -29,7 +29,6 @@ namespace spblas { - // // multiply_compute -- csr/csc * csr/csc -> csr with ExecutionPolicy // @@ -72,9 +71,11 @@ operation_info_t oneapi::mkl::sparse::set_csr_data( q, c_handle, __backend::shape(c)[0], __backend::shape(c)[1], -#if defined(__INTEL_MKL__) && ( (__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || (__INTEL_MKL__ > 2025 ) ) +#if defined(__INTEL_MKL__) && \ + ((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \ + (__INTEL_MKL__ > 2025)) __backend::size(c), // nnz added in 2025.3, and without deprecated -#endif +#endif oneapi::mkl::index_base::zero, c_rowptr, (I*) nullptr, (T*) nullptr) .wait(); @@ -133,23 +134,24 @@ template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && __detail::is_csr_view_v -void - multiply_compute(ExecutionPolicy&& policy, operation_info_t &info, A&& a, B&& b, C&& c) { +void multiply_compute(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + B&& b, C&& c) { log_trace(""); - auto tmp_info = multiply_compute(std::forward(policy), std::forward(a), std::forward(b), std::forward(c)); + auto tmp_info = multiply_compute(std::forward(policy), + std::forward(a), std::forward(b), + std::forward(c)); // fill the normal bucket of state stuf based on creating model for now. info.update_impl_(tmp_info.result_shape(), tmp_info.result_nnz()); info.state_.a_handle = tmp_info.state_.a_handle; info.state_.b_handle = tmp_info.state_.b_handle; info.state_.c_handle = tmp_info.state_.c_handle; - info.state_.descr = tmp_info.state_.descr; + info.state_.descr = tmp_info.state_.descr; info.state_.c_rowptr = tmp_info.state_.c_rowptr; - info.state_.q = tmp_info.state_.q; - -} // multiply_compute + info.state_.q = tmp_info.state_.q; +} // multiply_compute // // multiply_fill -- csr/csc * csr/csc -> csr with ExecutionPolicy @@ -190,9 +192,11 @@ void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a, auto ev_setC = oneapi::mkl::sparse::set_csr_data( q, c_handle, __backend::shape(c)[0], __backend::shape(c)[1], -#if defined(__INTEL_MKL__) && ( (__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || (__INTEL_MKL__ > 2025 ) ) +#if defined(__INTEL_MKL__) && \ + ((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \ + (__INTEL_MKL__ > 2025)) __backend::size(c), // nnz added in 2025.3, and without deprecated -#endif +#endif oneapi::mkl::index_base::zero, c_rowptr, c.colind().data(), c.values().data()); @@ -228,8 +232,9 @@ template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && __detail::is_csr_view_v -void multiply_compute(operation_info_t & info, A&& a, B&& b, C&& c) { - return multiply_compute(mkl::par, std::forward(info), std::forward(a), std::forward(b), +void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) { + return multiply_compute(mkl::par, std::forward(info), + std::forward(a), std::forward(b), std::forward(c)); } @@ -250,7 +255,6 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) { return multiply_compute(transposed(b), transposed(a), transposed(c)); } - template requires((__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && diff --git a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp index 1b6ee6a..ca106e5 100644 --- a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp @@ -37,7 +37,8 @@ template __mdspan::layout_right> && std::is_same_v::layout_type, __mdspan::layout_right>) -void multiply_inspect(ExecutionPolicy&& policy, operation_info_t& info, A&& a, X&& x, Y&& y) { +void multiply_inspect(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + X&& x, Y&& y) { log_trace(""); if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { throw std::runtime_error( @@ -53,13 +54,15 @@ void multiply_inspect(ExecutionPolicy&& policy, operation_info_t& info, A&& a, X auto x_base = __detail::get_ultimate_base(x); - oneapi::mkl::sparse::optimize_gemm(q, oneapi::mkl::layout::row_major, a_transpose, - oneapi::mkl::transpose::nontrans, a_handle, static_cast(x_base.extent(1))) + oneapi::mkl::sparse::optimize_gemm( + q, oneapi::mkl::layout::row_major, a_transpose, + oneapi::mkl::transpose::nontrans, a_handle, + static_cast(x_base.extent(1))) .wait(); - } - else { + } else { // do nothing, since it would be immediately discarded - log_info("No work done, since no matrix_opt to store optimized results into!"); + log_info( + "No work done, since no matrix_opt to store optimized results into!"); } } // multiply_inspect @@ -71,16 +74,17 @@ template __mdspan::layout_right> && std::is_same_v::layout_type, __mdspan::layout_right>) -operation_info_t multiply_inspect(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { +operation_info_t multiply_inspect(ExecutionPolicy&& policy, A&& a, X&& x, + Y&& y) { log_trace(""); operation_info_t info{}; - multiply_inspect(std::forward(policy), info, std::forward(a), std::forward(x), std::forward(y)); + multiply_inspect(std::forward(policy), info, + std::forward(a), std::forward(x), std::forward(y)); return info; } - template requires( (__detail::has_csr_base || __detail::has_csc_base) && @@ -89,7 +93,8 @@ template __mdspan::layout_right> && std::is_same_v::layout_type, __mdspan::layout_right>) -void multiply(ExecutionPolicy&& policy, operation_info_t& info, A&& a, X&& x, Y&& y) { +void multiply(ExecutionPolicy&& policy, operation_info_t& info, A&& a, X&& x, + Y&& y) { log_trace(""); if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { @@ -119,9 +124,9 @@ void multiply(ExecutionPolicy&& policy, operation_info_t& info, A&& a, X&& x, Y& } } - // -// multiply_inspect - CSR/CSC with row major dense matrix rhs without execution policy +// multiply_inspect - CSR/CSC with row major dense matrix rhs without execution +// policy // template requires( @@ -132,13 +137,14 @@ template std::is_same_v::layout_type, __mdspan::layout_right>) operation_info_t multiply_inspect(A&& a, X&& x, Y&& y) { - auto info = multiply_inspect(mkl::par, std::forward(a), - std::forward(x), std::forward(y)); + auto info = multiply_inspect(mkl::par, std::forward(a), std::forward(x), + std::forward(y)); return info; } // -// multiply_inspect - CSR/CSC with row major dense matrix rhs without execution policy +// multiply_inspect - CSR/CSC with row major dense matrix rhs without execution +// policy // template requires( @@ -150,10 +156,9 @@ template __mdspan::layout_right>) void multiply_inspect(operation_info_t& info, A&& a, X&& x, Y&& y) { multiply_inspect(mkl::par, info, std::forward(a), std::forward(x), - std::forward(y)); + std::forward(y)); } - // // multiply - CSR/CSC with row major dense matrix rhs without execution policy // @@ -171,7 +176,8 @@ void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { } // -// multiply - CSR/CSC with row major dense matrix rhs without execution policy or state object +// multiply - CSR/CSC with row major dense matrix rhs without execution policy +// or state object // template requires( @@ -187,5 +193,4 @@ void multiply(A&& a, X&& x, Y&& y) { std::forward(y)); } - } // namespace spblas diff --git a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp index 6a2e9d0..e6e1c90 100644 --- a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp @@ -28,7 +28,6 @@ namespace spblas { - // // multiply_inspect with CSR/CSC and single rhs // @@ -52,16 +51,14 @@ void multiply_inspect(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { auto a_transpose = __mkl::get_transpose(a); oneapi::mkl::sparse::optimize_gemv(q, a_transpose, a_handle).wait(); - } - else { + } else { // do nothing, since it would be trashed immediately after - log_info("No work done, since no matrix_opt to store optimized results into!"); - + log_info( + "No work done, since no matrix_opt to store optimized results into!"); } } // multiply_inspect - // // multiply with CSR/CSC and single rhs // @@ -96,9 +93,8 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { } } - // -// multiply_inspect -- CSR/CSC + single rhs vector +// multiply_inspect -- CSR/CSC + single rhs vector // with no ExecutionPolicy // template @@ -107,11 +103,11 @@ template __ranges::contiguous_range) void multiply_inspect(A&& a, X&& x, Y&& y) { multiply_inspect(mkl::par, std::forward(a), std::forward(x), - std::forward(y)); + std::forward(y)); } // -// multiply -- CSR/CSC + single rhs vector +// multiply -- CSR/CSC + single rhs vector // template requires((__detail::has_csr_base || __detail::has_csc_base) && diff --git a/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp b/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp index 4da55b5..295ff0b 100644 --- a/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp @@ -29,7 +29,8 @@ namespace spblas { // // CSR triangular solve inspection step // -template +template requires __detail::has_csr_base && __detail::has_contiguous_range_base && __ranges::contiguous_range @@ -65,19 +66,19 @@ void triangular_solve_inspect(ExecutionPolicy&& policy, A&& a, Triangle uplo, ? oneapi::mkl::diag::nonunit : oneapi::mkl::diag::unit; - oneapi::mkl::sparse::optimize_trsv(q, uplo_val, a_op, diag_val, a_handle).wait(); + oneapi::mkl::sparse::optimize_trsv(q, uplo_val, a_op, diag_val, a_handle) + .wait(); if (!__detail::has_matrix_opt(a)) { oneapi::mkl::sparse::release_matrix_handle(q, &a_handle).wait(); } - } - // // CSR triangular solve execution step // -template +template requires __detail::has_csr_base && __detail::has_contiguous_range_base && __ranges::contiguous_range @@ -107,10 +108,9 @@ void triangular_solve(ExecutionPolicy&& policy, A&& a, Triangle uplo, auto a_handle = __mkl::get_matrix_handle(q, a); auto a_op = __mkl::get_transpose(a); - auto uplo_val = - std::is_same_v - ? oneapi::mkl::uplo::upper - : oneapi::mkl::uplo::lower; + auto uplo_val = std::is_same_v + ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; auto diag_val = std::is_same_v ? oneapi::mkl::diag::nonunit @@ -128,8 +128,6 @@ void triangular_solve(ExecutionPolicy&& policy, A&& a, Triangle uplo, } // triangular_solve - - // // CSR triangular_solve_inspect with no exception policy // @@ -137,17 +135,14 @@ template requires __detail::has_csr_base && __detail::has_contiguous_range_base && __ranges::contiguous_range -void triangular_solve_inspect(A&& a, Triangle uplo, - DiagonalStorage diag, B&& b, X&& x) { - triangular_solve_inspect(mkl::par, - std::forward(a), - std::forward(uplo), - std::forward(diag), - std::forward(b), - std::forward(x) ); +void triangular_solve_inspect(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, + X&& x) { + triangular_solve_inspect(mkl::par, std::forward(a), + std::forward(uplo), + std::forward(diag), + std::forward(b), std::forward(x)); } // triangular_solve_inspect - // // CSR triangular_solve with no exception policy // @@ -155,15 +150,11 @@ template requires __detail::has_csr_base && __detail::has_contiguous_range_base && __ranges::contiguous_range -void triangular_solve(A&& a, Triangle uplo, - DiagonalStorage diag, B&& b, X&& x) { - triangular_solve(mkl::par, - std::forward(a), - std::forward(uplo), - std::forward(diag), - std::forward(b), - std::forward(x) ); +void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, + X&& x) { + triangular_solve(mkl::par, std::forward(a), std::forward(uplo), + std::forward(diag), std::forward(b), + std::forward(x)); } // triangular_solve - } // namespace spblas From cd8842215893a5721d72cf0b1f2836697d138779 Mon Sep 17 00:00:00 2001 From: Spencer Patty Date: Fri, 27 Mar 2026 16:06:16 -0700 Subject: [PATCH 3/6] Add log_trace("") to many of the new apis for future debugging Co-authored-by: Spencer Patty --- include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp | 1 + include/spblas/vendor/onemkl_sycl/spmm_impl.hpp | 4 ++++ include/spblas/vendor/onemkl_sycl/spmv_impl.hpp | 1 + 3 files changed, 6 insertions(+) diff --git a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp index 1b0b952..11e265a 100644 --- a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp @@ -233,6 +233,7 @@ template (__detail::has_csr_base || __detail::has_csc_base) && __detail::is_csr_view_v void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) { + log_trace(""); return multiply_compute(mkl::par, std::forward(info), std::forward(a), std::forward(b), std::forward(c)); diff --git a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp index ca106e5..441ce28 100644 --- a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp @@ -137,6 +137,7 @@ template std::is_same_v::layout_type, __mdspan::layout_right>) operation_info_t multiply_inspect(A&& a, X&& x, Y&& y) { + log_trace(""); auto info = multiply_inspect(mkl::par, std::forward(a), std::forward(x), std::forward(y)); return info; @@ -155,6 +156,7 @@ template std::is_same_v::layout_type, __mdspan::layout_right>) void multiply_inspect(operation_info_t& info, A&& a, X&& x, Y&& y) { + log_trace(""); multiply_inspect(mkl::par, info, std::forward(a), std::forward(x), std::forward(y)); } @@ -171,6 +173,7 @@ template std::is_same_v::layout_type, __mdspan::layout_right>) void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { + log_trace(""); multiply(mkl::par, info, std::forward(a), std::forward(x), std::forward(y)); } @@ -188,6 +191,7 @@ template std::is_same_v::layout_type, __mdspan::layout_right>) void multiply(A&& a, X&& x, Y&& y) { + log_trace(""); operation_info_t info{}; multiply(mkl::par, info, std::forward(a), std::forward(x), std::forward(y)); diff --git a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp index e6e1c90..6377397 100644 --- a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp @@ -102,6 +102,7 @@ template __detail::has_contiguous_range_base && __ranges::contiguous_range) void multiply_inspect(A&& a, X&& x, Y&& y) { + log_trace(""); multiply_inspect(mkl::par, std::forward(a), std::forward(x), std::forward(y)); } From 669b8d7685cc0f768751786d17de572968aee6e2 Mon Sep 17 00:00:00 2001 From: Spencer Patty Date: Fri, 10 Apr 2026 16:21:53 -0700 Subject: [PATCH 4/6] Update examples/CMakeLists.txt Co-authored-by: Yu-Hsiang M. Tsai <19565938+yhmtsai@users.noreply.github.com> --- examples/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index ead25c4..fb4b05b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -11,7 +11,7 @@ if (SPBLAS_CPU_BACKEND) add_example(simple_sptrsv) add_example(spmm_csc) add_example(matrix_opt_example) - if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND ) + if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND) # needs CPU + matrix_opt + operation_info_t to run add_example(sptrsv_csr) # needs triangular_solve{_inspect} to run add_example(spmm_csr) # needs multiply{_inspect} to run From 8632f30f6be3be53c73f33ac64d9eb7ab52c6afc Mon Sep 17 00:00:00 2001 From: Spencer Patty Date: Fri, 10 Apr 2026 16:22:18 -0700 Subject: [PATCH 5/6] Update examples/spmm_csr.cpp Co-authored-by: Yu-Hsiang M. Tsai <19565938+yhmtsai@users.noreply.github.com> --- examples/spmm_csr.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/spmm_csr.cpp b/examples/spmm_csr.cpp index 57ceeb9..442345a 100644 --- a/examples/spmm_csr.cpp +++ b/examples/spmm_csr.cpp @@ -18,7 +18,6 @@ int main(int argc, char** argv) { "######################"); fmt::print("\n\t### Running Advanced SpMM Example:"); fmt::print("\n\t###"); - fmt::print("\n\t### Y = alpha * A * X"); fmt::print("\n\t###"); fmt::print("\n\t### with "); fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, k, From 1515c9c796b14fd210dc95d9f37fcd5f0dae48f6 Mon Sep 17 00:00:00 2001 From: Spencer Patty Date: Fri, 10 Apr 2026 16:22:42 -0700 Subject: [PATCH 6/6] Update include/spblas/algorithms/multiply_impl.hpp Co-authored-by: Yu-Hsiang M. Tsai <19565938+yhmtsai@users.noreply.github.com> --- include/spblas/algorithms/multiply_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/spblas/algorithms/multiply_impl.hpp b/include/spblas/algorithms/multiply_impl.hpp index 856c97b..40b03f5 100644 --- a/include/spblas/algorithms/multiply_impl.hpp +++ b/include/spblas/algorithms/multiply_impl.hpp @@ -85,7 +85,7 @@ void multiply(A&& a, B&& b, C&& c) { __backend::for_each(a, [&](auto&& e) { auto&& [idx, a_v] = e; auto&& [i, k] = idx; - for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { // b_row + for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { __backend::lookup(c, i, j) += a_v * __backend::lookup(b, k, j); } });