Make distributed V-arrays survive solve + simulate at production scale#364
Open
hmgaudecker wants to merge 16 commits into
Open
Make distributed V-arrays survive solve + simulate at production scale#364hmgaudecker wants to merge 16 commits into
hmgaudecker wants to merge 16 commits into
Conversation
`solve` lowers each unique `max_Q_over_a` with the regime's declared V-array sharding as `out_shardings`, so the compiled XLA program produces V already partitioned across the right devices. This matches the sharding the next-period consumer's `next_regime_to_V_arr` slot was lowered against, eliminating a runtime sharding mismatch on distributed-grid models without a post-hoc `device_put` reshard. Adds `test_solution_with_distributed_and_batched_grid` as a regression test for the case that surfaced the issue: a state grid with both `distributed=True` and `batch_size > 0`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Wraps the per-(regime, period) `validate_V` and `log_nan_in_V` calls
in `if validation_enabled(logger):` so `log_level="off"` does zero
V-array work — no NaN/Inf reductions, no bool host transfers, and
the jit-wrapped helpers never get traced/compiled.
`validate_V` and `log_nan_in_V` keep their original signatures
("do the work; caller chose to call"). `_simulate_regime_in_period`
gains a `logger` param so it can gate `validate_V` at its own call
site without breaking the fail-fast order (validate before
`calculate_next_states` consumes the V-array).
Adds `test_simulate_with_log_level_off_skips_nan_value_function_check`
exercising the off-mode skip on a NaN-bearing V.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Adds module-level `v_array_has_nan` / `v_array_has_inf` helpers decorated with `@jax.jit` and routes the solve diagnostic accumulator, `validate_V`, `log_nan_in_V`, and the post-loop NaN/Inf scan through them. With the reduction inside the XLA compiled graph, GSPMD partitions it across the V-array's devices (per-device any → all-reduce → mesh-replicated scalar) and `isnan`/`isinf` fuses into the `any` reduction in one pass. The eager-dispatch alternative `jnp.any(jnp.isnan(V_arr))` can fall through to gathering V onto the default device before reducing — a path that exhausts GPU memory on a sharded production-scale V-array. Adds two regression tests on 4 CPU devices verifying that both helpers return a fully replicated 0-d scalar on a `NamedSharding`- sharded input. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Documentation build overview
34 files changed ·
|
hmgaudecker
commented
May 23, 2026
| logger.info(" lowering ...") | ||
| start = time.monotonic() | ||
| lowered[func_id] = jax.jit(func).lower(**lower_args) | ||
| lowered[func_id] = jax.jit( |
Member
Author
There was a problem hiding this comment.
I would have never, ever found this one myself. Manifested itself as an OOM on the 80GB GPU...
The shock-integration productmap in `get_Q_and_F` and `get_compute_intermediates` was passing `dict.fromkeys(stochastic_variables, 0)` unconditionally, so any user-configured `batch_size>0` on the stochastic grid was silently dropped — the full cartesian product of stochastic shocks was materialized in one kernel. New `_get_stochastic_batch_sizes` helper reads `batch_size` from each stochastic variable's grid in `v_interpolation_info.discrete_states` (both `DiscreteGrid` and `_ContinuousStochasticProcess` expose the attribute), so configuring `batch_size>0` now chunks the inner shock-integration loop as documented. Unit test on the helper plus integration test that solves a stochastic model with `batch_size=0` vs `batch_size=1` and confirms identical V_arrs (correctness preserved across the chunked-vs-unchunked dispatch). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Benchmark comparison (main → HEAD)Comparing
|
The simulate-side `argmax_and_max_Q_over_a` builds a Q tensor whose intermediate broadcasts across `(n_subjects × actions × shocks × interpolation × next-regimes)` simultaneously. At production grid sizes this can reach tens of TB per device — XLA's HLO rematerialization can't bring it inside per-device HBM, and the simulate-side AOT compile fails to allocate. New `Model(subjects_batch_size=B)` knob (default `0` — current behaviour) chunks each device's local subject shard into Python-level batches of size `B`. `simulation_spacemap` swaps `vmap_1d` for a new `chunked_map_1d` helper built on `jax.lax.map(..., batch_size=B)` — within each chunk JAX vmaps; across chunks it scans. JAX handles non-divisible per-device remainders. Per-iteration intermediate shrinks by `n_per_device / B`, which lets production runs fit. `0` (default) keeps the existing single-vmap path bit-equivalent. Threaded through Model → build_regimes_and_template → process_regimes → _build_argmax_and_max_Q_over_a_per_period → simulation_spacemap. The chunked dispatch is float-tolerance equivalent to the vmapped baseline but not bit-exact: XLA reduction order changes with the chunk boundary, and the resulting 1-ULP V drift can flip individual `argmax` choices when two actions are nearly tied. This compounds through stochastic transitions across periods — same caveat as running on different hardware. Regression test asserts the period-0 V_arr invariant (initial state identical, so chunked vs vmapped Q evaluations agree within atol) across `subjects_batch_size=0` vs `subjects_batch_size=2`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The `subject_array_sharding` constraint required callers to size `n_subjects` as a multiple of `n_devices` when any grid was `distributed=True`. That coupling propagated into data pipelines — N_DRAWS_PER_INDIVIDUAL on the caller side was picked specifically to ensure divisibility against a hardware count the model layer shouldn't know about. `pad_initial_conditions_for_devices` now pads the leading axis of `initial_conditions` up to the next multiple of `n_devices`, duplicating the last subject. Pad rows pass validation automatically (the last real row already did), simulate through their assigned regime, and are trimmed by `trim_pad_from_raw_results` before `SimulationResult` is constructed — the user never sees them. `generate_simulation_keys` takes a new optional `original_n_subjects` argument so per-subject PRNG keys are split based on the user's real subject count, with pad slots replicating the last real subject's key. This keeps RNG draws for the real subjects device-count-invariant: a run with `n_subjects=N` produces the same outcomes on 1 device, 4 devices, or 5 devices (in the absence of XLA reduction-order drift). The AOT compile cache keys on the padded shape — two different user inputs that pad to the same dispatched shape share one compiled program — while the mismatch warning compares user-facing counts so the diagnostic still reads in user terms. The defensive `subject_array_sharding` divisibility check stays in place: any future code path that constructs a sharding without going through `pad_initial_conditions_for_devices` still fails loudly rather than misbehaving silently. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Benchmark comparison (main → HEAD)Comparing
|
…imulate
Two CPU-multi-device coverage tests for the AOT simulate path that the production stack hits but the existing distributed-test set didn't cover:
- test_aot_compiled_simulation_with_subjects_batch_size_on_distributed_grid: exercises Model(subjects_batch_size=B) (jax.lax.map chunked dispatch) combined with distributed grids and AOT compile via n_subjects. Same input contract as the single-vmap path; the assertion is that the per-subject sharding survives the chunked map.
- test_simulate_with_partial_distribution_accepts_sharded_inputs: simulates partially_distributed_model end-to-end so the cross-regime hand-off from distributed working_life to undistributed retirement actually fires (the existing fixture had only a solve test). Confirms the undistributed regime's compiled program accepts the sharded per-subject arrays produced by the distributed regime.
Both pass at HEAD on the 4-device CPU fixture, so they don't reproduce the multi-GPU sharding-mismatch ValueError seen on Marvin (compiled for SingleDeviceSharding(CudaDevice(0)), called with NamedSharding(Mesh('X':3))). They lock in the design intent so a future regression on this combo fails CI before it ships.
`subject_array_sharding` now operates over all regimes: when any grid in any regime is `distributed=True`, every regime's per-subject placeholder arrays lower against the same `NamedSharding`. The AOT-compiled simulate programs on either side of every regime transition then accept the inputs runtime hands them, instead of falling back to `SingleDeviceSharding` for regimes that declare no distributed grid of their own. A new `validate_sharding_consistency` runs at `Model(...)` and rejects models that declare a state name with disagreeing `distributed` flags across regimes — a setup that would otherwise compile two AOT programs expecting different shardings for the same per-subject input. New exception `ShardingConsistencyError` aggregates every conflicting state name into one report. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…id health probs. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Six coupled changes make
distributed=Truegrids work end-to-end through solve and simulate at production scale (Marvin 4×A100 aca-baseline). Stacks on #361.Solve emits V already-sharded. Each
max_Q_over_ais lowered with the regime's declared V-array sharding asout_shardings. XLA produces V partitioned across the right devices in one go, matching what the next-period consumer'snext_regime_to_V_arrslot was lowered against. Eliminates the runtime sharding-mismatch on grids that combinedistributed=Trueandbatch_size > 0. Regression:test_solution_with_distributed_and_batched_grid.NaN/Inf reductions stay sharded. Module-level
v_array_has_nan/v_array_has_infare jit-decorated. The solve diagnostic accumulator,validate_V,log_nan_in_V, and the post-loop scan all route through them. GSPMD partitions the reduction per-device → all-reduce → mesh-replicated scalar;isnan/isinffuses into theanyreduction. The eager alternativejnp.any(jnp.isnan(V_arr))could fall through to gathering V onto the default device — exhausts GPU memory at production scale. Regression: two checks on 4 CPU devices that both helpers return a fully replicated 0-d scalar on aNamedSharding-sharded input.log_level="off"does zero per-(regime, period) V work. Thevalidate_Vandlog_nan_in_Vcalls in_simulate_regime_in_periodand the period loop are gated behindvalidation_enabled(logger). Helpers keep their "do the work; caller chose to call" signatures._simulate_regime_in_periodcarries the logger so it gatesvalidate_Vbeforecalculate_next_statesconsumes V. Regression:test_simulate_with_log_level_off_skips_nan_value_function_check.Stochastic-grid
batch_sizeis honored._get_stochastic_batch_sizesreadsbatch_sizefrom each stochastic variable's grid (bothDiscreteGridand_ContinuousStochasticProcessexpose it) and passes the dict into the shock-integration productmap inget_Q_and_F/get_compute_intermediates. User-configuredbatch_size>0actually chunks the inner shock-integration loop. Regression: helper unit + solvebatch_size=0vsbatch_size=1identical V_arrs.Model(subjects_batch_size=B)chunks each device's subject shard. Simulate-sideargmax_and_max_Q_over_abuilds a Q intermediate that broadcasts across(n_subjects × actions × shocks × interpolation × next-regimes)simultaneously — tens of TB per device at production grid sizes; XLA HLO rematerialization can't bring it inside HBM. New knob (default0= single-vmap path, current behaviour) routes throughchunked_map_1dbuilt onjax.lax.map(..., batch_size=B): vmaps within a chunk, scans across chunks. JAX handles non-divisible per-device remainders. Per-iteration intermediate shrinks byn_per_device / B. ThreadedModel → process_regimes → _build_argmax_and_max_Q_over_a_per_period → simulation_spacemap. Float-tolerance equivalent to the baseline but not bit-exact: XLA reduction order changes with chunk boundary. Regression: period-0 V_arr invariant acrosssubjects_batch_size=0vs=2.n_subjectsno longer needs to be a multiple ofn_devices.pad_initial_conditions_for_devicespads the leading axis up to the next multiple ofn_devicesby duplicating the last subject; pad rows simulate through their assigned regime,trim_pad_from_raw_resultsslices them off beforeSimulationResult.generate_simulation_keystakesoriginal_n_subjectsso per-subject PRNG keys split from the user's real count — pad slots replicate the last real key — keeping RNG draws for real subjects device-count-invariant (same outcomes on 1, 4, or 5 devices, modulo XLA reduction-order drift). The AOT cache keys on the padded shape (different user inputs that pad alike share one compiled program); the mismatch warning compares user-facing counts. The defensivesubject_array_shardingdivisibility check stays in place as a tripwire for code paths that bypass the padder.Test plan
pixi run -e tests-cpu pytest tests -n 4— full suite greenpixi run -e tests-cpu pytest tests/test_distributed.py— passes on 4 CPU devicespixi run -e type-checking ty— cleanprek run --all-files— cleanlog_level="debug", distributed assets) reaches estimation green