Skip to content

Commit f479ecf

Browse files
authored
Fuse linear projections, SiLU activation, and replace conv1d to reduce kernel launches (pytorch#18392)
Apply 5 optimizations validated on nano_qwen35_moe: 1. Fuse SiLU into MoE GEMM2: new _fused_moe_silu_kernel reads gate+up from GEMM1 output and applies SiLU on-the-fly during GEMM2, eliminating the intermediate buffer and 1 kernel launch per layer. 2. Fuse QKV projections in full attention: separate q_proj, k_proj, v_proj replaced with single qkv_proj. Saves 2 kernel launches per full attention layer (10 layers = 20 launches). 3. Fuse GDN input projections: separate in_proj_qkv, in_proj_z, in_proj_b, in_proj_a replaced with single in_proj. Saves 3 kernel launches per GDN layer (30 layers = 90 launches). 4. Fuse gate+up in shared expert: separate gate_proj, up_proj replaced with single gate_up_proj. Saves 1 kernel launch per layer (40 layers = 40 launches). 5. Replace F.conv1d with manual depthwise conv (4 slice-multiply-adds). The conv1d->conv2d decomposition generated a catastrophically slow Triton kernel (2.1ms/call at 8192 channels). The manual approach produces simple element-wise ops that Inductor fuses efficiently. Eliminates 81.8% of decode CUDA time (64ms -> 4ms per step). Before: Decode latency: 12.41 tokens/s Prefill latency: 47.3 tokens/s After Decode latency: 58.5 tokens/s Prefill latency: 96 tokens/s on A100
1 parent 5ae3bf2 commit f479ecf

3 files changed

Lines changed: 215 additions & 61 deletions

File tree

backends/cuda/triton/kernels/fused_moe.py

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,109 @@ def _fused_moe_kernel(
147147
tl.store(c_ptrs, acc.to(compute_type), mask=n_mask)
148148

149149

150+
@triton.jit
151+
def _fused_moe_silu_kernel(
152+
# Pointers
153+
A, # [M * top_k, 2*inter] bf16 GEMM1 output (gate | up)
154+
B, # [E, N, K//2] int8 packed INT4 weights
155+
C, # [M * top_k, N] bf16 output
156+
B_scale, # [E, N, K//group_size] bf16 scales
157+
topk_ids, # [M * top_k] int64 expert indices
158+
topk_weights, # [M * top_k] float32 router weights
159+
# Dimensions
160+
N: tl.constexpr,
161+
K: tl.constexpr, # intermediate_size
162+
num_token_expert_pairs,
163+
# Strides
164+
stride_am,
165+
stride_ak,
166+
stride_be,
167+
stride_bk,
168+
stride_bn,
169+
stride_cm,
170+
stride_cn,
171+
stride_bse,
172+
stride_bsk,
173+
stride_bsn,
174+
# Config
175+
group_size: tl.constexpr,
176+
BLOCK_SIZE_N: tl.constexpr,
177+
BLOCK_SIZE_K: tl.constexpr,
178+
compute_type: tl.constexpr,
179+
):
180+
"""GEMM2 with fused SiLU activation.
181+
182+
Reads gate and up columns from GEMM1 output (A), applies SiLU(gate)*up
183+
on-the-fly, and multiplies by INT4 w2 weights. Router weights are applied
184+
to the output. Eliminates the intermediate activation buffer.
185+
"""
186+
pid = tl.program_id(0)
187+
num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N)
188+
pair_idx = pid // num_n_blocks
189+
n_block = pid % num_n_blocks
190+
191+
if pair_idx >= num_token_expert_pairs:
192+
return
193+
194+
expert_id = tl.load(topk_ids + pair_idx).to(tl.int64)
195+
196+
offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
197+
n_mask = offs_n < N
198+
offs_k = tl.arange(0, BLOCK_SIZE_K)
199+
200+
# A pointers: gate at columns [0, K), up at columns [K, 2*K)
201+
a_gate_ptrs = A + pair_idx * stride_am + offs_k * stride_ak
202+
a_up_ptrs = a_gate_ptrs + K * stride_ak
203+
204+
# B pointer: [expert_id, offs_n, offs_k//2]
205+
b_ptrs = (
206+
B
207+
+ expert_id * stride_be
208+
+ (offs_k[:, None] // 2) * stride_bk
209+
+ offs_n[None, :] * stride_bn
210+
)
211+
b_shifter = (offs_k[:, None] % 2) * 4
212+
213+
acc = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
214+
215+
for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
216+
k_remaining = K - k_step * BLOCK_SIZE_K
217+
k_mask = offs_k < k_remaining
218+
219+
# Load gate and up, apply SiLU(gate) * up
220+
gate = tl.load(a_gate_ptrs, mask=k_mask, other=0.0).to(tl.float32)
221+
up = tl.load(a_up_ptrs, mask=k_mask, other=0.0)
222+
a = (gate * tl.sigmoid(gate) * up).to(compute_type)
223+
224+
# Load and dequantize INT4 weights
225+
b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0)
226+
b = (b >> b_shifter) & 0xF
227+
228+
scale_ptrs = (
229+
B_scale
230+
+ expert_id * stride_bse
231+
+ offs_n[None, :] * stride_bsn
232+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
233+
)
234+
b_scale = tl.load(
235+
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
236+
).to(tl.float32)
237+
238+
b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type)
239+
acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0)
240+
241+
a_gate_ptrs += BLOCK_SIZE_K * stride_ak
242+
a_up_ptrs += BLOCK_SIZE_K * stride_ak
243+
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
244+
245+
# Multiply by router weight
246+
weight = tl.load(topk_weights + pair_idx)
247+
acc = acc * weight
248+
249+
c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn
250+
tl.store(c_ptrs, acc.to(compute_type), mask=n_mask)
251+
252+
150253
# ---------------------------------------------------------------------------
151254
# triton_op wrapper
152255
# ---------------------------------------------------------------------------
@@ -231,18 +334,13 @@ def fused_moe(
231334
compute_type=tl.bfloat16,
232335
)
233336

234-
# ---- Activation: SiLU(gate) * up ----
235-
gate = cache1[:, :intermediate]
236-
up = cache1[:, intermediate:]
237-
cache2 = torch.nn.functional.silu(gate) * up
238-
239-
# ---- GEMM2: down projection, multiply by router weights ----
337+
# ---- GEMM2 with fused SiLU: reads gate+up from cache1, no intermediate buffer ----
240338
cache3 = torch.empty(
241339
num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device
242340
)
243341
grid2 = (num_pairs * triton.cdiv(N2, BLOCK_SIZE_N),)
244-
wrap_triton(_fused_moe_kernel)[grid2](
245-
cache2,
342+
wrap_triton(_fused_moe_silu_kernel)[grid2](
343+
cache1,
246344
w2,
247345
cache3,
248346
w2_scale,
@@ -251,8 +349,8 @@ def fused_moe(
251349
N=N2,
252350
K=intermediate,
253351
num_token_expert_pairs=num_pairs,
254-
stride_am=cache2.stride(0),
255-
stride_ak=cache2.stride(1),
352+
stride_am=cache1.stride(0),
353+
stride_ak=cache1.stride(1),
256354
stride_be=w2.stride(0),
257355
stride_bk=w2.stride(2),
258356
stride_bn=w2.stride(1),
@@ -264,8 +362,6 @@ def fused_moe(
264362
group_size=group_size,
265363
BLOCK_SIZE_N=BLOCK_SIZE_N,
266364
BLOCK_SIZE_K=BLOCK_SIZE_K,
267-
MUL_ROUTED_WEIGHT=True,
268-
top_k=1,
269365
compute_type=tl.bfloat16,
270366
)
271367

examples/models/qwen3_5_moe/export.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def export_and_lower(model, config, args):
267267
to_edge_transform_and_lower,
268268
)
269269
from executorch.exir.passes import MemoryPlanningPass
270-
from torch._inductor.decomposition import conv1d_to_conv2d
271270
from torch.export import Dim, export
272271

273272
# Coordinate descent recompiles each kernel trying config perturbations,
@@ -293,11 +292,6 @@ def export_and_lower(model, config, args):
293292
)
294293
print("Export successful!")
295294

296-
# conv1d → conv2d decomposition (required for CUDA backend)
297-
exported = exported.run_decompositions(
298-
{torch.ops.aten.conv1d.default: conv1d_to_conv2d}
299-
)
300-
301295
# Lower with CUDA backend
302296
print("Lowering to ExecuTorch with CUDA...")
303297
compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]

0 commit comments

Comments
 (0)