Skip to content

[Frontend] Fix #238: MLIRCSEVariable/MaskCSEVariable encode wrapper-C ABI invariant#246

Open
YWHyuk wants to merge 2 commits into
developfrom
bugfix/issue-238-type-system
Open

[Frontend] Fix #238: MLIRCSEVariable/MaskCSEVariable encode wrapper-C ABI invariant#246
YWHyuk wants to merge 2 commits into
developfrom
bugfix/issue-238-type-system

Conversation

@YWHyuk
Copy link
Copy Markdown
Collaborator

@YWHyuk YWHyuk commented May 26, 2026

Summary

Closes #238 (uint8/bool silently downcast to int8 via MLIR_TO_DTYPE round-trip) by replacing the parallel var_info dict with first-class attributes on two CSEVariable subclasses. The wrapper-C byte-aligned ABI invariant is now encoded in the type system rather than as an implicit convention.

Root cause

mlir_codegen_backend.py:1536 did dtype = MLIR_TO_DTYPE[var_info[1]]. DTYPE_TO_MLIR is many-to-one (torch.bool/torch.uint8/torch.int8"i8"), so the inverse silently picked torch.int8. Downstream DTYPE_TO_C then chose int8_t instead of uint8_t for SRAM buffer declarations — a silent-corruption-class bug.

The fix is not "make the inverse table richer" (i1 vs i8 is genuinely two representations: predicate SSA vs byte storage) but "stop throwing the torch dtype away in the first place".

The decisive invariant

The wrapper C (mlir_caller_codegen.py:57-105) loads/stores in sizeof(ctype) units and uses bits=8 for bool malloc. i1 storage is structurally not supported. That made the old (N+7)//8 line in dump_args a dead workaround for an abandoned i1-storage experiment.

Knowing this, mlir_dtype = "i1" on any value implies "SSA-only predicate". This is now expressed via two CSEVariable subclasses:

  • MLIRCSEVariable: storage-eligible. mlir_dtype is a @property derived from torch dtype. Never "i1".
  • MaskCSEVariable: i1 predicate. mlir_dtype = "i1" is a class-level constant. Memory-entry ops (store, allocate_sram_buffer) assert not isinstance(value, MaskCSEVariable).

What changed

  1. New types in mlir_common.py: MLIRCSEVariable, MaskCSEVariable, OpResult. BaseMLIRKernel.create_cse_var overridden.
  2. Handler proxy rewritten to expect (code, OpResult|None) from ops; constructs MaskCSEVariable for is_mask=True, otherwise MLIRCSEVariable.
  3. All ~58 ops in mlir_ops.py, mlir_template.py, mlir_sort_template.py migrated to return OpResult (or op_result_from_mlir/op_result_from_var helpers).
  4. 108 read sites of var_info[v][...] migrated to attribute access (v.vec_size, v.mlir_dtype). i1 string compares became isinstance(_, MaskCSEVariable).
  5. 9 write sites of register_var_* replaced with make_named_csevar(...) or direct attribute set on existing csevars.
  6. Deleted: self.var_info dict, register_var_info, register_var_cse helpers. (MLIR_TO_DTYPE table kept for legitimate uses by op_result_from_mlir and get_const_cse — no longer on the silent-corruption path.)
  7. [Bug][Frontend] uint8 and bool dtypes silently downcast to int8 via MLIR_TO_DTYPE round-trip #238 fix at mlir_codegen_backend.py:1535-1543:
    # Before
    var_info = [v for k, v in self.var_info.items() if str(k) == target_dim][0]
    dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]]
    # After
    csevar = self.cse.varname_map[target_dim]
    assert isinstance(csevar, mlir_common.MLIRCSEVariable)
    dtype = csevar.dtype
  8. ops.to_predicate split out from ops.to_bool (storage→i1). ops.to_bool retained as a thin alias.
  9. Dead (N+7)//8 line in Simulator/simulator.py:109 removed (independent commit).

Commit layout

For review clarity the PR is split into logical commits:

  1. [Sim] Remove dead bit-pack formula for bool in dump_args — independent precondition.
  2. [Frontend] Introduce MLIRCSEVariable / MaskCSEVariable / OpResult — additive foundation.
  3. [Frontend] Migrate type tracking to CSEVariable attributes; fix Issue #238 — the bulk of the work (handler proxy, ops layer, var_info removal, [Bug][Frontend] uint8 and bool dtypes silently downcast to int8 via MLIR_TO_DTYPE round-trip #238 fix).
  4. [Frontend] Encode wrapper-C byte-aligned ABI invariant via isinstance guards — isinstance assertions on memory-entry ops + to_predicate split + ops.ext fix.

Verification

Locally passed:

test_sparse_core fails identically on develop — pre-existing, unrelated.

Why this is more than a one-line fix

Issue #238 looks like a typo in one table. But the lossy-table pattern is a symptom of PyTorchSim maintaining parallel type-tracking infrastructure (self.var_info) on top of Inductor's existing CSEVariable.dtype. The two-class refactor:

  • eliminates var_info and all its readers (108 sites);
  • makes the wrapper-C ABI invariant ("i1 cannot enter memory") structural rather than conventional;
  • prevents the entire class of "MLIR string → torch dtype" round-trip corruption from being possible in the future.

🤖 Generated with Claude Code

@YWHyuk YWHyuk changed the base branch from master to develop May 26, 2026 10:45
Comment thread PyTorchSimFrontend/mlir/mlir_common.py Outdated
assert isinstance(ret, OpResult), (
f"op {name!r} must return (code, OpResult|None); got {type(ret).__name__}"
)
if ret.is_mask:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it needed?

Comment thread PyTorchSimFrontend/mlir/mlir_common.py Outdated
var = self.create_cse_var(name, ValueRanges.unknown())
self.register_var_info(var, [size, dtype])
return var
def create_cse_var(self, *args, **kwargs):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this wrapper needed? Inestaed of just using MLIRCSEVariable

Comment thread PyTorchSimFrontend/mlir/mlir_common.py Outdated
self.buffer_types : dict = None # format: dtype, numel, size, stride
# Create compute idx
self.compute_idx = self.register_var_cse("compute_idx", 1, "index")
self.compute_idx = self.make_named_csevar("compute_idx", vec_size=1)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use default vec_size

body_iter_arg = self.iterator_cse.generate(self.compute, f"reduction {body_key}body_iter_arg", write=False)
self.register_var_info(body_iter_arg, [vec_len, type_name])
body_iter_arg.vec_size = vec_len
body_iter_arg.dtype = dtype
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This manual assign is not good

self.register_var_info(acc, [reduction_size, type_name])
acc.vec_size = reduction_size
acc.dtype = dtype
assert(vec_len % reduction_size==0)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

# Initialize base vector
if not self.base_vector_initialized:
init_iter = self.register_var_cse("init_iter", 1, "index")
init_iter = self.make_named_csevar("init_iter", vec_size=1)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use default vec_size

if key not in self.consts:
self.consts[key] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}")
self.register_var_info(self.consts[key], [1, dtype])
self.consts[key].vec_size = 1
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not good

Comment thread PyTorchSimFrontend/mlir/mlir_common.py Outdated
return DTYPE_TO_MLIR[self.dtype]


class MaskCSEVariable(common.CSEVariable):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is required

out = ops._load(vsize, mlir_dtype, sram_var, compute_index_var, tile_shape)
self.register_var_info(out, [self.compute_body_loop.step, mlir_dtype])
out.vec_size = self.compute_body_loop.step
out.dtype = dtype
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

Does the updated proxy can't handle this?

self.r_dim_size = template_fusion_info['r_dim_size']
self.reduction_nr_outer_loop = nr_outer_loop
self.reduction_loop_idx = self.register_var_cse("reduce_loop_idx", 1, "index")
self.reduction_loop_idx = self.make_named_csevar("reduce_loop_idx", vec_size=1)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use default vec_size

dump_args:109 had `(N+7)//8` for torch.bool which computed a bit-packed
byte count — inconsistent with the rest of the data path which uses 1
byte per bool throughout:

- write_arg writes tensor.untyped_storage() as raw bytes (N bytes for N bools)
- C wrapper load_arg reads N * sizeof(uint8_t) bytes
- C wrapper malloc uses bits=8 for bool (mlir_caller_codegen.py:101)
- DTYPE_TO_C[torch.bool] = "uint8_t"
- spike sees the raw N bytes

The (N+7)//8 line was a vestige of an abandoned i1-storage experiment;
the wrapper C ABI is structurally byte-aligned and cannot accept
bit-packed bool shapes. Plus array_size is unused at the caller
(`_, file_path = self.dump_args(...)` at line 134), so this dead line
was silently inconsistent.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@YWHyuk YWHyuk force-pushed the bugfix/issue-238-type-system branch 2 times, most recently from 1811855 to 9a771e4 Compare May 26, 2026 14:59
…ributes

Issue #238 was the visible symptom (silent bool/uint8 -> int8 downcast
via the lossy `MLIR_TO_DTYPE[var_info[1]]` round-trip at
mlir_codegen_backend.py:1535) of a deeper architectural smell: PyTorchSim
maintained a parallel `self.var_info` dict tracking `[vec_size,
mlir_dtype_string]` per csevar, duplicating type info that already lives
on Inductor's `CSEVariable.dtype`. The lossy MLIR->torch round-trip was
the only place this duplication actively caused corruption, but
collapsing the two systems is the structural fix.

Core changes:

- New type:
  - `MLIRCSEVariable(common.CSEVariable)` carries `vec_size: int` and
    inherits `dtype: Optional[torch.dtype]`. `mlir_dtype` is a derived
    @Property from `dtype` via `DTYPE_TO_MLIR`. There is no separate
    predicate/mask subclass: `torch.bool` maps to MLIR `"i1"` directly
    (DTYPE_TO_MLIR[torch.bool] = "i1"). MLIR-to-LLVM lowering pads i1
    storage to bytes, matching the wrapper C ABI (`uint8_t*`, one byte
    per element). The wrapper architecturally cannot accept bit-packed
    i1 storage (mlir_caller_codegen.py uses sizeof(ctype) loads), so the
    `memref<...xi1>` -> `i8`-backed pipeline is the natural fit.
  - `OpResult(vec_size, dtype)` frozen dataclass replaces the legacy
    `[vec_size, mlir_dtype_string]` ret_info list. `OpResult.from_var`
    and `OpResult.from_mlir` are classmethod constructors.
  - `INDEX_DTYPE` singleton sentinel for MLIR `index` type (no torch
    equivalent). `MLIR_TO_DTYPE["index"] = INDEX_DTYPE` and
    `DTYPE_TO_MLIR[INDEX_DTYPE] = "index"` so the dicts are
    bijective for all known types — clearer than overloading `None`.

- `MLIRCSE(common.CSE)` extends Inductor's CSE with a `vec_size` axis:
  - `newvar` / `namedvar` construct `MLIRCSEVariable` directly,
    bypassing the kernel-side `V.kernel.create_cse_var` hook (which is
    no longer needed).
  - `generate(buffer, code, *, vec_size=N, dtype=X, ...)` plumbs
    `vec_size` to `newvar` via a transient instance attribute, calling
    `super().generate(...)` for the rest. No need to reimplement the
    upstream generate body.

- Handler proxy (mlir_common.py CSEProxy) rewritten to expect
  `(code, OpResult|None)` from ops. Single uniform path:
  `target_cse.generate(buf, code, dtype=ret.dtype, vec_size=ret.vec_size)`
  — no post-hoc attribute assignment.

- All ops in mlir_ops.py, mlir_template.py, mlir_sort_template.py
  return `(code, OpResult)` (or `OpResult.from_var` / `OpResult.from_mlir`
  helpers). Legacy `[size, mlir_str]` shape gone.

- `register_var_info` / `register_var_cse` deleted. Six previously-named
  csevars (`compute_idx`, `itervar_cses`, `init_iter`, `reduce_loop_idx`,
  `idx_step_index`, `idx_base`) now use `cse.namedvar(..., dtype=...,
  vec_size=...)` directly. `make_named_csevar` wrapper removed.

- ~108 read sites of `var_info[v][...]` migrated to attribute access
  (`v.vec_size`, `v.mlir_dtype`). `var_info[v][1] == "i1"` patterns
  collapse to `v.dtype == torch.bool` since the mask subclass is gone.

- `self.var_info` dict removed entirely.

- Issue #238 fix at mlir_codegen_backend.py:1535:
    csevar = self.cse.varname_map[target_dim]
    dtype = csevar.dtype
  No more round-trip; the torch dtype set at csevar construction is
  preserved end-to-end.

Files touched: mlir_common.py (foundation), mlir_codegen_backend.py
(#238 site + read migration + memory-entry call sites),
mlir_ops.py (ops layer ret_info migration), mlir_template.py +
mlir_sort_template.py (template ops + named csevar sites).

Sample-verified: test_add, test_softmax, test_sort (i1 mask path via cmp),
test_matmul, test_layernorm, test_indirect_access (#238 critical path),
test_expert_mask, test_transcendental, test_reduce.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@YWHyuk YWHyuk force-pushed the bugfix/issue-238-type-system branch from 9a771e4 to bd7e98f Compare May 27, 2026 03:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug][Frontend] uint8 and bool dtypes silently downcast to int8 via MLIR_TO_DTYPE round-trip

1 participant