Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,25 @@ endif()
# Download dependencies
include(FetchContent)

set(SPBLAS_CPU_BACKEND OFF)
set(SPBLAS_GPU_BACKEND OFF)

if (ENABLE_ONEMKL_SYCL)
set(SPBLAS_CPU_BACKEND ON)
set(SPBLAS_GPU_BACKEND ON)
find_package(MKL REQUIRED)
target_link_libraries(spblas INTERFACE MKL::MKL_SYCL) # SYCL APIs
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPBLAS_ENABLE_ONEMKL_SYCL")

FetchContent_Declare(
sycl_thrust
GIT_REPOSITORY https://github.com/SparseBLAS/sycl-thrust.git
GIT_TAG main)
FetchContent_MakeAvailable(sycl_thrust)
endif()

if (ENABLE_ARMPL)
set(SPBLAS_CPU_BACKEND ON)
if (NOT DEFINED ENV{ARMPL_DIR})
message(FATAL_ERROR "Environment variable ARMPL_DIR must be set when the ArmPL is enabled.")
endif()
Expand All @@ -36,6 +48,7 @@ if (ENABLE_ARMPL)
endif()

if (ENABLE_AOCLSPARSE)
set(SPBLAS_CPU_BACKEND ON)
if (NOT DEFINED ENV{AOCLSPARSE_DIR})
message(FATAL_ERROR "Environment variable AOCLSPARSE_DIR must be set when the AOCLSPARSE is enabled.")
endif()
Expand Down Expand Up @@ -81,6 +94,15 @@ if (ENABLE_CUSPARSE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPBLAS_ENABLE_CUSPARSE")
endif()

# If no vendor backend is enabled, enable CPU backend for reference implementation
if (NOT ENABLE_ONEMKL_SYCL AND
NOT ENABLE_ARMPL AND
NOT ENABLE_AOCLSPARSE AND
NOT ENABLE_ROCSPARSE AND
NOT ENABLE_CUSPARSE)
set(SPBLAS_CPU_BACKEND ON)
endif()

# turn on/off debug logging
if (LOG_LEVEL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLOG_LEVEL=${LOG_LEVEL}") # SPBLAS_DEBUG | SPBLAS_WARNING | SPBLAS_TRACE | SPBLAS_INFO
Expand Down
21 changes: 12 additions & 9 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@ function(add_example example_name)
target_link_libraries(${example_name} spblas fmt)
endfunction()

if (NOT SPBLAS_GPU_BACKEND)
# CPU examples
if (SPBLAS_CPU_BACKEND)
add_example(simple_spmv)
add_example(simple_spmm)
add_example(simple_spgemm)
add_example(simple_sptrsv)
add_example(matrix_opt_example)
add_example(spmm_csc)
else()
add_subdirectory(device)
add_example(matrix_opt_example)
endif()

if (ENABLE_ROCSPARSE)
add_subdirectory(rocsparse)
endif()
if (ENABLE_CUSPARSE)
add_subdirectory(cusparse)
# GPU examples
if (SPBLAS_GPU_BACKEND)
add_subdirectory(device)
if (ENABLE_CUSPARSE)
add_subdirectory(cusparse)
endif()
if (ENABLE_ROCSPARSE)
add_subdirectory(rocsparse)
endif()
endif()
4 changes: 2 additions & 2 deletions examples/cusparse/cusparse_simple_spmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ int main(int argc, char** argv) {
std::span<value_t> y_span(d_y, m);

// y = A * x
spblas::spmv_state_t state;
spblas::multiply(state, a, x_span, y_span);
spblas::operation_info_t info;
spblas::multiply(info, a, x_span, y_span);
Comment on lines +79 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this change? My understanding was that we will have the following elements in

  • execution policy
  • spmv state object
  • sparse matrix object
  • x and y vector objects

what is operation_info_t ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operation_info_t is the state. I believe in the current proposal we refer to it as operation_state_t. It just hasn't been renamed here (I can do that in a separate PR). (Perhaps confusingly, there is a non-user-visible operation_state_t class, which is an implementation detail. That's how vendor backends store their data inside the state.)

I think one semi-open question is whether to have different state objects for each type of operation. Personally I think it creates more complexity for the user, as they have to juggle different kinds of state objects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it should be _state_t as the paper described.
I still believe that the first step should be individual state for each component.
They can be much easier merged later rather than create a big object then later split up.


CUDA_CHECK(
cudaMemcpy(y.data(), d_y, y.size() * sizeof(value_t), cudaMemcpyDefault));
Expand Down
5 changes: 4 additions & 1 deletion examples/device/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ function(add_device_example example_name)
add_executable(${example_name} ${example_name}.cpp)
if (ENABLE_ROCSPARSE)
set_source_files_properties(${example_name}.cpp PROPERTIES LANGUAGE HIP)
target_link_libraries(${example_name} roc::rocthrust)
elseif (ENABLE_CUSPARSE)
target_link_libraries(${example_name} Thrust)
elseif (ENABLE_ONEMKL_SYCL)
target_link_libraries(${example_name} sycl_thrust)
else()
message(FATAL_ERROR "Device backend not found.")
endif()
target_link_libraries(${example_name} spblas fmt)
endfunction()

add_device_example(simple_spmv)
add_device_example(device_spmv)
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ int main(int argc, char** argv) {
std::span<value_t> y_span(d_y.data().get(), m);

// y = A * x
spblas::spmv_state_t state;
spblas::multiply(state, a, x_span, y_span);
spblas::multiply(a, x_span, y_span);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where did the state go ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be resolved now---both rocSPARSE and cuSPARSE now take an optional operation_info_t object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so no state now in this simple example ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, although you can create a state and pass it in if you want:

spblas::operation_info_t state;
spblas::multiply(state, a, x_span, y_span);


thrust::copy(d_y.begin(), d_y.end(), y.begin());

Expand Down
4 changes: 2 additions & 2 deletions examples/rocsparse/rocsparse_simple_spmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ int main(int argc, char** argv) {
std::span<value_t> y_span(d_y, m);

// y = A * x
spblas::spmv_state_t state;
spblas::multiply(state, a, x_span, y_span);
spblas::operation_info_t info;
spblas::multiply(info, a, x_span, y_span);

HIP_CHECK(
hipMemcpy(y.data(), d_y, y.size() * sizeof(value_t), hipMemcpyDefault));
Expand Down
25 changes: 25 additions & 0 deletions include/spblas/detail/operation_info_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
#include <spblas/vendor/aoclsparse/operation_state_t.hpp>
#endif

#ifdef SPBLAS_ENABLE_CUSPARSE
#include <spblas/vendor/cusparse/operation_state_t.hpp>
#endif

#ifdef SPBLAS_ENABLE_ROCSPARSE
#include <spblas/vendor/rocsparse/operation_state_t.hpp>
#endif

namespace spblas {

class operation_info_t {
Expand Down Expand Up @@ -53,6 +61,13 @@ class operation_info_t {
state_(std::move(state)) {}
#endif

#ifdef SPBLAS_ENABLE_CUSPARSE
operation_info_t(index<> result_shape, offset_t result_nnz,
__cusparse::operation_state_t&& state)
: result_shape_(result_shape), result_nnz_(result_nnz),
state_(std::move(state)) {}
#endif

void update_impl_(index<> result_shape, offset_t result_nnz) {
result_shape_ = result_shape;
result_nnz_ = result_nnz;
Expand All @@ -76,6 +91,16 @@ class operation_info_t {
public:
__aoclsparse::operation_state_t state_;
#endif

#ifdef SPBLAS_ENABLE_CUSPARSE
public:
__cusparse::operation_state_t state_;
#endif

#ifdef SPBLAS_ENABLE_ROCSPARSE
public:
__rocsparse::operation_state_t state_;
#endif
};

} // namespace spblas
34 changes: 34 additions & 0 deletions include/spblas/vendor/cusparse/detail/abstract_operation_state.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <cusparse.h>
#include <memory>

namespace spblas {
namespace __cusparse {

class abstract_operation_state_t {
public:
// Common state that all operations need
cusparseHandle_t handle() const {
return handle_;
}

// Make std::default_delete a friend so unique_ptr can delete us
friend struct std::default_delete<abstract_operation_state_t>;

protected:
abstract_operation_state_t() {
cusparseCreate(&handle_);
}

virtual ~abstract_operation_state_t() {
if (handle_) {
cusparseDestroy(handle_);
}
}

cusparseHandle_t handle_;
};

} // namespace __cusparse
} // namespace spblas
41 changes: 41 additions & 0 deletions include/spblas/vendor/cusparse/detail/cusparse_tensors.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

#include <cusparse.h>

#include <spblas/detail/types.hpp>
#include <spblas/detail/view_inspectors.hpp>
#include <spblas/vendor/cusparse/exception.hpp>
#include <spblas/vendor/cusparse/types.hpp>

namespace spblas {

namespace __cusparse {

template <matrix M>
requires __detail::is_csr_view_v<M>
cusparseSpMatDescr_t create_cusparse_handle(M&& m) {
cusparseSpMatDescr_t mat_descr;
__cusparse::throw_if_error(cusparseCreateCsr(
&mat_descr, __backend::shape(m)[0], __backend::shape(m)[1],
m.values().size(), m.rowptr().data(), m.colind().data(),
m.values().data(), detail::cusparse_index_type_v<tensor_offset_t<M>>,
detail::cusparse_index_type_v<tensor_index_t<M>>,
CUSPARSE_INDEX_BASE_ZERO, detail::cuda_data_type_v<tensor_scalar_t<M>>));

return mat_descr;
}

template <vector V>
requires __ranges::contiguous_range<V>
cusparseDnVecDescr_t create_cusparse_handle(V&& v) {
cusparseDnVecDescr_t vec_descr;
__cusparse::throw_if_error(
cusparseCreateDnVec(&vec_descr, __backend::shape(v), __ranges::data(v),
detail::cuda_data_type_v<tensor_scalar_t<V>>));

return vec_descr;
}

} // namespace __cusparse

} // namespace spblas
32 changes: 32 additions & 0 deletions include/spblas/vendor/cusparse/detail/get_transpose.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <cusparse.h>
#include <spblas/detail/view_inspectors.hpp>

namespace spblas {
namespace __cusparse {

//
// Takes in a CSR or CSR_transpose (aka CSC) or CSC or CSC_transpose
// and returns the cusparseOperation_t value associated with it being
// represented in the CSR format
//
// CSR = CSR + NON_TRANSPOSE
// CSR_transpose = CSR + TRANSPOSE
// CSC = CSR + TRANSPOSE
// CSC_transpose = CSR + NON_TRANSPOSE
//
template <matrix M>
cusparseOperation_t get_transpose(M&& m) {
static_assert(__detail::has_csr_base<M> || __detail::has_csc_base<M>);
if constexpr (__detail::has_base<M>) {
return get_transpose(m.base());
} else if constexpr (__detail::is_csr_view_v<M>) {
return CUSPARSE_OPERATION_NON_TRANSPOSE;
} else if constexpr (__detail::is_csc_view_v<M>) {
return CUSPARSE_OPERATION_TRANSPOSE;
}
}

} // namespace __cusparse
} // namespace spblas
55 changes: 55 additions & 0 deletions include/spblas/vendor/cusparse/detail/spmv_state_t.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#pragma once

#include <cusparse.h>
#include <memory>

#include "abstract_operation_state.hpp"

namespace spblas {
namespace __cusparse {

class spmv_state_t : public abstract_operation_state_t {
public:
spmv_state_t() = default;
~spmv_state_t() {
if (a_descr_) {
cusparseDestroySpMat(a_descr_);
}
if (b_descr_) {
cusparseDestroyDnVec(b_descr_);
}
if (c_descr_) {
cusparseDestroyDnVec(c_descr_);
}
}

// Accessors for the descriptors
cusparseSpMatDescr_t a_descriptor() const {
return a_descr_;
}
cusparseDnVecDescr_t b_descriptor() const {
return b_descr_;
}
cusparseDnVecDescr_t c_descriptor() const {
return c_descr_;
}

// Setters for the descriptors
void set_a_descriptor(cusparseSpMatDescr_t descr) {
a_descr_ = descr;
}
void set_b_descriptor(cusparseDnVecDescr_t descr) {
b_descr_ = descr;
}
void set_c_descriptor(cusparseDnVecDescr_t descr) {
c_descr_ = descr;
}

private:
cusparseSpMatDescr_t a_descr_ = nullptr;
cusparseDnVecDescr_t b_descr_ = nullptr;
cusparseDnVecDescr_t c_descr_ = nullptr;
};

} // namespace __cusparse
} // namespace spblas
Loading