Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_simplified_shape(x: TensorVariable, *, fgraph) -> tuple:
return tuple(x.shape)


def get_simplified_broadcast_shape(first, *others, fgraph) -> list:
def get_simplified_broadcast_shape(first, *others, fgraph, batch_ndim=None) -> list:
"""Per-axis fold of ``first`` with ``others``, prioritizing non-broadcastable lengths.

The shape entry from ``first`` wins on every axis where ``first`` is not
Expand All @@ -145,23 +145,26 @@ def get_simplified_broadcast_shape(first, *others, fgraph) -> list:
commutative, but the resulting symbolic shape is not — we want the
simplest/most-static expression).

Assumes all inputs have the same ndim (the Elemwise contract).
With ``batch_ndim`` only the leading ``batch_ndim`` dimensions are folded and
returned, so inputs may differ on their trailing (core) dimensions, as with
``Blockwise``. Otherwise all inputs must share ndim (the Elemwise contract).
"""
first_broadcastable = first.type.broadcastable
if batch_ndim is None:
batch_ndim = first.type.ndim

first_broadcastable = first.type.broadcastable[:batch_ndim]
if not (any(first_broadcastable) and others):
return list(get_simplified_shape(first, fgraph=fgraph))
return list(get_simplified_shape(first, fgraph=fgraph))[:batch_ndim]

broadcastable_dims = list(first_broadcastable)
broadcast_shape = list(get_simplified_shape(first, fgraph=fgraph))
broadcast_shape = list(get_simplified_shape(first, fgraph=fgraph))[:batch_ndim]
for other in others:
other_shape = get_simplified_shape(other, fgraph=fgraph)
for i, (other_broadcastable, other_dim_length) in enumerate(
zip(other.type.broadcastable, other_shape, strict=True)
):
if other_broadcastable or not broadcastable_dims[i]:
for i in range(batch_ndim):
if other.type.broadcastable[i] or not broadcastable_dims[i]:
# Doesn't provide any new info
continue
broadcast_shape[i] = other_dim_length
broadcast_shape[i] = other_shape[i]
broadcastable_dims[i] = False # Don't override again
return broadcast_shape

Expand Down
146 changes: 66 additions & 80 deletions pytensor/tensor/rewriting/subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import Dot, dot, minimum
from pytensor.tensor.rewriting.basic import (
broadcasted_by,
get_simplified_broadcast_shape,
register_canonicalize,
register_specialize,
register_stabilize,
Expand Down Expand Up @@ -77,13 +79,19 @@
from pytensor.tensor.variable import TensorConstant, TensorVariable


def _canonical_indexing(var, indices):
def _canonical_indexing(var, indices, drop_broadcasted_index=False):
"""Index ``var``, squeezing indexed broadcast dims whose index has size 1.

On a length-1 dim only zero is a valid index, so the index is
redundant — squeezing is equivalent and simpler. If squeezed
indices contributed unique output dimensions, those are reinserted
via ``expand_dims`` after indexing.

When ``drop_broadcasted_index`` is set, the index is neutralized on *every*
broadcast dim: a size>1 advanced index that would otherwise expand the dim is
dropped, and a slice that would otherwise shrink it (e.g. ``[1:]`` -> size 0)
is replaced by a full slice. This keeps ``var`` small at the cost of an
under-broadcast result; the caller must broadcast it back to the full shape.
"""
squeeze_axes = []
kept_indices = []
Expand All @@ -94,27 +102,35 @@ def _canonical_indexing(var, indices):
zip(var.type.broadcastable, indices, strict=False)
):
if isinstance(idx, slice):
kept_indices.append(idx)
if bcast and drop_broadcasted_index:
# Slicing a length-1 dim can shrink it (e.g. [1:] -> size 0);
# leave it untouched and let the caller broadcast it back.
kept_indices.append(slice(None))
else:
kept_indices.append(idx)
else:
if first_adv_axis is None:
first_adv_axis = axis

# np.ndim works for all supported cases: int, numpy arrays, pytensor variables
idx_ndim = np.ndim(idx)
if bcast:
match idx:
case Variable():
idx_size1 = all(idx.type.broadcastable)
case np.ndarray():
idx_size1 = idx.size == 1
case int() | np.integer():
idx_size1 = True
case _:
raise AssertionError
if drop_broadcasted_index:
drop_idx = True
else:
match idx:
case Variable():
drop_idx = all(idx.type.broadcastable)
case np.ndarray():
drop_idx = idx.size == 1
case int() | np.integer():
drop_idx = True
case _:
raise AssertionError

# idx only contributes dummy dimensions (if any), not actual shape
# It doesn't really matter what the index was, only valid values are zeros.
if idx_size1:
if drop_idx:
max_drop_ndim = max(max_drop_ndim, idx_ndim)
squeeze_axes.append(axis)
continue
Expand Down Expand Up @@ -304,8 +320,9 @@ def local_subtensor_of_batch_dims(fgraph, node):

Bail on boolean masks and non-consecutive advanced indexing — numpy hoists
those advanced groups to position 0, which would misalign the lifted
indices. On a broadcast (length-1) axis of an input, replace the advanced
index with length-1 zeros so the lifted input still broadcasts correctly.
indices. On a broadcast (length-1) axis of an input the index is dropped
(only zero is in bounds there), and an Alloc restores the full output shape
when a dropped index was what determined it.
"""
elem, *idx = node.inputs

Expand Down Expand Up @@ -369,45 +386,14 @@ def local_subtensor_of_batch_dims(fgraph, node):
return [new_elem]

elem_inputs = elem.owner.inputs
elem_bcast = elem.type.broadcastable[:batch_ndim]
if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs):
# No need to worry about implicit broadcasting.
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]

else:
# The original indices may not make sense on some of the broadcasted dimensions
new_idxs = [list(idx_tuple) for _ in elem_inputs]
for dim, (dim_idx, dim_bcast_out, *dim_bcast_inputs) in enumerate(
zip(
idx_tuple,
elem_bcast,
*(inp.type.broadcastable[:batch_ndim] for inp in elem_inputs),
# Indices can be shorter than input ndims
strict=False,
)
):
if isinstance(dim_idx, slice) and dim_idx == slice(None):
# Full slice can be safely applied to all inputs
continue

if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs):
# This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
continue

# Slices stay; advanced indices become length-1 zeros
# that _canonical_indexing will squeeze.
if isinstance(dim_idx, slice):
safe_bcast_dim_idx = slice(None)
else:
safe_bcast_dim_idx = np.zeros((1,) * dim_idx.type.ndim, dtype="int64")
for inp_idx, dim_bcast_inp in zip(new_idxs, dim_bcast_inputs, strict=True):
if dim_bcast_inp:
inp_idx[dim] = safe_bcast_dim_idx

indexed_inputs = [
_canonical_indexing(inp, tuple(new_idx))
for inp, new_idx in zip(elem_inputs, new_idxs, strict=True)
]
# Drop indices on broadcast input dims instead of applying them: an advanced
# index can't validly index a length-1 dim (only zero is in bounds) and would
# wastefully expand it. The Elemwise broadcasts the small inputs back together.
indexed_inputs = [
_canonical_indexing(inp, idx_tuple, drop_broadcasted_index=True)
for inp in elem_inputs
]

[old_out] = node.outputs

Expand All @@ -417,6 +403,16 @@ def local_subtensor_of_batch_dims(fgraph, node):
# Define elemwise operation on indexed inputs
new_out = elem.owner.op(*indexed_inputs)

# The indices dropped on broadcast dims may have been needed to determine the output shape
# We use an alloc to enforce the output shape.
if broadcasted_by(new_out, old_out):
batch_shape = get_simplified_broadcast_shape(
*elem_inputs, fgraph=fgraph, batch_ndim=batch_ndim
)
new_batch_shape = indexed_result_shape(batch_shape, idx_tuple)
core_shape = tuple(new_out.shape)[len(new_batch_shape) :]
new_out = alloc(new_out, *new_batch_shape, *core_shape)

# Copy stack trace to new output
copy_stack_trace([old_out, *node.inputs], new_out)

Expand Down Expand Up @@ -738,37 +734,27 @@ def lift_subtensor_through_alloc(fgraph, node):
if _non_consecutive_adv_indexing(indices):
return None

val_indexer: list = []
dangerous_index_reaches_val = False
for axis, idx in enumerate(indices):
if axis < n_added_dims:
# Axis was added by Alloc; index doesn't reach val.
continue
val_static_dim = val.type.shape[axis - n_added_dims]
if val_static_dim == 1:
# Broadcast val dim: slices stay (Alloc broadcasts on top);
# advanced indices become length-1 zeros for squeeze.
if isinstance(idx, slice):
val_indexer.append(slice(None))
else:
val_indexer.append(np.zeros((1,) * idx.type.ndim, dtype=np.int64))
continue
val_indexer.append(idx)
if not _index_provably_smaller(idx, val_static_dim):
# Per-axis check; doesn't account for net effect across all axes.
dangerous_index_reaches_val = True

nw_val = _canonical_indexing(val, val_indexer)
new_shape = indexed_result_shape(alloc_dims, indices)
drops_alloc = nw_val.type.broadcastable == node.outputs[0].type.broadcastable

if dangerous_index_reaches_val and not drops_alloc:
# Indices on Alloc-added dims don't reach val; the rest line up with val's dims.
val_indexer = indices[n_added_dims:]
dangerous_index_reaches_val = any(
not val.type.broadcastable[axis]
# Per-axis check; doesn't account for net effect across all axes.
and not _index_provably_smaller(idx, val.type.shape[axis])
for axis, idx in enumerate(val_indexer)
)

# On broadcast val dims the index is neutralized (advanced indices dropped,
# shrinking slices made full); the trailing Alloc broadcasts val back up.
nw_val = _canonical_indexing(val, val_indexer, drop_broadcasted_index=True)
needs_alloc = broadcasted_by(nw_val, node.outputs[0])

if dangerous_index_reaches_val and needs_alloc:
return None

if drops_alloc:
result = nw_val
if needs_alloc:
result = alloc(nw_val, *indexed_result_shape(alloc_dims, indices))
else:
result = alloc(nw_val, *new_shape)
result = nw_val

copy_stack_trace(node.outputs[0], result)
return [result]
Expand Down
74 changes: 70 additions & 4 deletions tests/tensor/rewriting/test_subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@


class TestLocalSubtensorOfBatchDims:
rewrite_kw = dict(
include=("ShapeOpt", "canonicalize", "specialize"),
exclude=("local_replace_AdvancedSubtensor",),
clone=True,
)

def test_unary_multiple_clients(self):
# as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor
Expand Down Expand Up @@ -151,7 +157,7 @@ def test_multinary_multiple_clients(self):
# Slice indexing on broadcastable dimension
(
lambda x, y: add(x[None], y[None])[1:],
lambda x, y: add(x[None][1:], y[None][1:]),
lambda x, y: pt.alloc(pt.add(x[None], y[None]), 0, *x.type.shape),
),
(
lambda x, y: add(x[None, :], y[:, None])[1:],
Expand All @@ -169,9 +175,7 @@ def test_elemwise(self, original_fn, expected_fn):
out = original_fn(x, y)
expected_opt_out = expected_fn(x, y)
opt_out = rewrite_graph(out)
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[expected_opt_out, opt_out], print_type=True
)
assert_equal_computations([opt_out], [expected_opt_out], strict_dtype=False)
eval_kwargs = dict(mode=NO_OPTIMIZATION_MODE, on_unused_input="ignore")
np.testing.assert_allclose(
opt_out.eval({x: x_test, y: y_test}, **eval_kwargs),
Expand Down Expand Up @@ -248,6 +252,68 @@ def perform(self, node, inputs, output_storage):
expected_out_sliced = block_test_op(x[2:, 0], y[2:])[:, 4:]
assert equal_computations([rewritten_out_sliced], [expected_out_sliced])

@pytest.mark.parametrize(
"idx", [np.zeros(20, dtype="int64"), slice(0, 5)], ids=["advanced", "slice"]
)
def test_stale_elemwise_output_type(self, idx):
"""When every input is length 1 on an indexed dim but elem's output type
stale-claims non-bcast, the lift still succeeds: the broadcast-back shape
is derived from the inputs (which are never stale), not from elem.type."""
x_input = pt.tensor("x_input", shape=(None, 3, 3), dtype="float64")
x_new_input = pt.tensor("x_new", shape=(1, 3, 3), dtype="float64")
x = pt.identity(x_input)
out = x * x
indexed = out[idx]
fgraph = FunctionGraph([x_input, x_new_input], [indexed], clone=False)

# Forge a stale state: inputs are broadcastable on dim 0, but elem
# output type is NOT. This happens naturally when upstream rewrites
# call fgraph.replace(x, x_new_input).
fgraph.replace(x, x_new_input)

# Confirm the state is genuinely stale: the Elemwise output type still
# claims dim 0 is non-broadcastable, while its (replaced) inputs are now
# length 1 there.
elem = indexed.owner.inputs[0]
assert not elem.type.broadcastable[0]
assert all(inp.type.broadcastable[0] for inp in elem.owner.inputs)

[new_out] = local_subtensor_of_batch_dims.transform(fgraph, indexed.owner)
# The lifted graph must be type-compatible with the (stale) original.
fgraph.replace(indexed, new_out)

rewritten = rewrite_graph(new_out, **self.rewrite_kw)
if isinstance(idx, slice):
# slice(0, 5) on a length-1 dim stays length 1, so no Alloc is needed.
expected = pt.sqr(x_new_input)
else:
expected = pt.alloc(pt.sqr(x_new_input), idx.shape[0], 3, 3)
assert_equal_computations([rewritten], [expected], strict_dtype=False)

def test_advanced_index_on_broadcast_dim_does_not_expand_inputs(self):
"""An advanced index on a dim that is length 1 in every input must not be
applied to those inputs (it would expand them 1->K and duplicate the
computation). They stay length 1 and the K-sized dim comes from one Alloc."""
x = pt.tensor("x", shape=(1, 3), dtype="float64")
y = pt.tensor("y", shape=(1, 3), dtype="float64")
out = (x + y)[np.zeros(5, dtype="int64")]

rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.alloc(x + y, 5, 3)
assert_equal_computations([rewritten], [expected], strict_dtype=False)

def test_broadcast_dim_does_not_block_lift(self):
"""An advanced index on a broadcast dim (which needs an Alloc back) must
not stop the lift: the shrinking index on the other, non-broadcast dim
still pushes the Elemwise inside. Here ``exp`` runs on a single row
instead of a million."""
x = pt.matrix("x", shape=(1_000_000, 1))
out = pt.exp(x)[0, np.array([0, 0, 0])]

rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.alloc(pt.exp(x[0]), 3)
assert_equal_computations([rewritten], [expected], strict_dtype=False)


def test_local_subtensor_of_dot():
m1 = matrix()
Expand Down
Loading