Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 135 additions & 1 deletion pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import jax
import jax.numpy as jnp
from jax import lax

from pytensor.graph.basic import Constant
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
get_idx_list,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice
Expand All @@ -31,15 +38,142 @@
"""


def _get_static_length(start_var, stop_var):
"""
Analyzes the PyTensor graph to prove that (stop_var - start_var) is a constant.
This bypasses PyTensor's (None,) static shape inference.
"""

def extract_offset(var):
# Unwrap ScalarFromTensor to get to the actual math operations
if (
hasattr(var, "owner")
and var.owner
and var.owner.op.__class__.__name__ == "ScalarFromTensor"
):
var = var.owner.inputs[0]

# Check if the variable is a pure constant
if isinstance(var, Constant):
try:
return None, int(var.data.item())
except Exception:
pass

# Check if the variable is an operation like (base + offset) or (base - offset)
if hasattr(var, "owner") and var.owner and isinstance(var.owner.op, Elemwise):
scalar_op = getattr(var.owner.op, "scalar_op", None)
if scalar_op:
op_name = getattr(scalar_op, "name", "")

if op_name == "add":
c_in = [i for i in var.owner.inputs if isinstance(i, Constant)]
v_in = [i for i in var.owner.inputs if not isinstance(i, Constant)]
if len(c_in) == 1 and len(v_in) == 1:
return v_in[0], int(c_in[0].data.item())

elif op_name == "sub":
if isinstance(var.owner.inputs[1], Constant) and not isinstance(
var.owner.inputs[0], Constant
):
return var.owner.inputs[0], -int(
var.owner.inputs[1].data.item()
)

return var, 0

if start_var is None or stop_var is None:
return None

start_base, start_off = extract_offset(start_var)
stop_base, stop_off = extract_offset(stop_var)

# If both variables share the same dynamic base , the size is static
if start_base is not None and stop_base is not None and start_base == stop_base:
return stop_off - start_off

return None


@jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
out_shape = list(node.outputs[0].type.shape)
is_basic_subtensor = isinstance(op, Subtensor)

# Extract original PyTensor symbolic variables to deduce static slice lengths
pt_idx_list = list(get_idx_list(node.inputs, idx_list))

def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1:
idx_iter = indices if isinstance(indices, tuple) else (indices,)

has_tracer = False
for idx in idx_iter:
if isinstance(idx, jax.core.Tracer):
has_tracer = True
elif isinstance(idx, slice):
if isinstance(idx.start, jax.core.Tracer) or isinstance(
idx.stop, jax.core.Tracer
):
has_tracer = True

if has_tracer and is_basic_subtensor:
try:
start_indices = []
slice_sizes = []
squeeze_dims = []

out_dim_idx = 0
for i, (idx, pt_idx) in enumerate(zip(idx_iter, pt_idx_list)):
if isinstance(idx, slice):
if idx.step not in (None, 1):
raise ValueError(
"Dynamic slicing with step != 1 is not supported by JAX."
)

start = 0 if idx.start is None else idx.start

# Determine slice size
size = out_shape[out_dim_idx]
if size is None:
# Mathematical Prover
size = _get_static_length(pt_idx.start, pt_idx.stop)
if size is None:
raise ValueError(
"Could not prove static slice size for JAX lowering."
)

start_indices.append(start)
slice_sizes.append(size)
out_dim_idx += 1
else:
start_indices.append(idx)
slice_sizes.append(1)
squeeze_dims.append(i)

for i in range(len(start_indices), x.ndim):
start_indices.append(0)
size = out_shape[out_dim_idx]
if size is None:
# unlikely to hit but unless the trailing dimension is genuinely dynamic
size = x.shape[i]
slice_sizes.append(size)
out_dim_idx += 1

sliced = lax.dynamic_slice(x, start_indices, slice_sizes)

if squeeze_dims:
sliced = jnp.squeeze(sliced, axis=tuple(squeeze_dims))

return sliced
except Exception:
# If prover fails or assumptions break, fall back to standard indexing crash
pass

if len(indices) == 1 and isinstance(indices, tuple):
indices = indices[0]

return x.__getitem__(indices)
Expand Down
24 changes: 24 additions & 0 deletions tests/link/jax/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,30 @@ def test_arange_of_shape():
compare_jax_and_py([x], [out], [np.zeros((5,))], jax_mode="JAX")


def test_jax_dynamic_subtensor_slice():
import numpy as np

import pytensor
import pytensor.tensor as pt

x = pt.dvector("x")
i = pt.iscalar("i")

# Test a dynamic start index with a statically provable length of 3
out = x[i : i + 3]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not statically provable, because i may be less than 3 units away from the edge

Copy link
Contributor Author

@ayulockedin ayulockedin Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 You are absolutely right. In pure NumPy semantics, the slice truncates near the edge, so the output length isn't strictly guaranteed.

However, XLA strictly requires static shapes and cannot compile dynamically sized outputs. To work around this, jax.lax.dynamic_slice intentionally deviates from NumPy by clipping out-of-bounds indices to guarantee the requested static size.

For the Scalable Online SSM project (and most sliding-window algorithms in pymc-extras), the models structurally guarantee that i stays within bounds, so we wouldn't actually hit this edge case in practice.

Since strictly enforcing NumPy's truncating semantics means dynamic slices with Tracers can never compile in JAX, would it be acceptable to adopt JAX's clipping behavior here as a pragmatic trade-off? If you would prefer a different architectural approach for handling dynamic sliding windows in JAX, I am completely open to pivoting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 If we want to maintain strict NumPy compliance, I could investigate, but that might significantly complicate the JAX JIT graph. What do you think is the best path forward?


# If the JAX dispatcher fails to prove the static length, this will raise a JitTracer IndexError
f_jax = pytensor.function([x, i], out, mode="JAX")

x_val = np.arange(10, dtype=np.float64)
i_val = 2

result = f_jax(x_val, i_val)
expected = x_val[i_val : i_val + 3]

assert np.array_equal(result, expected)


def test_arange_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""

Expand Down