Reduce NonSequential JIT compile times via jax.lax.scan and optimizer state packing#5
Draft
Reduce NonSequential JIT compile times via jax.lax.scan and optimizer state packing#5
Conversation
…king - Add scan-based forward pass for chains of identical modules in NonSequentialModel - Pack optimizer states by dtype in training.py to reduce while_loop carry size - Add benchmark test for compile time (test_compile_time.py) Co-authored-by: pdcook <16090923+pdcook@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Optimize JAX compile times for NonSequential models
Reduce NonSequential JIT compile times via jax.lax.scan and optimizer state packing
Mar 3, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
JIT compile times for
NonSequential.train()scale poorly with model size due to Pythonforloops andjax.tree.maps over modules and optimizer states, each of which gets unrolled during tracing. For a 16-module chain ofAffineObservablePMMs (~25k params), compile time drops from ~9.5s to ~4.3s (~55% reduction).jax.lax.scanfor chains of identical modules (nonsequentialmodel.py)When all modules in execution order share the same type, param shapes, state shapes, and form a linear chain,
_get_callable()now returns a scan-based callable that traces the module body once instead of N times. Detection via_can_use_scan(); falls back to the original for-loop for arbitrary DAG topologies.Optimizer state packing by dtype (
training.py)Packs the N flat parameter arrays into 1–2 concatenated arrays (one per unique dtype). This reduces
lax.while_loopcarry size from4*Narrays (e.g. 192) to4*num_dtypes(e.g. 8), and collapses N individual Adam updates into 1–2. Non-trainable params handled viastop_gradientmasking in the loss function. Default identity callback skips unnecessary unpack/repack.Benchmark test (
tests/test_compile_time.py)Constructs a 16-module
NonSequentialModelwith >20k trainable floats and asserts compile time stays under 30s.✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.