Skip to content

Comments

ENH: Support dynamic slice indexing in JAX backend via lax.dynamic_slice#1905

Open
ayulockedin wants to merge 1 commit intopymc-devs:mainfrom
ayulockedin:subtensor-fix
Open

ENH: Support dynamic slice indexing in JAX backend via lax.dynamic_slice#1905
ayulockedin wants to merge 1 commit intopymc-devs:mainfrom
ayulockedin:subtensor-fix

Conversation

@ayulockedin
Copy link
Contributor

Description

This PR addresses the fundamental incompatibility between PyTensor's dynamic graph indexing and XLA's strict static compilation requirements, specifically unblocking time-series sliding windows and sequential loops (like those required for the Scalable Online SSM project in pymc-extras).

The Problem:
When compiling a dynamic slice like x[i : i+3] where i is a runtime variable, PyTensor's static shape inference evaluates the sliced dimension as (None,). Because JAX's lax.dynamic_slice strictly requires a static integer for the slice size, the JAX dispatcher falls back to standard __getitem__ indexing, which immediately crashes with a JitTracer error.

The Solution:
Instead of adding new Ops or modifying core PyTensor shape inference, I implemented a static length prover (_get_static_length) directly inside jax_funcify_Subtensor:

  1. It detects if the idx_list contains JAX Tracers.
  2. It traverses the symbolic PyTensor graph, unwrapping ScalarFromTensor and Elemwise operations to algebraically deduce the constant difference between start and stop (e.g., proving (i+3) - i == 3).
  3. If a static size is successfully proven, it feeds that integer directly to jax.lax.dynamic_slice, preserving XLA compatibility.
  4. It squeezes out any integer-indexed dimensions to maintain exact PyTensor dimensionality semantics.

Addressing the Semantic Trade-off (@ricardoV94):
Regarding the out-of-bounds semantics (e.g., pt.zeros(5)[10:13]): In pure NumPy semantics, out-of-bounds slices shrink the array size. However, XLA mathematically cannot JIT-compile a function that returns dynamically sized arrays based on runtime variables.

Because of this, jax.lax.dynamic_slice intentionally deviates from NumPy by clipping out-of-bounds indices so the returned slice always strictly matches the requested static size. We face a binary choice for the JAX backend here:

  1. We strictly enforce NumPy shrinking semantics, which means dynamic slicing with Tracers will never compile in JAX.
  2. We accept JAX's clipping semantics for out-of-bounds dynamic slices. This safely unlocks compilation for dynamic time-series algorithms where indices are structurally guaranteed to be in-bounds anyway.

Given the constraints of XLA, I believe adopting the clipping semantics specifically for the JAX dispatcher when Tracers are present is the necessary path forward.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ayulockedin
Copy link
Contributor Author

@ricardoV94 can i have your thoughts on this take a look at it when u have a moment

i = pt.iscalar("i")

# Test a dynamic start index with a statically provable length of 3
out = x[i : i + 3]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not statically provable, because i may be less than 3 units away from the edge

Copy link
Contributor Author

@ayulockedin ayulockedin Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 You are absolutely right. In pure NumPy semantics, the slice truncates near the edge, so the output length isn't strictly guaranteed.

However, XLA strictly requires static shapes and cannot compile dynamically sized outputs. To work around this, jax.lax.dynamic_slice intentionally deviates from NumPy by clipping out-of-bounds indices to guarantee the requested static size.

For the Scalable Online SSM project (and most sliding-window algorithms in pymc-extras), the models structurally guarantee that i stays within bounds, so we wouldn't actually hit this edge case in practice.

Since strictly enforcing NumPy's truncating semantics means dynamic slices with Tracers can never compile in JAX, would it be acceptable to adopt JAX's clipping behavior here as a pragmatic trade-off? If you would prefer a different architectural approach for handling dynamic sliding windows in JAX, I am completely open to pivoting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 If we want to maintain strict NumPy compliance, I could investigate, but that might significantly complicate the JAX JIT graph. What do you think is the best path forward?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants