Skip to content

Commit 9fbf5cb

Browse files
committed
Simplify lift_subtensor_through_alloc
1 parent 9948bde commit 9fbf5cb

1 file changed

Lines changed: 16 additions & 25 deletions

File tree

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,9 @@ def local_subtensor_of_batch_dims(fgraph, node):
320320
321321
Bail on boolean masks and non-consecutive advanced indexing — numpy hoists
322322
those advanced groups to position 0, which would misalign the lifted
323-
indices. On a broadcast (length-1) axis of an input, replace the advanced
324-
index with length-1 zeros so the lifted input still broadcasts correctly.
323+
indices. On a broadcast (length-1) axis of an input the index is dropped
324+
(only zero is in bounds there), and an Alloc restores the full output shape
325+
when a dropped index was what determined it.
325326
"""
326327
elem, *idx = node.inputs
327328

@@ -733,28 +734,18 @@ def lift_subtensor_through_alloc(fgraph, node):
733734
if _non_consecutive_adv_indexing(indices):
734735
return None
735736

736-
val_indexer: list = []
737-
dangerous_index_reaches_val = False
738-
for axis, idx in enumerate(indices):
739-
if axis < n_added_dims:
740-
# Axis was added by Alloc; index doesn't reach val.
741-
continue
742-
val_static_dim = val.type.shape[axis - n_added_dims]
743-
if val_static_dim == 1:
744-
# Broadcast val dim: slices stay (Alloc broadcasts on top);
745-
# advanced indices become length-1 zeros for squeeze.
746-
if isinstance(idx, slice):
747-
val_indexer.append(slice(None))
748-
else:
749-
val_indexer.append(np.zeros((1,) * idx.type.ndim, dtype=np.int64))
750-
continue
751-
val_indexer.append(idx)
752-
if not _index_provably_smaller(idx, val_static_dim):
753-
# Per-axis check; doesn't account for net effect across all axes.
754-
dangerous_index_reaches_val = True
755-
756-
nw_val = _canonical_indexing(val, val_indexer)
757-
new_shape = indexed_result_shape(alloc_dims, indices)
737+
# Indices on Alloc-added dims don't reach val; the rest line up with val's dims.
738+
val_indexer = indices[n_added_dims:]
739+
dangerous_index_reaches_val = any(
740+
not val.type.broadcastable[axis]
741+
# Per-axis check; doesn't account for net effect across all axes.
742+
and not _index_provably_smaller(idx, val.type.shape[axis])
743+
for axis, idx in enumerate(val_indexer)
744+
)
745+
746+
# On broadcast val dims the index is neutralized (advanced indices dropped,
747+
# shrinking slices made full); the trailing Alloc broadcasts val back up.
748+
nw_val = _canonical_indexing(val, val_indexer, drop_broadcasted_index=True)
758749
drops_alloc = nw_val.type.broadcastable == node.outputs[0].type.broadcastable
759750

760751
if dangerous_index_reaches_val and not drops_alloc:
@@ -763,7 +754,7 @@ def lift_subtensor_through_alloc(fgraph, node):
763754
if drops_alloc:
764755
result = nw_val
765756
else:
766-
result = alloc(nw_val, *new_shape)
757+
result = alloc(nw_val, *indexed_result_shape(alloc_dims, indices))
767758

768759
copy_stack_trace(node.outputs[0], result)
769760
return [result]

0 commit comments

Comments
 (0)