Skip to content

[Refactor] Harden and speed up the JIT cache-key computation#597

Open
sjfeng1999 wants to merge 6 commits into
mainfrom
pr/enh-cache-key
Open

[Refactor] Harden and speed up the JIT cache-key computation#597
sjfeng1999 wants to merge 6 commits into
mainfrom
pr/enh-cache-key

Conversation

@sjfeng1999
Copy link
Copy Markdown
Collaborator

Per-call cache key (jit_function.py):

  • Require cache_signature on every JitArgument; drop the str(ir.Type) fallback so unknown types raise instead of silently colliding under one key.
  • Fold module globals into the key: a cross-process-stable snapshot plus Triton-style in-process drift detection that raises on change (no auto-recompile). Composite (name, module) identity so same-named globals across modules don't collide.
  • target = (GPUTarget, device_id), read live per call so device switches and arch/env changes participate in the key (device_id via the active DeviceRuntime).
  • Whitelisted code-gen env vars re-read live into the key (os._Environ._data fast path with a public-API fallback).
  • Performance: memoize the recursive global-ref discovery and the globals key segment; per call only re-snapshots values and runs a lean drift loop.

jit_argument.py:

  • Construct JitArgument-annotated params (e.g. Stream) via the annotation; remove the int+Stream special-case and the type fallback.

Tests:

  • test_jit_cache_key_completeness.py: env drift, globals snapshot in key, globals drift raises, required cache_signature protocol.
  • test_compile_hints.py: target entry is now (GPUTarget, device_id).

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Per-call cache key (jit_function.py):
- Require __cache_signature__ on every JitArgument; drop the str(ir.Type)
  fallback so unknown types raise instead of silently colliding under one key.
- Fold module globals into the key: a cross-process-stable snapshot plus
  Triton-style in-process drift detection that raises on change (no
  auto-recompile). Composite (name, module) identity so same-named globals
  across modules don't collide.
- target = (GPUTarget, device_id), read live per call so device switches and
  arch/env changes participate in the key (device_id via the active
  DeviceRuntime).
- Whitelisted code-gen env vars re-read live into the key (os._Environ._data
  fast path with a public-API fallback).
- Performance: memoize the recursive global-ref discovery and the globals key
  segment; per call only re-snapshots values and runs a lean drift loop.

jit_argument.py:
- Construct JitArgument-annotated params (e.g. Stream) via the annotation;
  remove the int+Stream special-case and the type fallback.

Tests:
- test_jit_cache_key_completeness.py: env drift, globals snapshot in key,
  globals drift raises, required __cache_signature__ protocol.
- test_compile_hints.py: _target_ entry is now (GPUTarget, device_id).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens JIT cache-key construction so generated artifacts better reflect codegen-affecting inputs such as environment, target/device, and referenced globals.

Changes:

  • Adds env-var, GPU target/device, and globals snapshot segments to JIT cache keys.
  • Refactors JitArgument handling to require explicit cache signatures and annotation-driven construction.
  • Adds/updates unit tests for cache-key completeness and target key shape.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
python/flydsl/compiler/jit_function.py Expands cache-key construction and dependency/global tracking.
python/flydsl/compiler/jit_argument.py Adjusts JitArgument conversion dispatch to prefer annotated argument types.
tests/unit/test_jit_cache_key_completeness.py Adds regression coverage for env drift, globals drift, device id, and cache signature requirements.
tests/unit/test_compile_hints.py Updates target cache-key assertions for (GPUTarget, device_id).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread python/flydsl/compiler/jit_function.py Outdated
Comment thread tests/unit/test_jit_cache_key_completeness.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

Comment thread python/flydsl/compiler/jit_function.py Outdated
Copy link
Copy Markdown
Collaborator

@coderfeli coderfeli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found two issues that seem worth addressing before merging:

  1. This may regress the launch hot path. _build_full_cache_key() runs before checking _call_state_cache, and _resolve_and_make_cache_key() now constructs the registered JitArgument for raw torch.Tensor arguments in order to compute cache_signature(). That means even a fully warmed CallState cache hit still creates a TensorAdaptor and goes through DLPack/adaptor setup just to build the key.

The CallState fast path still avoids DLPack for argument packing/execution, but it no longer avoids this cache-key construction cost. Since this PR is also intended to speed up key computation, can we restore a lightweight tensor metadata signature path, e.g. TensorAdaptor.cache_signature_from_tensor() / the old raw-signature behavior, and only build the full adaptor on miss/compile?

  1. There seems to be a drift-detection hole for class-bound/inherited JIT methods. _global_refs_cache and _globals_prefix_cache are keyed by owner_cls, but _used_global_vals is shared across all owners. If the same inherited JitFunction is first called on Base and later on Sub, and Sub overrides a helper that references an additional global, that new global is not in the original baseline. _check_globals_drift() then skips it via _NOT_IN_BASELINE, while _globals_prefix_cache[Sub] can still be memoized from the first Sub call.

After that, mutating the Sub-specific global may neither raise nor update the cached _globals_ key segment, allowing stale reuse. I think _used_global_vals should be keyed by owner_cls as well, matching the refs/prefix caches, or missing/new refs should invalidate the owner-specific globals prefix / raise.

Co-authored-by: Cursor <cursoragent@cursor.com>
@sjfeng1999 sjfeng1999 closed this Jun 1, 2026
@sjfeng1999 sjfeng1999 reopened this Jun 1, 2026
@sjfeng1999
Copy link
Copy Markdown
Collaborator Author

  1. Not an issue for the pre-compile path. flyc.compile(...) then compiled_fn(*args) runs CompiledFunction.__call__
    CallState.__call__, which never builds a cache key — no TensorAdaptor/DLPack, no globals snapshot, no hipGetDevice. All that work is paid once at compile time, so there's no per-launch overhead.
  2. Already fixed on this branch — _used_global_vals is now keyed by owner_cls

Comment thread python/flydsl/compiler/jit_function.py Outdated
sig = self._sig
key_parts = [("_target_", self._target)]
# Re-read device_id and env vars on every call.
target = (self._backend_target, get_device_runtime().current_device_id())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_id is folded into _target_, and this same key is stringified for the on-disk artifact (_cache_key_to_str(cache_key) -> cache_manager.get/set, L1463-1464). HSACO is arch-specific, not device-specific (pre-PR _target_ was arch-only and shared across devices). So on a multi-GPU box of the same arch, the first launch on each device misses the disk cache and recompiles, storing a duplicate artifact per device. Suggest scoping device_id to the in-process key only (CallState/func_exe) and keying the disk cache by arch alone.

old = baseline.get(key, _NOT_IN_BASELINE)
if old is _NOT_IN_BASELINE:
continue
new = _snapshot_global_value(var_dict[name], stable=False) if name in var_dict else None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_check_globals_drift re-snapshots every captured global on every call, and _snapshot_global_value walks builtin containers by value with no size bound (plus a sorted(key=repr) per level). A kernel that captures a large dict/list/set module global therefore pays an O(size) deep-walk on every warm launch. Worth bounding container size here (and documenting that oversized globals aren't deep drift-checked), since this is on the hot path.

Copy link
Copy Markdown
Collaborator Author

@sjfeng1999 sjfeng1999 Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a kernel doesn't use too much global variables, it won't hurt cache computation performance. It costs as needed not always.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed for scalar / small globals — count isn't the concern. The edge is a single large container global: _snapshot_global_value recurses by value with no size cap (plus a sorted(key=repr) per dict/set level), so one big dict/list/set is O(size) on every warm call regardless of how few globals there are. A size bound (or skip-with-warning above some threshold) would cover that case while keeping the common path free.

… GPUs)

The on-disk HSACO is arch-specific, and the in-process artifact / func_exe
is reusable across same-arch GPUs (as on main). Folding device_id into the
key split the cache per device, recompiling and storing a duplicate disk
artifact on each GPU of a multi-GPU box. Revert _target_ to the arch-only
GPUTarget so one entry is shared across devices.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants