Skip to content

Reduce NonSequential JIT compile times via jax.lax.scan and optimizer state packing#5

Draft
Copilot wants to merge 2 commits intomainfrom
copilot/reduce-nonsequential-compile-time
Draft

Reduce NonSequential JIT compile times via jax.lax.scan and optimizer state packing#5
Copilot wants to merge 2 commits intomainfrom
copilot/reduce-nonsequential-compile-time

Conversation

Copy link

Copilot AI commented Mar 3, 2026

JIT compile times for NonSequential.train() scale poorly with model size due to Python for loops and jax.tree.maps over modules and optimizer states, each of which gets unrolled during tracing. For a 16-module chain of AffineObservablePMMs (~25k params), compile time drops from ~9.5s to ~4.3s (~55% reduction).

jax.lax.scan for chains of identical modules (nonsequentialmodel.py)

When all modules in execution order share the same type, param shapes, state shapes, and form a linear chain, _get_callable() now returns a scan-based callable that traces the module body once instead of N times. Detection via _can_use_scan(); falls back to the original for-loop for arbitrary DAG topologies.

# Instead of tracing 16 separate module calls:
for mod_path, module_callable, ... in zip(execution_order, ...):
    out, state = module_callable(params, data, ...)

# Traces one body, applies via scan:
final_data, new_states = jax.lax.scan(scan_body, data, (stacked_params, stacked_states))

Optimizer state packing by dtype (training.py)

Packs the N flat parameter arrays into 1–2 concatenated arrays (one per unique dtype). This reduces lax.while_loop carry size from 4*N arrays (e.g. 192) to 4*num_dtypes (e.g. 8), and collapses N individual Adam updates into 1–2. Non-trainable params handled via stop_gradient masking in the loss function. Default identity callback skips unnecessary unpack/repack.

Benchmark test (tests/test_compile_time.py)

Constructs a 16-module NonSequentialModel with >20k trainable floats and asserts compile time stays under 30s.


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

…king

- Add scan-based forward pass for chains of identical modules in NonSequentialModel
- Pack optimizer states by dtype in training.py to reduce while_loop carry size
- Add benchmark test for compile time (test_compile_time.py)

Co-authored-by: pdcook <16090923+pdcook@users.noreply.github.com>
Copilot AI changed the title [WIP] Optimize JAX compile times for NonSequential models Reduce NonSequential JIT compile times via jax.lax.scan and optimizer state packing Mar 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants