Skip to content

GalacticDynamics/jax-bounded-while

Repository files navigation

jax-bounded-while

Bounded while loop in JAX.

PyPI version PyPI platforms Actions status

This is a micro-package, containing the single function bounded_while_loop.
Reverse-mode-friendly, bounded while_loop implemented via lax.scan.

Installation

pip install jax-bounded-while

Examples

Simple loop over a scalar:

import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(x):
    return x < 5


def body_fn(x):
    return x + 1


result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result)  # Array(5, dtype=int32)

PyTree carry (tuple):

import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(state):
    x, _ = state
    return x < 3


def body_fn(state):
    x, y = state
    return x + 1, y * 2


result = bounded_while_loop(
    cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result)  # (Array(3, dtype=int32), Array(8, dtype=int32))

About

A bounded (and autodiff friendly) while loop in JAX.

Resources

License

Contributing

Stars

Watchers

Forks

Languages