[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573
Open
[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573
Conversation
d1637a7 to
764cb65
Compare
The is_supported_dtype check in nvte_multi_tensor_gemm previously required
A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32
case where the GEMM output is fp32 for gradient accumulation. This forced
a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop),
bypassing the CK grouped GEMM kernel entirely on ROCm.
The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already
supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(fp32, fp16, bf16). The wrapper check is the only thing that prevents it
from being reached.
Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which
matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B
MoE training on MI355X (gfx950): fallback warning rate drops from
~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel
itself rejects via Kernel::IsSupportedArgument). Throughput is essentially
unchanged in this workload because hipblaslt's per-shape autotuning
happens to be competitive with the hardcoded CK tile configs for these
MoE shapes; the gain will materialize once the CK dispatcher gains more
tile configs (or shape-aware tile selection by aggregate M).
This is a CUDA path file; the same patch applies to the AMD path via
hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on
NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.
Follow-ups (out of scope for this PR):
- Add more CK tile configs (e.g. TileCfg_64x256x64, TileCfg_128x256x64)
and shape-aware tile selection by aggregate M per call. Currently
throughput is unchanged on this workload because the existing hipblaslt
fallback is well-tuned and the 3 hardcoded CK tile configs
(TileCfg_256x256x64, TileCfg_256x128x64, TileCfg_256x128x64_padding)
don't fit MoE shapes (highly variable per-expert M) optimally. Real
CK-grouped-GEMM perf wins will materialize once tile selection adapts
to M.
- Investigate the ~3% of GEMMs that hit Kernel::IsSupportedArgument
rejection (likely small per-expert M values that fail tile-size
constraints in the current TileCfg_256x* instantiations).
764cb65 to
ff19241
Compare
Collaborator
|
@matthiasdiener @aris134 Could you review this PR? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The is_supported_dtype check in nvte_multi_tensor_gemm previously required A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32 case where the GEMM output is fp32 for gradient accumulation. This forced a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop), bypassing the CK grouped GEMM kernel entirely on ROCm.
The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY (fp32, fp16, bf16). The wrapper check is the only thing that prevents it from being reached.
Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B MoE training on MI355X (gfx950): fallback warning rate drops from ~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel itself rejects via Kernel::IsSupportedArgument). Throughput is essentially unchanged in this workload because hipblaslt's per-shape autotuning happens to be competitive with the hardcoded CK tile configs for these MoE shapes; the gain will materialize once the CK dispatcher gains more tile configs (or shape-aware tile selection by aggregate M).
This is a CUDA path file; the same patch applies to the AMD path via hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: