Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/spblas/algorithms/multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c);
template <matrix A, matrix B, matrix C>
void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c);

template <matrix A, matrix B, matrix C>
void multiply_fill_update(operation_info_t& info, A&& a, B&& b, C&& c);

} // namespace spblas
47 changes: 47 additions & 0 deletions include/spblas/algorithms/multiply_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <spblas/backend/csr_builder.hpp>
#include <spblas/backend/spa_accumulator.hpp>
#include <spblas/detail/operation_info_t.hpp>
#include <spblas/detail/view_inspectors.hpp>

#include <algorithm>

Expand Down Expand Up @@ -147,6 +148,45 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) {
return operation_info_t{__backend::shape(c), nnz};
}

// C = AB
// SpGEMM (Gustavson's Algorithm) on existing C values
template <matrix A, matrix B, matrix C>
requires(__backend::row_iterable<A> && __backend::row_iterable<B> &&
__detail::is_csr_view_v<C>)
void multiply_update(A&& a, B&& b, C&& c) {
log_trace("");
if (__backend::shape(a)[0] != __backend::shape(c)[0] ||
__backend::shape(b)[1] != __backend::shape(c)[1] ||
__backend::shape(a)[1] != __backend::shape(b)[0]) {
throw std::invalid_argument(
"multiply: matrix dimensions are incompatible.");
}

using T = tensor_scalar_t<C>;
using I = tensor_index_t<C>;
using O = tensor_offset_t<C>;

auto c_base = __detail::get_ultimate_base(c);
const auto c_rowptr = c_base.rowptr();
const auto c_colind = c_base.colind();
const auto c_values = c_base.values();

for (auto&& [i, a_row] : __backend::rows(a)) {
std::unordered_map<I, O> c_columns;
const auto c_begin = c_rowptr[i];
const auto c_end = c_rowptr[i + 1];
for (auto c_nz : __ranges::views::iota(c_begin, c_end)) {
c_columns.emplace(c_colind[c_nz], c_nz);
c_values[c_nz] = 0;
}
for (auto&& [k, a_v] : a_row) {
for (auto&& [j, b_v] : __backend::lookup_row(b, k)) {
c_values[c_columns[j]] += a_v * b_v;
}
}
}
}

template <matrix A, matrix B, matrix C>
requires(__backend::row_iterable<A> && __backend::row_iterable<B> &&
__detail::is_csr_view_v<C>)
Expand All @@ -163,4 +203,11 @@ void multiply_fill(operation_info_t info, A&& a, B&& b, C&& c) {
multiply(a, b, c);
}

// C = AB after multiply_fill(info, A, B, C) was called previously
template <matrix A, matrix B, matrix C>
void multiply_fill_update(operation_info_t info, A&& a, B&& b, C&& c) {
log_trace("");
multiply_update(a, b, c);
}

} // namespace spblas
45 changes: 45 additions & 0 deletions test/gtest/spgemm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,51 @@ TEST(CsrView, SpGEMM) {
}
}

TEST(CsrView, SpGEMMUpdate) {
using T = float;
using I = spblas::index_t;
using O = spblas::offset_t;

for (auto&& [m, k, nnz] : util::dims) {
for (auto&& n : {m, k}) {
auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] =
spblas::generate_csr<T, I>(m, k, nnz);

auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] =
spblas::generate_csr<T, I>(k, n, nnz);

spblas::csr_view<T, I> a(a_values, a_rowptr, a_colind, a_shape, a_nnz);
spblas::csr_view<T, I> b(b_values, b_rowptr, b_colind, b_shape, b_nnz);

std::vector<I> c_rowptr(m + 1);

spblas::csr_view<T, I> c(nullptr, c_rowptr.data(), nullptr, {m, n}, 0);

auto info = spblas::multiply_compute(a, b, c);

std::vector<T> c_values(info.result_nnz());
std::vector<T> c_ref_values(info.result_nnz());
std::vector<I> c_colind(info.result_nnz());

spblas::csr_view<T, I> c_ref(c_ref_values.data(), c_rowptr.data(),
c_colind.data(), {m, n}, info.result_nnz());
c.update(c_values, c_rowptr, c_colind);

spblas::__ranges::transform(a_values, a_values.begin(),
[](auto value) { return 1 / value; });
spblas::__ranges::transform(b_values, b_values.begin(),
[](auto value) { return 1 / value; });

spblas::multiply_fill(info, a, b, c_ref);
spblas::multiply_fill_update(info, a, b, c);

for (auto i : spblas::__ranges::views::iota(O{}, c.size())) {
EXPECT_EQ_(c_values[i], c_ref_values[i]);
}
}
}
}

TEST(CsrView, SpGEMM_AScaled) {
using T = float;
using I = spblas::index_t;
Expand Down
Loading