@@ -320,8 +320,9 @@ def local_subtensor_of_batch_dims(fgraph, node):
320320
321321 Bail on boolean masks and non-consecutive advanced indexing — numpy hoists
322322 those advanced groups to position 0, which would misalign the lifted
323- indices. On a broadcast (length-1) axis of an input, replace the advanced
324- index with length-1 zeros so the lifted input still broadcasts correctly.
323+ indices. On a broadcast (length-1) axis of an input the index is dropped
324+ (only zero is in bounds there), and an Alloc restores the full output shape
325+ when a dropped index was what determined it.
325326 """
326327 elem , * idx = node .inputs
327328
@@ -733,28 +734,18 @@ def lift_subtensor_through_alloc(fgraph, node):
733734 if _non_consecutive_adv_indexing (indices ):
734735 return None
735736
736- val_indexer : list = []
737- dangerous_index_reaches_val = False
738- for axis , idx in enumerate (indices ):
739- if axis < n_added_dims :
740- # Axis was added by Alloc; index doesn't reach val.
741- continue
742- val_static_dim = val .type .shape [axis - n_added_dims ]
743- if val_static_dim == 1 :
744- # Broadcast val dim: slices stay (Alloc broadcasts on top);
745- # advanced indices become length-1 zeros for squeeze.
746- if isinstance (idx , slice ):
747- val_indexer .append (slice (None ))
748- else :
749- val_indexer .append (np .zeros ((1 ,) * idx .type .ndim , dtype = np .int64 ))
750- continue
751- val_indexer .append (idx )
752- if not _index_provably_smaller (idx , val_static_dim ):
753- # Per-axis check; doesn't account for net effect across all axes.
754- dangerous_index_reaches_val = True
755-
756- nw_val = _canonical_indexing (val , val_indexer )
757- new_shape = indexed_result_shape (alloc_dims , indices )
737+ # Indices on Alloc-added dims don't reach val; the rest line up with val's dims.
738+ val_indexer = indices [n_added_dims :]
739+ dangerous_index_reaches_val = any (
740+ not val .type .broadcastable [axis ]
741+ # Per-axis check; doesn't account for net effect across all axes.
742+ and not _index_provably_smaller (idx , val .type .shape [axis ])
743+ for axis , idx in enumerate (val_indexer )
744+ )
745+
746+ # On broadcast val dims the index is neutralized (advanced indices dropped,
747+ # shrinking slices made full); the trailing Alloc broadcasts val back up.
748+ nw_val = _canonical_indexing (val , val_indexer , drop_broadcasted_index = True )
758749 drops_alloc = nw_val .type .broadcastable == node .outputs [0 ].type .broadcastable
759750
760751 if dangerous_index_reaches_val and not drops_alloc :
@@ -763,7 +754,7 @@ def lift_subtensor_through_alloc(fgraph, node):
763754 if drops_alloc :
764755 result = nw_val
765756 else :
766- result = alloc (nw_val , * new_shape )
757+ result = alloc (nw_val , * indexed_result_shape ( alloc_dims , indices ) )
767758
768759 copy_stack_trace (node .outputs [0 ], result )
769760 return [result ]
0 commit comments