Add _vectorize_node dispatchers for sparse ops#2190
Conversation
1851ec5 to
0659b8f
Compare
| usmm = Usmm() | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- |
There was a problem hiding this comment.
please remove these global comments/separators
| # That contradicts the sparse-input contract enforced by as_sparse_variable, | ||
| # so every sparse op needs a custom dispatcher (or a clear NotImplementedError). | ||
| @_vectorize_node.register(StructuredDot) | ||
| def _vectorize_structured_dot(op, node, batch_a, batch_b): |
|
looks good, just need style cleanup |
The default Blockwise-based fallback in pytensor/graph/replace.py wraps ops in Blockwise and rebuilds their make_node with dense dummy core inputs, which contradicts the sparse-input contract enforced by as_sparse_variable. As a result, vectorize_graph crashes with "Variable type field must be a SparseTensorType" the moment it encounters any sparse op — pmx-extras pathfinder hits this whenever a PyMC model uses a sparse projection (e.g. a sum-to-zero constraint encoded as pt.dot(flat, as_sparse_variable(csr))). This patch: - Registers an explicit dispatcher for StructuredDot that batches the dense (right) input via a moveaxis+reshape trick while keeping the sparse (left) input unbatched (scipy has no batched-sparse type). Raises NotImplementedError with a clear message if the caller tries to batch the sparse input. - Registers NotImplementedError stubs for the other sparse ops likely to appear in user graphs (TrueDot, AddSS, AddSSData, AddSD, SparseSparseMultiply, SparseDenseMultiply) so callers see a descriptive error instead of the cryptic as_sparse_variable TypeError from the Blockwise fallback.
Add TestVectorizeSparse covering the StructuredDot dispatcher (batched dense input, no-batch no-op, batched-sparse error) and the AddSD / SparseDenseMultiply NotImplementedError stubs. The structured_dot test reproduces the original "Variable type field must be a SparseTensorType" crash without the dispatcher (issue pymc-devs#2189). Drop the NotImplementedError stubs for the all-sparse-input ops (TrueDot, AddSS, AddSSData, SparseSparseMultiply): a sparse input can never become batched, so vectorize_graph never dispatches to them. Keep only the reachable AddSD / SparseDenseMultiply, and reword the error since AddSD's output is dense, not sparse. Move the _vectorize_node import to the top of the module (no circular import) to satisfy ruff E402. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
0659b8f to
be177e0
Compare
be177e0 to
ddd2049
Compare
| existing StructuredDot 2D matmul. | ||
| """ | ||
| a, b = node.inputs | ||
| if batch_a is not a: |
There was a problem hiding this comment.
a need not be batch_a, just check batch_a.ndim==2?
| k_axis = -2 | ||
| moved = ptb.moveaxis(batch_b, k_axis, 0) # (k, B1,...,BN, n) | ||
|
|
||
| # Compose the trailing shape (B1*...*BN*n) symbolically. |
| # Reshape back to (m, B1,...,BN, n). | ||
| m = flat_out.shape[0] | ||
| target_shape = ptb.concatenate([ptb.stack([m]), ptb.stack(list(trailing))]) | ||
| unflat = flat_out.reshape(target_shape, ndim=batch_b.type.ndim) |
| unflat = flat_out.reshape(target_shape, ndim=batch_b.type.ndim) | ||
| # Move m back into the (-2) slot: (B1,...,BN, m, n). | ||
| out = ptb.moveaxis(unflat, 0, -2) | ||
| return out.owner |
There was a problem hiding this comment.
you no longer need to artificially return an apply, that will actually be deprecated. Just return the variables (maybe in a list)
|
|
||
| def _vectorize_sparse_unsupported(op, node, *batched_inputs): | ||
| raise NotImplementedError( | ||
| f"Cannot vectorize {type(op).__name__}: scipy has no batched-sparse " |
There was a problem hiding this comment.
I'm not sure this advice is actionable. Blockwise will still fail even if sparse input has no batch dims no? At the very least it tries to add dummy expand dims. There's an issue open about this IIRC
| xb_val = rng.normal(size=(*batch_shape, 3, 5)).astype("float64") | ||
| out = pytensor.function([xb], yb)(xb_val) | ||
|
|
||
| S_dense = np.eye(3) |
There was a problem hiding this comment.
isn't this just expected = np.eye(3) @ xb_val?
| expected[idx] = S_dense @ xb_val[idx] | ||
| np.testing.assert_allclose(out, expected) | ||
|
|
||
| def test_structured_dot_no_batch_is_noop(self): |
There was a problem hiding this comment.
Remove in favor of batch_dims=0 parametrization above?
Description
Related Issue
Checklist
Type of change