Skip to content

Commit 8df47f9

Browse files
committed
Implement GetItem2Lists and GetItem2ListsGrad sparse Ops in Numba backend
1 parent 25a3e0c commit 8df47f9

File tree

3 files changed

+129
-27
lines changed

3 files changed

+129
-27
lines changed

pytensor/link/numba/dispatch/sparse/basic.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
ColScaleCSC,
1818
CSMProperties,
1919
DenseFromSparse,
20+
GetItem2Lists,
21+
GetItem2ListsGrad,
2022
GetItemList,
2123
GetItemListGrad,
2224
HStack,
@@ -357,7 +359,7 @@ def get_item_list_csc(x, idx):
357359
@register_funcify_default_op_cache_key(GetItemListGrad)
358360
def numba_funcify_GetItemListGrad(op, node, **kwargs):
359361
output_format = node.outputs[0].type.format
360-
out_dtype = np.dtype(node.outputs[0].type.dtype)
362+
out_dtype = node.outputs[0].type.dtype
361363

362364
@numba_basic.numba_njit
363365
def get_item_list_grad_csr(x, idxs, gz):
@@ -461,47 +463,50 @@ def get_item_list_grad_csc(x, idx, gz):
461463
return get_item_list_grad_csr(x, idx, gz).tocsc()
462464

463465
return get_item_list_grad_csc
464-
<<<<<<< HEAD
465-
=======
466466

467467

468468
@register_funcify_default_op_cache_key(GetItem2Lists)
469469
def numba_funcify_GetItem2Lists(op, node, **kwargs):
470-
out_dtype = np.dtype(node.outputs[0].type.dtype)
470+
out_dtype = node.outputs[0].type.dtype
471471

472472
@numba_basic.numba_njit
473473
def get_item_2lists(x, ind1, ind2):
474-
x_csr = x.tocsr()
475-
n_rows, n_cols = x_csr.shape
474+
# Reproduces SciPy and NumPy when running:
475+
# np.asarray(x[ind1, ind2]).flatten()
476476

477477
if ind1.shape != ind2.shape:
478478
raise ValueError("shape mismatch in row/column indices")
479479

480+
# Output vector contains as many elements as the length of the index lists.
480481
out_size = ind1.shape[0]
481482
out = np.zeros(out_size, dtype=out_dtype)
482483

483-
x_data = x_csr.data
484+
x_csr = x.tocsr()
484485
x_indices = x_csr.indices.view(np.uint32)
485486
x_indptr = x_csr.indptr.view(np.uint32)
487+
n_rows, n_cols = x_csr.shape
486488

487489
for i in range(out_size):
490+
# Normalize row index
488491
row_idx = ind1[i]
489492
if row_idx < 0:
490493
row_idx += n_rows
491494
if row_idx < 0 or row_idx >= n_rows:
492495
raise IndexError("row index out of bounds")
493496

497+
# Normalize column index
494498
col_idx = ind2[i]
495499
if col_idx < 0:
496500
col_idx += n_cols
497501
if col_idx < 0 or col_idx >= n_cols:
498502
raise IndexError("column index out of bounds")
499503

500-
col_idx_u32 = np.uint32(col_idx)
504+
row_idx = np.uint32(row_idx)
505+
col_idx = np.uint32(col_idx)
501506
for data_idx in range(x_indptr[row_idx], x_indptr[row_idx + 1]):
502-
if x_indices[data_idx] == col_idx_u32:
507+
if x_indices[data_idx] == col_idx:
503508
# Duplicate sparse entries must accumulate like scipy indexing.
504-
out[i] += x_data[data_idx]
509+
out[i] += x_csr.data[data_idx]
505510

506511
return out
507512

@@ -511,31 +516,42 @@ def get_item_2lists(x, ind1, ind2):
511516
@register_funcify_default_op_cache_key(GetItem2ListsGrad)
512517
def numba_funcify_GetItem2ListsGrad(op, node, **kwargs):
513518
output_format = node.outputs[0].type.format
514-
out_dtype = np.dtype(node.outputs[0].type.dtype)
519+
out_dtype = node.outputs[0].type.dtype
515520

516521
@numba_basic.numba_njit
517522
def get_item_2lists_grad_csr(x, ind1, ind2, gz):
518-
n_rows, n_cols = x.shape
519-
n_assignments = ind1.shape[0]
523+
# Reproduces SciPy when running:
524+
# y = [csc|csr]_matrix(x.shape)
525+
# for i in range(len(ind1)):
526+
# y[(ind1[i], ind2[i])] = gz[i]
527+
#
528+
# Note that gz is a dense vector.
520529

521-
if ind2.shape[0] != n_assignments:
530+
if ind1.shape != ind2.shape:
522531
raise ValueError("shape mismatch in row/column indices")
532+
533+
n_assignments = ind1.shape[0]
523534
if gz.shape[0] < n_assignments:
524535
raise IndexError("gradient index out of bounds")
525536

526-
norm_row = np.empty(n_assignments, dtype=np.int32)
527-
norm_col = np.empty(n_assignments, dtype=np.int32)
537+
# Vectors with normalized (non-negative) row and column indices
538+
norm_row = np.empty(n_assignments, dtype=np.uint32)
539+
norm_col = np.empty(n_assignments, dtype=np.uint32)
528540

541+
n_rows, n_cols = x.shape
542+
# Maps original rows to values in [0, ..., touched_n_rows - 1]
529543
row_to_pos = np.full(n_rows, -1, dtype=np.int32)
530544
touched_n_rows = 0
531545

532546
for i in range(n_assignments):
547+
# Normalize row idx
533548
row_idx = ind1[i]
534549
if row_idx < 0:
535550
row_idx += n_rows
536551
if row_idx < 0 or row_idx >= n_rows:
537552
raise IndexError("row index out of bounds")
538553

554+
# Normalize column idx
539555
col_idx = ind2[i]
540556
if col_idx < 0:
541557
col_idx += n_cols
@@ -552,40 +568,46 @@ def get_item_2lists_grad_csr(x, ind1, ind2, gz):
552568
# Build row-wise buffers for touched rows. Repeated writes overwrite values.
553569
row_data = np.zeros((touched_n_rows, n_cols), dtype=out_dtype)
554570
row_mask = np.zeros((touched_n_rows, n_cols), dtype=np.bool_)
571+
row_nnz = np.zeros(touched_n_rows, dtype=np.int32)
555572

556573
for i in range(n_assignments):
557574
row_pos = row_to_pos[norm_row[i]]
558575
col_idx = norm_col[i]
576+
if not row_mask[row_pos, col_idx]:
577+
row_nnz[row_pos] += 1
578+
row_mask[row_pos, col_idx] = True
559579
row_data[row_pos, col_idx] = gz[i]
560-
row_mask[row_pos, col_idx] = True
561580

581+
# Build output indptr.
582+
# For touched rows add row_nnz[row_pos] to total_nnz.
583+
# For untouched rows, carry forward the previous total_nnz count.
562584
out_indptr = np.empty(n_rows + 1, dtype=np.int32)
563585
out_indptr[0] = 0
564586

565587
total_nnz = 0
566588
for row_idx in range(n_rows):
567589
row_pos = row_to_pos[row_idx]
568590
if row_pos >= 0:
569-
row_nnz = 0
570-
for col_idx in range(n_cols):
571-
if row_mask[row_pos, col_idx]:
572-
row_nnz += 1
573-
total_nnz += row_nnz
591+
total_nnz += row_nnz[row_pos]
574592
out_indptr[row_idx + 1] = total_nnz
575593

594+
# Build output data and indices, which need the total number of non-zero elements.
576595
out_data = np.empty(total_nnz, dtype=out_dtype)
577596
out_indices = np.empty(total_nnz, dtype=np.int32)
578-
out_pos = 0
579597

598+
# Populate indices and data by storing col_idx and value (row_data[row_pos, col_idx])
599+
# for touched rows/columns.
580600
for row_idx in range(n_rows):
581601
row_pos = row_to_pos[row_idx]
582602
if row_pos < 0:
583603
continue
604+
605+
dst = out_indptr[row_idx]
584606
for col_idx in range(n_cols):
585607
if row_mask[row_pos, col_idx]:
586-
out_indices[out_pos] = col_idx
587-
out_data[out_pos] = row_data[row_pos, col_idx]
588-
out_pos += 1
608+
out_indices[dst] = col_idx
609+
out_data[dst] = row_data[row_pos, col_idx]
610+
dst += 1
589611

590612
return sp.sparse.csr_matrix(
591613
(out_data, out_indices, out_indptr), shape=(n_rows, n_cols)
@@ -599,4 +621,3 @@ def get_item_2lists_grad_csc(x, ind1, ind2, gz):
599621
return get_item_2lists_grad_csr(x, ind1, ind2, gz).tocsc()
600622

601623
return get_item_2lists_grad_csc
602-
>>>>>>> fb1d09134 (Better comments for GetItemList and GetItemListGrad)

pytensor/sparse/basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,8 @@ def make_node(self, x, ind1, ind2):
926926
assert x.format in ("csr", "csc")
927927
ind1 = ptb.as_tensor_variable(ind1)
928928
ind2 = ptb.as_tensor_variable(ind2)
929+
assert ind1.ndim == 1
930+
assert ind2.ndim == 1
929931
assert ind1.dtype in integer_dtypes
930932
assert ind2.dtype in integer_dtypes
931933

tests/link/numba/sparse/test_basic.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,82 @@ def test_sparse_get_item_list_grad_wrong_index(format):
513513

514514
with pytest.raises(IndexError):
515515
fn(x_test, idx_test, gz_test)
516+
517+
518+
@pytest.mark.parametrize("format", ("csr", "csc"))
519+
def test_sparse_get_item_2lists(format):
520+
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
521+
ind1 = pt.ivector("ind1")
522+
ind2 = pt.ivector("ind2")
523+
z = ps.get_item_2lists(x, ind1, ind2)
524+
525+
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
526+
ind1_test = np.asarray([0, 0, 3, 5], dtype=np.int32)
527+
ind2_test = np.asarray([0, 4, 2, 1], dtype=np.int32)
528+
529+
compare_numba_and_py_sparse([x, ind1, ind2], z, [x_test, ind1_test, ind2_test])
530+
531+
532+
@pytest.mark.parametrize("format", ("csr", "csc"))
533+
@pytest.mark.parametrize(
534+
("ind1_test", "ind2_test"),
535+
[
536+
(np.asarray([0, 6], dtype=np.int32), np.asarray([0, 3], dtype=np.int32)),
537+
(np.asarray([0, 3], dtype=np.int32), np.asarray([0, 5], dtype=np.int32)),
538+
],
539+
)
540+
def test_sparse_get_item_2lists_wrong_index(format, ind1_test, ind2_test):
541+
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
542+
ind1 = pt.ivector("ind1")
543+
ind2 = pt.ivector("ind2")
544+
z = ps.get_item_2lists(x, ind1, ind2)
545+
fn = function([x, ind1, ind2], z, mode="NUMBA")
546+
547+
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
548+
549+
with pytest.raises(IndexError):
550+
fn(x_test, ind1_test, ind2_test)
551+
552+
553+
@pytest.mark.parametrize("format", ("csr", "csc"))
554+
def test_sparse_get_item_2lists_grad(format):
555+
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
556+
ind1 = pt.ivector("ind1")
557+
ind2 = pt.ivector("ind2")
558+
gz = pt.vector(name="gz", shape=(4,), dtype=config.floatX)
559+
z = ps.get_item_2lists_grad(x, ind1, ind2, gz)
560+
561+
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
562+
ind1_test = np.asarray([0, 2, 5, 2], dtype=np.int32)
563+
ind2_test = np.asarray([1, 0, 4, 0], dtype=np.int32)
564+
gz_test = np.asarray([0.5, -1.25, 2.0, 4.5], dtype=config.floatX)
565+
566+
with pytest.warns(sp.sparse.SparseEfficiencyWarning):
567+
# GetItem2ListsGrad.perform does sparse item assignment into an initially empty
568+
# sparse matrix, which changes sparsity structure incrementally.
569+
compare_numba_and_py_sparse(
570+
[x, ind1, ind2, gz], z, [x_test, ind1_test, ind2_test, gz_test]
571+
)
572+
573+
574+
@pytest.mark.parametrize("format", ("csr", "csc"))
575+
@pytest.mark.parametrize(
576+
("ind1_test", "ind2_test"),
577+
[
578+
(np.asarray([0, 6], dtype=np.int32), np.asarray([0, 3], dtype=np.int32)),
579+
(np.asarray([0, 3], dtype=np.int32), np.asarray([0, 5], dtype=np.int32)),
580+
],
581+
)
582+
def test_sparse_get_item_2lists_grad_wrong_index(format, ind1_test, ind2_test):
583+
x = ps.matrix(format, name="x", shape=(6, 5), dtype=config.floatX)
584+
ind1 = pt.ivector("ind1")
585+
ind2 = pt.ivector("ind2")
586+
gz = pt.vector(name="gz", shape=(2,), dtype=config.floatX)
587+
z = ps.get_item_2lists_grad(x, ind1, ind2, gz)
588+
fn = function([x, ind1, ind2, gz], z, mode="NUMBA")
589+
590+
x_test = sp.sparse.random(6, 5, density=0.4, format=format, dtype=config.floatX)
591+
gz_test = np.asarray([1.0, -2.0], dtype=config.floatX)
592+
593+
with pytest.raises(IndexError):
594+
fn(x_test, ind1_test, ind2_test, gz_test)

0 commit comments

Comments
 (0)