Skip to content

Commit eeb3a67

Browse files
committed
add vector descriptor helper
1 parent 4595c8d commit eeb3a67

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

include/spblas/vendor/rocsparse/descriptor.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,17 @@ rocsparse_spmat_descr create_matrix_descr(mat&& a) {
2929
return descr;
3030
}
3131

32+
// create dense vector from mdspan
33+
template <vector vec>
34+
requires __ranges::contiguous_range<vec>
35+
rocsparse_dnvec_descr create_vector_descr(vec&& v) {
36+
using vector_type = std::remove_cvref_t<vec>;
37+
rocsparse_dnvec_descr descr;
38+
throw_if_error(rocsparse_create_dnvec_descr(
39+
&descr, v.size(), v.data(),
40+
to_rocsparse_datatype<typename vector_type::value_type>()));
41+
return descr;
42+
}
43+
3244
} // namespace __rocsparse
3345
} // namespace spblas

include/spblas/vendor/rocsparse/multiply.hpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,8 @@ class spmv_state_t {
6161
auto handle = this->handle_.get();
6262

6363
rocsparse_spmat_descr mat = __rocsparse::create_matrix_descr(a_base);
64-
rocsparse_dnvec_descr vecb;
65-
rocsparse_dnvec_descr vecc;
66-
__rocsparse::throw_if_error(rocsparse_create_dnvec_descr(
67-
&vecb, b_base.size(), b_base.data(),
68-
to_rocsparse_datatype<typename input_type::value_type>()));
69-
__rocsparse::throw_if_error(rocsparse_create_dnvec_descr(
70-
&vecc, c.size(), c.data(),
71-
to_rocsparse_datatype<typename output_type::value_type>()));
64+
rocsparse_dnvec_descr vecb = __rocsparse::create_vector_descr(b_base);
65+
rocsparse_dnvec_descr vecc = __rocsparse::create_vector_descr(c);
7266
value_type alpha_val = alpha;
7367
value_type beta = 0.0;
7468
long unsigned int buffer_size = 0;

0 commit comments

Comments
 (0)