Skip to content

Fix local_subtensor_of_batch_dims when encountering stale broadcast types#2167

Merged
ricardoV94 merged 5 commits into
pymc-devs:mainfrom
jaanerik:fix-local-subtensor-of-batch-dims-collapse
Jun 1, 2026
Merged

Fix local_subtensor_of_batch_dims when encountering stale broadcast types#2167
ricardoV94 merged 5 commits into
pymc-devs:mainfrom
jaanerik:fix-local-subtensor-of-batch-dims-collapse

Conversation

@jaanerik
Copy link
Copy Markdown
Contributor

@jaanerik jaanerik commented May 25, 2026

Description

The local_subtensor_of_batch_dims rewrite crashes with a TypeError when it attempts to lift a subtensor through an Elemwise node that is in a "stale" state.

A stale state occurs mid-optimization when upstream rewrites have made an Elemwise node's inputs broadcastable (length-1), but the node's output Type has not yet been updated to reflect this (remaining non-broadcastable).

Checklist

Type of change

  • Bug fix

Issues

@jaanerik jaanerik force-pushed the fix-local-subtensor-of-batch-dims-collapse branch 3 times, most recently from 70780a2 to 7475ed4 Compare May 25, 2026 14:29

if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs):
# This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
if not dim_bcast_out and all(dim_bcast_inputs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to think a bit better about the patch.

The rewrite could do x[idx] -> alloc(x, idx.shape[dim]) in this case, the issue is it is missing that last step when after correctly figuring out the idx is never needed, we may still need to provide shape with an alloc.

Also there's a branch above that is not great:

if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs):
        # No need to worry about implicit broadcasting.
        indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]

It tries to check if there's any broadcasting going on, so in the cases where we know x is length 1 it just blindly lifts, but that's not good either, because we end up with wasteful x[idx] + x[idx]. expanding size 1 dims for no reason.

Would be better to fix these two scenarios, we never want to end up with a worse graph, but also we don't want the bail out that's currently proposed

Copy link
Copy Markdown
Contributor Author

@jaanerik jaanerik May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new commit hopefully fixes those. Also Claude thought it is worth noting that

  • alloc uses *old_out.shape rather than per-dim idx.shape[d]
  • a fast inp[idx_tuple] path is kept when no per-dim adjustment is needed, to sidestep _canonical_indexing's ScalarType AttributeError (separate latent bug — idx.type.broadcastable at line 107 doesn't exist on ScalarType).

Comment on lines +439 to +440
if new_out.type.broadcastable != old_out.type.broadcastable:
new_out = alloc(new_out, *old_out.shape)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a new helper, broadcast_like_elemwise or smth you shoud use. you should never rewrite new_x = alloc(..., x.shape), because x.shape references the old x you are replacing, so depending on rewrite ordering/subset you can have an infinite loop of rewrites f(shape(x)) -> f(shape(f(shape(x)))

@ricardoV94 ricardoV94 force-pushed the fix-local-subtensor-of-batch-dims-collapse branch from 3ee8418 to 1e6a82b Compare June 1, 2026 18:04
@ricardoV94 ricardoV94 added bug Something isn't working graph rewriting labels Jun 1, 2026
@ricardoV94
Copy link
Copy Markdown
Member

@jaanerik I simplified a bit the implementation. I realize the canonical_index helper already had most of the logic.

Also cleaned a bit the tests.

And fixed the reference to the original variable by reusing/extending a helper we had around

Comment thread pytensor/tensor/rewriting/subtensor_lift.py Outdated
Comment thread tests/tensor/rewriting/test_subtensor_lift.py Outdated
@ricardoV94 ricardoV94 force-pushed the fix-local-subtensor-of-batch-dims-collapse branch from 1e6a82b to 70abc60 Compare June 1, 2026 20:43
jaanerik and others added 4 commits June 1, 2026 22:46
Instead of bailing out when every input is length 1 on a dim but the
Elemwise output type stale-claims non-bcast, lift through and recover
the original shape with an alloc. Also avoid the prior fast path that
wastefully expanded size 1 dims when an advanced index landed on a dim
where all inputs were length 1 (would emit x[idx] + x[idx] expanding
size 1 to size K for no reason).

The per-dim loop now:
- skips slice(None) and dims with no broadcast input,
- applies basic indices (slice/scalar) as-is on all-bcast dims since
  they can't expand size 1,
- otherwise replaces the index on bcast inputs with a size 1 stand-in.

After building the lifted Elemwise, broadcast back to old_out.shape via
alloc whenever the bcast pattern differs from the original, which
covers both the stale-type case and the collapsed-all-bcast case.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@ricardoV94 ricardoV94 force-pushed the fix-local-subtensor-of-batch-dims-collapse branch from 70abc60 to 9948bde Compare June 1, 2026 20:46
@ricardoV94 ricardoV94 force-pushed the fix-local-subtensor-of-batch-dims-collapse branch from 9fbf5cb to 59bf5d1 Compare June 1, 2026 21:27
@ricardoV94 ricardoV94 merged commit 1df3475 into pymc-devs:main Jun 1, 2026
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working graph rewriting

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: TypeError in local_subtensor_of_batch_dims when encountering stale broadcast types

2 participants