[Frontend] Fix #238: MLIRCSEVariable/MaskCSEVariable encode wrapper-C ABI invariant#246
Open
YWHyuk wants to merge 2 commits into
Open
[Frontend] Fix #238: MLIRCSEVariable/MaskCSEVariable encode wrapper-C ABI invariant#246YWHyuk wants to merge 2 commits into
YWHyuk wants to merge 2 commits into
Conversation
YWHyuk
commented
May 26, 2026
| assert isinstance(ret, OpResult), ( | ||
| f"op {name!r} must return (code, OpResult|None); got {type(ret).__name__}" | ||
| ) | ||
| if ret.is_mask: |
Collaborator
Author
There was a problem hiding this comment.
Why does it needed?
| var = self.create_cse_var(name, ValueRanges.unknown()) | ||
| self.register_var_info(var, [size, dtype]) | ||
| return var | ||
| def create_cse_var(self, *args, **kwargs): |
Collaborator
Author
There was a problem hiding this comment.
Why this wrapper needed? Inestaed of just using MLIRCSEVariable
| self.buffer_types : dict = None # format: dtype, numel, size, stride | ||
| # Create compute idx | ||
| self.compute_idx = self.register_var_cse("compute_idx", 1, "index") | ||
| self.compute_idx = self.make_named_csevar("compute_idx", vec_size=1) |
Collaborator
Author
There was a problem hiding this comment.
use default vec_size
| body_iter_arg = self.iterator_cse.generate(self.compute, f"reduction {body_key}body_iter_arg", write=False) | ||
| self.register_var_info(body_iter_arg, [vec_len, type_name]) | ||
| body_iter_arg.vec_size = vec_len | ||
| body_iter_arg.dtype = dtype |
Collaborator
Author
There was a problem hiding this comment.
This manual assign is not good
| self.register_var_info(acc, [reduction_size, type_name]) | ||
| acc.vec_size = reduction_size | ||
| acc.dtype = dtype | ||
| assert(vec_len % reduction_size==0) |
| # Initialize base vector | ||
| if not self.base_vector_initialized: | ||
| init_iter = self.register_var_cse("init_iter", 1, "index") | ||
| init_iter = self.make_named_csevar("init_iter", vec_size=1) |
Collaborator
Author
There was a problem hiding this comment.
use default vec_size
| if key not in self.consts: | ||
| self.consts[key] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") | ||
| self.register_var_info(self.consts[key], [1, dtype]) | ||
| self.consts[key].vec_size = 1 |
| return DTYPE_TO_MLIR[self.dtype] | ||
|
|
||
|
|
||
| class MaskCSEVariable(common.CSEVariable): |
Collaborator
Author
There was a problem hiding this comment.
Why this is required
| out = ops._load(vsize, mlir_dtype, sram_var, compute_index_var, tile_shape) | ||
| self.register_var_info(out, [self.compute_body_loop.step, mlir_dtype]) | ||
| out.vec_size = self.compute_body_loop.step | ||
| out.dtype = dtype |
Collaborator
Author
There was a problem hiding this comment.
is this needed?
Does the updated proxy can't handle this?
YWHyuk
commented
May 26, 2026
| self.r_dim_size = template_fusion_info['r_dim_size'] | ||
| self.reduction_nr_outer_loop = nr_outer_loop | ||
| self.reduction_loop_idx = self.register_var_cse("reduce_loop_idx", 1, "index") | ||
| self.reduction_loop_idx = self.make_named_csevar("reduce_loop_idx", vec_size=1) |
Collaborator
Author
There was a problem hiding this comment.
use default vec_size
dump_args:109 had `(N+7)//8` for torch.bool which computed a bit-packed byte count — inconsistent with the rest of the data path which uses 1 byte per bool throughout: - write_arg writes tensor.untyped_storage() as raw bytes (N bytes for N bools) - C wrapper load_arg reads N * sizeof(uint8_t) bytes - C wrapper malloc uses bits=8 for bool (mlir_caller_codegen.py:101) - DTYPE_TO_C[torch.bool] = "uint8_t" - spike sees the raw N bytes The (N+7)//8 line was a vestige of an abandoned i1-storage experiment; the wrapper C ABI is structurally byte-aligned and cannot accept bit-packed bool shapes. Plus array_size is unused at the caller (`_, file_path = self.dump_args(...)` at line 134), so this dead line was silently inconsistent. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1811855 to
9a771e4
Compare
…ributes Issue #238 was the visible symptom (silent bool/uint8 -> int8 downcast via the lossy `MLIR_TO_DTYPE[var_info[1]]` round-trip at mlir_codegen_backend.py:1535) of a deeper architectural smell: PyTorchSim maintained a parallel `self.var_info` dict tracking `[vec_size, mlir_dtype_string]` per csevar, duplicating type info that already lives on Inductor's `CSEVariable.dtype`. The lossy MLIR->torch round-trip was the only place this duplication actively caused corruption, but collapsing the two systems is the structural fix. Core changes: - New type: - `MLIRCSEVariable(common.CSEVariable)` carries `vec_size: int` and inherits `dtype: Optional[torch.dtype]`. `mlir_dtype` is a derived @Property from `dtype` via `DTYPE_TO_MLIR`. There is no separate predicate/mask subclass: `torch.bool` maps to MLIR `"i1"` directly (DTYPE_TO_MLIR[torch.bool] = "i1"). MLIR-to-LLVM lowering pads i1 storage to bytes, matching the wrapper C ABI (`uint8_t*`, one byte per element). The wrapper architecturally cannot accept bit-packed i1 storage (mlir_caller_codegen.py uses sizeof(ctype) loads), so the `memref<...xi1>` -> `i8`-backed pipeline is the natural fit. - `OpResult(vec_size, dtype)` frozen dataclass replaces the legacy `[vec_size, mlir_dtype_string]` ret_info list. `OpResult.from_var` and `OpResult.from_mlir` are classmethod constructors. - `INDEX_DTYPE` singleton sentinel for MLIR `index` type (no torch equivalent). `MLIR_TO_DTYPE["index"] = INDEX_DTYPE` and `DTYPE_TO_MLIR[INDEX_DTYPE] = "index"` so the dicts are bijective for all known types — clearer than overloading `None`. - `MLIRCSE(common.CSE)` extends Inductor's CSE with a `vec_size` axis: - `newvar` / `namedvar` construct `MLIRCSEVariable` directly, bypassing the kernel-side `V.kernel.create_cse_var` hook (which is no longer needed). - `generate(buffer, code, *, vec_size=N, dtype=X, ...)` plumbs `vec_size` to `newvar` via a transient instance attribute, calling `super().generate(...)` for the rest. No need to reimplement the upstream generate body. - Handler proxy (mlir_common.py CSEProxy) rewritten to expect `(code, OpResult|None)` from ops. Single uniform path: `target_cse.generate(buf, code, dtype=ret.dtype, vec_size=ret.vec_size)` — no post-hoc attribute assignment. - All ops in mlir_ops.py, mlir_template.py, mlir_sort_template.py return `(code, OpResult)` (or `OpResult.from_var` / `OpResult.from_mlir` helpers). Legacy `[size, mlir_str]` shape gone. - `register_var_info` / `register_var_cse` deleted. Six previously-named csevars (`compute_idx`, `itervar_cses`, `init_iter`, `reduce_loop_idx`, `idx_step_index`, `idx_base`) now use `cse.namedvar(..., dtype=..., vec_size=...)` directly. `make_named_csevar` wrapper removed. - ~108 read sites of `var_info[v][...]` migrated to attribute access (`v.vec_size`, `v.mlir_dtype`). `var_info[v][1] == "i1"` patterns collapse to `v.dtype == torch.bool` since the mask subclass is gone. - `self.var_info` dict removed entirely. - Issue #238 fix at mlir_codegen_backend.py:1535: csevar = self.cse.varname_map[target_dim] dtype = csevar.dtype No more round-trip; the torch dtype set at csevar construction is preserved end-to-end. Files touched: mlir_common.py (foundation), mlir_codegen_backend.py (#238 site + read migration + memory-entry call sites), mlir_ops.py (ops layer ret_info migration), mlir_template.py + mlir_sort_template.py (template ops + named csevar sites). Sample-verified: test_add, test_softmax, test_sort (i1 mask path via cmp), test_matmul, test_layernorm, test_indirect_access (#238 critical path), test_expert_mask, test_transcendental, test_reduce. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
9a771e4 to
bd7e98f
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes #238 (uint8/bool silently downcast to int8 via
MLIR_TO_DTYPEround-trip) by replacing the parallelvar_infodict with first-class attributes on twoCSEVariablesubclasses. The wrapper-C byte-aligned ABI invariant is now encoded in the type system rather than as an implicit convention.Root cause
mlir_codegen_backend.py:1536diddtype = MLIR_TO_DTYPE[var_info[1]].DTYPE_TO_MLIRis many-to-one (torch.bool/torch.uint8/torch.int8→"i8"), so the inverse silently pickedtorch.int8. DownstreamDTYPE_TO_Cthen choseint8_tinstead ofuint8_tfor SRAM buffer declarations — a silent-corruption-class bug.The fix is not "make the inverse table richer" (i1 vs i8 is genuinely two representations: predicate SSA vs byte storage) but "stop throwing the torch dtype away in the first place".
The decisive invariant
The wrapper C (
mlir_caller_codegen.py:57-105) loads/stores insizeof(ctype)units and usesbits=8for bool malloc. i1 storage is structurally not supported. That made the old(N+7)//8line indump_argsa dead workaround for an abandoned i1-storage experiment.Knowing this,
mlir_dtype = "i1"on any value implies "SSA-only predicate". This is now expressed via two CSEVariable subclasses:MLIRCSEVariable: storage-eligible.mlir_dtypeis a@propertyderived from torchdtype. Never"i1".MaskCSEVariable: i1 predicate.mlir_dtype = "i1"is a class-level constant. Memory-entry ops (store,allocate_sram_buffer)assert not isinstance(value, MaskCSEVariable).What changed
mlir_common.py:MLIRCSEVariable,MaskCSEVariable,OpResult.BaseMLIRKernel.create_cse_varoverridden.(code, OpResult|None)from ops; constructsMaskCSEVariableforis_mask=True, otherwiseMLIRCSEVariable.mlir_ops.py,mlir_template.py,mlir_sort_template.pymigrated to returnOpResult(orop_result_from_mlir/op_result_from_varhelpers).var_info[v][...]migrated to attribute access (v.vec_size,v.mlir_dtype).i1string compares becameisinstance(_, MaskCSEVariable).register_var_*replaced withmake_named_csevar(...)or direct attribute set on existing csevars.self.var_infodict,register_var_info,register_var_csehelpers. (MLIR_TO_DTYPEtable kept for legitimate uses byop_result_from_mlirandget_const_cse— no longer on the silent-corruption path.)mlir_codegen_backend.py:1535-1543:ops.to_predicatesplit out fromops.to_bool(storage→i1).ops.to_boolretained as a thin alias.(N+7)//8line inSimulator/simulator.py:109removed (independent commit).Commit layout
For review clarity the PR is split into logical commits:
[Sim] Remove dead bit-pack formula for bool in dump_args— independent precondition.[Frontend] Introduce MLIRCSEVariable / MaskCSEVariable / OpResult— additive foundation.[Frontend] Migrate type tracking to CSEVariable attributes; fix Issue #238— the bulk of the work (handler proxy, ops layer, var_info removal, [Bug][Frontend] uint8 and bool dtypes silently downcast to int8 via MLIR_TO_DTYPE round-trip #238 fix).[Frontend] Encode wrapper-C byte-aligned ABI invariant via isinstance guards— isinstance assertions on memory-entry ops +to_predicatesplit + ops.ext fix.Verification
Locally passed:
test_add,test_softmax,test_sort(i1 mask path),test_matmul,test_layernorm,test_indirect_access(the [Bug][Frontend] uint8 and bool dtypes silently downcast to int8 via MLIR_TO_DTYPE round-trip #238 critical path),test_expert_mask,test_transcendental,test_reduce.test_sparse_corefails identically ondevelop— pre-existing, unrelated.Why this is more than a one-line fix
Issue #238 looks like a typo in one table. But the lossy-table pattern is a symptom of PyTorchSim maintaining parallel type-tracking infrastructure (
self.var_info) on top of Inductor's existingCSEVariable.dtype. The two-class refactor:var_infoand all its readers (108 sites);🤖 Generated with Claude Code