Skip to content

Drop _lcm/jaxtyping_patch.py in favour of upstream fix#365

Open
hmgaudecker wants to merge 1 commit into
example/mahler-yum-from-replicationfrom
drop/jaxtyping-patch
Open

Drop _lcm/jaxtyping_patch.py in favour of upstream fix#365
hmgaudecker wants to merge 1 commit into
example/mahler-yum-from-replicationfrom
drop/jaxtyping-patch

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

Summary

Removes the local monkey-patch at src/_lcm/jaxtyping_patch.py in favour of an upstream fix in jaxtyping that covers all three affected sentinels (_any_dtype, _anonymous_dim, _anonymous_variadic_dim), not just the variadic-dim sentinel we patched locally.

Until the upstream fix is released as jaxtyping >= 0.3.10, pylcm pulls the patched branch directly via a pixi pypi-dependencies git override — mirroring the pattern in place for dags. Once the release lands, the override and the temporary comment can both go.

Changes

  • pyproject.toml: jaxtyping>=0.3.2 → jaxtyping>=0.3.10; pixi pypi-dependencies adds the git override pointing at hmgaudecker/jaxtyping@fix/sentinel-cloudpickle.
  • Deleted src/_lcm/jaxtyping_patch.py.
  • Simplified src/_lcm/__init__.py — dropped the load-bearing-import comment block, since jaxtyping_patch was the only thing needing pre-emption.
  • Regenerated pixi.lock for all four pixi envs.

Background

jaxtyping's _array_types module-level object() sentinels broke under cloudpickle round-trips because plain object() does not survive serialisation with identity intact. The variadic-dim case is what beartype-claw'd pylcm hit through DAG-built cloudpickleable callables; the other two sentinels would fail similarly under cloudpickle, just through different paths. The upstream fix replaces all three with __reduce__-backed singleton classes.

Upstream issue: patrick-kidger/jaxtyping#390. Fix branch: hmgaudecker/jaxtyping@fix/sentinel-cloudpickle (upstream PR pending).

Test plan

  • pixi run -e type-checking ty — clean.
  • prek run --all-files — clean.
  • pixi run --environment tests-cpu tests -n 4 — 1018 passed, 42 skipped. The cloudpickle round-trip tests in tests/simulation/test_simulate_aot.py are the binding ones for what the local patch was protecting.

Stacking

Stacked on #363. Rebase / change base to main once #363 lands; the diff is independent of the mahler-yum example.

jaxtyping's `_array_types` module-level `object()` sentinels broke
under cloudpickle round-trips because plain `object()` does not survive
serialisation with identity intact. The workaround lived as a local
monkey-patch in `_lcm/jaxtyping_patch.py` for the variadic-dim sentinel.

Replace the workaround by depending on a jaxtyping branch that fixes
all three affected sentinels (`_any_dtype`, `_anonymous_dim`,
`_anonymous_variadic_dim`) at the source via `__reduce__`-backed
singleton classes (patrick-kidger/jaxtyping#390 — pending upstream
review). Floor pin bumps to `jaxtyping>=0.3.10`; the pixi
pypi-dependency override pulls the fork branch until upstream releases.

`tests/simulation/test_simulate_aot.py` covers the cloudpickle
round-trip paths and continues to pass without the local patch.
@read-the-docs-community
Copy link
Copy Markdown

@github-actions
Copy link
Copy Markdown

Benchmark comparison (main → HEAD)

Comparing 629ac442 (main) → 78086a37 (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 26.281 s 14.916 s 0.57
peak GPU mem 639 MB 581 MB 0.91
compilation time 297.53 s 320.56 s 1.08
peak CPU mem 7.40 GB 7.98 GB 1.08
aca-baseline-debug execution time 71.906 s
peak GPU mem 581 MB
compilation time 411.45 s
peak CPU mem 7.62 GB
Mahler-Yum execution time 4.679 s 4.678 s 1.00
peak GPU mem 529 MB 520 MB 0.98
compilation time 14.15 s 11.62 s 0.82
peak CPU mem 1.71 GB 1.57 GB 0.92
Precautionary Savings - Solve execution time 46.3 ms 25.4 ms 0.55
peak GPU mem 101 MB 8 MB 0.08
compilation time 2.66 s 1.53 s 0.58
peak CPU mem 1.14 GB 1.13 GB 0.99
Precautionary Savings - Simulate execution time 120.2 ms 94.8 ms 0.79
peak GPU mem 349 MB 162 MB 0.46
compilation time 5.01 s 3.82 s 0.76
peak CPU mem 1.34 GB 1.32 GB 0.98
Precautionary Savings - Solve & Simulate execution time 159.7 ms 123.7 ms 0.77
peak GPU mem 586 MB 588 MB 1.00
compilation time 7.22 s 4.99 s 0.69
peak CPU mem 1.30 GB 1.29 GB 0.99
Precautionary Savings - Solve & Simulate (irreg) execution time 288.0 ms 254.1 ms 0.88
peak GPU mem 2.20 GB 2.20 GB 1.00
compilation time 7.53 s 5.32 s 0.71
peak CPU mem 1.36 GB 1.34 GB 0.99

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant