Enable gfx1151 (RDNA3.5 / Strix Halo): LDS, FP8 guards, f16/bf16 WMMA GEMM#567
Conversation
15e1b5c to
bc3382f
Compare
38af4cc to
9a9561a
Compare
There was a problem hiding this comment.
Pull request overview
Enables initial gfx1151 (RDNA3.5 / Strix Halo) support in FlyDSL’s kernel/test/benchmark stack by adding LDS capacity metadata, introducing a gfx11 WMMA f16/bf16 GEMM kernel variant, and adding clearer failure modes for unsupported FP8 WMMA paths on gfx11.
Changes:
- Add
gfx1151to shared-memory (LDS) capacity map for correct allocation checks. - Introduce
kernels/rdna3_f16_gemm.py(gfx11 WMMA f16/bf16 GEMM) and dispatch between gfx11 vs gfx12 WMMA ABIs in tests/benchmarks. - Fail fast on gfx11 when FP8 WMMA is requested (both via
default_f8_type()and the FP8 GEMM compiler entrypoint), and extend benchmark runner gating to include gfx11 WMMA.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
tests/kernels/test_rdna_gemm.py |
Dispatch f16/bf16 WMMA tests to gfx11 vs gfx12 kernel variants; keep FP8 tests gfx120x-only. |
tests/kernels/benchmark_common.py |
Arch-dispatch WMMA GEMM benchmark between gfx11/gfx12; skip FP8 sweep on gfx11. |
scripts/run_benchmark.sh |
Enable WMMA benchmark section on gfx11* as well as gfx12*. |
python/flydsl/utils/smem_allocator.py |
Add gfx1151 LDS capacity entry. |
python/flydsl/expr/typing.py |
Raise a clear error when FP8 types are requested on gfx11. |
kernels/rdna3_f16_gemm.py |
New gfx11 WMMA f16/bf16 GEMM kernel using legacy v16 operand ABI and gfx11 accumulator layout. |
kernels/rdna_fp8_preshuffle_gemm.py |
Add an early gfx11 guard to avoid late LLVM “cannot select intrinsic” failures. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
9a9561a to
bf5ad65
Compare
bf5ad65 to
5731db0
Compare
|
Thanks for your contribution, we will test and review the PR ASAP! |
5731db0 to
b7d65b6
Compare
Add gfx1151 to SMEM_CAPACITY_MAP at 64 KB per workgroup. The WGP has
128 KB of physical LDS but it is split between the two CUs, so a
single workgroup gets 64 KB — confirmed by HIP (sharedMemPerBlock =
65536) and by the AMDGPU backend, which rejects kernels exceeding
65536 bytes of static LDS with
"local memory (...) exceeds limit (65536)".
Validated on a Strix Halo system:
bash scripts/run_tests.sh
Pytest: 575 passed, 2485 skipped, 904 deselected
Examples: PASS 01-vectorAdd.py, PASS 02-tiledCopy.py
MLIR: all FileCheck tests pass
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
gfx11* (RDNA3 / RDNA3.5) has no native FP8 instructions — no WMMA fp8/bf8, no cvt_pk_fp8/bf8. Attempting FP8 compute on gfx1151 today surfaces as a late LLVM "cannot select intrinsic" abort deep inside ISel, after the kernel has already JIT-compiled most of the way. Two guards make this fail with a clear message at the entry point: - default_f8_type() on gfx11 raises rather than silently returning E4M3FNUZ (the CDNA-only format, which would also be the wrong encoding even if the chip could handle it). - rdna_fp8_preshuffle_gemm.compile_fp8_gemm() rejects gfx11* up front, before any rocdl.wmma_f32_16x16x16_fp8_fp8 op is emitted. Neither path changes behavior on gfx94*, gfx95*, or gfx12*. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Port of kernels/rdna_f16_gemm.py to the legacy v16-operand WMMA ABI
used by RDNA3 / RDNA3.5 (gfx11*). Same algorithm — 4-warp 128x128x32
LDS ping-pong with software pipelining — but two ABI-level changes:
- Input operands A and B are vector<16>, not vector<8>. Each lane
carries 16 contiguous K-elements of one M (or N) row. We do the
v16 load as two v8 LDS loads concatenated via vector.shuffle so
the existing LDS storage layout (v8 chunks) stays unchanged.
Lanes 16-31 mirror lanes 0-15 in this ABI, so they re-issue the
same LDS loads — wasted bandwidth, but the alternative is a
wave-half broadcast which complicates the schedule.
- The accumulator stays vector<8> on both ABIs, but the per-lane
row distribution differs. On gfx12 lane L holds rows
8*(L/16)..8*(L/16)+7; on gfx11 lane L holds rows
2*si + (L/16) for si in 0..7 (even rows in lanes 0-15, odd rows
in lanes 16-31). The store-back loop uses the gfx11 mapping.
The barrier sequence is also gfx11-specific: ``s_waitcnt
lgkmcnt(0); s_barrier`` instead of the gfx12+
``s_wait_dscnt 0x0; s_wait_storecnt 0x0; s_barrier_signal -1;
s_barrier_wait -1``.
Wiring in tests/kernels/test_rdna_gemm.py:
- ``_requires_rdna_wmma()`` is the new gate for f16/bf16 cases
(accepts gfx11* or gfx12*). ``_requires_rdna4()`` stays as-is for
the FP8 cases — those still need gfx12.
- ``create_wmma_gemm_module()`` becomes an arch-dispatching wrapper
that picks rdna3_f16_gemm on gfx11 and rdna_f16_gemm on gfx12.
Both have the same call signature.
Validated on Strix Halo (gfx1151):
bash scripts/run_tests.sh
Pytest: 584 passed, 2476 skipped, 904 deselected
(the +9 over the previous baseline is exactly the 9 newly
unlocked test_f16_gemm_* cases.)
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
scripts/run_benchmark.sh was gating its WMMA section on a single ``IS_RDNA4`` flag (gfx120* only). Inside benchmark_common.py the ``run_wmma_sweep`` helper had the same gate, and the ``wmma_gemm`` op hardcoded the gfx12 kernel import. Three small changes wire up gfx11: - run_benchmark.sh: add ``IS_RDNA_WMMA`` covering gfx11* or gfx12* and gate the WMMA section on it. ``IS_RDNA4`` stays for future gfx12-only paths (e.g. an FP8-only entry point). - benchmark_common.run_wmma_sweep: widen the arch gate, then skip the ``wmma_fp8_gemm`` inner loop on gfx11* (no native FP8). - benchmark_common._bench_flydsl_torch(op="wmma_gemm"): pick rdna3_f16_gemm on gfx11*, rdna_f16_gemm on gfx12*. Same call signature, so the surrounding harness code is unchanged. Standalone harness output on Strix Halo: op shape dtype FlyDSL torch wmma_gemm 256x256x256 bf16 16.8u 7.5u wmma_gemm 1024x1024x1024 bf16 126.6u 71.1u wmma_gemm 2048x2048x2048 bf16 507.0u 457.9u wmma_gemm 4096x4096x4096 bf16 5829.6u 3770.3u (matches the standalone bench script within ~1%) Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
|
Thanks for the gfx1151 enablement. I found one merge blocker and one arch-gating question:
I verified locally: |
gfx1250 has its own dedicated WMMA kernels (wmma_gemm_gfx1250.py) and a different WMMA ABI; the rdna_f16_gemm.create_wmma_gemm_module path is for gfx120x (RDNA4) only.
Thanks, addressed both! |
|
Thanks for the gfx1151 enablement. Did an independent run on real hardware, sharing what I found. Setup: Radeon 8060S (Strix Halo, gfx1151) . Ubuntu 26.04 + mainline kernel 7.0 (not 22.04/24.04 — so a bit off the official ROCm matrix) + ROCm 7.2.3 from the Verified end-to-end on real GPU: 6/6 PASS on the (in_dtype × A couple of findings along the way:
A couple of tiny non-blocking nits while I'm here:
Couldn't fully reproduce the perf table — I got 46 μs / 0.73 TF on 256³ vs your 16.8 μs / 2.0 TF. Likely the |
|
@mgehre-amd do you have navi3.5 CI machines for these? |
- Add missing f16 cast in store-back: without it, f32 accumulator bits were reinterpreted by buffer_store as 2 f16 elements, silently producing NaN/Inf for out_dtype="f16". - Extend test_f16_gemm_correctness to parametrize (in_dtype, out_dtype) over all 4 combos so CI covers the previously orphan f16-output paths. - Document the ds_swizzle_b32 XOR 16 broadcast as a TODO(perf) for the v16-ABI duplicate-load waste on lanes 16-31. - default_f8_type() error message now points gfx11* callers to rdna3_f16_gemm.create_wmma_gemm_module for a direct migration line.
When working on ROCm/vLLM repo, we had requested ROCm DevOps to share the gfx1151 runners from TheRock with us, which they did. I guess the same can work here. But note that there are only a few runners and there will be wait times, so maybe make those opt-in for PRs that touch gfx11. |
assert is stripped under `python -O`, which would let an invalid num_k_tiles silently reach the prefetch pipeline.
Thanks a lot for testing this so deeply! I addressed the issues you raised. |
Without the f16 cast, the f32 accumulator bits were reinterpreted by buffer_store as 2 f16 elements, silently producing NaN/Inf. Same bug that was just fixed in rdna3_f16_gemm; mirror the fix here so the extended test matrix passes on gfx120x (RDNA4) too.
Summary
Four commits enabling gfx1151 (Strix Halo / RDNA3.5):
Add gfx1151 LDS capacity — single entry in
SMEM_CAPACITY_MAP(gfx1151: 65536). 64 KB per workgroup (the WGP has 128 KB of physical LDS split across 2 CUs); confirmed by HIP (sharedMemPerBlock = 65536) and by the AMDGPU backend rejecting kernels exceeding 65536 bytes withlocal memory (...) exceeds limit (65536).Fail-fast on FP8 paths from gfx11 — gfx11 has no native FP8 instructions, but
default_f8_type()was silently returning E4M3FNUZ andcompile_fp8_gemm()was emittingrocdl.wmma_f32_16x16x16_fp8_fp8directly, both surfacing asLLVM ERROR: cannot select intrinsicdeep in ISel. Add early arch checks so the failure is clear at the entry point. No behavior change on gfx94*, gfx95*, gfx12*.Add
rdna3_f16_gemm.py: f16/bf16 WMMA GEMM for gfx11 — port ofrdna_f16_gemm.pyto the legacy v16-operand WMMA ABI. Two ABI-level differences from gfx12: (a) A/B operands arevector<16>notvector<8>(done as two v8 LDS loads +vector.shuffleconcat, lanes 16-31 mirror lanes 0-15), and (b) the v8 accumulator's per-lane row distribution is stride-2 (lane L holds rows2*si + L/16) instead of contiguous-8 (lane L holds rows8*(L/16) + si); the store-back loop uses the gfx11 mapping. Barrier is also gfx11-specific (s_waitcnt lgkmcnt(0); s_barrier).tests/kernels/test_rdna_gemm.pygets an arch-dispatch wrapper aroundcreate_wmma_gemm_moduleso f16/bf16 tests run on either ABI; FP8 cases stay gated to gfx12 via_requires_rdna4().Run wmma_gemm benchmark on gfx11 too —
scripts/run_benchmark.shgets anIS_RDNA_WMMAflag covering gfx11*/gfx12*,benchmark_common.run_wmma_sweepwidens its arch gate and skips the FP8 inner loop on gfx11*, and_bench_flydsl_torch(op="wmma_gemm", …)arch-dispatches between the two kernel variants. Same call signature, so the rest of the harness is unchanged.Motivation
gfx1151 (Strix Halo / RDNA3.5) was not on the verified-platform list. The arch-detection infrastructure landed in #221 already covered gfx11* via
is_rdna_arch(), so the generic pipeline (575 tests) already worked once a small LDS-capacity entry was added — but the production f16/bf16 GEMM kernel was hardcoded to the gfx12 v8-operand WMMA ABI and refused to run, leaving the platform without a usable GEMM kernel. The FP8 kernel additionally crashed deep in LLVM ISel rather than failing cleanly. This PR closes both gaps.Changes
python/flydsl/utils/smem_allocator.py: +1 line —gfx1151: 65536inSMEM_CAPACITY_MAP.python/flydsl/expr/typing.py:default_f8_type()raises ongfx11*with a clear message instead of returning the CDNA-only E4M3FNUZ format.kernels/rdna_fp8_preshuffle_gemm.py:compile_fp8_gemm()rejectsgfx11*up front (earlyRuntimeError) instead of letting the call torocdl.wmma_f32_16x16x16_fp8_fp8reach ISel.kernels/rdna3_f16_gemm.py: new file (~380 lines) — f16/bf16 WMMA GEMM tuned for the gfx11 v16-operand ABI.tests/kernels/test_rdna_gemm.py: adds_requires_rdna_wmma()(gfx11* or gfx12*) and acreate_wmma_gemm_modulewrapper that arch-dispatches betweenrdna3_f16_gemmandrdna_f16_gemm. f16/bf16 tests use the wrapper; FP8 tests stay gated to_requires_rdna4().scripts/run_benchmark.sh: addsIS_RDNA_WMMAflag (gfx11* or gfx12*) and uses it to gate the WMMA section;IS_RDNA4retained for any future gfx12-only entries (e.g. FP8).tests/kernels/benchmark_common.py:run_wmma_sweepwidens its arch gate and skips the FP8 inner loop on gfx11*;_bench_flydsl_torch(op="wmma_gemm", …)arch-dispatches the kernel import.Performance (Strix Halo, Radeon 8060S, 65 GB unified memory)
Standalone harness, identical NN matmul shapes:
Reference —
hipblaslt-benchon the same shapes (NN, f32 compute,--use_gpu_timer --rotating 256 --cold_iters 10 --iters 50):FlyDSL's tile config (
reg_m=4, reg_n=4, waves_m=2, waves_n=2) was hand-tuned for the gfx12 v8 ABI on gfx1201 — it lands at peak (~34 TFLOPS) at 2048³ on gfx1151 but pays for it at 1024³ (launch + tile granularity) and 4096³ (likely cache blocking). A gfx1151-specific tile sweep is left as a follow-up; the hipBLASLt numbers serve as a static reference for that work, not as something this PR is trying to match.Testing
tests/kernels/test_rdna_gemm.pynow exercises 9 f16/bf16 cases on gfx11 (was 0).bash scripts/run_benchmark.sh(WMMA section) on gfx1151. Numbers above.bash scripts/run_tests.sh→ 584 passed, 2476 skipped (CDNA/gfx1250-only), 904 deselected. +9 vs. previous gfx1151 baseline — exactly thetest_f16_gemm_*cases now unlocked. Both RDNA-whitelisted examples pass; all MLIR FileCheck tests pass._create_wmma_gemm_module_gfx12is still picked there)Note: on a wheel-only venv (no system ROCm), running locally requires
ROCM_PATHpointed at a<root>/llvm/bin/ld.lld+<root>/amdgcn/bitcodetoolkit dir, or stacking PR #568 which auto-discovers it. CI on a properly-configured runner doesn't need this.Dependencies
Breaking Changes
None. All existing call sites of
default_f8_type()andcompile_fp8_gemm()on supported arches (gfx94*, gfx95*, gfx12*) are unchanged; only gfx11* now raises (where it previously crashed).test_rdna_gemm.py's_requires_rdna4()is unchanged; new_requires_rdna_wmma()is added alongside it. TheIS_RDNA4shell variable is retained.