Skip to content

Commit acf15bf

Browse files
authored
Enable fused MoE kernel for Qwen 3.5 MoE model (pytorch#18388)
Replace the compute-all-gather approach (ConditionalFeedForward with grouped nn.Linear) with FusedMoEExperts that calls the fused MoE Triton kernel directly. Expert weights are quantized to simple packed INT4 using torchao primitives, separate from the tinygemm path used for attention and shared expert linears. For decode (M=1), only 8 of 256 experts' weights are loaded from HBM per layer (128x less memory traffic vs the old approach). Depends on the fused MoE Triton kernel (triton::fused_moe). Decode latency: 12.41 tokens/s Prefill latency: 47.3 tokens/s on A100
1 parent f1a61fc commit acf15bf

2 files changed

Lines changed: 149 additions & 94 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.nn as nn
1414

15-
from executorch.examples.models.qwen3_5_moe.model import Qwen35MoE
15+
from executorch.examples.models.qwen3_5_moe.model import FusedMoEExperts, Qwen35MoE
1616

1717

1818
# ---------------------------------------------------------------------------
@@ -43,6 +43,89 @@ def load_and_quantize(args):
4343
return model, config
4444

4545

46+
def _quantize_experts_int4(model, config, group_size=32, use_hqq=False):
47+
"""Quantize expert weights to packed INT4 for the fused MoE kernel.
48+
49+
Two quantization methods:
50+
--hqq: HQQ (Half-Quadratic Quantization) iteratively refines scales
51+
via least-squares for better accuracy (slower).
52+
default: Standard min/max symmetric quantization (faster).
53+
54+
Converts w1_weight [E, N, K] and w2_weight [E, N, K] to:
55+
w1 [E, N, K//2] int8 packed, w1_scale [E, N, K//gs] bf16
56+
w2 [E, N, K//2] int8 packed, w2_scale [E, N, K//gs] bf16
57+
"""
58+
if use_hqq:
59+
from torchao.quantization.quant_primitives import (
60+
_choose_qparams_and_quantize_scale_only_hqq,
61+
)
62+
else:
63+
from torchao.quantization.quant_primitives import (
64+
choose_qparams_affine,
65+
MappingType,
66+
quantize_affine,
67+
)
68+
69+
method = "HQQ" if use_hqq else "min/max"
70+
71+
for i, layer in enumerate(model.layers):
72+
experts = layer.mlp.experts
73+
if not isinstance(experts, FusedMoEExperts):
74+
continue
75+
76+
experts.group_size = group_size
77+
for name in ("w1_weight", "w2_weight"):
78+
w = getattr(experts, name).data.float()
79+
E, N, K = w.shape
80+
81+
if use_hqq:
82+
qdata, scale = _choose_qparams_and_quantize_scale_only_hqq(
83+
w.view(E * N, K),
84+
block_size=[1, group_size],
85+
qmin=-8,
86+
qmax=7,
87+
)
88+
int_data = qdata.to(torch.int8).view(E, N, K)
89+
scale = scale.view(E, N, -1)
90+
else:
91+
block_size = (1, 1, group_size)
92+
scale, zero_point = choose_qparams_affine(
93+
w,
94+
MappingType.SYMMETRIC,
95+
block_size,
96+
target_dtype=torch.int8,
97+
quant_min=-8,
98+
quant_max=7,
99+
)
100+
int_data = quantize_affine(
101+
w,
102+
block_size,
103+
scale,
104+
zero_point,
105+
output_dtype=torch.int8,
106+
quant_min=-8,
107+
quant_max=7,
108+
)
109+
scale = scale.reshape(E, N, -1)
110+
111+
# Pack two int4 values per byte: even K -> low nibble, odd K -> high nibble
112+
uint4 = (int_data + 8).to(torch.int16) # shift to unsigned [0, 15]
113+
low = uint4[:, :, 0::2]
114+
high = uint4[:, :, 1::2]
115+
packed = (low | (high << 4)).to(torch.int8) # [E, N, K//2]
116+
117+
buf_name = name.replace("_weight", "")
118+
experts.register_buffer(buf_name, packed)
119+
experts.register_buffer(f"{buf_name}_scale", scale.to(torch.bfloat16))
120+
delattr(experts, name)
121+
122+
print(
123+
f" Quantized experts (INT4 {method}) layer {i + 1}/{config.num_hidden_layers}",
124+
end="\r",
125+
)
126+
print()
127+
128+
46129
def _to_device_skip_meta(module, device, dtype=None):
47130
"""Move submodules to device, skipping any that have meta-device buffers.
48131
@@ -71,6 +154,10 @@ def _quantize(model, config, args):
71154
"""
72155
from executorch.extension.llm.export.quantize import quantize_model_
73156

157+
# Quantize MoE expert weights (packed INT4 for fused_moe kernel)
158+
if args.qlinear:
159+
_quantize_experts_int4(model, config, args.qlinear_group_size, use_hqq=args.hqq)
160+
74161
# Untie lm_head/embedding so they can be quantized independently:
75162
# embedding uses index lookup (8w), lm_head uses matmul (4w).
76163
if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr():
@@ -287,8 +374,16 @@ def main():
287374
parser.add_argument(
288375
"--qembedding", default=None, choices=["8w"], help="Quantize embedding layers."
289376
)
377+
parser.add_argument(
378+
"--hqq",
379+
action="store_true",
380+
help="Use HQQ scale-only optimization for expert quantization (slower, better accuracy).",
381+
)
290382
args = parser.parse_args()
291383

384+
if args.hqq and not args.qlinear:
385+
parser.error("--hqq requires --qlinear")
386+
292387
# Register FLA Triton kernel
293388
import executorch.backends.cuda.triton.kernels # noqa: F401
294389

examples/models/qwen3_5_moe/model.py

Lines changed: 53 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -399,71 +399,54 @@ def forward(self, x, input_pos):
399399

400400

401401
# ---------------------------------------------------------------------------
402-
# MoE: stacked expert weights + index by top-k
402+
# MoE: expert weights for fused MoE Triton kernel
403403

404-
# 16 experts per group keeps each nn.Linear under ~32K output features,
405-
# within tinygemm int4 packing limits while keeping the graph small
406-
# (32 matmul nodes per layer instead of 768 with per-expert linears).
407-
_EXPERTS_PER_GROUP = 16
408404

405+
class FusedMoEExperts(nn.Module):
406+
"""Expert weights stored as stacked tensors for the fused MoE Triton kernel.
409407
410-
class ConditionalFeedForward(nn.Module):
411-
"""Grouped expert weights as nn.Linear for quantization compatibility.
408+
Before quantization: w1_weight [E, 2*inter, hidden] and w2_weight [E, hidden, inter]
409+
are nn.Parameter tensors loaded from the checkpoint.
412410
413-
Experts are split into groups of _EXPERTS_PER_GROUP. Each group has:
414-
gate_up_projs[g]: nn.Linear(hidden_size, G * intermediate_size * 2)
415-
down_projs[g]: nn.Linear(intermediate_size, G * hidden_size)
416-
This keeps each nn.Linear small enough for tinygemm int4 packing while
417-
allowing quantize_model_() to handle them automatically.
411+
After quantization (in export.py): replaced with packed INT4 buffers
412+
w1 [E, 2*inter, hidden//2], w1_scale, w2 [E, hidden, inter//2], w2_scale.
418413
"""
419414

420-
def __init__(self, hidden_size, intermediate_size, num_experts):
415+
def __init__(self, config):
421416
super().__init__()
422-
self.num_experts = num_experts
423-
self.intermediate_size = intermediate_size
424-
self.hidden_size = hidden_size
425-
G = _EXPERTS_PER_GROUP
426-
assert num_experts % G == 0
427-
num_groups = num_experts // G
428-
429-
self.gate_up_projs = nn.ModuleList(
430-
[
431-
nn.Linear(hidden_size, G * intermediate_size * 2, bias=False)
432-
for _ in range(num_groups)
433-
]
417+
self.num_experts = config.num_experts
418+
self.intermediate_size = config.moe_intermediate_size
419+
self.hidden_size = config.hidden_size
420+
self.group_size = 32
421+
422+
self.w1_weight = nn.Parameter(
423+
torch.empty(
424+
config.num_experts,
425+
2 * config.moe_intermediate_size,
426+
config.hidden_size,
427+
)
434428
)
435-
self.down_projs = nn.ModuleList(
436-
[
437-
nn.Linear(intermediate_size, G * hidden_size, bias=False)
438-
for _ in range(num_groups)
439-
]
429+
self.w2_weight = nn.Parameter(
430+
torch.empty(
431+
config.num_experts,
432+
config.hidden_size,
433+
config.moe_intermediate_size,
434+
)
440435
)
441436

442-
def forward(self, x, expert_indices):
443-
# x: (T, D), expert_indices: (T, top_k)
444-
T = x.size(0)
445-
top_k = expert_indices.size(1)
446-
G = _EXPERTS_PER_GROUP
447-
H = self.intermediate_size
448-
D = self.hidden_size
449-
450-
# Gate + Up: compute per-group, cat, gather top-k
451-
gate_up_parts = [proj(x).view(T, G, 2, H) for proj in self.gate_up_projs]
452-
gate_up = torch.cat(gate_up_parts, dim=1) # (T, E, 2, H)
453-
454-
idx = expert_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, H)
455-
gate_up_sel = gate_up.gather(1, idx) # (T, top_k, 2, H)
456-
intermediate = F.silu(gate_up_sel[:, :, 0, :]) * gate_up_sel[:, :, 1, :]
457-
458-
# Down: compute per-group, cat, gather correct expert per slot
459-
intermediate_flat = intermediate.reshape(T * top_k, H)
460-
down_parts = [
461-
proj(intermediate_flat).view(T, top_k, G, D) for proj in self.down_projs
462-
]
463-
all_down = torch.cat(down_parts, dim=2) # (T, top_k, E, D)
464-
465-
eidx = expert_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, D)
466-
return all_down.gather(2, eidx).squeeze(2) # (T, top_k, D)
437+
def forward(self, x, expert_weights, expert_indices, top_k):
438+
return torch.ops.triton.fused_moe(
439+
x,
440+
self.w1,
441+
self.w1_scale,
442+
self.w2,
443+
self.w2_scale,
444+
expert_weights,
445+
expert_indices,
446+
top_k,
447+
self.num_experts,
448+
self.group_size,
449+
)
467450

468451

469452
class SwiGLU(nn.Module):
@@ -484,12 +467,9 @@ class SparseMoE(nn.Module):
484467
def __init__(self, config):
485468
super().__init__()
486469
self.top_k = config.num_experts_per_tok
470+
self.num_experts = config.num_experts
487471
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
488-
self.cond_ffn = ConditionalFeedForward(
489-
config.hidden_size,
490-
config.moe_intermediate_size,
491-
config.num_experts,
492-
)
472+
self.experts = FusedMoEExperts(config)
493473
self.shared_expert = SwiGLU(
494474
config.hidden_size, config.shared_expert_intermediate_size
495475
)
@@ -503,8 +483,9 @@ def forward(self, x):
503483
expert_weights, expert_indices = torch.topk(scores, self.top_k, dim=-1)
504484
expert_weights = expert_weights.softmax(dim=-1)
505485

506-
expert_outs = self.cond_ffn(x_flat, expert_indices)
507-
routed_out = torch.einsum("tai,ta->ti", expert_outs, expert_weights)
486+
routed_out = self.experts(
487+
x_flat, expert_weights.float(), expert_indices, self.top_k
488+
)
508489

509490
shared_out = self.shared_expert(x_flat)
510491
shared_gate = torch.sigmoid(self.shared_expert_gate(x_flat))
@@ -641,9 +622,8 @@ def _load_and_remap_checkpoint(model_dir, config):
641622
expert_weights,
642623
)
643624

644-
# Stack per-expert weights, split into groups, reshape for nn.Linear
625+
# Stack per-expert weights into [E, N, K] tensors for FusedMoEExperts
645626
if expert_weights:
646-
G = _EXPERTS_PER_GROUP
647627
for layer_idx in range(config.num_hidden_layers):
648628
gate_list = [
649629
expert_weights.get((layer_idx, "gate", e))
@@ -661,21 +641,13 @@ def _load_and_remap_checkpoint(model_dir, config):
661641
if gate_list[0] is not None:
662642
w_gate = torch.stack(gate_list, dim=0) # (E, H, D)
663643
w_up = torch.stack(up_list, dim=0)
664-
fused = torch.cat([w_gate, w_up], dim=1) # (E, 2*H, D)
665-
num_groups = config.num_experts // G
666-
for g in range(num_groups):
667-
chunk = fused[g * G : (g + 1) * G]
668-
state_dict[
669-
f"layers.{layer_idx}.mlp.cond_ffn.gate_up_projs.{g}.weight"
670-
] = chunk.reshape(-1, chunk.size(-1))
644+
state_dict[f"layers.{layer_idx}.mlp.experts.w1_weight"] = torch.cat(
645+
[w_gate, w_up], dim=1
646+
) # (E, 2*H, D)
671647
if down_list[0] is not None:
672-
w_down = torch.stack(down_list, dim=0) # (E, D, H)
673-
num_groups = config.num_experts // G
674-
for g in range(num_groups):
675-
chunk = w_down[g * G : (g + 1) * G]
676-
state_dict[
677-
f"layers.{layer_idx}.mlp.cond_ffn.down_projs.{g}.weight"
678-
] = chunk.reshape(-1, chunk.size(-1))
648+
state_dict[f"layers.{layer_idx}.mlp.experts.w2_weight"] = torch.stack(
649+
down_list, dim=0
650+
) # (E, D, H)
679651
del expert_weights
680652

681653
# Handle tied embeddings
@@ -697,27 +669,15 @@ def _process_checkpoint_key(ckpt_key, tensor, state_dict, expert_weights):
697669
if norm_key.startswith(("model.visual.", "model.mtp_")):
698670
return
699671

700-
# Fused expert weights: split into groups of _EXPERTS_PER_GROUP
672+
# Fused expert weights: store directly as [E, N, K] for FusedMoEExperts
701673
m = _FUSED_EXPERT_RE.match(norm_key)
702674
if m:
703675
layer_idx = int(m.group(1))
704676
proj_name = m.group(2)
705-
G = _EXPERTS_PER_GROUP
706-
num_groups = tensor.size(0) // G
707677
if proj_name == "gate_up_proj":
708-
# (E, 2*H, D) → groups of (G, 2*H, D) → each (G*2*H, D)
709-
for g in range(num_groups):
710-
chunk = tensor[g * G : (g + 1) * G]
711-
state_dict[
712-
f"layers.{layer_idx}.mlp.cond_ffn.gate_up_projs.{g}.weight"
713-
] = chunk.reshape(-1, chunk.size(-1)).contiguous()
678+
state_dict[f"layers.{layer_idx}.mlp.experts.w1_weight"] = tensor
714679
else:
715-
# down_proj: (E, D, H) → groups of (G, D, H) → each (G*D, H)
716-
for g in range(num_groups):
717-
chunk = tensor[g * G : (g + 1) * G]
718-
state_dict[f"layers.{layer_idx}.mlp.cond_ffn.down_projs.{g}.weight"] = (
719-
chunk.reshape(-1, chunk.size(-1)).contiguous()
720-
)
680+
state_dict[f"layers.{layer_idx}.mlp.experts.w2_weight"] = tensor
721681
return
722682

723683
# Per-expert weights

0 commit comments

Comments
 (0)