Skip to content

Enable gfx1151 (RDNA3.5 / Strix Halo): LDS, FP8 guards, f16/bf16 WMMA GEMM#567

Merged
coderfeli merged 9 commits into
mainfrom
matthias.enable-gfx1151
May 29, 2026
Merged

Enable gfx1151 (RDNA3.5 / Strix Halo): LDS, FP8 guards, f16/bf16 WMMA GEMM#567
coderfeli merged 9 commits into
mainfrom
matthias.enable-gfx1151

Conversation

@mgehre-amd
Copy link
Copy Markdown
Contributor

@mgehre-amd mgehre-amd commented May 26, 2026

Summary

Four commits enabling gfx1151 (Strix Halo / RDNA3.5):

  1. Add gfx1151 LDS capacity — single entry in SMEM_CAPACITY_MAP (gfx1151: 65536). 64 KB per workgroup (the WGP has 128 KB of physical LDS split across 2 CUs); confirmed by HIP (sharedMemPerBlock = 65536) and by the AMDGPU backend rejecting kernels exceeding 65536 bytes with local memory (...) exceeds limit (65536).

  2. Fail-fast on FP8 paths from gfx11 — gfx11 has no native FP8 instructions, but default_f8_type() was silently returning E4M3FNUZ and compile_fp8_gemm() was emitting rocdl.wmma_f32_16x16x16_fp8_fp8 directly, both surfacing as LLVM ERROR: cannot select intrinsic deep in ISel. Add early arch checks so the failure is clear at the entry point. No behavior change on gfx94*, gfx95*, gfx12*.

  3. Add rdna3_f16_gemm.py: f16/bf16 WMMA GEMM for gfx11 — port of rdna_f16_gemm.py to the legacy v16-operand WMMA ABI. Two ABI-level differences from gfx12: (a) A/B operands are vector<16> not vector<8> (done as two v8 LDS loads + vector.shuffle concat, lanes 16-31 mirror lanes 0-15), and (b) the v8 accumulator's per-lane row distribution is stride-2 (lane L holds rows 2*si + L/16) instead of contiguous-8 (lane L holds rows 8*(L/16) + si); the store-back loop uses the gfx11 mapping. Barrier is also gfx11-specific (s_waitcnt lgkmcnt(0); s_barrier). tests/kernels/test_rdna_gemm.py gets an arch-dispatch wrapper around create_wmma_gemm_module so f16/bf16 tests run on either ABI; FP8 cases stay gated to gfx12 via _requires_rdna4().

  4. Run wmma_gemm benchmark on gfx11 tooscripts/run_benchmark.sh gets an IS_RDNA_WMMA flag covering gfx11*/gfx12*, benchmark_common.run_wmma_sweep widens its arch gate and skips the FP8 inner loop on gfx11*, and _bench_flydsl_torch(op="wmma_gemm", …) arch-dispatches between the two kernel variants. Same call signature, so the rest of the harness is unchanged.

Motivation

gfx1151 (Strix Halo / RDNA3.5) was not on the verified-platform list. The arch-detection infrastructure landed in #221 already covered gfx11* via is_rdna_arch(), so the generic pipeline (575 tests) already worked once a small LDS-capacity entry was added — but the production f16/bf16 GEMM kernel was hardcoded to the gfx12 v8-operand WMMA ABI and refused to run, leaving the platform without a usable GEMM kernel. The FP8 kernel additionally crashed deep in LLVM ISel rather than failing cleanly. This PR closes both gaps.

Changes

  • python/flydsl/utils/smem_allocator.py: +1 line — gfx1151: 65536 in SMEM_CAPACITY_MAP.
  • python/flydsl/expr/typing.py: default_f8_type() raises on gfx11* with a clear message instead of returning the CDNA-only E4M3FNUZ format.
  • kernels/rdna_fp8_preshuffle_gemm.py: compile_fp8_gemm() rejects gfx11* up front (early RuntimeError) instead of letting the call to rocdl.wmma_f32_16x16x16_fp8_fp8 reach ISel.
  • kernels/rdna3_f16_gemm.py: new file (~380 lines) — f16/bf16 WMMA GEMM tuned for the gfx11 v16-operand ABI.
  • tests/kernels/test_rdna_gemm.py: adds _requires_rdna_wmma() (gfx11* or gfx12*) and a create_wmma_gemm_module wrapper that arch-dispatches between rdna3_f16_gemm and rdna_f16_gemm. f16/bf16 tests use the wrapper; FP8 tests stay gated to _requires_rdna4().
  • scripts/run_benchmark.sh: adds IS_RDNA_WMMA flag (gfx11* or gfx12*) and uses it to gate the WMMA section; IS_RDNA4 retained for any future gfx12-only entries (e.g. FP8).
  • tests/kernels/benchmark_common.py: run_wmma_sweep widens its arch gate and skips the FP8 inner loop on gfx11*; _bench_flydsl_torch(op="wmma_gemm", …) arch-dispatches the kernel import.

Performance (Strix Halo, Radeon 8060S, 65 GB unified memory)

Standalone harness, identical NN matmul shapes:

ROCM_PATH=<toolkit-shim> gpu-lock python3 -c \
  "from tests.kernels.benchmark_common import run_wmma_sweep, print_perf_table; \
   print_perf_table(run_wmma_sweep())"
Configuration Before (gfx11 baseline) After (this PR) TFLOPS
wmma_gemm 256³ bf16 unsupported / crashes 16.8 μs 2.0
wmma_gemm 1024³ bf16 unsupported / crashes 126.6 μs 17.0
wmma_gemm 2048³ bf16 unsupported / crashes 507.0 μs 33.9
wmma_gemm 4096³ bf16 unsupported / crashes 5,829.6 μs 23.6

Reference — hipblaslt-bench on the same shapes (NN, f32 compute, --use_gpu_timer --rotating 256 --cold_iters 10 --iters 50):

Shape dtype hipBLASLt μs hipBLASLt TFLOPS FlyDSL TFLOPS FlyDSL/hipBLASLt
1024³ bf16 79.5 27.0 17.0 0.63×
1024³ f16 76.8 27.9 17.0 0.61×
2048³ bf16 708.1 24.3 33.9 1.40×
2048³ f16 712.1 24.1 33.7 1.40×
4096³ bf16 3,689.4 37.3 23.6 0.63×
4096³ f16 3,719.0 37.0 23.5 0.63×

FlyDSL's tile config (reg_m=4, reg_n=4, waves_m=2, waves_n=2) was hand-tuned for the gfx12 v8 ABI on gfx1201 — it lands at peak (~34 TFLOPS) at 2048³ on gfx1151 but pays for it at 1024³ (launch + tile granularity) and 4096³ (likely cache blocking). A gfx1151-specific tile sweep is left as a follow-up; the hipBLASLt numbers serve as a static reference for that work, not as something this PR is trying to match.

Testing

  • Unit tests added/updated — tests/kernels/test_rdna_gemm.py now exercises 9 f16/bf16 cases on gfx11 (was 0).
  • Performance benchmarks run — bash scripts/run_benchmark.sh (WMMA section) on gfx1151. Numbers above.
  • Tested on MI300X — no access. CI will cover.
  • Tested on gfx1151 (Strix Halo): bash scripts/run_tests.sh584 passed, 2476 skipped (CDNA/gfx1250-only), 904 deselected. +9 vs. previous gfx1151 baseline — exactly the test_f16_gemm_* cases now unlocked. Both RDNA-whitelisted examples pass; all MLIR FileCheck tests pass.
  • gfx1201 sanity (gfx12 dispatch path unchanged but worth a CI run to confirm _create_wmma_gemm_module_gfx12 is still picked there)

Note: on a wheel-only venv (no system ROCm), running locally requires ROCM_PATH pointed at a <root>/llvm/bin/ld.lld + <root>/amdgcn/bitcode toolkit dir, or stacking PR #568 which auto-discovers it. CI on a properly-configured runner doesn't need this.

Dependencies

  • No new third-party dependencies added.

Breaking Changes

None. All existing call sites of default_f8_type() and compile_fp8_gemm() on supported arches (gfx94*, gfx95*, gfx12*) are unchanged; only gfx11* now raises (where it previously crashed). test_rdna_gemm.py's _requires_rdna4() is unchanged; new _requires_rdna_wmma() is added alongside it. The IS_RDNA4 shell variable is retained.

@mgehre-amd mgehre-amd force-pushed the matthias.enable-gfx1151 branch from 15e1b5c to bc3382f Compare May 26, 2026 11:54
@mgehre-amd mgehre-amd changed the title Enable gfx1151 (RDNA3.5 / Strix Halo) Add gfx1151 LDS capacity (RDNA3.5 / Strix Halo) May 26, 2026
@mgehre-amd mgehre-amd changed the title Add gfx1151 LDS capacity (RDNA3.5 / Strix Halo) Enable gfx1151 (RDNA3.5 / Strix Halo): LDS, FP8 guards, f16/bf16 WMMA GEMM May 26, 2026
@mgehre-amd mgehre-amd force-pushed the matthias.enable-gfx1151 branch 2 times, most recently from 38af4cc to 9a9561a Compare May 26, 2026 18:33
@mgehre-amd mgehre-amd marked this pull request as ready for review May 26, 2026 20:36
Copilot AI review requested due to automatic review settings May 26, 2026 20:36
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables initial gfx1151 (RDNA3.5 / Strix Halo) support in FlyDSL’s kernel/test/benchmark stack by adding LDS capacity metadata, introducing a gfx11 WMMA f16/bf16 GEMM kernel variant, and adding clearer failure modes for unsupported FP8 WMMA paths on gfx11.

Changes:

  • Add gfx1151 to shared-memory (LDS) capacity map for correct allocation checks.
  • Introduce kernels/rdna3_f16_gemm.py (gfx11 WMMA f16/bf16 GEMM) and dispatch between gfx11 vs gfx12 WMMA ABIs in tests/benchmarks.
  • Fail fast on gfx11 when FP8 WMMA is requested (both via default_f8_type() and the FP8 GEMM compiler entrypoint), and extend benchmark runner gating to include gfx11 WMMA.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/kernels/test_rdna_gemm.py Dispatch f16/bf16 WMMA tests to gfx11 vs gfx12 kernel variants; keep FP8 tests gfx120x-only.
tests/kernels/benchmark_common.py Arch-dispatch WMMA GEMM benchmark between gfx11/gfx12; skip FP8 sweep on gfx11.
scripts/run_benchmark.sh Enable WMMA benchmark section on gfx11* as well as gfx12*.
python/flydsl/utils/smem_allocator.py Add gfx1151 LDS capacity entry.
python/flydsl/expr/typing.py Raise a clear error when FP8 types are requested on gfx11.
kernels/rdna3_f16_gemm.py New gfx11 WMMA f16/bf16 GEMM kernel using legacy v16 operand ABI and gfx11 accumulator layout.
kernels/rdna_fp8_preshuffle_gemm.py Add an early gfx11 guard to avoid late LLVM “cannot select intrinsic” failures.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/rdna3_f16_gemm.py Outdated
@vivienfanghuagood
Copy link
Copy Markdown
Collaborator

Thanks for your contribution, we will test and review the PR ASAP!

Comment thread kernels/rdna3_f16_gemm.py Outdated
sjfeng1999
sjfeng1999 previously approved these changes May 27, 2026
Add gfx1151 to SMEM_CAPACITY_MAP at 64 KB per workgroup. The WGP has
128 KB of physical LDS but it is split between the two CUs, so a
single workgroup gets 64 KB — confirmed by HIP (sharedMemPerBlock =
65536) and by the AMDGPU backend, which rejects kernels exceeding
65536 bytes of static LDS with
"local memory (...) exceeds limit (65536)".

Validated on a Strix Halo system:

  bash scripts/run_tests.sh
    Pytest:    575 passed, 2485 skipped, 904 deselected
    Examples:  PASS 01-vectorAdd.py, PASS 02-tiledCopy.py
    MLIR:      all FileCheck tests pass

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
gfx11* (RDNA3 / RDNA3.5) has no native FP8 instructions — no WMMA
fp8/bf8, no cvt_pk_fp8/bf8. Attempting FP8 compute on gfx1151 today
surfaces as a late LLVM "cannot select intrinsic" abort deep inside
ISel, after the kernel has already JIT-compiled most of the way.

Two guards make this fail with a clear message at the entry point:

- default_f8_type() on gfx11 raises rather than silently returning
  E4M3FNUZ (the CDNA-only format, which would also be the wrong
  encoding even if the chip could handle it).
- rdna_fp8_preshuffle_gemm.compile_fp8_gemm() rejects gfx11* up
  front, before any rocdl.wmma_f32_16x16x16_fp8_fp8 op is emitted.

Neither path changes behavior on gfx94*, gfx95*, or gfx12*.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Port of kernels/rdna_f16_gemm.py to the legacy v16-operand WMMA ABI
used by RDNA3 / RDNA3.5 (gfx11*). Same algorithm — 4-warp 128x128x32
LDS ping-pong with software pipelining — but two ABI-level changes:

- Input operands A and B are vector<16>, not vector<8>. Each lane
  carries 16 contiguous K-elements of one M (or N) row. We do the
  v16 load as two v8 LDS loads concatenated via vector.shuffle so
  the existing LDS storage layout (v8 chunks) stays unchanged.
  Lanes 16-31 mirror lanes 0-15 in this ABI, so they re-issue the
  same LDS loads — wasted bandwidth, but the alternative is a
  wave-half broadcast which complicates the schedule.

- The accumulator stays vector<8> on both ABIs, but the per-lane
  row distribution differs. On gfx12 lane L holds rows
  8*(L/16)..8*(L/16)+7; on gfx11 lane L holds rows
  2*si + (L/16) for si in 0..7 (even rows in lanes 0-15, odd rows
  in lanes 16-31). The store-back loop uses the gfx11 mapping.

The barrier sequence is also gfx11-specific: ``s_waitcnt
lgkmcnt(0); s_barrier`` instead of the gfx12+
``s_wait_dscnt 0x0; s_wait_storecnt 0x0; s_barrier_signal -1;
s_barrier_wait -1``.

Wiring in tests/kernels/test_rdna_gemm.py:

- ``_requires_rdna_wmma()`` is the new gate for f16/bf16 cases
  (accepts gfx11* or gfx12*). ``_requires_rdna4()`` stays as-is for
  the FP8 cases — those still need gfx12.
- ``create_wmma_gemm_module()`` becomes an arch-dispatching wrapper
  that picks rdna3_f16_gemm on gfx11 and rdna_f16_gemm on gfx12.
  Both have the same call signature.

Validated on Strix Halo (gfx1151):

  bash scripts/run_tests.sh
    Pytest:    584 passed, 2476 skipped, 904 deselected
    (the +9 over the previous baseline is exactly the 9 newly
     unlocked test_f16_gemm_* cases.)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
scripts/run_benchmark.sh was gating its WMMA section on a single
``IS_RDNA4`` flag (gfx120* only). Inside benchmark_common.py the
``run_wmma_sweep`` helper had the same gate, and the ``wmma_gemm`` op
hardcoded the gfx12 kernel import.

Three small changes wire up gfx11:

- run_benchmark.sh: add ``IS_RDNA_WMMA`` covering gfx11* or gfx12* and
  gate the WMMA section on it. ``IS_RDNA4`` stays for future
  gfx12-only paths (e.g. an FP8-only entry point).
- benchmark_common.run_wmma_sweep: widen the arch gate, then skip the
  ``wmma_fp8_gemm`` inner loop on gfx11* (no native FP8).
- benchmark_common._bench_flydsl_torch(op="wmma_gemm"): pick
  rdna3_f16_gemm on gfx11*, rdna_f16_gemm on gfx12*. Same call
  signature, so the surrounding harness code is unchanged.

Standalone harness output on Strix Halo:

  op         shape              dtype  FlyDSL    torch
  wmma_gemm  256x256x256        bf16     16.8u    7.5u
  wmma_gemm  1024x1024x1024     bf16    126.6u   71.1u
  wmma_gemm  2048x2048x2048     bf16    507.0u  457.9u
  wmma_gemm  4096x4096x4096     bf16   5829.6u 3770.3u

(matches the standalone bench script within ~1%)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@coderfeli
Copy link
Copy Markdown
Collaborator

Thanks for the gfx1151 enablement. I found one merge blocker and one arch-gating question:

  • Check Python Code Style still fails on kernels/rdna3_f16_gemm.py due to unused elem_bytes and v8_in_ty; after removing them, the T import likely becomes unused too.
  • Should the new gfx12* gates include gfx1250? The repo has separate gfx1250 WMMA kernels/tests, while this path dispatches non-gfx11 gfx12* to rdna_f16_gemm. If gfx1250 should not use that kernel, the gate may need to stay at gfx120* or dispatch explicitly.

I verified locally: bash scripts/check_python_style.sh fails with the same two F841 issues, and git diff --check origin/main...HEAD passes.

gfx1250 has its own dedicated WMMA kernels (wmma_gemm_gfx1250.py)
and a different WMMA ABI; the rdna_f16_gemm.create_wmma_gemm_module
path is for gfx120x (RDNA4) only.
@mgehre-amd
Copy link
Copy Markdown
Contributor Author

Thanks for the gfx1151 enablement. I found one merge blocker and one arch-gating question:

  • Check Python Code Style still fails on kernels/rdna3_f16_gemm.py due to unused elem_bytes and v8_in_ty; after removing them, the T import likely becomes unused too.
  • Should the new gfx12* gates include gfx1250? The repo has separate gfx1250 WMMA kernels/tests, while this path dispatches non-gfx11 gfx12* to rdna_f16_gemm. If gfx1250 should not use that kernel, the gate may need to stay at gfx120* or dispatch explicitly.

I verified locally: bash scripts/check_python_style.sh fails with the same two F841 issues, and git diff --check origin/main...HEAD passes.

Thanks, addressed both!

@GeisYaO
Copy link
Copy Markdown

GeisYaO commented May 27, 2026

Thanks for the gfx1151 enablement. Did an independent run on real hardware, sharing what I found.

Setup: Radeon 8060S (Strix Halo, gfx1151) . Ubuntu 26.04 + mainline kernel 7.0 (not 22.04/24.04 — so a bit off the official ROCm matrix) + ROCm 7.2.3 from the noble channel + PyTorch 2.10.0+rocm7.0 wheel. Built FlyDSL from source on the PR branch (pr567-v2 / 4f807c4) via scripts/build_llvm.sh + build.sh. Tests run with LD_PRELOAD=/opt/rocm/lib/libhsa-runtime64.so.1 to override the stale hsa-rocr (1.18.0-rocm-rel-7.0-56) bundled inside the torch wheel, which otherwise segfaults on AMDKFD_IOC_CREATE_QUEUE against the newer kernel KFD ABI.

Verified end-to-end on real GPU: 6/6 PASS on the (in_dtype × out_dtype="bf16") paths your test_f16_gemm_correctness parametrizes — reproduces your reported 9/9 pass. Tolerance: torch.allclose(atol=0.05, rtol=0.05) matching your verify_output. Also independently LLVM-mc verified the static claims (LDS 64KB, gfx11 v16 / gfx12 v8 WMMA operand layouts, FP8 unsupported on gfx11). Algorithm reads clean against the gfx12 baseline — g_row = 2*si + klane stride-2 layout matches the RDNA3 ABI, K-loop double-buffer indexing is consistent, LDS 40KB is in budget.

A couple of findings along the way:

  1. One thing that looked off in the dtype matrix: the kernel signature exposes out_dtype as a free parameter and in_dtype already parametrizes both bf16 and f16 in your test, so I tried the 4 (in_dtype, out_dtype) combinations on the full dtype square. The out_dtype="f16" branch (both f16→f16 and bf16→f16) silently produces NaN/Inf because the store-back only casts when out_dtype=="bf16":

    if const_expr(out_dtype == "bf16"):
        val = val.to(fx.BFloat16)
    # val is still f32 here for any other out_dtype value
    buffer_ops.buffer_store(val, c_rsrc, elem_off)

    buffer_store then reinterprets the 4-byte f32 as 2 consecutive f16 elements — output ends up full of NaN/Inf. Same shape on kernels/rdna_f16_gemm.py (gfx12 baseline) — so it's not introduced by this PR, just copied across. Grep'd all callers in the repo to make sure: test_f16_gemm_correctness / test_f16_gemm_benchmark / _bench_flydsl_torch all pass out_dtype="bf16"; test_f16_gemm_f32_output passes "f32". out_dtype="f16" is currently an orphan branch in the dtype matrix — accepted by the API but never exercised by CI.

    Two ways to land this depending on intent:

    • If out_dtype="f16" is meant to work: a 2-line elif const_expr(out_dtype == "f16"): val = val.to(fx.Float16) in both kernels. Verified locally — 96/96 pass across an 8-seed × 3-shape × 4-dtype-combo matrix (seeds 0, 1, 42, 1337, 0xCAFE, 12345, 0xDEAD, 2025; shapes 128³, 256³, 256x256x512; dtype combos {bf16,f16} × {bf16,f16}). Previously-untested out_dtype="f16" paths now bit-equal torch.matmul reference across all 8 seeds; the out_dtype="bf16" paths are unaffected (no regression).

    • If out_dtype="f16" was never intended to be a supported configuration: reject it upfront in create_wmma_gemm_module with a clear ValueError, and the if const_expr(out_dtype == "bf16") guard in the store-back can collapse since it would then be the only valid case.

    I have the cast patch ready as a follow-up PR (after this one lands so main and the new gfx11 kernel can get the same elif in one commit), but happy to switch to the reject route if that matches the design intent better — that's a smaller change and doesn't expand the supported surface.

  2. Suggest extending the test matrix so the second half of the dtype grid runs in CI:

    @pytest.mark.parametrize("in_dtype,out_dtype", [
        ("bf16", "bf16"), ("f16", "bf16"),
        ("f16",  "f16"),  ("bf16", "f16"),  # would catch the bug above
    ])
  3. RDNA3.5 ISA limit note (not a PR issue, just a heads-up for future gfx1151 perf work): buffer_load_lds_* and global_load_lds_* are both gone on gfx1150/1151,even though they still exist on RDNA3 dGPUs (gfx1100). So any future async-prefetch-style optimization for gfx1151 has to stay on the GMEM→VGPR→LDS path. The v16-ABI duplicate loads (lanes 16–31 mirroring lanes 0–15) you already call out in the comment is probably the biggest single perf lever left — ds_swizzle_b32 XOR 16 half-wave broadcast would cut that in half. Worth a TODO(perf) so it doesn't get lost.

A couple of tiny non-blocking nits while I'm here:

  • default_f8_type() error message could append "Use bf16/f16 GEMM via rdna3_f16_gemm.create_wmma_gemm_module on gfx11* targets" — gives folks a direct migration line.

  • assert num_k_tiles >= 2 would survive python -O better as raise ValueError(...).

Couldn't fully reproduce the perf table — I got 46 μs / 0.73 TF on 256³ vs your 16.8 μs / 2.0 TF. Likely the gpu-lock step in your harness (we didn't replicate it) plus thermal state on a laptop chassis. Didn't dig further since the numerics path is fine.

@coderfeli
Copy link
Copy Markdown
Collaborator

@mgehre-amd do you have navi3.5 CI machines for these?

- Add missing f16 cast in store-back: without it, f32 accumulator
  bits were reinterpreted by buffer_store as 2 f16 elements, silently
  producing NaN/Inf for out_dtype="f16".
- Extend test_f16_gemm_correctness to parametrize (in_dtype, out_dtype)
  over all 4 combos so CI covers the previously orphan f16-output paths.
- Document the ds_swizzle_b32 XOR 16 broadcast as a TODO(perf) for the
  v16-ABI duplicate-load waste on lanes 16-31.
- default_f8_type() error message now points gfx11* callers to
  rdna3_f16_gemm.create_wmma_gemm_module for a direct migration line.
@mgehre-amd
Copy link
Copy Markdown
Contributor Author

@mgehre-amd do you have navi3.5 CI machines for these?

When working on ROCm/vLLM repo, we had requested ROCm DevOps to share the gfx1151 runners from TheRock with us, which they did. I guess the same can work here. But note that there are only a few runners and there will be wait times, so maybe make those opt-in for PRs that touch gfx11.

assert is stripped under `python -O`, which would let an invalid
num_k_tiles silently reach the prefetch pipeline.
@mgehre-amd
Copy link
Copy Markdown
Contributor Author

Thanks for the gfx1151 enablement. Did an independent run on real hardware, sharing what I found.

Thanks a lot for testing this so deeply! I addressed the issues you raised.

@mgehre-amd mgehre-amd requested a review from sjfeng1999 May 28, 2026 07:06
mgehre-amd and others added 2 commits May 28, 2026 03:25
Without the f16 cast, the f32 accumulator bits were reinterpreted by
buffer_store as 2 f16 elements, silently producing NaN/Inf. Same bug
that was just fixed in rdna3_f16_gemm; mirror the fix here so the
extended test matrix passes on gfx120x (RDNA4) too.
@coderfeli coderfeli merged commit 9861ab2 into main May 29, 2026
17 of 18 checks passed
@coderfeli coderfeli deleted the matthias.enable-gfx1151 branch May 29, 2026 08:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants