ENH: Support dynamic slice indexing in JAX backend via lax.dynamic_slice#1905
ENH: Support dynamic slice indexing in JAX backend via lax.dynamic_slice#1905ayulockedin wants to merge 1 commit intopymc-devs:mainfrom
Conversation
51bab99 to
5d14d9d
Compare
|
@ricardoV94 can i have your thoughts on this take a look at it when u have a moment |
| i = pt.iscalar("i") | ||
|
|
||
| # Test a dynamic start index with a statically provable length of 3 | ||
| out = x[i : i + 3] |
There was a problem hiding this comment.
this is not statically provable, because i may be less than 3 units away from the edge
There was a problem hiding this comment.
@ricardoV94 You are absolutely right. In pure NumPy semantics, the slice truncates near the edge, so the output length isn't strictly guaranteed.
However, XLA strictly requires static shapes and cannot compile dynamically sized outputs. To work around this, jax.lax.dynamic_slice intentionally deviates from NumPy by clipping out-of-bounds indices to guarantee the requested static size.
For the Scalable Online SSM project (and most sliding-window algorithms in pymc-extras), the models structurally guarantee that i stays within bounds, so we wouldn't actually hit this edge case in practice.
Since strictly enforcing NumPy's truncating semantics means dynamic slices with Tracers can never compile in JAX, would it be acceptable to adopt JAX's clipping behavior here as a pragmatic trade-off? If you would prefer a different architectural approach for handling dynamic sliding windows in JAX, I am completely open to pivoting.
There was a problem hiding this comment.
@ricardoV94 If we want to maintain strict NumPy compliance, I could investigate, but that might significantly complicate the JAX JIT graph. What do you think is the best path forward?
Description
This PR addresses the fundamental incompatibility between PyTensor's dynamic graph indexing and XLA's strict static compilation requirements, specifically unblocking time-series sliding windows and sequential loops (like those required for the Scalable Online SSM project in
pymc-extras).The Problem:
When compiling a dynamic slice like
x[i : i+3]whereiis a runtime variable, PyTensor's static shape inference evaluates the sliced dimension as(None,). Because JAX'slax.dynamic_slicestrictly requires a static integer for the slice size, the JAX dispatcher falls back to standard__getitem__indexing, which immediately crashes with aJitTracererror.The Solution:
Instead of adding new Ops or modifying core PyTensor shape inference, I implemented a static length prover (
_get_static_length) directly insidejax_funcify_Subtensor:idx_listcontains JAX Tracers.ScalarFromTensorandElemwiseoperations to algebraically deduce the constant difference betweenstartandstop(e.g., proving(i+3) - i == 3).jax.lax.dynamic_slice, preserving XLA compatibility.Addressing the Semantic Trade-off (@ricardoV94):
Regarding the out-of-bounds semantics (e.g.,
pt.zeros(5)[10:13]): In pure NumPy semantics, out-of-bounds slices shrink the array size. However, XLA mathematically cannot JIT-compile a function that returns dynamically sized arrays based on runtime variables.Because of this,
jax.lax.dynamic_sliceintentionally deviates from NumPy by clipping out-of-bounds indices so the returned slice always strictly matches the requested static size. We face a binary choice for the JAX backend here:Given the constraints of XLA, I believe adopting the clipping semantics specifically for the JAX dispatcher when Tracers are present is the necessary path forward.
Related Issue
Checklist
Type of change