Skip to content

feat(jit): implement @pl.jit preprocessor over @pl.program pipeline#915

Merged
Hzfengsy merged 7 commits intohw-native-sys:mainfrom
huxinyuan1215:feat/jit-preprocessor
Apr 20, 2026
Merged

feat(jit): implement @pl.jit preprocessor over @pl.program pipeline#915
Hzfengsy merged 7 commits intohw-native-sys:mainfrom
huxinyuan1215:feat/jit-preprocessor

Conversation

@huxinyuan1215
Copy link
Copy Markdown
Contributor

@huxinyuan1215 huxinyuan1215 commented Apr 8, 2026

Summary

Implements the @pl.jit decorator as specified in issue #878 / #915.

  • python/pypto/jit/decorator.pyJITFunction class with @pl.jit and @pl.jit.incore decorators. Supports Style A (inline pl.at(level=CORE_GROUP) scope) and Style B (explicit multi-function cross-call). __call__ specializes, runs the full pass pipeline, and caches the resulting CompiledProgram in an L1 in-memory cache keyed by (source_hash, shapes, dtypes, scalar_values). Dynamic dims (via bind_dynamic) are stored as None in the key so different concrete values share the same cache entry. platform is extracted from RunConfig at compile time and stored in CompiledProgram so execution uses the correct target.
  • python/pypto/jit/specializer.py — AST transformer that converts @pl.jit source into type-annotated @pl.program DSL source. Inlines static shape dims as ConstInt, preserves dynamic dims as Var, rewrites a.shape[i] references, removes bind_dynamic calls, and promotes pl.dynamic("M") declarations to module level. Entry functions with with pl.incore(): scopes are emitted as FunctionType.Opaque (not Orchestration) so that OutlineIncoreScopes can outline them correctly.
  • python/pypto/jit/cache.py — Cache key construction with PyPTO version stamp for automatic invalidation on upgrades.
  • python/pypto/language/typing/tensor.py — Adds bind_dynamic(dim, dynvar) method to Tensor.
  • python/pypto/language/__init__.py — Exports jit, restores PtrType/Ptr aliases.

Bug fixes:

  • Propagate dynamic dims from @pl.jit.incore deps back to entry function so params use Var(M) not ConstInt(64) (_propagate_dynamic_dims_from_deps, _backfill_entry_dynvar_bindings).
  • Emit FunctionType.Opaque (not Orchestration) for entry functions containing with pl.incore(): scopes; OutlineIncoreScopes only processes Opaque functions, so using Orchestration caused tile.alloc to be misplaced.
  • Forward platform from RunConfig through _compile() to ir.compile() so CompiledProgram stores the correct execution platform instead of defaulting to "a2a3sim".
  • test_assemble_acc_mat: upstream expand_mixed_kernel bug produces dangling result__ssa_v0 in AIC body for InCore+matmul without SplitMode; use enable_auto_mapping=True for comparison.

Testing

Unit tests (95 tests, 0 warnings)

Each UT contains a hand-written @pl.program as ground truth; both JIT output and reference run through PassManager and are compared via ir.assert_structural_equal.

  • tests/ut/jit/test_specializer.py — AST transformation unit tests
  • tests/ut/jit/test_decorator.py — decorator/cache/Style A+B integration tests
  • tests/ut/jit/test_roundtrip.py — round-trip tests covering all examples/kernels/ programs within @pl.jit scope: elementwise (01), fused ops (02), matmul (03), activation (05), softmax (06), normalization (07), assemble (08); plus dynamic dim test. 04_concat and 09_dyn_valid_shape are intentionally excluded (see file docstring).
  • tests/ut/jit/test_cache.py — cache key construction unit tests

System tests (device execution)

  • tests/st/runtime/test_jit.py — end-to-end @pl.jit compile + execute on real NPU: first-call compile, L1 cache hit on same shape, cache miss on different shape.

Scope vs #878 testing plan

#878 listed "round-trip: create @pl.jit equivalent for every examples/ program". This PR covers all kernel examples within @pl.jit scope. Programs excluded from round-trip tests (04_concat, 09_dyn_valid_shape, examples/models/) use patterns outside the current @pl.jit scope (module-level @pl.function, pl.create_tensor output allocation, pl.tensor.read scalar configs) and are tracked as follow-up work.

Related Issues

Closes #915
Addresses #878

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a JIT compilation subsystem under python/pypto/jit: package initializer, cache, AST-driven specializer, decorator API (pl.jit / pl.jit.incore), a small language typing hook (Tensor.bind_dynamic), and comprehensive unit/integration tests exercising specialization, caching, dependency discovery, and IR round-trips.

Changes

Cohort / File(s) Summary
JIT Core
python/pypto/jit/__init__.py
New package initializer re-exporting JITFunction and jit and documenting @pl.jit usage.
Cache
python/pypto/jit/cache.py
New two-level cache implementation: TensorCacheInfo/ScalarCacheInfo, CacheKey construction with dynamic-dim masking, compute_source_hash (SHA-256 truncated), and L2 lookup/store under ~/.cache/pypto/jit/.
Decorator / API
python/pypto/jit/decorator.py
New JITFunction and jit singleton: AST-based dependency discovery, arg classification, cache-key assembly, specialization orchestration, parse/compile flow, and caching logic.
Specializer
python/pypto/jit/specializer.py
AST-driven transformer and SpecializeContext: collects bind_dynamic/dynvars, classifies params, rewrites shapes/dtypes/dynvars, emits module-level pl.dynamic(...), and outputs parseable @pl.program source.
Language typing
python/pypto/language/__init__.py, python/pypto/language/typing/tensor.py
pypto.language now re-exports jit/JITFunction; Tensor.bind_dynamic(dim, var) added as a runtime no-op used by the specializer.
Type stubs
python/pypto/pypto_core/__init__.pyi
Added DataType.__hash__ stub to enable hashability.
Tests — Pytest setup
tests/ut/jit/conftest.py, tests/ut/jit/__init__.py
Test fixture ensures project root is on sys.path; added license header module.
Tests — Cache unit tests
tests/ut/jit/test_cache.py
Unit tests for compute_source_hash and make_cache_key: determinism, order-sensitivity, dynamic-dim masking, scalar handling, and key hashability.
Tests — Decorator & integration
tests/ut/jit/test_decorator.py
Tests for jit/jit.incore behavior, metadata preservation, bind_dynamic no-op, cache hit/miss semantics (torch-conditional), dependency discovery (Style B), and structural equality vs hand-written programs.
Tests — Specializer unit tests
tests/ut/jit/test_specializer.py
AST analysis and _BodyTransformer tests: dynamic-dim collection, dynvar extraction, param classification, shape/dtype substitution, dependency-call rewriting, generated source parseability and IR equality.
Tests — Round-trip integration
tests/ut/jit/test_roundtrip.py
Large set of round-trip tests mapping example kernels to @jit.incore implementations and asserting structural equality of produced IR vs reference kernels (torch-conditional).
Config
pyproject.toml
Pyright adjustments: exclude test_roundtrip from checks, extend extraPaths, and silence selected diagnostics.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant JIT as JITFunction
    participant Cache as Cache Layer
    participant Specializer as Specializer
    participant Parser as Parser
    participant Program as ir.Program

    User->>JIT: call jit_kernel(tensors..., scalars...)
    JIT->>JIT: classify args, extract TensorMeta
    JIT->>JIT: compute source_hash + CacheKey (mask dyn dims)
    JIT->>Cache: l2_lookup(cache_key)

    alt cache hit
        Cache-->>JIT: cached artifact path / Program
        JIT-->>User: return CompiledProgram / Program
    else cache miss
        JIT->>JIT: discover incore deps via AST
        JIT->>Specializer: build contexts (deps first, entry last)
        Specializer->>Specializer: rewrite AST (remove bind_dynamic, emit dynvars, substitute shapes/dtypes)
        Specializer-->>JIT: specialized DSL source
        JIT->>Parser: parse -> ir.Program
        Parser-->>Program: ir.Program
        JIT->>Cache: l2_store(cache_key, compiled_artifacts)
        JIT-->>User: return CompiledProgram / Program
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • lyfne123

Poem

🐰
I nibble AST leaves in moonlight bright,
I mark a dim as dynamic, then hop with delight,
Sources hashed, the cache chest hums below—
Kernels bloom, compiled, and off they go! 🥕✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.65% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title 'feat(jit): implement @pl.jit preprocessor over @pl.program pipeline' accurately describes the main change: implementing a JIT preprocessor decorator system for specializing Python kernels into @pl.program IR.
Linked Issues check ✅ Passed The PR successfully implements all primary objectives from #915 and #878: @pl.jit decorator with specialization, dynamic dims support, caching with versioning, Style A/B support, tensor metadata extraction, and comprehensive testing (95 tests across 4 test modules validating specialization, decoration, caching, and round-trip IR structural equality).
Out of Scope Changes check ✅ Passed All code changes are directly aligned with PR objectives: new jit package (decorator.py, specializer.py, cache.py), minimal typing additions (Tensor.bind_dynamic), language module exports, pyproject.toml configuration, type stubs, and comprehensive tests. No unrelated changes detected.
Description check ✅ Passed The PR description comprehensively details the implementation of the @pl.jit decorator, explaining module purposes, caching strategy, dynamic dimension handling, bug fixes, and testing scope.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a JIT compilation module for PyPTO, providing a @pl.jit decorator that enables automatic specialization and compilation of kernel functions based on tensor shapes and dtypes. The implementation includes a compilation cache, an AST specializer to transform JIT-style code into @pl.program source, and necessary updates to the pl.Tensor API. My review suggests optimizing the _is_tensor function to avoid redundant calls to _get_torch and updating _is_tensor_annotation to consistently handle both native and pl.Tensor types.

Comment on lines +109 to +112
torch = _get_torch()
if torch is None:
return False
return isinstance(obj, torch.Tensor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _is_tensor function calls _get_torch() which performs a try-except import. This is redundant because _get_torch() is already called inside _extract_tensor_meta and _torch_dtype_to_pypto. Since _get_torch caches the result, it is better to just call _get_torch() once and check the result, or simply use isinstance directly if torch is already imported in the module scope if possible, or rely on the cached result.

is_scalar_ann = _is_scalar_annotation(outer)
if is_scalar_ann:
scalar_dtype_strs[name] = _ast_to_str(inner)
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _is_tensor_annotation function only checks for the native Tensor type. It should be updated to also check for the custom pl.Tensor type to be consistent with how types are resolved in the DSL (similar to how pl.Scalar and pl.Tuple are handled). This ensures that type resolution correctly identifies tensor annotations regardless of whether the native or custom type is used.

References
  1. When checking for nested tuple types during type resolution, ensure the check covers both native tuple (which may resolve to a list) and the custom pl.Tuple type (which resolves to ir.TupleType).

@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch from 572087d to 3de82d2 Compare April 8, 2026 15:24
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
python/pypto/jit/decorator.py (2)

580-586: Consider adding strict=True for closure variable extraction.

The co_freevars and __closure__ tuples are guaranteed by Python to have matching lengths, but adding strict=True would make this invariant explicit and catch any future edge cases.

♻️ Optional: add strict=True
-    for name, cell in zip(co_freevars, closure):
+    for name, cell in zip(co_freevars, closure, strict=True):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/decorator.py` around lines 580 - 586, The loop extracting
closure variables (using co_freevars = getattr(getattr(func, "__code__", None),
"co_freevars", ()) and closure = getattr(func, "__closure__", None) or () and
for name, cell in zip(co_freevars, closure): ...) should use zip(...,
strict=True) to assert the freevar/closure length invariant; update the for-loop
to for name, cell in zip(co_freevars, closure, strict=True): so mismatched
lengths raise immediately and still catch ValueError from cell.cell_contents as
before (ensure runtime Python supports zip(strict=True)).

520-530: Potential silent truncation when dependency has more parameters than positional call args.

If the entry function calls a dependency with keyword arguments or fewer positional args than the dependency expects, zip(dep_param_names, call_args) will silently truncate, leaving some dependency parameters without metadata. This could cause downstream issues in specialization.

Consider adding strict=True or handling the length mismatch explicitly:

♻️ Proposed fix to handle length mismatch
         if call_args is not None:
-            for dep_param, entry_arg in zip(dep_param_names, call_args):
+            for dep_param, entry_arg in zip(dep_param_names, call_args, strict=False):
+                # Note: strict=False is intentional - call_args may be shorter if
+                # some args are passed as kwargs. Unmatched params fall through
+                # to name-based matching below.
                 if entry_arg is None:
                     continue
                 if entry_arg in all_tensor_meta:

Alternatively, if partial positional matching followed by name-based fallback is intended, document this behavior.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/decorator.py` around lines 520 - 530, The loop over
zip(dep_param_names, call_args) can silently truncate when a dependency expects
more parameters than provided as positional args (or when kwargs are used);
update the logic in the block handling dep_param_names/call_args to detect
length mismatches and handle them explicitly: either iterate over
dep_param_names and index into call_args with bounds checks (filling remaining
params from keyword mapping if available) or raise/log an explicit error when
required dep parameters lack metadata, and ensure dep_tensor_meta,
dep_scalar_values, and dep_scalar_dtypes are filled from
entry_scalar_values/entry_scalar_dtypes or all_tensor_meta by parameter name
fallback rather than relying solely on zip truncation. Ensure you touch the code
referencing dep_param_names, call_args, dep_tensor_meta, all_tensor_meta,
entry_scalar_values, and entry_scalar_dtypes.
tests/ut/jit/test_specializer.py (1)

172-180: Add one specialize() assertion for scalar substitution.

These tests prove _classify_params() recognizes pl.INDEX / pl.FP32, but nothing here feeds concrete scalar values into specialize() and checks that something like BLOCK_M or alpha becomes a literal in the generated source. One focused case would cover the biggest remaining gap in the new JIT contract.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ut/jit/test_specializer.py` around lines 172 - 180, The test currently
verifies _classify_params() recognizes pl.INDEX/pl.FP32 but doesn't assert that
specialize() substitutes scalar params into the generated source; update
test_scalar_bare_dtype to call specialize() on the parsed function (use concrete
values e.g. BLOCK_M=16 and alpha=0.5), get the generated source from the
specialization result, and assert that the source contains the literal "16" for
BLOCK_M and "0.5" (or equivalent literal) for alpha so the scalar substitution
is validated; keep references to _parse_func(func_def), _classify_params(), and
specialize() when locating code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/pypto/jit/specializer.py`:
- Around line 165-186: _collect_dynvar_names currently returns only Python
identifiers which lets specialize recreate pl.dynamic(...) using the local name;
instead capture and return the original literal used in the call (e.g., map
variable name -> literal string when the Call has a Constant/Str arg) so that
specialize can re-emit pl.dynamic(original_literal) instead of
pl.dynamic(variable_name); update _collect_dynvar_names to return a dict[str,
str] (or similar) and change specialize (and the other similar site that handles
dynvars) to consult that mapping when reconstructing pl.dynamic calls.
- Around line 279-295: The constructor stores scalar_values on self._scalars but
nothing uses it; update the _BodyTransformer to actually specialize compile-time
scalars by replacing references to parameter names with their concrete values
when walking the AST (e.g., implement handling in visit_Name / visit_Constant or
add a helper like _specialize_scalar_usage called from existing visit_*
methods). Ensure you read self._scalars (keys are parameter names) and replace
occurrences of pl.INDEX / pl.FP32-style compile-time params with literal values
so the transformed body reflects the scalar and the cache key then matches the
actual IR.

In `@tests/ut/jit/test_roundtrip.py`:
- Around line 66-122: The tests create tile_add/tile_mul JIT functions that
load/store 128x128 tensors but compare against reference IR fixtures
TileAddProgram and TileMulProgram which are 32x32, causing structural
mismatches; update the test functions (test_tile_add_128x128 and
test_tile_mul_128x128) to match the referenced programs by either (A) changing
the loads, stores, and test tensor shapes to 32x32 (adjust tile_a/tile_b load
sizes and created tensors a, b, c) or (B) importing or referencing the correct
128x128 fixtures if those exist; likewise scan the other failing tests that
compare against FusedAddScaleProgram and TileAssembleAccMatProgram and align
their JIT bodies (e.g., apply the 0.5 scale in the fused case and accept a
precomputed acc tile signature in the assemble case) so the orchestrator/tile_*
functions (and their argument shapes) exactly match the structure of the
referenced program symbols before calling ir.assert_structural_equal().

---

Nitpick comments:
In `@python/pypto/jit/decorator.py`:
- Around line 580-586: The loop extracting closure variables (using co_freevars
= getattr(getattr(func, "__code__", None), "co_freevars", ()) and closure =
getattr(func, "__closure__", None) or () and for name, cell in zip(co_freevars,
closure): ...) should use zip(..., strict=True) to assert the freevar/closure
length invariant; update the for-loop to for name, cell in zip(co_freevars,
closure, strict=True): so mismatched lengths raise immediately and still catch
ValueError from cell.cell_contents as before (ensure runtime Python supports
zip(strict=True)).
- Around line 520-530: The loop over zip(dep_param_names, call_args) can
silently truncate when a dependency expects more parameters than provided as
positional args (or when kwargs are used); update the logic in the block
handling dep_param_names/call_args to detect length mismatches and handle them
explicitly: either iterate over dep_param_names and index into call_args with
bounds checks (filling remaining params from keyword mapping if available) or
raise/log an explicit error when required dep parameters lack metadata, and
ensure dep_tensor_meta, dep_scalar_values, and dep_scalar_dtypes are filled from
entry_scalar_values/entry_scalar_dtypes or all_tensor_meta by parameter name
fallback rather than relying solely on zip truncation. Ensure you touch the code
referencing dep_param_names, call_args, dep_tensor_meta, all_tensor_meta,
entry_scalar_values, and entry_scalar_dtypes.

In `@tests/ut/jit/test_specializer.py`:
- Around line 172-180: The test currently verifies _classify_params() recognizes
pl.INDEX/pl.FP32 but doesn't assert that specialize() substitutes scalar params
into the generated source; update test_scalar_bare_dtype to call specialize() on
the parsed function (use concrete values e.g. BLOCK_M=16 and alpha=0.5), get the
generated source from the specialization result, and assert that the source
contains the literal "16" for BLOCK_M and "0.5" (or equivalent literal) for
alpha so the scalar substitution is validated; keep references to
_parse_func(func_def), _classify_params(), and specialize() when locating code
to change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8fecc048-92c4-4d96-aa29-285e3f44a771

📥 Commits

Reviewing files that changed from the base of the PR and between 08bdedd and 572087d.

📒 Files selected for processing (12)
  • python/pypto/jit/__init__.py
  • python/pypto/jit/cache.py
  • python/pypto/jit/decorator.py
  • python/pypto/jit/specializer.py
  • python/pypto/language/__init__.py
  • python/pypto/language/typing/tensor.py
  • tests/ut/jit/__init__.py
  • tests/ut/jit/conftest.py
  • tests/ut/jit/test_cache.py
  • tests/ut/jit/test_decorator.py
  • tests/ut/jit/test_roundtrip.py
  • tests/ut/jit/test_specializer.py

Comment thread python/pypto/jit/specializer.py Outdated
Comment thread python/pypto/jit/specializer.py
Comment thread tests/ut/jit/test_roundtrip.py Outdated
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
python/pypto/jit/specializer.py (2)

279-295: ⚠️ Potential issue | 🟠 Major

Scalar values affect the cache key, but not the generated body.

scalar_values is stored on _BodyTransformer, but no visitor ever reads self._scalars. Different pl.INDEX / pl.FP32 call values therefore generate different cache entries while producing the same IR, and scalar-dependent DSL constructs never get specialized by value. Please inline scalar loads and add a regression in tests/ut/jit/test_specializer.py.

Minimal fix
 class _BodyTransformer(ast.NodeTransformer):
@@
     def _shape_dim_node(self, param_name: str, dim_idx: int) -> ast.expr:
         meta = self._meta[param_name]
         if (param_name, dim_idx) in self._dynamic_dims:
             dv = _dynvar_name_for_dim(param_name, dim_idx, self._dv_names)
             return ast.Name(id=dv, ctx=ast.Load())
         return ast.Constant(value=meta.shape[dim_idx])
+
+    def visit_Name(self, node: ast.Name) -> ast.expr:
+        if isinstance(node.ctx, ast.Load) and node.id in self._scalars:
+            return ast.copy_location(ast.Constant(value=self._scalars[node.id]), node)
+        return node
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/specializer.py` around lines 279 - 295, The transformer
currently stores scalar_values on _BodyTransformer as self._scalars but never
uses them, causing cache divergence; update _BodyTransformer to inline scalar
constants by replacing loads of scalar DSL variables with their literal values
during IR generation (e.g., implement logic in the visitor used to walk
expressions — such as visit_Name/visit_Attribute or the method that handles DSL
scalar loads — to check self._scalars and return a Constant node for keys
present), remove reliance on scalar_values from any cache key generation if
present, and add a regression test in tests/ut/jit/test_specializer.py that
creates two specializations differing only by scalar_values and asserts they
produce identical generated bodies/IR.

615-623: ⚠️ Potential issue | 🟠 Major

DynVar emission renames user-defined dynamic symbols.

_iter_dynvar_names() only yields the Python binding name, so line 622 reconstructs pl.dynamic() as rows = pl.dynamic("rows") even when the source was rows = pl.dynamic("M"). That silently changes the symbol name and can also merge unrelated dynvars when different functions reuse the same local variable name. Please carry the original literal through specialization and emit that here instead.

Also applies to: 642-649

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/specializer.py` around lines 615 - 623, The emitted
module-level DynVar declarations are using the Python binding name from
_iter_dynvar_names() which loses the original literal passed to pl.dynamic()
(causing e.g. pl.dynamic("rows") instead of pl.dynamic("M")); update the
pipeline so the original literal is preserved through specialization (e.g.,
store the literal string on the context or the DynVar metadata when you discover
pl.dynamic) and change the emission logic to iterate those preserved literals
(not the local binding names) when building lines (the code around
_iter_dynvar_names and the emission loop that appends f'{dv_varname} =
pl.dynamic("{dv_varname}")'); apply the same fix to the second emission site
mentioned (the 642-649 block) so you emit the exact original string passed to
pl.dynamic() and avoid merging different DynVars that happen to share a local
variable name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/pypto/jit/decorator.py`:
- Around line 369-375: The current _get_source_hash and dependency handling only
includes direct deps, so nested `@jit.incore` dependencies (e.g., dep_a calling
dep_b) are not specialized or hashed; update _get_deps to perform a recursive,
deduplicating walk of dependency contexts and change call-site emission (where
contexts and dep_names are built) to include all transitive deps so emitted
contexts list contains every unique dep; ensure _get_source_hash collects
inspect.getsource for self._func plus the full deduped transitive set from
_get_deps and then calls compute_source_hash on that complete list (apply same
recursion fix to the other occurrences referenced around lines 472-490 and
540-549).
- Around line 239-257: The bug is that _extract_call_args_for_dep (and the same
logic at the other occurrence) only inspects node.args so keyword-only calls
like sub(a=a, c=out) yield an empty list which is treated as "found" and causes
_build_dep_context to take the positional branch; update
_extract_call_args_for_dep to also inspect node.keywords and build a mapping:
for each keyword in node.keywords, append keyword.arg (or None for None arg) at
the corresponding logical position or, simpler, if node.args is empty but
node.keywords is non-empty, return None so the caller falls back; apply the same
change to the duplicate block at the other location (lines referenced in the
comment) and ensure callers (_build_dep_context) correctly interpret a None
return as "no mapping".
- Around line 423-440: The entry-level dynamic_dims computed by
_scan_dynamic_dims must be augmented with any dependency-level bind_dynamic
markers before building the cache key and creating the entry specialization:
after calling _scan_dynamic_dims(self._func, param_names) merge in dynamic-dim
mappings produced/propagated by dependent `@jit.incore` functions (i.e. any
dep-derived dynamic dim map you already compute or can obtain from the
dependency analysis) so that dynamic_dims reflects both entry and dependency
bindings; then pass the merged dynamic_dims into make_cache_key(...) and into
the entry specialization/_compile(...) (and into the SpecializeContext creation)
so dependency-level DynVar bindings affect caching and compilation. Apply the
same merge at the other locations noted (the blocks around lines 479-489 and
537-548).

---

Duplicate comments:
In `@python/pypto/jit/specializer.py`:
- Around line 279-295: The transformer currently stores scalar_values on
_BodyTransformer as self._scalars but never uses them, causing cache divergence;
update _BodyTransformer to inline scalar constants by replacing loads of scalar
DSL variables with their literal values during IR generation (e.g., implement
logic in the visitor used to walk expressions — such as
visit_Name/visit_Attribute or the method that handles DSL scalar loads — to
check self._scalars and return a Constant node for keys present), remove
reliance on scalar_values from any cache key generation if present, and add a
regression test in tests/ut/jit/test_specializer.py that creates two
specializations differing only by scalar_values and asserts they produce
identical generated bodies/IR.
- Around line 615-623: The emitted module-level DynVar declarations are using
the Python binding name from _iter_dynvar_names() which loses the original
literal passed to pl.dynamic() (causing e.g. pl.dynamic("rows") instead of
pl.dynamic("M")); update the pipeline so the original literal is preserved
through specialization (e.g., store the literal string on the context or the
DynVar metadata when you discover pl.dynamic) and change the emission logic to
iterate those preserved literals (not the local binding names) when building
lines (the code around _iter_dynvar_names and the emission loop that appends
f'{dv_varname} = pl.dynamic("{dv_varname}")'); apply the same fix to the second
emission site mentioned (the 642-649 block) so you emit the exact original
string passed to pl.dynamic() and avoid merging different DynVars that happen to
share a local variable name.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 905c2132-c18d-427d-b9e6-adf01bc749df

📥 Commits

Reviewing files that changed from the base of the PR and between 572087d and 3de82d2.

📒 Files selected for processing (12)
  • python/pypto/jit/__init__.py
  • python/pypto/jit/cache.py
  • python/pypto/jit/decorator.py
  • python/pypto/jit/specializer.py
  • python/pypto/language/__init__.py
  • python/pypto/language/typing/tensor.py
  • tests/ut/jit/__init__.py
  • tests/ut/jit/conftest.py
  • tests/ut/jit/test_cache.py
  • tests/ut/jit/test_decorator.py
  • tests/ut/jit/test_roundtrip.py
  • tests/ut/jit/test_specializer.py
✅ Files skipped from review due to trivial changes (5)
  • tests/ut/jit/init.py
  • tests/ut/jit/conftest.py
  • python/pypto/language/typing/tensor.py
  • python/pypto/jit/init.py
  • python/pypto/jit/cache.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/ut/jit/test_cache.py
  • tests/ut/jit/test_decorator.py

Comment thread python/pypto/jit/decorator.py Outdated
Comment thread python/pypto/jit/decorator.py
Comment thread python/pypto/jit/decorator.py Outdated
@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch 6 times, most recently from df06df4 to 44b5260 Compare April 9, 2026 02:53
@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch 3 times, most recently from 183f873 to c9e33d2 Compare April 19, 2026 05:04
@huxinyuan1215
Copy link
Copy Markdown
Contributor Author

@coderabbitai help

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 19, 2026

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

CodeRabbit commands

These commands are invoked using PR/Issue comments.

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai evaluate custom pre-merge check --instructions <custom-checks-instructions> --name <custom-checks-title> [--mode <error|warning>] to validate the custom pre-merge checks instructions. Defaults to error when --mode is omitted.
  • @coderabbitai ignore pre-merge checks to override pre-merge checks and get an approval on PR.
  • @coderabbitai run pre-merge checks to run pre-merge checks on the pull request.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve merge conflict to automatically resolve merge conflicts.
  • @coderabbitai autofix to automatically fix issues identified in unresolved review comments.
  • @coderabbitai resolve to resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai generate configuration to create a PR that adds the current resolved configuration as .coderabbit.yaml (or show it if already present).
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit configuration file (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • You can also validate your configuration using the online YAML validator.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, support, documentation and community

  • Visit our status page to check the current availability of CodeRabbit.
  • Create a ticket on our support page for assistance with any issues or questions.
  • Visit our documentation site for detailed information on how to use CodeRabbit.
  • Join our Discord community to connect with other users and get help from the community.
  • Follow us on X/Twitter for updates and announcements.

@huxinyuan1215
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 19, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

♻️ Duplicate comments (1)
python/pypto/jit/decorator.py (1)

487-492: ⚠️ Potential issue | 🟠 Major

Walk transitive @jit.incore dependencies for emission and hashing.

The current flow still includes only direct deps in the source hash and only emits direct dep contexts; dependency contexts also use dep_names=[]. If dep_a() calls dep_b(), dep_b is not emitted as self.dep_b(...) and its source will not invalidate the cache.

Also applies to: 632-654, 723-732

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/decorator.py` around lines 487 - 492, The source-hash and
emission logic only include direct deps, so update _get_source_hash and the
emission code that calls _get_deps to walk dependencies transitively: change
dependency gathering to perform a recursive traversal starting from self (use
_get_deps recursively or a new helper) to collect a unique set/list of all
transitive Dep/Jit contexts, then for each collected dep call
inspect.getsource(dep._func) when building sources for compute_source_hash;
likewise update the emission logic that currently emits only direct deps and
uses dep_names=[] to emit every transitive dep as an attribute on the root
(e.g., self.dep_b(...)) and populate dep_names with the attribute path for each
emitted dep so transitive calls are emitted and their sources will invalidate
the cache. Ensure you deduplicate cycles and maintain original ordering
semantics when collecting deps.
🧹 Nitpick comments (2)
python/pypto/jit/decorator.py (1)

78-78: Move intentional import suppressions out of inline noqa comments.

If these lazy imports are intentional, add a documented per-file ignore in pyproject.toml instead of embedding repeated inline suppressions.

Based on learnings, inline # noqa comments are not acceptable for intentional patterns; fix root causes in linter configuration such as ruff.toml/per-file ignores.

Also applies to: 184-185, 365-365, 523-523, 593-593, 751-752

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/decorator.py` at line 78, The file
python/pypto/jit/decorator.py contains multiple inline "# noqa" suppressions on
lazy imports like the `import torch` statements (e.g., the import at line with
`import torch  # noqa: PLC0415` and others at the referenced lines), so remove
those inline `# noqa` comments and instead add a documented per-file linter
ignore for this module in your project config (pyproject.toml or ruff.toml)
specifying the rule(s) being suppressed (e.g., PLC0415) and a brief comment
explaining the lazy-import rationale; ensure you update the config entry to
target the file `python/pypto/jit/decorator.py` (covering the occurrences you
noted) and run the linter to confirm no remaining inline suppressions are
needed.
pyproject.toml (1)

118-131: Keep Pyright suppressions narrow.

Excluding the new JIT round-trip test and disabling reportGeneralTypeIssues globally weakens type checking for the whole repo. Prefer fixing the JIT typing gaps or using a documented, narrowly scoped per-file setting so unrelated regressions remain visible.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pyproject.toml` around lines 118 - 131, The global Pyright suppressions
(e.g., reportGeneralTypeIssues = false and other report* flags) and excluding
tests/ut/jit/test_roundtrip.py weaken repo-wide type checking; revert those
globals to their defaults (remove or set reportGeneralTypeIssues and the other
report* flags back to true) and stop excluding the JIT test, then either fix the
JIT typing gaps or add a narrow, file-scoped suppression inside
tests/ut/jit/test_roundtrip.py (for example a top-of-file pyright directive like
"# pyright: reportGeneralTypeIssues=false" or targeted type-ignore comments) so
only that test keeps suppressed checks while the rest of the codebase retains
full reporting.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/pypto/jit/cache.py`:
- Around line 87-91: The hash construction in this module currently concatenates
raw source strings (using _PYPTO_VERSION and iterating sources into h.update),
which makes different source lists collide (e.g., ["ab","c"] vs ["a","bc"]); fix
it by framing each source before feeding the hasher—e.g., for each item in
sources prepend its length or a fixed unambiguous delimiter and update the
hasher with that framed byte sequence (use the same encoding as _PYPTO_VERSION)
so the loop that updates h (the hashlib.sha256 instance) uniquely represents
list boundaries.

In `@python/pypto/jit/decorator.py`:
- Around line 335-390: The current _build_dynvar_bindings builds a flat bindings
dict keyed by f"{param}__{dim}" which allows different SpecializeContext entries
(ctx.func_name) to clobber each other; change the key to include the
function/context identity (e.g. incorporate ctx.func_name into the key like
f"{ctx.func_name}__{fn.value.id}__{dim_node.value}") or alternatively return a
per-context mapping (e.g. dict[ctx.func_name] → dict[param__dim → var]) and
update any caller that expects the flat bindings to consult the per-context map
or perform verified call-site backfilling; update uses of bindings and literals
accordingly so lookups are done with the new context-qualified keys or via the
per-context map.
- Around line 266-279: The code currently treats any presence of node.args as
positional-only and returns a list of names, which drops keyword mappings when
calls mix positional and keyword args; update the branch in the function
handling ast.Call nodes so that when node.args and node.keywords are both
present you return a unified dep-param→entry-arg map (the same dict-like list of
2-tuples used for pure-keyword calls) by pairing positional parameter names with
their corresponding node.args values (use arg.id if ast.Name else None) and then
appending the (kw.arg, kw.value.id if ast.Name else None) tuples for
node.keywords; ensure _build_dep_context continues to accept this dict-like list
and that kw.arg None entries are filtered out exactly as the existing
keyword-only branch does.

In `@python/pypto/jit/specializer.py`:
- Around line 156-160: The scanner only reads positional args (call.args[0]) so
keyword-form calls like Tensor.bind_dynamic(dim=0, var=M) are missed; update the
code that extracts dim and var from ast.Call to first check call.keywords for
entries with arg == "dim" and arg == "var" (parsing their .value like
dim_node/var_node and applying the same isinstance(ast.Constant)/int checks),
falling back to call.args[0]/call.args[1] for positional forms; apply the same
keyword-aware logic in the dynvar-binding collection path so the
cache/annotation keys include dims and vars specified as keywords (and keep the
existing result.add((func.value.id, dim_value)) behavior when a valid int dim is
found).
- Around line 673-675: The emitted dynamic var line uses an unescaped string
literal and can produce invalid Python for values like M"0; change the
formatting in the lines.append call to embed an escaped/quoted representation of
dv_literal (from self._dv_literals via dv_varname) by using Python's repr (or
the !r formatter) so the generated line becomes pl.dynamic(<escaped literal>)
rather than interpolating raw dv_literal.
- Around line 383-399: The bug is stale inlining of shape aliases: when a name
has been recorded in self._shape_inlined (e.g., from the loop that inlines
static dims) and later user code assigns to that same name, the inlined mapping
must be removed so subsequent loads aren't rewritten to the old constant; to
fix, update the assignment-handling path (e.g., visit_Assign / any code that
handles ast.Store targets) to detect target names and pop them from
self._shape_inlined (and similarly handle AugAssign/AnnAssign/targets that are
ast.Name) so any reassigned identifier previously inlined is invalidated; apply
the same change for the other inlining site referenced (lines 426-433) where
static dims are recorded.

In `@python/pypto/pypto_core/__init__.pyi`:
- Line 68: The stub declares DataType.__hash__ but the C++ binding in
python/bindings/modules/core.cpp only exposes __eq__, __ne__, __repr__, and
__str__; implement and expose a DataType.__hash__ in the C++ binding (e.g., add
a hash function that returns a stable int based on the DataType identity/content
and register it with pybind as "__hash__") so DataType instances become usable
as dict keys (needed by _DTYPE_TO_PL in python/pypto/jit/specializer.py). Locate
the DataType binding block in modules/core.cpp and add the corresponding
lambda/static method for "__hash__" consistent with the equality implementation
and the pyd binding style already used for __eq__/__ne__.

---

Duplicate comments:
In `@python/pypto/jit/decorator.py`:
- Around line 487-492: The source-hash and emission logic only include direct
deps, so update _get_source_hash and the emission code that calls _get_deps to
walk dependencies transitively: change dependency gathering to perform a
recursive traversal starting from self (use _get_deps recursively or a new
helper) to collect a unique set/list of all transitive Dep/Jit contexts, then
for each collected dep call inspect.getsource(dep._func) when building sources
for compute_source_hash; likewise update the emission logic that currently emits
only direct deps and uses dep_names=[] to emit every transitive dep as an
attribute on the root (e.g., self.dep_b(...)) and populate dep_names with the
attribute path for each emitted dep so transitive calls are emitted and their
sources will invalidate the cache. Ensure you deduplicate cycles and maintain
original ordering semantics when collecting deps.

---

Nitpick comments:
In `@pyproject.toml`:
- Around line 118-131: The global Pyright suppressions (e.g.,
reportGeneralTypeIssues = false and other report* flags) and excluding
tests/ut/jit/test_roundtrip.py weaken repo-wide type checking; revert those
globals to their defaults (remove or set reportGeneralTypeIssues and the other
report* flags back to true) and stop excluding the JIT test, then either fix the
JIT typing gaps or add a narrow, file-scoped suppression inside
tests/ut/jit/test_roundtrip.py (for example a top-of-file pyright directive like
"# pyright: reportGeneralTypeIssues=false" or targeted type-ignore comments) so
only that test keeps suppressed checks while the rest of the codebase retains
full reporting.

In `@python/pypto/jit/decorator.py`:
- Line 78: The file python/pypto/jit/decorator.py contains multiple inline "#
noqa" suppressions on lazy imports like the `import torch` statements (e.g., the
import at line with `import torch  # noqa: PLC0415` and others at the referenced
lines), so remove those inline `# noqa` comments and instead add a documented
per-file linter ignore for this module in your project config (pyproject.toml or
ruff.toml) specifying the rule(s) being suppressed (e.g., PLC0415) and a brief
comment explaining the lazy-import rationale; ensure you update the config entry
to target the file `python/pypto/jit/decorator.py` (covering the occurrences you
noted) and run the linter to confirm no remaining inline suppressions are
needed.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: e7b64ba1-0efb-4bc4-9392-e5627d50b9da

📥 Commits

Reviewing files that changed from the base of the PR and between 572087d and 35afd4a.

📒 Files selected for processing (8)
  • pyproject.toml
  • python/pypto/jit/__init__.py
  • python/pypto/jit/cache.py
  • python/pypto/jit/decorator.py
  • python/pypto/jit/specializer.py
  • python/pypto/language/__init__.py
  • python/pypto/language/typing/tensor.py
  • python/pypto/pypto_core/__init__.pyi
✅ Files skipped from review due to trivial changes (1)
  • python/pypto/jit/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • python/pypto/language/init.py

Comment thread python/pypto/jit/cache.py
Comment on lines +87 to +91
h = hashlib.sha256()
h.update(_PYPTO_VERSION.encode())
for src in sources:
h.update(src.encode())
return h.hexdigest()[:16]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Frame each source before hashing it.

Concatenating raw source strings makes different source lists hash identically before SHA-256, e.g. ["ab", "c"] and ["a", "bc"]. Prefix each source with its length or a delimiter that cannot be ambiguous.

🐛 Proposed fix
     h = hashlib.sha256()
     h.update(_PYPTO_VERSION.encode())
     for src in sources:
-        h.update(src.encode())
+        encoded = src.encode()
+        h.update(len(encoded).to_bytes(8, "big"))
+        h.update(encoded)
     return h.hexdigest()[:16]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/cache.py` around lines 87 - 91, The hash construction in
this module currently concatenates raw source strings (using _PYPTO_VERSION and
iterating sources into h.update), which makes different source lists collide
(e.g., ["ab","c"] vs ["a","bc"]); fix it by framing each source before feeding
the hasher—e.g., for each item in sources prepend its length or a fixed
unambiguous delimiter and update the hasher with that framed byte sequence (use
the same encoding as _PYPTO_VERSION) so the loop that updates h (the
hashlib.sha256 instance) uniquely represents list boundaries.

Comment thread python/pypto/jit/decorator.py
Comment thread python/pypto/jit/decorator.py
Comment thread python/pypto/jit/specializer.py
Comment thread python/pypto/jit/specializer.py
Comment on lines +673 to +675
# Use the original string literal if available, else fall back to var name
dv_literal = self._dv_literals.get(dv_varname, dv_varname)
lines.append(f'{dv_varname} = pl.dynamic("{dv_literal}")')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Escape dynvar string literals when regenerating source.

A valid call like pl.dynamic('M"0') currently emits invalid Python: pl.dynamic("M"0"). Use repr/!r when embedding the literal.

🐛 Proposed fix
-                    lines.append(f'{dv_varname} = pl.dynamic("{dv_literal}")')
+                    lines.append(f"{dv_varname} = pl.dynamic({dv_literal!r})")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/jit/specializer.py` around lines 673 - 675, The emitted dynamic
var line uses an unescaped string literal and can produce invalid Python for
values like M"0; change the formatting in the lines.append call to embed an
escaped/quoted representation of dv_literal (from self._dv_literals via
dv_varname) by using Python's repr (or the !r formatter) so the generated line
becomes pl.dynamic(<escaped literal>) rather than interpolating raw dv_literal.

The size in bits of the data type
"""

def __hash__(self) -> int: ...
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

rg -n -C3 'class_<DataType>|__hash__|__eq__|__ne__|Code' \
  python/bindings/modules/core.cpp \
  python/pypto/pypto_core/__init__.pyi

Repository: hw-native-sys/pypto

Length of output: 3315


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 1. Check if DataType is used as a dict key or in cache contexts
echo "=== Searching for DataType usage in cache/dict contexts ==="
rg -n 'DataType.*cache|cache.*DataType|dict.*DataType|Set\[.*DataType|Dict.*DataType' --type py python/

echo ""
echo "=== Searching for dataclass definitions with DataType fields ==="
rg -n '@dataclass|DataType' python/pypto/ir/ --type py | head -50

echo ""
echo "=== Verifying DataType::Code() exists and is public ==="
rg -n 'Code\(\)|uint8_t.*Code' include/pypto/ src/ --type cpp | head -20

Repository: hw-native-sys/pypto

Length of output: 7071


Implement DataType.__hash__ in the C++ binding before declaring it in the stub.

The stub declares hashability (line 68), but python/bindings/modules/core.cpp only binds __eq__, __ne__, __repr__, and __str__. The new JIT code uses DataType as dict keys in _DTYPE_TO_PL (python/pypto/jit/specializer.py:93), which requires hashability. Without the implementation, this will fail at runtime with a TypeError.

Suggested fix
       // Operators
       .def("__eq__", &DataType::operator==, nb::arg("other"), "Equality comparison operator")
       .def("__ne__", &DataType::operator!=, nb::arg("other"), "Inequality comparison operator")
+      .def("__hash__", [](const DataType& self) {
+        return std::hash<uint8_t>{}(self.Code());
+      }, "Hash based on the underlying data type code")
       .def("__repr__", &DataType::ToString, "String representation for debugging")
       .def("__str__", &DataType::ToString, "String representation for printing");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/pypto_core/__init__.pyi` at line 68, The stub declares
DataType.__hash__ but the C++ binding in python/bindings/modules/core.cpp only
exposes __eq__, __ne__, __repr__, and __str__; implement and expose a
DataType.__hash__ in the C++ binding (e.g., add a hash function that returns a
stable int based on the DataType identity/content and register it with pybind as
"__hash__") so DataType instances become usable as dict keys (needed by
_DTYPE_TO_PL in python/pypto/jit/specializer.py). Locate the DataType binding
block in modules/core.cpp and add the corresponding lambda/static method for
"__hash__" consistent with the equality implementation and the pyd binding style
already used for __eq__/__ne__.

@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch from 35afd4a to e206de0 Compare April 19, 2026 05:59
…stamp

- _compile() now calls ir.compile(skip_ptoas=True) to run full pass
  pipeline and generate C++ artifacts without requiring the Ascend
  toolchain locally. Returns output_dir (str) instead of ir.Program.

- __call__ checks L1 (in-memory) then L2 (on-disk) cache before
  recompiling. L2 cache lives under ~/.cache/pypto/jit/<key_hash>/
  and copies compiled artifacts so they survive process restarts and
  directory cleanup.

- compute_source_hash() mixes in pypto.__version__ so that upgrading
  PyPTO automatically invalidates stale L1 and L2 cache entries without
  manual clearing (addresses issue hw-native-sys#878 Q3).

- l2_lookup() / l2_store() added to cache.py with non-fatal error
  handling so L2 write failures never break compilation.

- compile_for_test() added to JITFunction: runs specialization and
  pass pipeline without codegen for unit tests that compare IR via
  ir.assert_structural_equal.

- test_roundtrip.py updated: all 18 test calls use
  orchestrator.compile_for_test() so tests work locally without the
  Ascend toolchain.
@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch from e206de0 to 48e9448 Compare April 19, 2026 06:01
…tion tests

- Extract 'config' kwarg from __call__ before sig.bind() so RunConfig
  is forwarded to CompiledProgram.__call__() without confusing the
  JIT parameter specializer
- Add tests/st/runtime/test_jit.py: system tests that compile and
  execute @pl.jit kernels on device, covering in-place execution,
  cache hit/miss, multi-shape correctness, and bind_dynamic reuse
…=True

When ptoas is available (PTOAS_ROOT env var or PATH), compile with the
full pipeline so CompiledProgram.__call__() can execute on device.
When ptoas is absent (e.g. UT CI runners), fall back to skip_ptoas=True
so unit tests continue to pass without the Ascend toolchain.
JIT execution tests require integration with the PTOTestCase pre-build
harness to avoid runtime g++-15 compilation. Remove for now to unblock
the PR; JIT st coverage can be added in a follow-up.
Verifies end-to-end JIT compile+execute flow on real NPU:
- test_inplace_add: first call compiles and produces correct result
- test_cache_hit_reuses_compiled_program: same shape hits L1 cache
- test_cache_miss_different_shape: different shape triggers recompile
Specializer was emitting @pl.function(type=pl.FunctionType.Orchestration)
for all @pl.jit entry functions.  OutlineIncoreScopes only processes
FunctionType::Opaque functions, so tile.alloc ops inside 'with pl.incore():'
were left in the Orchestration function, causing codegen to fail with
'Misplaced builtin op tile.alloc in Orchestration function'.

Fix: emit Opaque when the entry function contains a pl.incore()/auto_incore()
scope (so the pass can outline and promote it), and Orchestration otherwise
(multi-function style B where the entry has no incore scope).

Also remove the workaround in tests/ut/jit/conftest.py that suppressed pass
verification instruments for JIT tests -- the root cause is now fixed.
@Hzfengsy
Copy link
Copy Markdown
Member

@codex review

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Implements a new @pl.jit frontend that specializes Python kernel functions into generated @pl.program DSL source, compiles via the existing IR/pass pipeline, and caches compiled artifacts by a shape/dtype/scalar key (with support for runtime-dynamic dimensions via bind_dynamic).

Changes:

  • Add pypto.jit module (decorator, specializer, cache) and export it through pypto.language.
  • Add Tensor.bind_dynamic() API for marking runtime-dynamic dimensions in @pl.jit specialization.
  • Add extensive UT/ST coverage for AST specialization, caching behavior, round-trip structural equality, and device execution.

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
python/pypto/jit/decorator.py Implements JITFunction / pl.jit decorators, dep discovery, specialization + compile + L1 caching, plus test-only compile path.
python/pypto/jit/specializer.py AST-based source-to-DSL transformer and generated @pl.program emitter (incl. dynvar/module-level promotion).
python/pypto/jit/cache.py Cache key construction (and documented L2 helpers).
python/pypto/jit/init.py Public exports for the JIT module.
python/pypto/language/typing/tensor.py Adds Tensor.bind_dynamic() API for JIT dynamic-dim marking.
python/pypto/language/init.py Re-exports jit/JITFunction from pypto.jit and updates __all__.
python/pypto/pypto_core/init.pyi Updates DataType typing stub to include __hash__.
tests/ut/jit/test_specializer.py Unit tests for specialization transforms and parseability.
tests/ut/jit/test_decorator.py Unit tests for decorator behavior, caching, dep discovery, and bind_dynamic no-op.
tests/ut/jit/test_cache.py Unit tests for cache key/source hash behavior.
tests/ut/jit/test_roundtrip.py Round-trip structural-equality tests across many example-kernel patterns, including dynamic dims.
tests/ut/jit/conftest.py JIT test backend setup and pass verification fixture scaffolding.
tests/st/runtime/test_jit.py End-to-end device execution + cache-hit/miss behavior for @pl.jit.
pyproject.toml Pyright configuration tweaks (exclude roundtrip test; extraPaths and additional report suppressions).

c.bind_dynamic(0, M) # dim 0 of c shares the same DynVar
K = a.shape[1] # dim 1 is compile-time constant
...
"""
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bind_dynamic() has only a docstring and no function body (no pass/return), so this file will fail to import with a syntax/indentation error. Add an explicit no-op implementation (e.g., return None) so Tensor.bind_dynamic() is callable at runtime as intended.

Suggested change
"""
"""
return None

Copilot uses AI. Check for mistakes.
Comment thread python/pypto/jit/decorator.py Outdated
Comment on lines +813 to +819
# OutlineIncoreScopes limitation — that is expected and the cache
# entry is left empty in that case.
if key not in self._cache:
try:
self._cache[key] = self._compile(tensor_meta, scalar_values, scalar_dtypes, dynamic_dims, pl)
except Exception:
pass # codegen failure for single-function programs is expected
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This try/except Exception: pass in compile_for_test() will silently hide compilation/codegen failures, even unexpected ones, while the docstring claims the cache is populated via ir.compile(). Consider catching only the specific expected failure(s) (or re-raising with context) so tests don't mask real regressions.

Suggested change
# OutlineIncoreScopes limitation — that is expected and the cache
# entry is left empty in that case.
if key not in self._cache:
try:
self._cache[key] = self._compile(tensor_meta, scalar_values, scalar_dtypes, dynamic_dims, pl)
except Exception:
pass # codegen failure for single-function programs is expected
# OutlineIncoreScopes limitation — that specific failure is expected
# and the cache entry is left empty in that case.
if key not in self._cache:
try:
self._cache[key] = self._compile(tensor_meta, scalar_values, scalar_dtypes, dynamic_dims, pl)
except Exception as e:
if "OutlineIncoreScopes" not in str(e):
raise RuntimeError(
f"@pl.jit function '{self.__name__}': compile_for_test() failed during ir.compile()"
) from e

Copilot uses AI. Check for mistakes.
Comment thread python/pypto/jit/cache.py Outdated
Comment on lines +10 to +15
"""Compilation cache for @pl.jit functions.

L1 cache: in-memory dict on each JITFunction instance.
L2 cache: on-disk under ~/.cache/pypto/jit/<key_hash>/ (survives process restarts).

Cache key encodes source hash, tensor shapes/dtypes, scalar values, and the
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module docstring claims an on-disk "L2 cache" is implemented, but JITFunction currently only uses the in-memory _cache and never calls l2_lookup/l2_store. Either wire L2 into the decorator or adjust the docstring to avoid documenting behavior that doesn't exist.

Suggested change
"""Compilation cache for @pl.jit functions.
L1 cache: in-memory dict on each JITFunction instance.
L2 cache: on-disk under ~/.cache/pypto/jit/<key_hash>/ (survives process restarts).
Cache key encodes source hash, tensor shapes/dtypes, scalar values, and the
"""Compilation cache support for @pl.jit functions.
The active cache used by JITFunction is an in-memory dict on each instance.
This module also defines cache-key construction and related helper utilities.
Cache keys encode source hash, tensor shapes/dtypes, scalar values, and the

Copilot uses AI. Check for mistakes.
Comment thread python/pypto/jit/specializer.py Outdated
Comment on lines +855 to +858
is_out = name in out_params
ann = _build_tensor_annotation(
name,
ctx.tensor_meta[name],
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.tensor_meta[name] can raise a KeyError if a tensor parameter's metadata wasn't inferred (e.g., dep called with an intermediate pl.create_tensor whose shape/dtype couldn't be statically extracted). Add an explicit check and raise a clear user-facing error indicating which parameter is missing meta and how to make it inferable.

Suggested change
is_out = name in out_params
ann = _build_tensor_annotation(
name,
ctx.tensor_meta[name],
meta = ctx.tensor_meta.get(name)
if meta is None:
raise ValueError(
f"Missing inferred tensor metadata for parameter '{name}' during JIT specialization. "
"This usually means the tensor's shape and dtype could not be determined statically. "
"To make it inferable, pass the tensor directly as a function argument or ensure any "
"intermediate pl.create_tensor(...) used for this parameter has statically inferable "
"shape and dtype."
)
is_out = name in out_params
ann = _build_tensor_annotation(
name,
meta,

Copilot uses AI. Check for mistakes.
Comment thread python/pypto/jit/specializer.py Outdated
Comment on lines +386 to +395
# Dynamic dim: emit assignment (M = pl.dynamic("M") ref)
val: ast.expr = self._shape_dim_node(param_name, i)
stmts.append(
ast.Assign(
targets=[ast.Name(id=tgt.id, ctx=ast.Store())],
value=val,
lineno=node.lineno,
col_offset=node.col_offset,
)
)
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a dimension is dynamic and the unpack target variable has the same name as the DynVar binding, this expansion can emit a redundant self-assignment (e.g. M = M). Consider skipping the assignment when the RHS is the same name as the LHS to keep generated code clean and avoid confusing downstream tooling.

Suggested change
# Dynamic dim: emit assignment (M = pl.dynamic("M") ref)
val: ast.expr = self._shape_dim_node(param_name, i)
stmts.append(
ast.Assign(
targets=[ast.Name(id=tgt.id, ctx=ast.Store())],
value=val,
lineno=node.lineno,
col_offset=node.col_offset,
)
)
# Dynamic dim: emit assignment unless it would be a no-op
val: ast.expr = self._shape_dim_node(param_name, i)
if not (isinstance(val, ast.Name) and val.id == tgt.id):
stmts.append(
ast.Assign(
targets=[ast.Name(id=tgt.id, ctx=ast.Store())],
value=val,
lineno=node.lineno,
col_offset=node.col_offset,
)
)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 86b0fbbad8

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread python/pypto/jit/decorator.py Outdated
Comment on lines +592 to +594
if run_config is not None:
return compiled(*args, config=run_config)
return compiled(*args)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Forward bound keyword arguments to compiled program

After binding *args/**kwargs, execution still calls compiled(*args, ...), which drops any keyword-provided function arguments. A call like kernel(a=x, b=y, c=z) will compile using the right bound arguments but execute with zero positional args, causing CompiledProgram argument-count/type failures (and mixed positional/keyword calls can misroute values). Invoke the compiled program from the bound argument map in signature order instead of the raw args tuple.

Useful? React with 👍 / 👎.

Comment on lines +574 to +578
key = make_cache_key(
source_hash=self._get_source_hash(),
param_names=param_names,
tensor_shapes={n: m.shape for n, m in tensor_meta.items()},
tensor_dtypes={n: m.dtype for n, m in tensor_meta.items()},
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Include target platform in JIT cache key

The cache key is built without the requested RunConfig.platform, but compilation is platform-dependent (_compile(..., platform=...)) and CompiledProgram stores that platform for execution. If the first call compiles for one platform and a later call uses another platform with the same shapes/dtypes, the old cache entry is reused, so execution runs artifacts built for the wrong target.

Useful? React with 👍 / 👎.

@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch 2 times, most recently from c063149 to c6b8c6c Compare April 20, 2026 06:55
When @pl.jit compiles on cache miss, the platform was not forwarded to
ir_compile(), causing CompiledProgram to default to 'a2a3sim'.  Execution
then used the wrong platform (sim instead of hardware), triggering a
g++-15 dependency on the CI NPU machine.

Fix: extract platform from RunConfig before the cache lookup and pass it
to _compile(), which forwards it to ir_compile().  CompiledProgram then
stores the correct platform and execute_compiled uses it.
@huxinyuan1215 huxinyuan1215 force-pushed the feat/jit-preprocessor branch from c6b8c6c to 8ef6e96 Compare April 20, 2026 06:58
@Hzfengsy Hzfengsy merged commit d1253a3 into hw-native-sys:main Apr 20, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants