Skip to content

[BUG]Support for repeat_interleave operation to enable Grouped Query Attention (GQA) #198

@Dignity-ghost

Description

@Dignity-ghost

Description:

The PyTorchSim MLIR backend does not support the repeat_interleave
operation, which is required for Grouped Query Attention (GQA) used in
modern LLMs like Mixtral 8x7B and Llama 2. In GQA, the number of KV heads
(n_local_heads) is smaller than the number of Q heads (n_head), and
repeat_interleave is used to expand K/V tensors to match Q's head count
(e.g., k.repeat_interleave(n_head // n_local_heads, dim=1)).

The error occurs in PyTorchSimFrontend/mlir/mlir_common.py at line 682-683
in the extract_dividers function, which raises NotImplementedError("Not
supporting this view operation...!") when encountering index expressions
with multiple free symbols. The repeat_interleave operation generates a
FloorDiv indexing pattern (e.g., index // 4 for 4x repetition) that the
current implementation cannot handle.

Current workaround: Set n_local_heads = -1 (which defaults to n_head),
effectively disabling GQA and using standard Multi-Head Attention (MHA).
This loses the memory efficiency benefits of GQA (4-8x KV cache
reduction).

Affected models: Mixtral 8x7B, Llama 2 70B, and other models using GQA.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions