[Refactor] Harden and speed up the JIT cache-key computation#597
[Refactor] Harden and speed up the JIT cache-key computation#597sjfeng1999 wants to merge 6 commits into
Conversation
Per-call cache key (jit_function.py): - Require __cache_signature__ on every JitArgument; drop the str(ir.Type) fallback so unknown types raise instead of silently colliding under one key. - Fold module globals into the key: a cross-process-stable snapshot plus Triton-style in-process drift detection that raises on change (no auto-recompile). Composite (name, module) identity so same-named globals across modules don't collide. - target = (GPUTarget, device_id), read live per call so device switches and arch/env changes participate in the key (device_id via the active DeviceRuntime). - Whitelisted code-gen env vars re-read live into the key (os._Environ._data fast path with a public-API fallback). - Performance: memoize the recursive global-ref discovery and the globals key segment; per call only re-snapshots values and runs a lean drift loop. jit_argument.py: - Construct JitArgument-annotated params (e.g. Stream) via the annotation; remove the int+Stream special-case and the type fallback. Tests: - test_jit_cache_key_completeness.py: env drift, globals snapshot in key, globals drift raises, required __cache_signature__ protocol. - test_compile_hints.py: _target_ entry is now (GPUTarget, device_id). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR hardens JIT cache-key construction so generated artifacts better reflect codegen-affecting inputs such as environment, target/device, and referenced globals.
Changes:
- Adds env-var, GPU target/device, and globals snapshot segments to JIT cache keys.
- Refactors
JitArgumenthandling to require explicit cache signatures and annotation-driven construction. - Adds/updates unit tests for cache-key completeness and target key shape.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
python/flydsl/compiler/jit_function.py |
Expands cache-key construction and dependency/global tracking. |
python/flydsl/compiler/jit_argument.py |
Adjusts JitArgument conversion dispatch to prefer annotated argument types. |
tests/unit/test_jit_cache_key_completeness.py |
Adds regression coverage for env drift, globals drift, device id, and cache signature requirements. |
tests/unit/test_compile_hints.py |
Updates target cache-key assertions for (GPUTarget, device_id). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
coderfeli
left a comment
There was a problem hiding this comment.
I found two issues that seem worth addressing before merging:
- This may regress the launch hot path.
_build_full_cache_key()runs before checking_call_state_cache, and_resolve_and_make_cache_key()now constructs the registered JitArgument for rawtorch.Tensorarguments in order to computecache_signature(). That means even a fully warmedCallStatecache hit still creates aTensorAdaptorand goes through DLPack/adaptor setup just to build the key.
The CallState fast path still avoids DLPack for argument packing/execution, but it no longer avoids this cache-key construction cost. Since this PR is also intended to speed up key computation, can we restore a lightweight tensor metadata signature path, e.g. TensorAdaptor.cache_signature_from_tensor() / the old raw-signature behavior, and only build the full adaptor on miss/compile?
- There seems to be a drift-detection hole for class-bound/inherited JIT methods.
_global_refs_cacheand_globals_prefix_cacheare keyed byowner_cls, but_used_global_valsis shared across all owners. If the same inheritedJitFunctionis first called onBaseand later onSub, andSuboverrides a helper that references an additional global, that new global is not in the original baseline._check_globals_drift()then skips it via_NOT_IN_BASELINE, while_globals_prefix_cache[Sub]can still be memoized from the firstSubcall.
After that, mutating the Sub-specific global may neither raise nor update the cached _globals_ key segment, allowing stale reuse. I think _used_global_vals should be keyed by owner_cls as well, matching the refs/prefix caches, or missing/new refs should invalidate the owner-specific globals prefix / raise.
Co-authored-by: Cursor <cursoragent@cursor.com>
|
| sig = self._sig | ||
| key_parts = [("_target_", self._target)] | ||
| # Re-read device_id and env vars on every call. | ||
| target = (self._backend_target, get_device_runtime().current_device_id()) |
There was a problem hiding this comment.
device_id is folded into _target_, and this same key is stringified for the on-disk artifact (_cache_key_to_str(cache_key) -> cache_manager.get/set, L1463-1464). HSACO is arch-specific, not device-specific (pre-PR _target_ was arch-only and shared across devices). So on a multi-GPU box of the same arch, the first launch on each device misses the disk cache and recompiles, storing a duplicate artifact per device. Suggest scoping device_id to the in-process key only (CallState/func_exe) and keying the disk cache by arch alone.
| old = baseline.get(key, _NOT_IN_BASELINE) | ||
| if old is _NOT_IN_BASELINE: | ||
| continue | ||
| new = _snapshot_global_value(var_dict[name], stable=False) if name in var_dict else None |
There was a problem hiding this comment.
_check_globals_drift re-snapshots every captured global on every call, and _snapshot_global_value walks builtin containers by value with no size bound (plus a sorted(key=repr) per level). A kernel that captures a large dict/list/set module global therefore pays an O(size) deep-walk on every warm launch. Worth bounding container size here (and documenting that oversized globals aren't deep drift-checked), since this is on the hot path.
There was a problem hiding this comment.
If a kernel doesn't use too much global variables, it won't hurt cache computation performance. It costs as needed not always.
There was a problem hiding this comment.
Agreed for scalar / small globals — count isn't the concern. The edge is a single large container global: _snapshot_global_value recurses by value with no size cap (plus a sorted(key=repr) per dict/set level), so one big dict/list/set is O(size) on every warm call regardless of how few globals there are. A size bound (or skip-with-warning above some threshold) would cover that case while keeping the common path free.
… GPUs) The on-disk HSACO is arch-specific, and the in-process artifact / func_exe is reusable across same-arch GPUs (as on main). Folding device_id into the key split the cache per device, recompiling and storing a duplicate disk artifact on each GPU of a multi-GPU box. Revert _target_ to the arch-only GPUTarget so one entry is shared across devices. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per-call cache key (jit_function.py):
jit_argument.py:
Tests:
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist