Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions pytensor/sparse/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytensor import config
from pytensor.gradient import grad_not_implemented
from pytensor.graph import Apply, Op
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.sparse.type import SparseTensorType
from pytensor.tensor.shape import specify_broadcastable
Expand Down Expand Up @@ -2067,3 +2068,63 @@ def perform(self, node, inputs, outputs):


usmm = Usmm()


@_vectorize_node.register(StructuredDot)
def _vectorize_structured_dot(op, node, batch_a, batch_b):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

"""Batch StructuredDot(sparse_const, dense): (m,k)@(B...,k,n) -> (B...,m,n).

The sparse left input is required to stay unbatched (no batched-sparse
type in scipy). The dense right input may gain any number of leading
batch dims; we use a moveaxis+reshape trick to fold them through the
existing StructuredDot 2D matmul.
"""
a, b = node.inputs
if batch_a is not a:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a need not be batch_a, just check batch_a.ndim==2?

# Caller batched the sparse input — structurally unsupported.
raise NotImplementedError(
"Cannot vectorize StructuredDot when the sparse (left) input is batched; "
"scipy has no batched-sparse type. Rewrite the model to keep the sparse "
"matrix constant across the batch."
)

extra = batch_b.type.ndim - b.type.ndim # number of batch dims added to b
if extra == 0:
# nothing to vectorize — just rebuild the op
return op.make_node(batch_a, batch_b)

# b is (B1,...,BN, k, n). Move k to front: (k, B1,...,BN, n).
# Use moveaxis on the second-to-last axis (-2 = k after the batch dims).
k_axis = -2
moved = ptb.moveaxis(batch_b, k_axis, 0) # (k, B1,...,BN, n)

# Compose the trailing shape (B1*...*BN*n) symbolically.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pt.join_dims?

shape = moved.shape
k = shape[0]
trailing = shape[1:]
flat_trailing = ptm.prod(trailing)
flat_b = moved.reshape((k, flat_trailing)) # (k, B*n)

# StructuredDot returns (m, B*n)
flat_out = op.make_node(batch_a, flat_b).outputs[0] # (m, B*n)

# Reshape back to (m, B1,...,BN, n).
m = flat_out.shape[0]
target_shape = ptb.concatenate([ptb.stack([m]), ptb.stack(list(trailing))])
unflat = flat_out.reshape(target_shape, ndim=batch_b.type.ndim)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pt.split_dims?

# Move m back into the (-2) slot: (B1,...,BN, m, n).
out = ptb.moveaxis(unflat, 0, -2)
return out.owner
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you no longer need to artificially return an apply, that will actually be deprecated. Just return the variables (maybe in a list)



def _vectorize_sparse_unsupported(op, node, *batched_inputs):
raise NotImplementedError(
f"Cannot vectorize {type(op).__name__}: scipy has no batched-sparse "
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this advice is actionable. Blockwise will still fail even if sparse input has no batch dims no? At the very least it tries to add dummy expand dims. There's an issue open about this IIRC

"representation, so the sparse operand cannot be broadcast against a "
"batched dense input. Rewrite the model to keep the sparse matrix "
"constant across the batch."
)


for _op_cls in (AddSD, SparseDenseMultiply):
_vectorize_node.register(_op_cls)(_vectorize_sparse_unsupported)
71 changes: 71 additions & 0 deletions tests/sparse/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytensor.tensor as pt
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.graph.replace import vectorize_graph
from pytensor.link.numba import NumbaLinker
from pytensor.scalar import upcast
from pytensor.sparse.basic import (
Expand Down Expand Up @@ -1471,3 +1472,73 @@ def test_grad(self):
ConjTester = elemwise_checker(psm.conjugate, np.conj, grad_test=False)

NegTester = elemwise_checker(psm.neg, np.negative, name="TestNeg")


class TestVectorizeSparse:
def _const_csr(self, n=3):
return as_sparse_variable(scipy_sparse.csr_matrix(np.eye(n, dtype="float64")))

@pytest.mark.parametrize("n_batch_dims", [1, 2])
def test_structured_dot_batches_dense_input(self, n_batch_dims):
# StructuredDot(sparse_const, dense): batch only the dense (right) input.
S = self._const_csr(3)
x = pt.matrix("x") # core (3, n)
y = structured_dot(S, x)

batch_shape = (2, 4)[:n_batch_dims]
xb = pt.tensor("xb", shape=(None,) * (n_batch_dims + 2))
yb = vectorize_graph(y, {x: xb})

assert yb.type.ndim == n_batch_dims + 2

rng = np.random.default_rng(123)
xb_val = rng.normal(size=(*batch_shape, 3, 5)).astype("float64")
out = pytensor.function([xb], yb)(xb_val)

S_dense = np.eye(3)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this just expected = np.eye(3) @ xb_val?

expected = np.empty((*batch_shape, 3, 5))
for idx in np.ndindex(*batch_shape):
expected[idx] = S_dense @ xb_val[idx]
np.testing.assert_allclose(out, expected)

def test_structured_dot_no_batch_is_noop(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove in favor of batch_dims=0 parametrization above?

# Replacing the dense input with another un-batched dense input must not
# wrap anything: the result is a plain StructuredDot again.
S = self._const_csr(3)
x = pt.matrix("x")
y = structured_dot(S, x)
new_x = pt.matrix("new_x")

new_y = vectorize_graph(y, {x: new_x})
assert isinstance(new_y.owner.op, StructuredDot)

def test_structured_dot_batched_sparse_raises(self):
# Batching the sparse (left) input is structurally unsupported.
x = pt.matrix("x")
S_sparse = self._const_csr(3)
y = structured_dot(S_sparse, x)

S_batched = csr_matrix(name="S_batched")
with pytest.raises(
NotImplementedError, match="sparse \\(left\\) input is batched"
):
vectorize_graph(y, {S_sparse: S_batched, x: pt.tensor3("xb")})

@pytest.mark.parametrize(
"build",
[
pytest.param(lambda a, d: add(a, d), id="AddSD"),
pytest.param(lambda a, d: multiply(a, d), id="SparseDenseMultiply"),
],
)
def test_sparse_dense_ops_batched_dense_raises(self, build):
# These ops take a dense input that *can* be batched. Batching it must
# raise a descriptive NotImplementedError rather than the cryptic
# as_sparse_variable TypeError from the Blockwise fallback.
a = self._const_csr(3)
d = pt.matrix("d")
out = build(a, d)

new_d = pt.tensor3("new_d")
with pytest.raises(NotImplementedError, match="no batched-sparse"):
vectorize_graph(out, {d: new_d})
Loading