This is a micro-package, containing the single function bounded_while_loop.
Reverse-mode-friendly, bounded while_loop implemented via lax.scan.
pip install jax-bounded-whileSimple loop over a scalar:
import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop
def cond_fn(x):
return x < 5
def body_fn(x):
return x + 1
result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result) # Array(5, dtype=int32)PyTree carry (tuple):
import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop
def cond_fn(state):
x, _ = state
return x < 3
def body_fn(state):
x, y = state
return x + 1, y * 2
result = bounded_while_loop(
cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result) # (Array(3, dtype=int32), Array(8, dtype=int32))