Skip to content

[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573

Open
lizamd wants to merge 1 commit intoROCm:devfrom
lizamd:fix/ck-grouped-gemm-bf16-fp32-output
Open

[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573
lizamd wants to merge 1 commit intoROCm:devfrom
lizamd:fix/ck-grouped-gemm-bf16-fp32-output

Conversation

@lizamd
Copy link
Copy Markdown

@lizamd lizamd commented May 4, 2026

The is_supported_dtype check in nvte_multi_tensor_gemm previously required A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32 case where the GEMM output is fp32 for gradient accumulation. This forced a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop), bypassing the CK grouped GEMM kernel entirely on ROCm.

The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY (fp32, fp16, bf16). The wrapper check is the only thing that prevents it from being reached.

Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B MoE training on MI355X (gfx950): fallback warning rate drops from ~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel itself rejects via Kernel::IsSupportedArgument). Throughput is essentially unchanged in this workload because hipblaslt's per-shape autotuning happens to be competitive with the hardcoded CK tile configs for these MoE shapes; the gain will materialize once the CK dispatcher gains more tile configs (or shape-aware tile selection by aggregate M).

This is a CUDA path file; the same patch applies to the AMD path via hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@lizamd lizamd force-pushed the fix/ck-grouped-gemm-bf16-fp32-output branch from d1637a7 to 764cb65 Compare May 4, 2026 23:56
The is_supported_dtype check in nvte_multi_tensor_gemm previously required
A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32
case where the GEMM output is fp32 for gradient accumulation. This forced
a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop),
bypassing the CK grouped GEMM kernel entirely on ROCm.

The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already
supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(fp32, fp16, bf16). The wrapper check is the only thing that prevents it
from being reached.

Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which
matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B
MoE training on MI355X (gfx950): fallback warning rate drops from
~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel
itself rejects via Kernel::IsSupportedArgument). Throughput is essentially
unchanged in this workload because hipblaslt's per-shape autotuning
happens to be competitive with the hardcoded CK tile configs for these
MoE shapes; the gain will materialize once the CK dispatcher gains more
tile configs (or shape-aware tile selection by aggregate M).

This is a CUDA path file; the same patch applies to the AMD path via
hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on
NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.

Follow-ups (out of scope for this PR):

- Add more CK tile configs (e.g. TileCfg_64x256x64, TileCfg_128x256x64)
  and shape-aware tile selection by aggregate M per call. Currently
  throughput is unchanged on this workload because the existing hipblaslt
  fallback is well-tuned and the 3 hardcoded CK tile configs
  (TileCfg_256x256x64, TileCfg_256x128x64, TileCfg_256x128x64_padding)
  don't fit MoE shapes (highly variable per-expert M) optimally. Real
  CK-grouped-GEMM perf wins will materialize once tile selection adapts
  to M.
- Investigate the ~3% of GEMMs that hit Kernel::IsSupportedArgument
  rejection (likely small per-expert M values that fail tile-size
  constraints in the current TileCfg_256x* instantiations).
@lizamd lizamd force-pushed the fix/ck-grouped-gemm-bf16-fp32-output branch from 764cb65 to ff19241 Compare May 5, 2026 00:02
@matthiasdiener matthiasdiener added the ci-level 1 CI test level 1 label May 5, 2026
@wenchenvincent
Copy link
Copy Markdown
Collaborator

@matthiasdiener @aris134 Could you review this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants