Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ endif()
# Download dependencies
include(FetchContent)

if (ENABLE_ONEMKL_IESPBLAS)
find_package(MKL REQUIRED)
target_link_libraries(spblas INTERFACE MKL::MKL) # C APIs
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPBLAS_ENABLE_ONEMKL_IESPBLAS")
endif()


if (ENABLE_ONEMKL_SYCL)
find_package(MKL REQUIRED)
target_link_libraries(spblas INTERFACE MKL::MKL_SYCL) # SYCL APIs
Expand Down
4 changes: 4 additions & 0 deletions include/spblas/backend/backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
#include <spblas/backend/generate.hpp>
#include <spblas/backend/view_customizations.hpp>

#ifdef SPBLAS_ENABLE_ONEMKL_IESPBLAS
#include <spblas/vendor/onemkl_iespblas/onemkl_iespblas.hpp>
#endif

#ifdef SPBLAS_ENABLE_ONEMKL_SYCL
#include <spblas/vendor/onemkl_sycl/onemkl_sycl.hpp>
#endif
Expand Down
16 changes: 16 additions & 0 deletions include/spblas/detail/operation_info_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#include <spblas/detail/index.hpp>
#include <spblas/detail/types.hpp>

#ifdef SPBLAS_ENABLE_ONEMKL_IESPBLAS
#include <spblas/vendor/onemkl_iespblas/operation_state_t.hpp>
#endif

#ifdef SPBLAS_ENABLE_ONEMKL_SYCL
#include <spblas/vendor/onemkl_sycl/operation_state_t.hpp>
#endif
Expand All @@ -28,6 +32,13 @@ class operation_info_t {
operation_info_t(index<> result_shape, offset_t result_nnz)
: result_shape_(result_shape), result_nnz_(result_nnz) {}

#ifdef SPBLAS_ENABLE_ONEMKL_IESPBLAS
operation_info_t(index<> result_shape, offset_t result_nnz,
__mkl_iespblas::operation_state_t&& state)
: result_shape_(result_shape), result_nnz_(result_nnz),
state_(std::move(state)) {}
#endif

#ifdef SPBLAS_ENABLE_ONEMKL_SYCL
operation_info_t(index<> result_shape, offset_t result_nnz,
__mkl::operation_state_t&& state)
Expand All @@ -51,6 +62,11 @@ class operation_info_t {
index<> result_shape_;
offset_t result_nnz_;

#ifdef SPBLAS_ENABLE_ONEMKL_IESPBLAS
public:
__mkl_iespblas::operation_state_t state_;
#endif

#ifdef SPBLAS_ENABLE_ONEMKL_SYCL
public:
__mkl::operation_state_t state_;
Expand Down
4 changes: 4 additions & 0 deletions include/spblas/detail/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#include <cstddef>
#include <type_traits>

#ifdef SPBLAS_ENABLE_ONEMKL_IESPBLAS
#include <spblas/vendor/onemkl_iespblas/types.hpp>
#endif

#ifdef SPBLAS_ENABLE_ONEMKL_SYCL
#include <spblas/vendor/onemkl_sycl/types.hpp>
#endif
Expand Down
3 changes: 2 additions & 1 deletion include/spblas/spblas.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#if defined(SPBLAS_ENABLE_ONEMKL_SYCL) || defined(SPBLAS_ENABLE_ARMPL)
#if defined(SPBLAS_ENABLE_ONEMKL_SYCL) || defined(SPBLAS_ENABLE_ONEMKL_IESPBLAS) || \
defined(SPBLAS_ENABLE_ARMPL)
#define SPBLAS_VENDOR_BACKEND true
#endif

Expand Down
6 changes: 6 additions & 0 deletions include/spblas/vendor/onemkl_iespblas/algorithms.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

#include "spmv_impl.hpp"
#include "spmm_impl.hpp"
#include "spgemm_impl.hpp"
#include "triangular_solve_impl.hpp"
136 changes: 136 additions & 0 deletions include/spblas/vendor/onemkl_iespblas/mkl_wrappers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#pragma once

#include <cstdint>
#include "mkl.h"

#include <spblas/detail/log.hpp>

//
// Add several templated functions for mapping from data_type to C style IE Sparse BLAS APIs
//


namespace spblas {
namespace __mkl_iespblas {

//
// mkl_sparse_create_csr
//
template<class T>
inline sparse_status_t mkl_sparse_create_csr( sparse_matrix_t *csrA, const sparse_index_base_t indexing,
const MKL_INT nrows, const MKL_INT ncols, MKL_INT *rowptr_st,
MKL_INT *rowptr_en, MKL_INT *colind, T *values)
{
log_warning("mkl_sparse_create_csr data types are not supported");
return SPARSE_STATUS_NOT_SUPPORTED;
}

template<>
inline sparse_status_t mkl_sparse_create_csr<float>( sparse_matrix_t *csrA, const sparse_index_base_t indexing,
const MKL_INT nrows, const MKL_INT ncols, MKL_INT *rowptr_st,
MKL_INT *rowptr_en, MKL_INT *colind, float *values)
{
return mkl_sparse_s_create_csr(csrA, indexing, nrows, ncols, rowptr_st, rowptr_en, colind, values);
}

template<>
inline sparse_status_t mkl_sparse_create_csr<double>( sparse_matrix_t *csrA, const sparse_index_base_t indexing,
const MKL_INT nrows, const MKL_INT ncols, MKL_INT *rowptr_st,
MKL_INT *rowptr_en, MKL_INT *colind, double *values)
{
return mkl_sparse_d_create_csr(csrA, indexing, nrows, ncols, rowptr_st, rowptr_en, colind, values);
}


//
// mkl_sparse_export_csr
//

template<class T>
inline sparse_status_t mkl_sparse_export_csr( const sparse_matrix_t csrA, sparse_index_base_t *indexing,
MKL_INT *nrows, MKL_INT *ncols, MKL_INT **rowptr_st,
MKL_INT **rowptr_en, MKL_INT **colind, T **values)
{
log_warning("mkl_sparse_export_csr data types are not supported");
return SPARSE_STATUS_NOT_SUPPORTED;
}

template<>
inline sparse_status_t mkl_sparse_export_csr<float>( const sparse_matrix_t csrA, sparse_index_base_t *indexing,
MKL_INT *nrows, MKL_INT *ncols, MKL_INT **rowptr_st,
MKL_INT **rowptr_en, MKL_INT **colind, float **values)
{
return mkl_sparse_s_export_csr(csrA, indexing, nrows, ncols, rowptr_st, rowptr_en, colind, values);
}

template<>
inline sparse_status_t mkl_sparse_export_csr<double>( const sparse_matrix_t csrA, sparse_index_base_t *indexing,
MKL_INT *nrows, MKL_INT *ncols, MKL_INT **rowptr_st,
MKL_INT **rowptr_en, MKL_INT **colind, double **values)
{
return mkl_sparse_d_export_csr(csrA, indexing, nrows, ncols, rowptr_st, rowptr_en, colind, values);
}


//
// mkl_sparse_mv
//
template<class T>
inline sparse_status_t mkl_sparse_mv( const sparse_operation_t op, const T alpha, const sparse_matrix_t csrA,
const struct matrix_descr descr, const T* x, const T beta, T* y)
{
log_warning("mkl_sparse_mv data types are not supported");
return SPARSE_STATUS_NOT_SUPPORTED;
}

template<>
inline sparse_status_t mkl_sparse_mv<float>( const sparse_operation_t op, const float alpha, const sparse_matrix_t csrA,
const struct matrix_descr descr, const float* x, const float beta, float* y)
{
return mkl_sparse_s_mv(op, alpha, csrA, descr, x, beta, y);
}

template<>
inline sparse_status_t mkl_sparse_mv<double>( const sparse_operation_t op, const double alpha, const sparse_matrix_t csrA,
const struct matrix_descr descr, const double* x, const double beta, double* y)
{
return mkl_sparse_d_mv(op, alpha, csrA, descr, x, beta, y);
}


//
// mkl_sparse_mm
//
template<class T>
inline sparse_status_t mkl_sparse_mm( const sparse_operation_t op, const T alpha, const sparse_matrix_t csrA,
const struct matrix_descr descr, const sparse_layout_t layout,
const T* x, const index_t nrhs, const index_t ldx, const T beta, T* y, const index_t ldy)
{
log_warning("mkl_sparse_mm data types are not supported");
return SPARSE_STATUS_NOT_SUPPORTED;
}

template<>
inline sparse_status_t mkl_sparse_mm<float>( const sparse_operation_t op, const float alpha, const sparse_matrix_t csrA,
const struct matrix_descr descr, const sparse_layout_t layout,
const float* x, const index_t nrhs, const index_t ldx, const float beta, float* y, const index_t ldy)
{
return mkl_sparse_s_mm(op, alpha, csrA, descr, layout, x, nrhs, ldx, beta, y, ldy);
}

template<>
inline sparse_status_t mkl_sparse_mm<double>( const sparse_operation_t op, const double alpha, const sparse_matrix_t csrA,
const struct matrix_descr descr, const sparse_layout_t layout,
const double* x, const index_t nrhs, const index_t ldx, const double beta, double* y, const index_t ldy)
{
return mkl_sparse_d_mm(op, alpha, csrA, descr, layout, x, nrhs, ldx, beta, y, ldy);
}



} // namespace __mkl_iespblas
} // namespace spblas




5 changes: 5 additions & 0 deletions include/spblas/vendor/onemkl_iespblas/onemkl_iespblas.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include "algorithms.hpp"
#include <cstdint>

53 changes: 53 additions & 0 deletions include/spblas/vendor/onemkl_iespblas/operation_state_t.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include "mkl.h"

namespace spblas {

namespace __mkl_iespblas{

struct operation_state_t {
sparse_matrix_t a_handle = nullptr;
sparse_matrix_t b_handle = nullptr;
sparse_matrix_t c_handle = nullptr;

operation_state_t() = default;

operation_state_t(sparse_matrix_t a_handle,
sparse_matrix_t b_handle,
sparse_matrix_t c_handle)
: a_handle(a_handle), b_handle(b_handle), c_handle(c_handle) {}

operation_state_t(operation_state_t&& other) {
*this = std::move(other);
}

operation_state_t& operator=(operation_state_t&& other) {
a_handle = other.a_handle;
b_handle = other.b_handle;
c_handle = other.c_handle;

other.a_handle = other.b_handle = other.c_handle = nullptr;

return *this;
}

operation_state_t(const operation_state_t& other) = delete;

~operation_state_t() {
release_matrix_handle(a_handle);
release_matrix_handle(b_handle);
release_matrix_handle(c_handle);
}

private:
void release_matrix_handle(sparse_matrix_t handle) {
if (handle != nullptr) {
mkl_sparse_destroy(handle);
}
}
};

} // namespace __mkl_iespblas

} // namespace spblas
Loading
Loading