Skip to content

Make distributed V-arrays survive solve + simulate at production scale#364

Open
hmgaudecker wants to merge 16 commits into
refactor/phase-2-api-reorganisationfrom
feat/distributed-V-arrays
Open

Make distributed V-arrays survive solve + simulate at production scale#364
hmgaudecker wants to merge 16 commits into
refactor/phase-2-api-reorganisationfrom
feat/distributed-V-arrays

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented May 23, 2026

Summary

Six coupled changes make distributed=True grids work end-to-end through solve and simulate at production scale (Marvin 4×A100 aca-baseline). Stacks on #361.

Solve emits V already-sharded. Each max_Q_over_a is lowered with the regime's declared V-array sharding as out_shardings. XLA produces V partitioned across the right devices in one go, matching what the next-period consumer's next_regime_to_V_arr slot was lowered against. Eliminates the runtime sharding-mismatch on grids that combine distributed=True and batch_size > 0. Regression: test_solution_with_distributed_and_batched_grid.

NaN/Inf reductions stay sharded. Module-level v_array_has_nan / v_array_has_inf are jit-decorated. The solve diagnostic accumulator, validate_V, log_nan_in_V, and the post-loop scan all route through them. GSPMD partitions the reduction per-device → all-reduce → mesh-replicated scalar; isnan/isinf fuses into the any reduction. The eager alternative jnp.any(jnp.isnan(V_arr)) could fall through to gathering V onto the default device — exhausts GPU memory at production scale. Regression: two checks on 4 CPU devices that both helpers return a fully replicated 0-d scalar on a NamedSharding-sharded input.

log_level="off" does zero per-(regime, period) V work. The validate_V and log_nan_in_V calls in _simulate_regime_in_period and the period loop are gated behind validation_enabled(logger). Helpers keep their "do the work; caller chose to call" signatures. _simulate_regime_in_period carries the logger so it gates validate_V before calculate_next_states consumes V. Regression: test_simulate_with_log_level_off_skips_nan_value_function_check.

Stochastic-grid batch_size is honored. _get_stochastic_batch_sizes reads batch_size from each stochastic variable's grid (both DiscreteGrid and _ContinuousStochasticProcess expose it) and passes the dict into the shock-integration productmap in get_Q_and_F / get_compute_intermediates. User-configured batch_size>0 actually chunks the inner shock-integration loop. Regression: helper unit + solve batch_size=0 vs batch_size=1 identical V_arrs.

Model(subjects_batch_size=B) chunks each device's subject shard. Simulate-side argmax_and_max_Q_over_a builds a Q intermediate that broadcasts across (n_subjects × actions × shocks × interpolation × next-regimes) simultaneously — tens of TB per device at production grid sizes; XLA HLO rematerialization can't bring it inside HBM. New knob (default 0 = single-vmap path, current behaviour) routes through chunked_map_1d built on jax.lax.map(..., batch_size=B): vmaps within a chunk, scans across chunks. JAX handles non-divisible per-device remainders. Per-iteration intermediate shrinks by n_per_device / B. Threaded Model → process_regimes → _build_argmax_and_max_Q_over_a_per_period → simulation_spacemap. Float-tolerance equivalent to the baseline but not bit-exact: XLA reduction order changes with chunk boundary. Regression: period-0 V_arr invariant across subjects_batch_size=0 vs =2.

n_subjects no longer needs to be a multiple of n_devices. pad_initial_conditions_for_devices pads the leading axis up to the next multiple of n_devices by duplicating the last subject; pad rows simulate through their assigned regime, trim_pad_from_raw_results slices them off before SimulationResult. generate_simulation_keys takes original_n_subjects so per-subject PRNG keys split from the user's real count — pad slots replicate the last real key — keeping RNG draws for real subjects device-count-invariant (same outcomes on 1, 4, or 5 devices, modulo XLA reduction-order drift). The AOT cache keys on the padded shape (different user inputs that pad alike share one compiled program); the mismatch warning compares user-facing counts. The defensive subject_array_sharding divisibility check stays in place as a tripwire for code paths that bypass the padder.

Test plan

  • pixi run -e tests-cpu pytest tests -n 4 — full suite green
  • pixi run -e tests-cpu pytest tests/test_distributed.py — passes on 4 CPU devices
  • pixi run -e type-checking ty — clean
  • prek run --all-files — clean
  • Marvin sim run (aca-estimation, 4×A100, log_level="debug", distributed assets) reaches estimation green

hmgaudecker and others added 3 commits May 23, 2026 06:29
`solve` lowers each unique `max_Q_over_a` with the regime's declared
V-array sharding as `out_shardings`, so the compiled XLA program
produces V already partitioned across the right devices. This matches
the sharding the next-period consumer's `next_regime_to_V_arr` slot
was lowered against, eliminating a runtime sharding mismatch on
distributed-grid models without a post-hoc `device_put` reshard.

Adds `test_solution_with_distributed_and_batched_grid` as a regression
test for the case that surfaced the issue: a state grid with both
`distributed=True` and `batch_size > 0`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Wraps the per-(regime, period) `validate_V` and `log_nan_in_V` calls
in `if validation_enabled(logger):` so `log_level="off"` does zero
V-array work — no NaN/Inf reductions, no bool host transfers, and
the jit-wrapped helpers never get traced/compiled.

`validate_V` and `log_nan_in_V` keep their original signatures
("do the work; caller chose to call"). `_simulate_regime_in_period`
gains a `logger` param so it can gate `validate_V` at its own call
site without breaking the fail-fast order (validate before
`calculate_next_states` consumes the V-array).

Adds `test_simulate_with_log_level_off_skips_nan_value_function_check`
exercising the off-mode skip on a NaN-bearing V.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Adds module-level `v_array_has_nan` / `v_array_has_inf` helpers
decorated with `@jax.jit` and routes the solve diagnostic accumulator,
`validate_V`, `log_nan_in_V`, and the post-loop NaN/Inf scan through
them. With the reduction inside the XLA compiled graph, GSPMD
partitions it across the V-array's devices (per-device any →
all-reduce → mesh-replicated scalar) and `isnan`/`isinf` fuses into
the `any` reduction in one pass. The eager-dispatch alternative
`jnp.any(jnp.isnan(V_arr))` can fall through to gathering V onto the
default device before reducing — a path that exhausts GPU memory on
a sharded production-scale V-array.

Adds two regression tests on 4 CPU devices verifying that both
helpers return a fully replicated 0-d scalar on a `NamedSharding`-
sharded input.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 23, 2026

Copy link
Copy Markdown
Member Author

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Autoreview.

@hmgaudecker hmgaudecker requested a review from mj023 May 23, 2026 05:07
logger.info(" lowering ...")
start = time.monotonic()
lowered[func_id] = jax.jit(func).lower(**lower_args)
lowered[func_id] = jax.jit(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have never, ever found this one myself. Manifested itself as an OOM on the 80GB GPU...

hmgaudecker and others added 4 commits May 23, 2026 08:22
The shock-integration productmap in `get_Q_and_F` and
`get_compute_intermediates` was passing `dict.fromkeys(stochastic_variables, 0)`
unconditionally, so any user-configured `batch_size>0` on the stochastic
grid was silently dropped — the full cartesian product of stochastic shocks
was materialized in one kernel.

New `_get_stochastic_batch_sizes` helper reads `batch_size` from each
stochastic variable's grid in `v_interpolation_info.discrete_states` (both
`DiscreteGrid` and `_ContinuousStochasticProcess` expose the attribute),
so configuring `batch_size>0` now chunks the inner shock-integration loop
as documented.

Unit test on the helper plus integration test that solves a stochastic
model with `batch_size=0` vs `batch_size=1` and confirms identical V_arrs
(correctness preserved across the chunked-vs-unchunked dispatch).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 23, 2026

Benchmark comparison (main → HEAD)

Comparing 629ac442 (main) → ef557810 (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 26.281 s 14.521 s 0.55
peak GPU mem 639 MB 3.79 GB 5.93
compilation time 297.53 s 280.61 s 0.94
peak CPU mem 7.40 GB 7.49 GB 1.01
aca-baseline-debug execution time 75.059 s
peak GPU mem 581 MB
compilation time 378.53 s
peak CPU mem 7.53 GB
Mahler-Yum execution time 4.679 s 4.223 s 0.90
peak GPU mem 529 MB 529 MB 1.00
compilation time 14.15 s 12.32 s 0.87
peak CPU mem 1.71 GB 1.67 GB 0.98
Precautionary Savings - Solve execution time 46.3 ms 26.6 ms 0.57
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.66 s 2.08 s 0.78
peak CPU mem 1.14 GB 1.11 GB 0.98
Precautionary Savings - Simulate execution time 120.2 ms 91.1 ms 0.76
peak GPU mem 349 MB 349 MB 1.00
compilation time 5.01 s 4.36 s 0.87
peak CPU mem 1.34 GB 1.30 GB 0.97
Precautionary Savings - Solve & Simulate execution time 159.7 ms 124.0 ms 0.78
peak GPU mem 586 MB 586 MB 1.00
compilation time 7.22 s 5.86 s 0.81
peak CPU mem 1.30 GB 1.27 GB 0.98
Precautionary Savings - Solve & Simulate (irreg) execution time 288.0 ms 220.4 ms 0.77
peak GPU mem 2.20 GB 2.20 GB 1.00
compilation time 7.53 s 6.14 s 0.82
peak CPU mem 1.36 GB 1.33 GB 0.98

hmgaudecker and others added 2 commits May 23, 2026 15:44
The simulate-side `argmax_and_max_Q_over_a` builds a Q tensor whose
intermediate broadcasts across `(n_subjects × actions × shocks ×
interpolation × next-regimes)` simultaneously. At production grid
sizes this can reach tens of TB per device — XLA's HLO
rematerialization can't bring it inside per-device HBM, and the
simulate-side AOT compile fails to allocate.

New `Model(subjects_batch_size=B)` knob (default `0` — current
behaviour) chunks each device's local subject shard into Python-level
batches of size `B`. `simulation_spacemap` swaps `vmap_1d` for a new
`chunked_map_1d` helper built on `jax.lax.map(..., batch_size=B)` —
within each chunk JAX vmaps; across chunks it scans. JAX handles
non-divisible per-device remainders.

Per-iteration intermediate shrinks by `n_per_device / B`, which lets
production runs fit. `0` (default) keeps the existing single-vmap
path bit-equivalent.

Threaded through Model → build_regimes_and_template → process_regimes
→ _build_argmax_and_max_Q_over_a_per_period → simulation_spacemap.

The chunked dispatch is float-tolerance equivalent to the vmapped
baseline but not bit-exact: XLA reduction order changes with the
chunk boundary, and the resulting 1-ULP V drift can flip individual
`argmax` choices when two actions are nearly tied. This compounds
through stochastic transitions across periods — same caveat as
running on different hardware.

Regression test asserts the period-0 V_arr invariant (initial state
identical, so chunked vs vmapped Q evaluations agree within atol)
across `subjects_batch_size=0` vs `subjects_batch_size=2`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The `subject_array_sharding` constraint required callers to size
`n_subjects` as a multiple of `n_devices` when any grid was
`distributed=True`. That coupling propagated into data pipelines —
N_DRAWS_PER_INDIVIDUAL on the caller side was picked specifically to
ensure divisibility against a hardware count the model layer
shouldn't know about.

`pad_initial_conditions_for_devices` now pads the leading axis of
`initial_conditions` up to the next multiple of `n_devices`,
duplicating the last subject. Pad rows pass validation automatically
(the last real row already did), simulate through their assigned
regime, and are trimmed by `trim_pad_from_raw_results` before
`SimulationResult` is constructed — the user never sees them.

`generate_simulation_keys` takes a new optional `original_n_subjects`
argument so per-subject PRNG keys are split based on the user's real
subject count, with pad slots replicating the last real subject's key.
This keeps RNG draws for the real subjects device-count-invariant: a
run with `n_subjects=N` produces the same outcomes on 1 device, 4
devices, or 5 devices (in the absence of XLA reduction-order drift).

The AOT compile cache keys on the padded shape — two different user
inputs that pad to the same dispatched shape share one compiled
program — while the mismatch warning compares user-facing counts so
the diagnostic still reads in user terms.

The defensive `subject_array_sharding` divisibility check stays in
place: any future code path that constructs a sharding without going
through `pad_initial_conditions_for_devices` still fails loudly
rather than misbehaving silently.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

Benchmark comparison (main → HEAD)

Comparing 629ac442 (main) → 073c6974 (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 26.281 s 14.967 s 0.57
peak GPU mem 639 MB 581 MB 0.91
compilation time 297.53 s 281.70 s 0.95
peak CPU mem 7.40 GB 7.49 GB 1.01
aca-baseline-debug execution time 77.575 s
peak GPU mem 614 MB
compilation time 370.03 s
peak CPU mem 7.52 GB
Mahler-Yum execution time 4.679 s 4.222 s 0.90
peak GPU mem 529 MB 529 MB 1.00
compilation time 14.15 s 12.54 s 0.89
peak CPU mem 1.71 GB 1.67 GB 0.98
Precautionary Savings - Solve execution time 46.3 ms 27.5 ms 0.59
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.66 s 2.11 s 0.79
peak CPU mem 1.14 GB 1.12 GB 0.99
Precautionary Savings - Simulate execution time 120.2 ms 91.6 ms 0.76
peak GPU mem 349 MB 349 MB 1.00
compilation time 5.01 s 4.34 s 0.87
peak CPU mem 1.34 GB 1.30 GB 0.98
Precautionary Savings - Solve & Simulate execution time 159.7 ms 124.5 ms 0.78
peak GPU mem 586 MB 586 MB 1.00
compilation time 7.22 s 5.97 s 0.83
peak CPU mem 1.30 GB 1.27 GB 0.97
Precautionary Savings - Solve & Simulate (irreg) execution time 288.0 ms 228.9 ms 0.79
peak GPU mem 2.20 GB 2.20 GB 1.00
compilation time 7.53 s 6.22 s 0.83
peak CPU mem 1.36 GB 1.32 GB 0.98

hmgaudecker and others added 5 commits May 24, 2026 08:54
…imulate

Two CPU-multi-device coverage tests for the AOT simulate path that the production stack hits but the existing distributed-test set didn't cover:

- test_aot_compiled_simulation_with_subjects_batch_size_on_distributed_grid: exercises Model(subjects_batch_size=B) (jax.lax.map chunked dispatch) combined with distributed grids and AOT compile via n_subjects. Same input contract as the single-vmap path; the assertion is that the per-subject sharding survives the chunked map.

- test_simulate_with_partial_distribution_accepts_sharded_inputs: simulates partially_distributed_model end-to-end so the cross-regime hand-off from distributed working_life to undistributed retirement actually fires (the existing fixture had only a solve test). Confirms the undistributed regime's compiled program accepts the sharded per-subject arrays produced by the distributed regime.

Both pass at HEAD on the 4-device CPU fixture, so they don't reproduce the multi-GPU sharding-mismatch ValueError seen on Marvin (compiled for SingleDeviceSharding(CudaDevice(0)), called with NamedSharding(Mesh('X':3))). They lock in the design intent so a future regression on this combo fails CI before it ships.
`subject_array_sharding` now operates over all regimes: when any grid
in any regime is `distributed=True`, every regime's per-subject
placeholder arrays lower against the same `NamedSharding`. The
AOT-compiled simulate programs on either side of every regime
transition then accept the inputs runtime hands them, instead of
falling back to `SingleDeviceSharding` for regimes that declare no
distributed grid of their own.

A new `validate_sharding_consistency` runs at `Model(...)` and rejects
models that declare a state name with disagreeing `distributed` flags
across regimes — a setup that would otherwise compile two AOT programs
expecting different shardings for the same per-subject input. New
exception `ShardingConsistencyError` aggregates every conflicting
state name into one report.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…id health probs.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant