Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
680bb1e
Merge branch 'main' of https://github.com/CDCgov/PyRenew
cdc-mitzimorris Sep 15, 2025
2cb876b
update
cdc-mitzimorris Sep 18, 2025
60db8df
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Sep 22, 2025
32a5314
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 5, 2025
d6213f2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 8, 2025
96f27c9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 17, 2025
1cb6fa2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 24, 2025
f62e1e4
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 4, 2025
0c6785d
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 22, 2025
1ee62b9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Jan 29, 2026
0629461
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 4, 2026
efeadee
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
371ba98
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
0304bed
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 6, 2026
ffeea65
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
50e7261
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
dae6af8
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 10, 2026
5cb3097
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 11, 2026
1d80ccc
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 11, 2026
e73b401
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 12, 2026
b1473b5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 18, 2026
0b929b5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 18, 2026
3ee00a7
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 24, 2026
307982a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 24, 2026
b862bc6
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 26, 2026
2c665a5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 11, 2026
60d6458
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 12, 2026
ec8c464
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 19, 2026
c018bf7
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 24, 2026
d0207dd
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 4, 2026
f3c706a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 9, 2026
684c6c5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 10, 2026
ca2454f
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 13, 2026
0f38afc
merge
cdc-mitzimorris Apr 14, 2026
d8e7a57
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 16, 2026
7e9b5fe
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
e1d8014
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
83ddbf0
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
69ea4ea
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
555e87b
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 28, 2026
fa5a7cb
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 4, 2026
69cdab0
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 6, 2026
c28a89f
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 6, 2026
fd091ca
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 7, 2026
8cee471
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 7, 2026
b2a1e1a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 11, 2026
2006afd
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 12, 2026
a31ec85
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 13, 2026
0db35a9
implementation and unit tests for centered versions of temporal proce…
cdc-mitzimorris May 19, 2026
a92d58b
checkpointing - test cleanup
cdc-mitzimorris May 19, 2026
9911f4a
checkpointing
cdc-mitzimorris May 19, 2026
7795672
fix unit test
cdc-mitzimorris May 19, 2026
b7030a3
benchmark test suite
cdc-mitzimorris May 19, 2026
c0d4684
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2026
fe81470
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris May 19, 2026
ee0c276
lint fix
cdc-mitzimorris May 19, 2026
1e99920
more unit tests
cdc-mitzimorris May 19, 2026
c8a8764
checkpointing
cdc-mitzimorris May 26, 2026
0c852ab
refactoring benchmarks
cdc-mitzimorris May 26, 2026
c67fe92
Day-of-week effects applied on observation time axis (i.e., not befor…
cdc-mitzimorris May 19, 2026
197d9da
simplify benchmarks
cdc-mitzimorris May 27, 2026
3fe2f2b
remove dependency on R forecasttools package by substituting local po…
cdc-mitzimorris May 27, 2026
fec5f4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
f910b2b
remove dependency on R forecasttools package by substituting local po…
cdc-mitzimorris May 27, 2026
a3e34ab
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris May 27, 2026
7d9619a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
360e8f0
more informative benchmark outputs
cdc-mitzimorris May 27, 2026
ab623e6
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris May 27, 2026
07f5bb6
checkpointing
cdc-mitzimorris May 27, 2026
72a4f19
fixing real data loading
cdc-mitzimorris May 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
# !your_data_file.csv
# !your_data_directory/

# Benchmark outputs
benchmarks/results/


#####
# Python
Expand Down
6 changes: 6 additions & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@
arange = "arange"
lod = "lod"
dows = "dows"
ND = "ND"

[default.extend-identifiers]
# NumPyro's Distribution base class spells this with a typo; we must
# match the upstream attribute name for `has_rsample` to work correctly.
reparametrized_params = "reparametrized_params"
188 changes: 188 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# PyRenew benchmarks

Opt-in MCMC performance experiments.
The suite is a CLI entry point under `benchmarks/suites/`.
Run from the repository root.

Benchmarks are not part of CI.
Use `test/` for correctness checks and this suite for sampler comparisons.

## Layout

```
benchmarks/
├── core/
│ ├── signals.py SignalSeries, DatasetBundle, DatasetProvider
│ ├── datasets.py SyntheticProvider over pyrenew/datasets/
│ ├── real_data.py RealDataProvider over CDC NHSN + NSSP feeds
│ ├── reference_data.py Static location names and populations
│ ├── priors.py benchmark-local priors for real-data builds
│ ├── models.py H+E model builder (weekly hospital + daily ED)
│ ├── runner.py fit_and_measure and ArviZ-free FitMetrics computation
│ └── reporting.py stdout tables and CSV / JSON / Markdown writers
├── suites/
│ └── rt_params.py centered vs non-centered weekly Rt parameterization
├── diagnose.py single-fit diagnostic harness
└── results/ output (gitignored)
```

The suite asks the dataset provider for the H+E bundle, builds the model under each parameterization, and the runner fits the model and collects metrics.
The `DatasetProvider` protocol in `core/signals.py` is the seam where real reporting inputs replace `SyntheticProvider` without touching the suite.

## rt_params suite

Compares the `innovation` (non-centered, NCP) and `state` (centered, CP) parameterizations of the inner `DifferencedAR1` weekly $\mathcal{R}(t)$ process, on the H+E model: weekly-aggregated hospital admissions plus daily ED visits.
Each fit uses one parameterization; the suite always runs both so the matched pair can be compared.

### Run

```bash
python -m benchmarks.suites.rt_params --quick
```

`--quick` overrides the sampler to 50 warmup, 50 samples, 1 chain.
Drop it for a full run.

```bash
python -m benchmarks.suites.rt_params --prior both --repeats 3
```

Useful options:

| Option | Effect |
| ----------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `--data-source` | `synthetic` (built-in fixtures) or `real` (CDC-internal NHSN/NSSP feeds; requires `cfa-stf-routine-forecasting` access and `--as-of`). |
| `--disease <name>` | Disease for `--data-source real`: `COVID-19`, `Influenza`, or `RSV`. |
| `--location <abbr>` | Location abbreviation for `--data-source real`, e.g. `US` or `CA`. |
| `--as-of YYYY-MM-DD` | Vintage date for `--data-source real`. Required for real data. |
| `--training-days N` | Training window length for `--data-source real`. Default: 150. |
| `--omit-last-days N` | Trailing days omitted from `--data-source real` to buffer right truncation. Default: 2. |
| `--dry-run-data` | Load and summarize selected data, then exit before model fitting. Useful for checking real-data access and signal noise. |
| `--prior <kind>` | `tight` (sd=0.01, autoreg=0.9), `loose` (sd=0.10, autoreg=0.5), `both`, or an explicit `sd,autoreg` pair (e.g. `0.05,0.7`). Repeatable. Default: `tight`. |
| `--repeats N` | Refit each cell `N` times with `seed + i` to estimate sampler noise. |
| `--num-warmup`, `--num-samples`, `--num-chains` | NUTS controls. `--num-chains` defaults to `min(4, os.cpu_count())`. |
| `--seed` | Base seed (default 42). |
| `--output-dir` | Where to write artifacts. Default `benchmarks/results/`. |
| `--no-write` | Skip artifact files; print summary only. |

On import, the suite sets `XLA_FLAGS=--xla_force_host_platform_device_count=N` (where `N = min(8, os.cpu_count())`) so JAX exposes enough logical devices for parallel chains, and `JAX_ENABLE_X64=true`.
If you set either variable yourself before invocation, it is honored.
x64 is required: in float32 the renewal recursion loses precision and NUTS diverges (a full chain diverged at 500/500/4 in float32, none under x64).

### Real data on CDC infrastructure

Real-data mode is intended for CDC environments that can import `cfa-stf-routine-forecasting` and access the internal feeds used by `cfa.stf.data`.
PyRenew does not depend on those internal packages for normal use; the `cfa.stf.*` imports happen only when `--data-source real` loads a bundle.

Start with a data-only dry run:

```bash
python -m benchmarks.suites.rt_params \
--data-source real \
--disease RSV \
--location US \
--as-of 2025-01-15 \
--training-days 150 \
--omit-last-days 2 \
--dry-run-data
```

This fetches NHSN weekly hospital admissions and NSSP daily ED visits, prints date ranges, missingness, and basic count summaries, then exits before model building or MCMC.

Then run a smoke benchmark:

```bash
python -m benchmarks.suites.rt_params \
--data-source real \
--disease RSV \
--location US \
--as-of 2025-01-15 \
--training-days 150 \
--omit-last-days 2 \
--quick
```

The H+E real-data builder uses benchmark-local priors (`core/priors.py`) mirroring the production prior subset needed for initial infections and ED day-of-week effects.
Location metadata and population totals are static benchmark inputs in `core/reference_data.py`.
Generation interval and infection-to-observation delay PMFs are pulled from the CDC NNH parameter catalog through `cfa.stf.data`, so they remain disease-specific and vintage-aware.
Real-data mode currently does not apply ED right truncation PMFs; use `--omit-last-days` to leave a reporting buffer.

### Output files

Written to `--output-dir` with prefix `rt_params_`:

| File | Contents |
| -------------------------- | ---------------------------------------------------------------------------------------------------------------- |
| `rt_params_runs.csv` | One row per fit, with full config and metrics. |
| `rt_params_candidates.csv` | One row per parameterization, averaged over repeats. |
| `rt_params_pairs.csv` | One row per matched state-vs-innovation pair, with `<metric>_innov`, `<metric>_state`, `<metric>_ratio` columns. |
| `rt_params_parameters.csv` | One row per scalar posterior site element per fit, with posterior mean, ESS, and R-hat. |
| `rt_params_runs.json` | All of the above plus a header (suite name, x64 flag, timestamp). |
| `rt_params_report.md` | Compact Markdown report (per-parameterization table and pairwise table). |

Column convention: `_innov` and `_state` carry the per-side values, and `_ratio` columns are state-benefit ratios.
For higher-is-better metrics such as ESS-per-second, `_ratio` is `state / innovation`.
For lower-is-better metrics such as wall time, `_ratio` is `innovation / state`.
In all cases, `_ratio > 1` favors the state parameterization.

### Reading the metrics

Per fit:

- **Wall time**: total seconds for warmup + sampling, after JIT, with `jax.block_until_ready` so the work is fully complete.
- **ESS/s Rt (median / min)**: effective samples per wall-second on the Rt trajectory.
Median summarizes typical timepoints; min identifies the worst-mixing timepoint that limits downstream inference.
- **Divergences**: total NUTS divergences across all chains and draws.
A saturated tree depth can mask divergences; read with tree depth.
- **Tree depth (mean / max)**: log2 of NUTS leapfrog steps.
NumPyro defaults to `max_tree_depth=10`.
A mean near the ceiling indicates the sampler is running out of budget per draw.
- **E-BFMI (min)**: minimum across chains of the energy Bayesian fraction of missing information.
Heuristic thresholds: >=0.3 acceptable, <0.3 warning, <0.1 strong pathology indicator.
- **R-hat Rt (max)**: max split R-hat across timepoints of the Rt trajectory.
Requires more than one chain.

### Suite design

The suite varies two axes:

1. **Parameterization**: `innovation` (non-centered) and `state` (centered) modes of the inner `DifferencedAR1`.
2. **Prior regime**: tight $(\sigma = 0.01, \phi = 0.9)$ or loose $(\sigma = 0.10, \phi = 0.5)$, where $\sigma$ is the weekly per-step innovation SD and $\phi$ the autoregressive coefficient.
The cumulative variance of $\log \mathcal{R}(T)$ is far more sensitive to $\phi$ than to $\sigma$.

The latent $\mathcal{R}(t)$ runs at weekly cadence, matching the production HEW model and the weekly forecasting setting.
Production treats both hyperparameters as inferred (`eta_sd ~ TruncatedNormal(0.15, 0.05)`, `autoreg_rt ~ Beta(2, 40)`); the benchmark fixes them to isolate the parameterization axis.

## Diagnostics

`benchmarks/diagnose.py` builds one model on one dataset under one config and reports the data-side summary, the priors `build_he_model` selects and the initial scale they imply, prior-predictive ranges, whether the initial potential energy and gradient (under the sampler's `init_to_sample` strategy) are finite, and optionally a short NUTS run with its divergence count.

Its `--real-i0`, `--real-dow`, `--real-trunc`, and `--all-real` flags force the real-data priors onto the synthetic bundle one at a time, so a real-data sampler failure can be bisected off the CDC VM.
`--data-source real` runs the same diagnostics against a live bundle.

```bash
python -m benchmarks.diagnose --all-real --mcmc
python -m benchmarks.diagnose --real-i0
```

## Adding a benchmark

1. Add a model builder to `benchmarks/core/models.py` that returns a `BuiltFit`.
Reuse `BuildConfig` if the new model fits the existing axes.
2. If the model needs a new dataset, add a builder to `benchmarks/core/datasets.py` and expose it through `SyntheticProvider`.
3. Add or extend a suite module in `benchmarks/suites/` with a `main()` CLI.
Use `fit_and_measure`, `print_pairwise_tables`, and `write_results` from `benchmarks.core`.

## Wiring real data

`benchmarks.core.signals.DatasetProvider` is a `Protocol`.
Implement it for a reporting source and pass the provider to the suite; the model builder and runner do not change.
The expected payload is a `DatasetBundle` whose `signals` mapping carries one `SignalSeries` per observation source.

`benchmarks/core/real_data.py` provides `RealDataProvider`, a concrete implementation over the CDC NHSN (weekly hospital admissions) and NSSP (daily ED visits) feeds.
Construct it with a mapping of dataset name to `RealDataSpec` (disease, location, `as_of` vintage, training window) and request bundles by name, exactly as with `SyntheticProvider`.

`RealDataProvider` reads live H+E feeds through `cfa.stf.data` (from `cfa-stf-routine-forecasting`) and requires valid Azure credentials at call time.
It does not call the R `forecasttools` package for benchmark setup; location names and populations come from `benchmarks/core/reference_data.py`.
PyRenew intentionally does **not** declare `cfa-stf-routine-forecasting` as a dependency: the `cfa.stf.*` imports live inside the provider's function bodies, so `real_data.py` imports cleanly without it and the synthetic path is unaffected.
To use `RealDataProvider`, install `cfa-stf-routine-forecasting` into your own environment separately.
11 changes: 11 additions & 0 deletions benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""PyRenew benchmark suites.

Run a suite as a module, for example:

python -m benchmarks.suites.rt_params --quick

Suites read datasets through :mod:`benchmarks.core.datasets` and build models
through :mod:`benchmarks.core.models`. The signal data interface lives in
:mod:`benchmarks.core.signals` and is the seam where real reporting inputs
can be substituted for the synthetic providers in the future.
"""
1 change: 1 addition & 0 deletions benchmarks/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Benchmark engine: signals, datasets, models, metrics, runner, reporting."""
101 changes: 101 additions & 0 deletions benchmarks/core/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Synthetic dataset provider wrapping ``pyrenew/datasets/``.

Each :class:`DatasetBundle` exposed here is paired with one model builder in
:mod:`benchmarks.core.models`. The pairing is implicit: a suite chooses a
model, and the model's builder calls a specific dataset by name.

A real-data provider would implement the same :class:`DatasetProvider`
protocol; suites would not change.
"""

from __future__ import annotations

from datetime import date

import jax.numpy as jnp

from benchmarks.core.signals import (
DatasetBundle,
DatasetProvider,
SignalSeries,
)
from pyrenew.datasets import (
load_example_infection_admission_interval,
load_synthetic_daily_ed_visits,
load_synthetic_true_parameters,
load_synthetic_weekly_hospital_admissions,
)

GEN_INT_PMF: jnp.ndarray = jnp.array(
[0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683]
)

SYNTHETIC_HE_WEEKLY_HOSPITAL = "synthetic_he_weekly_hospital"


def _build_synthetic_he_weekly_hospital() -> DatasetBundle: # numpydoc ignore=RT01
"""Build the synthetic H+E bundle with weekly-aggregated hospital admissions."""
weekly_hosp = load_synthetic_weekly_hospital_admissions()
daily_ed = load_synthetic_daily_ed_visits()
true_params = load_synthetic_true_parameters()
hosp_delay_pmf = jnp.array(
load_example_infection_admission_interval()["probability_mass"].to_numpy()
)
ed_delay_pmf = jnp.array(true_params["ed_visits"]["delay_pmf"])
ed_dow = jnp.array(true_params["ed_visits"]["day_of_week_effects"])

obs_start = date(2023, 11, 5)
hospital = SignalSeries(
name="hospital",
values=jnp.array(
weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32
),
cadence="weekly",
start_date=obs_start,
extras={"delay_pmf": hosp_delay_pmf, "aggregation": "weekly"},
)
ed_visits = SignalSeries(
name="ed_visits",
values=jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32),
cadence="daily",
start_date=obs_start,
extras={"delay_pmf": ed_delay_pmf, "day_of_week_effects": ed_dow},
)
return DatasetBundle(
name=SYNTHETIC_HE_WEEKLY_HOSPITAL,
population_size=float(weekly_hosp["pop"][0]),
obs_start_date=obs_start,
n_days_post_init=126,
signals={"hospital": hospital, "ed_visits": ed_visits},
gen_int_pmf=GEN_INT_PMF,
fixed_params={"i0_per_capita": true_params["i0_per_capita"]},
)


_BUILDERS = {
SYNTHETIC_HE_WEEKLY_HOSPITAL: _build_synthetic_he_weekly_hospital,
}


class SyntheticProvider(DatasetProvider):
"""Provider that wraps the built-in synthetic fixtures in ``pyrenew/datasets/``.

Bundles are cached on first request so repeated suite candidates do not
re-read the CSV files.
"""

def __init__(self) -> None:
"""Create an empty cache."""
self._cache: dict[str, DatasetBundle] = {}

def list_datasets(self) -> list[str]: # numpydoc ignore=RT01
"""Return the dataset names this provider exposes."""
return list(_BUILDERS)

def get(self, name: str) -> DatasetBundle: # numpydoc ignore=RT01
"""Return the named dataset bundle, building and caching on first request."""
if name not in _BUILDERS:
raise KeyError(f"Unknown dataset {name!r}. Available: {sorted(_BUILDERS)}")
if name not in self._cache:
self._cache[name] = _BUILDERS[name]()
return self._cache[name]
Loading
Loading