Skip to content
2 changes: 1 addition & 1 deletion src/sparse/stdlib_sparse_constants.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
module stdlib_sparse_constants
use stdlib_kinds, only: int8, int16, int32, int64, sp, dp, xdp, qp
use stdlib_kinds, only: int8, int16, int32, int64, sp, dp, xdp, qp, c_bool
use stdlib_constants
implicit none
public
Expand Down
65 changes: 45 additions & 20 deletions src/sparse/stdlib_sparse_kinds.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,17 @@ contains
! data accessors
!==================================================================

logical(c_bool) elemental function skip(sym,row,col)
integer(ilp), intent(in) :: sym, row, col
skip = (sym == sparse_lower .and. row < col) .or. (sym == sparse_upper .and. row > col)
end function

#:for k1, t1, s1 in (KINDS_TYPES)
pure ${t1}$ function at_value_coo_${s1}$(self,ik,jk) result(val)
class(COO_${s1}$_type), intent(in) :: self
integer(ilp), intent(in) :: ik, jk
integer(ilp) :: k, ik_, jk_
logical :: transpose
logical(c_bool) :: transpose
! naive implementation
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
Expand Down Expand Up @@ -373,14 +378,18 @@ contains
class(COO_${s1}$_type), intent(inout) :: self
${t1}$, intent(in) :: val(:,:)
integer(ilp), intent(in) :: ik(:), jk(:)
integer(ilp) :: k, i, j
integer(ilp) :: k, i, j, row, col
! naive implementation
do k = 1, self%nnz
do i = 1, size(ik)
if( ik(i) /= self%index(1,k) ) cycle
row = ik(i)
if( row /= self%index(1,k) ) cycle
do j = 1, size(jk)
if( jk(j) /= self%index(2,k) ) cycle
col = jk(j)
if( skip(self%storage,row,col) ) cycle
if( col /= self%index(2,k) ) cycle
self%data(k) = self%data(k) + val(i,j)
exit
end do
end do
end do
Expand All @@ -393,7 +402,7 @@ contains
class(CSR_${s1}$_type), intent(in) :: self
integer(ilp), intent(in) :: ik, jk
integer(ilp) :: k, ik_, jk_
logical :: transpose
logical(c_bool) :: transpose
! naive implementation
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
Expand Down Expand Up @@ -431,13 +440,17 @@ contains
class(CSR_${s1}$_type), intent(inout) :: self
${t1}$, intent(in) :: val(:,:)
integer(ilp), intent(in) :: ik(:), jk(:)
integer(ilp) :: k, i, j
integer(ilp) :: k, i, j, row, col
! naive implementation
do i = 1, size(ik)
do k = self%rowptr(ik(i)), self%rowptr(ik(i)+1)-1
row = ik(i)
do k = self%rowptr(row), self%rowptr(row+1)-1
do j = 1, size(jk)
if( jk(j) == self%col(k) ) then
col = jk(j)
if( skip(self%storage,row,col) ) cycle
if( col == self%col(k) ) then
self%data(k) = self%data(k) + val(i,j)
exit
end if
end do
end do
Expand All @@ -451,7 +464,7 @@ contains
class(CSC_${s1}$_type), intent(in) :: self
integer(ilp), intent(in) :: ik, jk
integer(ilp) :: k, ik_, jk_
logical :: transpose
logical(c_bool) :: transpose
! naive implementation
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
Expand Down Expand Up @@ -489,13 +502,17 @@ contains
class(CSC_${s1}$_type), intent(inout) :: self
${t1}$, intent(in) :: val(:,:)
integer(ilp), intent(in) :: ik(:), jk(:)
integer(ilp) :: k, i, j
integer(ilp) :: k, i, j, row, col
! naive implementation
do j = 1, size(jk)
do k = self%colptr(jk(j)), self%colptr(jk(j)+1)-1
col = jk(j)
do k = self%colptr(col), self%colptr(col+1)-1
do i = 1, size(ik)
if( ik(i) == self%row(k) ) then
row = ik(i)
if( skip(self%storage,row,col) ) cycle
if( row == self%row(k) ) then
self%data(k) = self%data(k) + val(i,j)
exit
end if
end do
end do
Expand All @@ -509,7 +526,7 @@ contains
class(ELL_${s1}$_type), intent(in) :: self
integer(ilp), intent(in) :: ik, jk
integer(ilp) :: k, ik_, jk_
logical :: transpose
logical(c_bool) :: transpose
! naive implementation
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
Expand Down Expand Up @@ -547,13 +564,17 @@ contains
class(ELL_${s1}$_type), intent(inout) :: self
${t1}$, intent(in) :: val(:,:)
integer(ilp), intent(in) :: ik(:), jk(:)
integer(ilp) :: k, i, j
integer(ilp) :: k, i, j, row, col
! naive implementation
do k = 1 , self%K
do j = 1, size(jk)
col = jk(j)
do i = 1, size(ik)
if( jk(j) == self%index(ik(i),k) ) then
self%data(ik(i),k) = self%data(ik(i),k) + val(i,j)
row = ik(i)
if( skip(self%storage,row,col) ) cycle
if( col == self%index(row,k) ) then
self%data(row,k) = self%data(row,k) + val(i,j)
exit
end if
end do
end do
Expand All @@ -567,7 +588,7 @@ contains
class(SELLC_${s1}$_type), intent(in) :: self
integer(ilp), intent(in) :: ik, jk
integer(ilp) :: k, ik_, jk_, idx
logical :: transpose
logical(c_bool) :: transpose
! naive implementation
if( (ik<1 .or. ik>self%nrows) .or. (jk<1 .or. jk>self%ncols) ) then
val = ieee_value( 0._${k1}$ , ieee_quiet_nan)
Expand Down Expand Up @@ -608,14 +629,18 @@ contains
class(SELLC_${s1}$_type), intent(inout) :: self
${t1}$, intent(in) :: val(:,:)
integer(ilp), intent(in) :: ik(:), jk(:)
integer(ilp) :: k, i, j, idx
integer(ilp) :: k, i, j, idx, row, col
! naive implementation
do k = 1 , self%chunk_size
do j = 1, size(jk)
col = jk(j)
do i = 1, size(ik)
idx = self%rowptr((ik(i) - 1)/self%chunk_size + 1)
if( jk(j) == self%col(k,idx) ) then
row = ik(i)
idx = self%rowptr((row - 1)/self%chunk_size + 1)
if( skip(self%storage,row,col) ) cycle
if( col == self%col(k,idx) ) then
self%data(k,idx) = self%data(k,idx) + val(i,j)
exit
end if
end do
end do
Expand Down
97 changes: 95 additions & 2 deletions test/linalg/test_linalg_sparse.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ contains
new_unittest('symmetries', test_symmetries), &
new_unittest('diagonal', test_diagonal), &
new_unittest('add_get_values', test_add_get_values), &
new_unittest('sparse_operators', test_sparse_operators) &
new_unittest('sparse_operators', test_sparse_operators), &
new_unittest('add_block_symmetric_skip', test_add_block_symmetric_skip) &
]
end subroutine

Expand Down Expand Up @@ -373,7 +374,6 @@ contains

call check(error, all(CSR%data == COO%data) )
if (allocated(error)) return

err = 0._wp
do i = 1, 5
do j = 1, 5
Expand Down Expand Up @@ -485,8 +485,101 @@ contains
end block
#:endfor
#:endfor

end subroutine

subroutine test_add_block_symmetric_skip(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error
#:for k1, t1, s1 in (KINDS_TYPES)
block
Comment thread
jalvesz marked this conversation as resolved.
integer, parameter :: wp = ${k1}$
integer :: connectivity(3,3)

real(wp) :: dense(5,5), dense_low(5,5), mat(3,3)
type(COO_${s1}$_type) :: COO_full, COO_low
type(CSR_${s1}$_type) :: CSR_full, CSR_low
type(CSC_${s1}$_type) :: CSC_full, CSC_low
real(wp) :: x(5), y(5), y_ref(5)
${t1}$:: err
integer :: i, j, locdof(3)

connectivity(1:3,1) = [1,2,3]
connectivity(1:3,2) = [2,3,4]
connectivity(1:3,3) = [3,4,5]

mat(:,1) = [1,2,3]
mat(:,2) = [2,1,4]
mat(:,3) = [3,4,1]

dense = 0._wp
do i = 1, 3
locdof(1:3) = connectivity(1:3,i)
dense(locdof,locdof) = dense(locdof,locdof) + mat
end do

call dense2coo(dense,COO_full)
call coo2csr(COO_full,CSR_full)
call coo2csc(COO_full,CSC_full)
dense_low = dense
do i = 1, 5
do j = i+1, 5
dense_low(i,j) = 0._wp
end do
end do
call dense2coo(dense_low,COO_low)
COO_low%storage = sparse_lower
call coo2csr(COO_low,CSR_low)
call coo2csc(COO_low,CSC_low)

COO_full%data = 0._wp
COO_low%data = 0._wp
CSR_full%data = 0._wp
CSR_low%data = 0._wp
CSC_full%data = 0._wp
CSC_low%data = 0._wp
do i = 1, 3
locdof(1:3) = connectivity(1:3,i)
call COO_full%add(locdof,locdof,mat)
call COO_low%add(locdof,locdof,mat)
call CSR_full%add(locdof,locdof,mat)
call CSR_low%add(locdof,locdof,mat)
call CSC_full%add(locdof,locdof,mat)
call CSC_low%add(locdof,locdof,mat)
end do

call check(error, all(CSR_full%data == COO_full%data) , "error in full CSR ${s1}$ data" )
if (allocated(error)) return

call check(error, all(CSR_low%data == COO_low%data) , "error in low CSR ${s1}$ data" )
if (allocated(error)) return

x = 1._wp
y_ref = matmul(dense,x)

y = 0._wp
call spmv( CSR_full, x, y )
call check(error, all(y == y_ref) , "error in full CSR ${s1}$ spmv" )
if (allocated(error)) return

y = 0._wp
call spmv( CSR_low, x, y )
call check(error, all(y == y_ref) , "error in low CSR ${s1}$ spmv" )
if (allocated(error)) return

y = 0._wp
call spmv( CSC_full, x, y )
call check(error, all(y == y_ref) , "error in full CSC ${s1}$ spmv" )
if (allocated(error)) return

y = 0._wp
call spmv( CSC_low, x, y )
call check(error, all(y == y_ref) , "error in low CSC ${s1}$ spmv" )
end block
#:endfor

end subroutine

end module


Expand Down
Loading