Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1163,12 +1163,18 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
auto A_dt = inputA->data.dtype;
auto B_dt = inputB->data.dtype;
auto D_dt = OutputD->data.dtype;
// CK FP16/BF16 grouped GEMM dispatcher (ck_tile_grouped_gemm_fp16_dispatch)
// already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
// (fp32, fp16, bf16). The previous check required A==B==D, which incorrectly
// rejected the common bf16/bf16/fp32 case (training with fp32 gradient
// accumulation), forcing a fallback to the per-expert hipblaslt loop.
// Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}.
return (
(is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt))
) ||
(
(A_dt == B_dt) && (A_dt == D_dt) &&
(is_fp16_dtype(A_dt))
(A_dt == B_dt) && is_fp16_dtype(A_dt) &&
(is_fp16_dtype(D_dt) || D_dt == transformer_engine::DType::kFloat32)
);
#else
auto A_type = get_cuda_dtype(inputA->data.dtype);
Expand Down
Loading