Skip to content

Add W4A8/W4A_FP8 MoE support with groupwise scale#202

Open
ClementLinCF wants to merge 8 commits intomainfrom
feature/w4a8-moe-port
Open

Add W4A8/W4A_FP8 MoE support with groupwise scale#202
ClementLinCF wants to merge 8 commits intomainfrom
feature/w4a8-moe-port

Conversation

@ClementLinCF
Copy link
Contributor

@ClementLinCF ClementLinCF commented Mar 12, 2026

Motivation

The existing fused MoE 2-stage kernel supports fp8, fp16, bf16, int8, and W4A16 (int4_bf16) data types. This PR extends it with W4A8 (int4) and W4A_FP8 (int4_fp8) support, and adds groupwise scale (group_size=32) for all three int4 weight variants — enabling lower-precision MoE inference paths that are critical for production deployment of large MoE models (e.g., Kimi K2.5).

Technical Details

New dtype: int4_fp8 (W4A_FP8)

  • FP8 activations + packed int4 weights, using mfma_f32_16x16x32_fp8_fp8.

  • In-kernel int4→fp8 unpack via cvt_pk_fp8_f32 ROCDL intrinsic.

  • 8-byte K64 weight loads (buffer_load_dwordx2) for improved memory efficiency.

Groupwise scale (group_size=32) for W4A8/W4A16/W4A_FP8

  • Per-K32 groupwise accumulation: fresh MFMA accumulator + per-group scale + running f32 accumulator, for both stage1 and stage2.

  • Groupwise scale address formula using [E, num_groups, N] layout with preshuffled scale tensors.

  • Epilogue correctly skips sitofp for groupwise accumulators (already f32 from per-K32 accumulation).

bf16 output dtype test coverage

  • Added out_dtype="bf16" to test parametrization (was only f16/f32).
  • Fixed run_moe_stage2 test helper to accept bf16 output dtype.
  • Fixed bytes_moved bandwidth calculation to treat bf16 as 2-byte.
  • The stage2 kernel already supported bf16 output via bf16 global atomics (gfx94+/gfx95+), but the test harness blocked it.

Test Plan

pytest tests/kernels/test_moe_gemm.py::test_moe_gemm_2stage covering:

  • Groupwise (g32) + out_f16/out_bf16: shapes S/M/L × int4, int4_bf16, int4_fp8 × atomic/reduce × eager/graph
  • Groupwise (g32) + out_f32: shapes S/M/L × int4, int4_bf16, int4_fp8 × atomic × eager/graph
  • Non-groupwise (perrow) + all dtypes: shapes S/M/L × fp8/fp16/bf16/int8/int8smooth/int4/int4_bf16/int4_fp8 × f16/bf16/f32 × atomic/reduce × eager/graph

Each test verifies correctness against torch reference (for S/M shapes) and runs perf timing.

Test Results

MI308 (gfx942) — Original tests

120 passed, 360 skipped, 288 deselected in 35.39s

  • Groupwise f16 (S/M/L): all PASSED (int4, int4_bf16, int4_fp8)
  • Groupwise f32 (S/M/L): all PASSED (int4, int4_bf16, int4_fp8)
  • Non-groupwise perrow fp8 (S/M/L): all PASSED
  • 0 failures

MI355X (gfx950) — Full test suite

264 passed, 504 skipped, 0 failed in 3m43s

All applicable tests pass on gfx950. Skips are expected:

  • mask tests (valid_mask not supported on gfx950)
  • graph mode tests (HIP graph capture skipped)
  • out_f32 reduce tests (accumulate=False forbids it)
  • Non-int4 dtypes in g32 mode (fp8/fp16/bf16/int8 with groupwise do not combine)

MI355X (gfx950) — bf16 output tests (new)

24/24 passed — all input dtypes × S/M/L sizes with out_dtype=bf16

Peak performance on MI355X (L size):

Path Stage1 TFLOPS Stage2 TFLOPS Bandwidth
fp8 → bf16 out 944 653 1.97 TB/s
bf16 → bf16 out 541 490 1.54 TB/s
int8 → bf16 out 921 653 1.97 TB/s
int4_bf16 (g32) → bf16 out 175 422 1.32 TB/s
int4_bf16 (g32) → f16 out 372 348 1.09 TB/s

E2E test

Kimi-K2.5 W4A16, W4A8 on MI308

Metrics W4A8 con=2 W4A8 con=40 W4A16 con=2 W4A16 con=40
Output throughput (tok/s) 61.37 261.02 24.48 104.97
Peak output throughput (tok/s) 82.00 800.00 60.00 442.00
Total throughput (tok/s) 1288.80 5481.35 514.06 2204.34
Mean TPOT (ms) 25.82 102.51 35.80 223.64
Median TPOT (ms) 25.83 102.98 35.51 231.12
Mean TTFT (ms) 3485.58 26063.61 23530.88 68158.30
Median ITL (ms) 24.89 52.07 34.08 83.65
Mean E2E Latency (ms) 16679.06 78443.72 41823.70 182436.60

Build & Test Environment (MI355X)

Built and tested inside Docker container (lmsysorg/sglang-rocm:v0.5.9-rocm700-mi35x-20260322):

  • LLVM built from source (commit 7f77ca0dbda4) with 128 threads
  • FlyDSL built from source (pip install -e .flydsl 0.1.1.dev413)
  • Hardware: AMD Instinct MI355X (gfx950), ROCm 7.0, PyTorch 2.9.0a0

Note: The PyPI wheel flydsl==0.1.1 is incompatible with this branch due to commit 4d84ee8 (native idx2crd APIs). Building from source is required.

Submission Checklist

Copilot AI review requested due to automatic review settings March 12, 2026 15:53
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds W4A8 (int4) and W4A_FP8 (int4_fp8) MoE support with groupwise scaling (group_size=32) to the FlyDSL fused MoE 2-stage kernel. It extends the existing int4_bf16 (W4A16) path with new load/unpack helpers and per-K32 group accumulation logic.

Changes:

  • New int4_fp8 dtype support with FP8 activations + packed int4 weights, using mfma_f32_16x16x32_fp8_fp8 and in-kernel int4→fp8 conversion via cvt_pk_fp8_f32.
  • Groupwise scale (group_size=32) for all int4 weight variants (int4, int4_bf16, int4_fp8) with per-K32 fresh-accumulator + scale + running f32 accumulator pattern.
  • New load/unpack helpers (load_b_raw_w4a8_k64, load_b_raw_w4a8_groupwise_k64, unpack_b_w4a8, unpack_b_w4a_fp8, etc.) in mfma_preshuffle_pipeline.py.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
kernels/mfma_preshuffle_pipeline.py New load/unpack helpers for W4A8, W4A_FP8, and groupwise scale variants
kernels/moe_gemm_2stage.py Extended stage1/stage2 compile functions with int4_fp8 dtype and groupwise scale paths
tests/kernels/test_moe_gemm.py Added int4_fp8 to test parameterization and corresponding quantization/routing logic
tests/test_common.py Minor whitespace cleanup

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderfeli coderfeli requested a review from yadaish March 17, 2026 01:14
MHYangAMD and others added 4 commits March 17, 2026 09:47
…tions

- Extract _unpack_int4_to_int8_pair(): shared 7-op int4->int8 bit
  manipulation used by unpack_b_w4a16, unpack_b_w4a8, unpack_b_w4a_fp8,
  and load_b_pack_k32 (was copy-pasted in 4 places)
- Extract _pack_i32_pair_to_i64(): shared (even, odd) -> i64 packing
- Extract _load_groupwise_scale(): shared scale address calculation and
  buffer_load for W4A16 and W4A8 groupwise paths
- Have load_b_raw_w4a8_groupwise_k64 delegate weight load to
  load_b_raw_w4a8_k64 (matching W4A16 groupwise pattern)
- Replace ir.IntegerType.get_signless(32) / ir.F32Type.get() with
  T.i32 / T.f32 to follow project conventions
- Replace arith.constant(..., index=True) with fx.Index(...) throughout
- Add 'bf16' to out_dtype parametrization (was only f16/f32)
- Fix run_moe_stage2 to accept bf16 output dtype
- Fix bytes_moved calculation to treat bf16 as 2-byte (like f16)

The stage2 kernel (compile_moe_gemm2) already supports out_dtype='bf16'
using bf16 global atomics on gfx94+/gfx95+, but the test harness
blocked it. Verified all 24 new test cases pass on MI355X (gfx950).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants