From b6e6b36106ec8c9e709987ad705be66693781119 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 21 Aug 2025 16:14:02 +0200 Subject: [PATCH 1/3] implement rocsparse spsv --- include/spblas/vendor/rocsparse/rocsparse.hpp | 1 + include/spblas/vendor/rocsparse/trisolve.hpp | 118 ++++++++++++++++++ test/gtest/CMakeLists.txt | 2 +- test/gtest/device/triangular_solve_test.cpp | 116 +++++++++++++++++ 4 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 include/spblas/vendor/rocsparse/trisolve.hpp create mode 100644 test/gtest/device/triangular_solve_test.cpp diff --git a/include/spblas/vendor/rocsparse/rocsparse.hpp b/include/spblas/vendor/rocsparse/rocsparse.hpp index 014b2ba..66a4cd9 100644 --- a/include/spblas/vendor/rocsparse/rocsparse.hpp +++ b/include/spblas/vendor/rocsparse/rocsparse.hpp @@ -2,3 +2,4 @@ #include "multiply.hpp" #include "multiply_spgemm.hpp" +#include "trisolve.hpp" diff --git a/include/spblas/vendor/rocsparse/trisolve.hpp b/include/spblas/vendor/rocsparse/trisolve.hpp new file mode 100644 index 0000000..3120acc --- /dev/null +++ b/include/spblas/vendor/rocsparse/trisolve.hpp @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include "exception.hpp" +#include "hip_allocator.hpp" +#include "types.hpp" + +namespace spblas { +class triangular_solve_state_t { +public: + triangular_solve_state_t() + : triangular_solve_state_t(rocsparse::hip_allocator{}) {} + + triangular_solve_state_t(rocsparse::hip_allocator alloc) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { + rocsparse_handle handle; + __rocsparse::throw_if_error(rocsparse_create_handle(&handle)); + if (auto stream = alloc.stream()) { + rocsparse_set_stream(handle, stream); + } + handle_ = handle_manager(handle, [](rocsparse_handle handle) { + __rocsparse::throw_if_error(rocsparse_destroy_handle(handle)); + }); + } + + triangular_solve_state_t(rocsparse::hip_allocator alloc, + rocsparse_handle handle) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { + handle_ = handle_manager(handle, [](rocsparse_handle handle) { + // it is provided by user, we do not delete it at all. + }); + } + + ~triangular_solve_state_t() { + alloc_.deallocate(workspace_); + } + + template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range + void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, + C&& c) { + auto a_base = __detail::get_ultimate_base(a); + auto b_base = __detail::get_ultimate_base(b); + using matrix_type = decltype(a_base); + using value_type = typename matrix_type::scalar_type; + const auto diag_type = std::is_same_v + ? rocsparse_diag_type_non_unit + : rocsparse_diag_type_unit; + const auto fill_mode = std::is_same_v + ? rocsparse_fill_mode_upper + : rocsparse_fill_mode_lower; + + auto a_descr = __rocsparse::create_rocsparse_handle(a_base); + auto b_descr = __rocsparse::create_rocsparse_handle(b_base); + auto c_descr = __rocsparse::create_rocsparse_handle(c); + + __rocsparse::throw_if_error(rocsparse_spmat_set_attribute( + a_descr, rocsparse_spmat_fill_mode, &fill_mode, sizeof(fill_mode))); + __rocsparse::throw_if_error(rocsparse_spmat_set_attribute( + a_descr, rocsparse_spmat_diag_type, &diag_type, sizeof(diag_type))); + value_type alpha = 1.0; + size_t buffer_size = 0; + auto handle = this->handle_.get(); + __rocsparse::throw_if_error(rocsparse_spsv( + handle, rocsparse_operation_none, &alpha, a_descr, b_descr, c_descr, + detail::rocsparse_data_type_v, rocsparse_spsv_alg_default, + rocsparse_spsv_stage_buffer_size, &buffer_size, nullptr)); + if (buffer_size > this->buffer_size_) { + this->alloc_.deallocate(workspace_, this->buffer_size_); + this->buffer_size_ = buffer_size; + workspace_ = this->alloc_.allocate(buffer_size); + } + __rocsparse::throw_if_error(rocsparse_spsv( + handle, rocsparse_operation_none, &alpha, a_descr, b_descr, c_descr, + detail::rocsparse_data_type_v, rocsparse_spsv_alg_default, + rocsparse_spsv_stage_preprocess, &buffer_size, this->workspace_)); + __rocsparse::throw_if_error(rocsparse_spsv( + handle, rocsparse_operation_none, &alpha, a_descr, b_descr, c_descr, + detail::rocsparse_data_type_v, rocsparse_spsv_alg_default, + rocsparse_spsv_stage_compute, &buffer_size, this->workspace_)); + __rocsparse::throw_if_error(rocsparse_destroy_spmat_descr(a_descr)); + __rocsparse::throw_if_error(rocsparse_destroy_dnvec_descr(b_descr)); + __rocsparse::throw_if_error(rocsparse_destroy_dnvec_descr(c_descr)); + } + +private: + using handle_manager = + std::unique_ptr::element_type, + std::function>; + handle_manager handle_; + rocsparse::hip_allocator alloc_; + std::uint64_t buffer_size_; + char* workspace_; +}; + +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve(triangular_solve_state_t& trisolve_handle, A&& a, + Triangle uplo, DiagonalStorage diag, B&& b, C&& c) { + trisolve_handle.triangular_solve(a, uplo, diag, b, c); +} + +} // namespace spblas diff --git a/test/gtest/CMakeLists.txt b/test/gtest/CMakeLists.txt index b458a17..820dacb 100644 --- a/test/gtest/CMakeLists.txt +++ b/test/gtest/CMakeLists.txt @@ -22,7 +22,7 @@ endif() # GPU tests if (SPBLAS_GPU_BACKEND) if (ENABLE_ROCSPARSE) - set(GPUTEST_SOURCES device/spmv_test.cpp device/spgemm_test.cpp device/spgemm_reuse_test.cpp device/rocsparse/spgemm_4args_test.cpp) + set(GPUTEST_SOURCES device/spmv_test.cpp device/spgemm_test.cpp device/spgemm_reuse_test.cpp device/rocsparse/spgemm_4args_test.cpp device/triangular_solve_test.cpp) set_source_files_properties(${GPUTEST_SOURCES} PROPERTIES LANGUAGE HIP) else () set(GPUTEST_SOURCES device/spmv_test.cpp) diff --git a/test/gtest/device/triangular_solve_test.cpp b/test/gtest/device/triangular_solve_test.cpp new file mode 100644 index 0000000..2afa362 --- /dev/null +++ b/test/gtest/device/triangular_solve_test.cpp @@ -0,0 +1,116 @@ +#include + +#include "../util.hpp" +#include + +#include + +template +void reference_triangular_solve(spblas::csr_view a, Triangle t, + DiagonalStorage d, B&& b, X&& x) { + auto&& values = a.values(); + auto&& colind = a.colind(); + auto&& rowptr = a.rowptr(); + auto shape = a.shape(); + + if constexpr (std::is_same_v) { + // backward solve + for (I row = shape[0]; row-- > 0;) { + T tmp = b[row]; + T diag_val = 0.0; + for (I j = rowptr[row]; j < rowptr[row + 1]; j++) { + I col = colind[j]; + if (col > row) { + T a_val = values[j]; + T x_val = x[col]; + tmp -= a_val * x_val; // b - U*x + } else if (col == row) { + diag_val = values[j]; + } + } + if constexpr (std::is_same_v) { + x[row] = tmp / diag_val; // ( b - U*x) / d + } else { + x[row] = tmp; // ( b- U*x) / 1 + } + } + } else if constexpr (std::is_same_v) { + // Forward Solve + for (I row = 0; row < shape[0]; row++) { + T tmp = b[row]; + T diag_val = 0.0; + for (I j = rowptr[row]; j < rowptr[row + 1]; ++j) { + I col = colind[j]; + if (col < row) { + T a_val = values[j]; + T x_val = x[col]; + tmp -= a_val * x_val; // b - L*x + } else if (col == row) { + diag_val = values[j]; + } + } + if constexpr (std::is_same_v) { + x[row] = tmp / diag_val; // ( b - L*x) / d + } else { + x[row] = tmp; // ( b- L*x) / 1 + } + } + } +} + +template +void triangular_solve_test(Triangle t, DiagonalStorage d) { + for (auto&& [m, n, nnz] : util::square_dims) { + // generate problem on host + auto [values, rowptr, colind, shape, _] = + spblas::generate_csr(m, n, nnz); + spblas::csr_view a(values, rowptr, colind, shape, nnz); + std::vector x(n, 1); + std::vector b(m, 1); + T scale_factor = 1e-3f; + std::transform(values.begin(), values.end(), values.begin(), + [scale_factor](T val) { return scale_factor * val; }); + // setup the problem on device + thrust::device_vector d_b(b); + thrust::device_vector d_x(x); + thrust::device_vector d_values(values); + thrust::device_vector d_rowptr(rowptr); + thrust::device_vector d_colind(colind); + spblas::csr_view d_a(d_values.data().get(), d_rowptr.data().get(), + d_colind.data().get(), shape, nnz); + std::span b_span(d_b.data().get(), m); + std::span x_span(d_x.data().get(), n); + + spblas::triangular_solve_state_t state; + spblas::triangular_solve(state, d_a, Triangle{}, DiagonalStorage{}, b_span, + x_span); + thrust::copy(d_x.begin(), d_x.end(), x.begin()); + + std::vector x_ref(m, 0); + reference_triangular_solve(a, Triangle{}, DiagonalStorage{}, b, x_ref); + + for (std::size_t i = 0; i < x.size(); i++) { + EXPECT_EQ_(x[i], x_ref[i]); + } + } +} + +TEST(CsrView, TriangularSolveLowerImplicit) { + using T = float; + using I = spblas::index_t; + + triangular_solve_test(spblas::lower_triangle_t{}, + spblas::implicit_unit_diagonal_t{}); +} + +TEST(CsrView, TriangularSolveUpperImplicit) { + using T = float; + using I = spblas::index_t; + + triangular_solve_test(spblas::upper_triangle_t{}, + spblas::implicit_unit_diagonal_t{}); +} From e7b375fc70df99e61f8e023b11a061894e179c67 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 10 Apr 2026 18:26:32 +0200 Subject: [PATCH 2/3] use the unified state/info for the trisolve --- .../vendor/rocsparse/operation_state_t.hpp | 1 + include/spblas/vendor/rocsparse/trisolve.hpp | 46 +++++++++---------- .../vendor/rocsparse/unified_trisolve.hpp | 24 ++++++++++ test/gtest/device/triangular_solve_test.cpp | 2 +- 4 files changed, 48 insertions(+), 25 deletions(-) create mode 100644 include/spblas/vendor/rocsparse/unified_trisolve.hpp diff --git a/include/spblas/vendor/rocsparse/operation_state_t.hpp b/include/spblas/vendor/rocsparse/operation_state_t.hpp index 5fd6f27..99a9270 100644 --- a/include/spblas/vendor/rocsparse/operation_state_t.hpp +++ b/include/spblas/vendor/rocsparse/operation_state_t.hpp @@ -9,6 +9,7 @@ namespace __rocsparse { class operation_state_t { public: operation_state_t() = default; + operation_state_t(std::unique_ptr&& state) : state_(std::move(state)) {} diff --git a/include/spblas/vendor/rocsparse/trisolve.hpp b/include/spblas/vendor/rocsparse/trisolve.hpp index 3120acc..b0a28ae 100644 --- a/include/spblas/vendor/rocsparse/trisolve.hpp +++ b/include/spblas/vendor/rocsparse/trisolve.hpp @@ -12,35 +12,21 @@ #include #include +#include "detail/abstract_operation_state.hpp" +#include "detail/rocsparse_tensors.hpp" #include "exception.hpp" #include "hip_allocator.hpp" #include "types.hpp" namespace spblas { -class triangular_solve_state_t { +class triangular_solve_state_t + : public __rocsparse::abstract_operation_state_t { public: triangular_solve_state_t() : triangular_solve_state_t(rocsparse::hip_allocator{}) {} triangular_solve_state_t(rocsparse::hip_allocator alloc) - : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { - rocsparse_handle handle; - __rocsparse::throw_if_error(rocsparse_create_handle(&handle)); - if (auto stream = alloc.stream()) { - rocsparse_set_stream(handle, stream); - } - handle_ = handle_manager(handle, [](rocsparse_handle handle) { - __rocsparse::throw_if_error(rocsparse_destroy_handle(handle)); - }); - } - - triangular_solve_state_t(rocsparse::hip_allocator alloc, - rocsparse_handle handle) - : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { - handle_ = handle_manager(handle, [](rocsparse_handle handle) { - // it is provided by user, we do not delete it at all. - }); - } + : alloc_(alloc), buffer_size_(0), workspace_(nullptr) {} ~triangular_solve_state_t() { alloc_.deallocate(workspace_); @@ -73,7 +59,7 @@ class triangular_solve_state_t { a_descr, rocsparse_spmat_diag_type, &diag_type, sizeof(diag_type))); value_type alpha = 1.0; size_t buffer_size = 0; - auto handle = this->handle_.get(); + auto handle = this->handle(); __rocsparse::throw_if_error(rocsparse_spsv( handle, rocsparse_operation_none, &alpha, a_descr, b_descr, c_descr, detail::rocsparse_data_type_v, rocsparse_spsv_alg_default, @@ -97,10 +83,6 @@ class triangular_solve_state_t { } private: - using handle_manager = - std::unique_ptr::element_type, - std::function>; - handle_manager handle_; rocsparse::hip_allocator alloc_; std::uint64_t buffer_size_; char* workspace_; @@ -115,4 +97,20 @@ void triangular_solve(triangular_solve_state_t& trisolve_handle, A&& a, trisolve_handle.triangular_solve(a, uplo, diag, b, c); } +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve(operation_info_t& info, A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, C&& c) { + // Get or create state + auto state = info.state_.get_state(); + if (!state) { + info.state_ = __rocsparse::operation_state_t( + std::make_unique()); + state = info.state_.get_state(); + } + state->triangular_solve(a, uplo, diag, b, c); +} + } // namespace spblas diff --git a/include/spblas/vendor/rocsparse/unified_trisolve.hpp b/include/spblas/vendor/rocsparse/unified_trisolve.hpp new file mode 100644 index 0000000..7ee3b38 --- /dev/null +++ b/include/spblas/vendor/rocsparse/unified_trisolve.hpp @@ -0,0 +1,24 @@ +#pragma once +#include "operation_state_t.hpp" +#include "trisolve.hpp" +#include + +namespace spblas { + +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve(operation_info_t& info, A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, C&& c) { + // Get or create state + auto state = info.state_.get_state(); + if (!state) { + info.state_ = __rocsparse::operation_state_t( + std::make_unique()); + state = info.state_.get_state(); + } + state->triangular_solve(a, uplo, diag, b, c); +} + +} // namespace spblas diff --git a/test/gtest/device/triangular_solve_test.cpp b/test/gtest/device/triangular_solve_test.cpp index 2afa362..baefa6a 100644 --- a/test/gtest/device/triangular_solve_test.cpp +++ b/test/gtest/device/triangular_solve_test.cpp @@ -85,7 +85,7 @@ void triangular_solve_test(Triangle t, DiagonalStorage d) { std::span b_span(d_b.data().get(), m); std::span x_span(d_x.data().get(), n); - spblas::triangular_solve_state_t state; + spblas::operation_info_t state; spblas::triangular_solve(state, d_a, Triangle{}, DiagonalStorage{}, b_span, x_span); thrust::copy(d_x.begin(), d_x.end(), x.begin()); From 88d5187776bf26693699bfd163bd2915fbe22d54 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Mon, 13 Apr 2026 14:47:15 +0200 Subject: [PATCH 3/3] use rocm 6.3.2 --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3d17a73..4f36e4c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -122,8 +122,8 @@ jobs: - name: CMake shell: bash -l {0} run: | - module load cmake - cmake -B build -DENABLE_ROCSPARSE=ON -DCMAKE_PREFIX_PATH=/opt/rocm + module load cmake rocm/6.3.2 + cmake -B build -DENABLE_ROCSPARSE=ON - name: Build shell: bash -l {0} run: |