Skip to content

[Bug][Frontend] Prologue fusion buffer-size matcher ignores slice/chunk views (llama4 SwiGLU) #244

@YWHyuk

Description

@YWHyuk

Summary

The prologue-fusion matcher in mlir_template.py:codegen_template_code decides which input buffer to reuse-as-spad by comparing the prologue node's get_numel() against each candidate read buffer's full get_size(). When the prologue reads a slice/chunk view of a larger buffer (numel of the view < numel of the parent buffer), the matcher finds no candidate and the bare assert(candidate_found) fires, killing the compile.

Triggered by llama4 MoE's SwiGLU FFN: gate_up = bmm(...) # [..., 2*E] then chunk(2, dim=-1) to produce gate/up, followed by silu(gate) * up as the prologue for the next bmm.

Repro

On develop @ 7b6daed (PR #231 merged), transformers 4.51.3:

python scripts/op_coverage.py --models llama4

(num_hidden_layers=2, interleave_moe_layer_step=2, num_local_experts=4, batch=1, seq_len=32, fp32.)

Original traceback

File ".../PyTorchSimFrontend/mlir/mlir_template.py", line 542, in codegen_template_code
    assert(candidate_found)
torch._inductor.exc.InductorError: AssertionError:

Diagnostic (patched assert locally with shape printout)

[prologue fusion] no input buffer matches numel of prologue node
  node: SchedulerNode(name='op95')
  node.get_numel(): 393216
  node.node.get_size(): [4, 32, 3072]
  reads: ['buf90', 'buf90']
  candidate buffers:
    buf90: size=[4, 32, 6144] numel=786432

393216 * 2 == 786432 -- the prologue reads two slice views of buf90, each half its parent. The matcher only checks the parent buffer's full numel and never considers the view.

Source pattern (transformers/models/llama4/modeling_llama4.py)

gate_up = torch.bmm(hidden_states, self.gate_up_proj)   # [..., 2 * expert_dim]
gate, up = gate_up.chunk(2, dim=-1)                     # each [..., expert_dim]
next_h = silu(gate) * up                                # prologue of down_proj

Root cause

mlir_template.py:537-542:

for candidate_read in read_list:
    if candidate_read in buf_dict and reduce(operator.mul, buf_dict[candidate_read].get_size(), 1) == node.node.get_numel():
        prologue_input_arg = candidate_read
        candidate_found = True
        break
assert(candidate_found)
  • buf_dict[candidate_read].get_size() is the parent buffer's size, not the view's.
  • read_list is derived from node.read_writes.reads (memdeps) which lose the view info before reaching this code.
  • For any chunk/split/slice read, the numel comparison is structurally guaranteed to mismatch.

Suggested fix

Two options:

  1. Make the matcher view-aware: walk node.read_writes.reads and use each MemoryDep's actual access size (or node.node.layout's view) instead of the parent buffer's get_size(). The current comment "memdep.get_size() != data.get_size()" already acknowledges this gap.

  2. Bail out gracefully when no candidate matches: turn the bare assert into a fallback that skips prologue fusion for this node (codegen the prologue as a standalone kernel). At minimum this should be the behavior; currently a recoverable scheduling choice kills the whole compile.

The bare assert(candidate_found) is also worth replacing with the diagnostic message above so future shape mismatches surface their context instead of an empty AssertionError.

Scope

Blocks llama4 MoE end-to-end forward; the same pattern likely affects any future model that uses split/chunk immediately before a fused-prologue gemm/bmm.

Environment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions