-
Notifications
You must be signed in to change notification settings - Fork 186
Add _vectorize_node dispatchers for sparse ops #2190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| from pytensor import config | ||
| from pytensor.gradient import grad_not_implemented | ||
| from pytensor.graph import Apply, Op | ||
| from pytensor.graph.replace import _vectorize_node | ||
| from pytensor.link.c.op import COp | ||
| from pytensor.sparse.type import SparseTensorType | ||
| from pytensor.tensor.shape import specify_broadcastable | ||
|
|
@@ -2067,3 +2068,63 @@ def perform(self, node, inputs, outputs): | |
|
|
||
|
|
||
| usmm = Usmm() | ||
|
|
||
|
|
||
| @_vectorize_node.register(StructuredDot) | ||
| def _vectorize_structured_dot(op, node, batch_a, batch_b): | ||
| """Batch StructuredDot(sparse_const, dense): (m,k)@(B...,k,n) -> (B...,m,n). | ||
|
|
||
| The sparse left input is required to stay unbatched (no batched-sparse | ||
| type in scipy). The dense right input may gain any number of leading | ||
| batch dims; we use a moveaxis+reshape trick to fold them through the | ||
| existing StructuredDot 2D matmul. | ||
| """ | ||
| a, b = node.inputs | ||
| if batch_a is not a: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a need not be batch_a, just check batch_a.ndim==2? |
||
| # Caller batched the sparse input — structurally unsupported. | ||
| raise NotImplementedError( | ||
| "Cannot vectorize StructuredDot when the sparse (left) input is batched; " | ||
| "scipy has no batched-sparse type. Rewrite the model to keep the sparse " | ||
| "matrix constant across the batch." | ||
| ) | ||
|
|
||
| extra = batch_b.type.ndim - b.type.ndim # number of batch dims added to b | ||
| if extra == 0: | ||
| # nothing to vectorize — just rebuild the op | ||
| return op.make_node(batch_a, batch_b) | ||
|
|
||
| # b is (B1,...,BN, k, n). Move k to front: (k, B1,...,BN, n). | ||
| # Use moveaxis on the second-to-last axis (-2 = k after the batch dims). | ||
| k_axis = -2 | ||
| moved = ptb.moveaxis(batch_b, k_axis, 0) # (k, B1,...,BN, n) | ||
|
|
||
| # Compose the trailing shape (B1*...*BN*n) symbolically. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pt.join_dims? |
||
| shape = moved.shape | ||
| k = shape[0] | ||
| trailing = shape[1:] | ||
| flat_trailing = ptm.prod(trailing) | ||
| flat_b = moved.reshape((k, flat_trailing)) # (k, B*n) | ||
|
|
||
| # StructuredDot returns (m, B*n) | ||
| flat_out = op.make_node(batch_a, flat_b).outputs[0] # (m, B*n) | ||
|
|
||
| # Reshape back to (m, B1,...,BN, n). | ||
| m = flat_out.shape[0] | ||
| target_shape = ptb.concatenate([ptb.stack([m]), ptb.stack(list(trailing))]) | ||
| unflat = flat_out.reshape(target_shape, ndim=batch_b.type.ndim) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pt.split_dims? |
||
| # Move m back into the (-2) slot: (B1,...,BN, m, n). | ||
| out = ptb.moveaxis(unflat, 0, -2) | ||
| return out.owner | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you no longer need to artificially return an apply, that will actually be deprecated. Just return the variables (maybe in a list) |
||
|
|
||
|
|
||
| def _vectorize_sparse_unsupported(op, node, *batched_inputs): | ||
| raise NotImplementedError( | ||
| f"Cannot vectorize {type(op).__name__}: scipy has no batched-sparse " | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this advice is actionable. Blockwise will still fail even if sparse input has no batch dims no? At the very least it tries to add dummy expand dims. There's an issue open about this IIRC |
||
| "representation, so the sparse operand cannot be broadcast against a " | ||
| "batched dense input. Rewrite the model to keep the sparse matrix " | ||
| "constant across the batch." | ||
| ) | ||
|
|
||
|
|
||
| for _op_cls in (AddSD, SparseDenseMultiply): | ||
| _vectorize_node.register(_op_cls)(_vectorize_sparse_unsupported) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| import pytensor.tensor as pt | ||
| from pytensor.compile import get_default_mode | ||
| from pytensor.configdefaults import config | ||
| from pytensor.graph.replace import vectorize_graph | ||
| from pytensor.link.numba import NumbaLinker | ||
| from pytensor.scalar import upcast | ||
| from pytensor.sparse.basic import ( | ||
|
|
@@ -1471,3 +1472,73 @@ def test_grad(self): | |
| ConjTester = elemwise_checker(psm.conjugate, np.conj, grad_test=False) | ||
|
|
||
| NegTester = elemwise_checker(psm.neg, np.negative, name="TestNeg") | ||
|
|
||
|
|
||
| class TestVectorizeSparse: | ||
| def _const_csr(self, n=3): | ||
| return as_sparse_variable(scipy_sparse.csr_matrix(np.eye(n, dtype="float64"))) | ||
|
|
||
| @pytest.mark.parametrize("n_batch_dims", [1, 2]) | ||
| def test_structured_dot_batches_dense_input(self, n_batch_dims): | ||
| # StructuredDot(sparse_const, dense): batch only the dense (right) input. | ||
| S = self._const_csr(3) | ||
| x = pt.matrix("x") # core (3, n) | ||
| y = structured_dot(S, x) | ||
|
|
||
| batch_shape = (2, 4)[:n_batch_dims] | ||
| xb = pt.tensor("xb", shape=(None,) * (n_batch_dims + 2)) | ||
| yb = vectorize_graph(y, {x: xb}) | ||
|
|
||
| assert yb.type.ndim == n_batch_dims + 2 | ||
|
|
||
| rng = np.random.default_rng(123) | ||
| xb_val = rng.normal(size=(*batch_shape, 3, 5)).astype("float64") | ||
| out = pytensor.function([xb], yb)(xb_val) | ||
|
|
||
| S_dense = np.eye(3) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this just expected = np.eye(3) @ xb_val? |
||
| expected = np.empty((*batch_shape, 3, 5)) | ||
| for idx in np.ndindex(*batch_shape): | ||
| expected[idx] = S_dense @ xb_val[idx] | ||
| np.testing.assert_allclose(out, expected) | ||
|
|
||
| def test_structured_dot_no_batch_is_noop(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove in favor of batch_dims=0 parametrization above? |
||
| # Replacing the dense input with another un-batched dense input must not | ||
| # wrap anything: the result is a plain StructuredDot again. | ||
| S = self._const_csr(3) | ||
| x = pt.matrix("x") | ||
| y = structured_dot(S, x) | ||
| new_x = pt.matrix("new_x") | ||
|
|
||
| new_y = vectorize_graph(y, {x: new_x}) | ||
| assert isinstance(new_y.owner.op, StructuredDot) | ||
|
|
||
| def test_structured_dot_batched_sparse_raises(self): | ||
| # Batching the sparse (left) input is structurally unsupported. | ||
| x = pt.matrix("x") | ||
| S_sparse = self._const_csr(3) | ||
| y = structured_dot(S_sparse, x) | ||
|
|
||
| S_batched = csr_matrix(name="S_batched") | ||
| with pytest.raises( | ||
| NotImplementedError, match="sparse \\(left\\) input is batched" | ||
| ): | ||
| vectorize_graph(y, {S_sparse: S_batched, x: pt.tensor3("xb")}) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "build", | ||
| [ | ||
| pytest.param(lambda a, d: add(a, d), id="AddSD"), | ||
| pytest.param(lambda a, d: multiply(a, d), id="SparseDenseMultiply"), | ||
| ], | ||
| ) | ||
| def test_sparse_dense_ops_batched_dense_raises(self, build): | ||
| # These ops take a dense input that *can* be batched. Batching it must | ||
| # raise a descriptive NotImplementedError rather than the cryptic | ||
| # as_sparse_variable TypeError from the Blockwise fallback. | ||
| a = self._const_csr(3) | ||
| d = pt.matrix("d") | ||
| out = build(a, d) | ||
|
|
||
| new_d = pt.tensor3("new_d") | ||
| with pytest.raises(NotImplementedError, match="no batched-sparse"): | ||
| vectorize_graph(out, {d: new_d}) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!