Make numba caching safe across internal name collisions#2051
Conversation
cf21582 to
f62c492
Compare
ad698a1 to
cdb677c
Compare
|
@aseyboldt is happier with drawing a random number instead of randomly offsetting the counter |
|
now uuid counter, reported to numba to ask for feedback: numba/numba#10486 (comment) And added a numba stricter pin and a job to automatically check for new numba releases and update the bound / test so that we guard against API changes (generally useful, more so now that we are definitily on internal API intrusion) |
pytensor's numba-backed disk cache is unsafe across multiprocessing: sibling child processes freshly compiling same-qualname functions can allocate identical ``FunctionIdentity._unique_ids``, producing byte-identical LLVM mangled symbols with different bodies. Later loads hit LLVM's weak-merge and silently dispatch into the wrong body, corrupting logp/energy in parallel sampling children. Surfaced here as systematic 'Bad initial energy' failures of tests/progress_bar/test_manager.py::test_progressbar_nested_compound on the ubuntu (numba, 3.14, ...) CI job (3/3 runs). Confirmed by wrapping the test with pytensor.config.change_flags(numba__cache=False), which made CI green. Backport the minimal part of pytensor#2051: replace numba's FunctionIdentity._unique_ids with a uuid.uuid4()-based iterator so every UID is globally unique. Applied unconditionally at pymc import time (no-op when numba isn't installed). Drop once pytensor#2051 ships. Refs pymc-devs/pytensor#2051 Refs numba/numba#10486
| # Matches pins of the form "numba>=X.Y,<A.B" with either a main-dep or | ||
| # optional-extra quote style. The cap (group 2) is what we bump. | ||
| PIN_RE = re.compile(r'("numba>=\d+\.\d+,<)(\d+\.\d+)(")') | ||
| PYPI_URL = "https://pypi.org/pypi/numba/json" |
There was a problem hiding this comment.
do we care about conda-forge availability? (i guess not)
There was a problem hiding this comment.
pypi and conda releases should behave the same so I don't think it matters that we check pypi?
|
I pinned an already not latest version of numba to see if the job works. The script should use <= though, as we don't know what the next version will be... |
Numba's per-process FunctionIdentity._unique_ids counter is inherited verbatim across os.fork(), so sibling forks allocate identical uid values to same-qualname functions they fresh-compile next. Their .nbc files then contain linkonce_odr weak symbols with byte-identical mangled names but different bodies; a later process that loads both hits LLVM's weak-merge and dispatches into the wrong body, segfaulting. The fix re-seeds FunctionIdentity._unique_ids to a random 48-bit offset the first time our cache locator is consulted in each process, and again after every fork (via os.register_at_fork). This covers fork, forkserver, and spawn uniformly. Caching still works because .nbc files store the mangled name in their serialized FunctionDescriptor and loading reads the baked name rather than re-deriving it from the current counter. See numba/numba#10486
The previous fix re-seeded FunctionIdentity._unique_ids to a random 48-bit offset on first use and after every fork. Sequential counting from distinct offsets still leaves a nonzero collision probability if the ranges overlap, and the fork hook adds state to maintain. Replace the counter with an iterator that yields uuid.uuid4().int on every next() call. Numba only consumes _unique_ids via next(cls._unique_ids) in FunctionIdentity.from_function, so any iterator of unique ints is a drop-in substitute. uuid.uuid4 reads fresh entropy from os.urandom on every call, so it is intrinsically fork-safe: no os.register_at_fork hook, no re-seed flag, and collision probability is 128-bit negligible. See numba/numba#10486
Pin numba in pyproject.toml (main dep and numba extra) so CI does not silently consume releases that change internals we depend on. Add a weekly GitHub Actions workflow running scripts/bump_numba_upper_bound.py: it rewrites the pin to the canonical "numba>=X.Y,<=A.B.C" form whenever PyPI has a release newer than the cap, or the pin shape has drifted (< cap, no upper bound), and opens a PR. The Tests workflow validates each bump; merging is gated on CI and human review.
60a4230 to
43ea05b
Compare
The patch #1992 turned out not to be sufficient. It depends on whether numba decides to inline our customly-cached disk functions or not. The underlying issue is still numba/numba#10486
Which means pytensor custom caching machinery is not safe inside multiprocessing. This hasn't proven to be such a big deal in PyMC, as we opt to first compile then multiprocess (fine... as long as we do an early call to actually trigger numba cache/compile).
But it's all fragile. We went down #1992 because of #1988. ASV defaults to launching each job on a fork process. It sounded like spawn was immune to this mangling, but seems to be just an accident from numba using a set of flags for the naming, and the set is random across spawns (I don't think the set is that big that it would be hard to see problems).
So the current solution is to hack numba's internal counter for generating unique function ids, the first time and at each subsequent fork. This helper was stable for a while, but maybe it's time to start pinning range of numba versions manually.