Skip to content

NVFP4 recipe with GEMM via BF16 dequant#518

Merged
matthiasdiener merged 101 commits intodevfrom
mdiener/nvfp4-gemm
Apr 29, 2026
Merged

NVFP4 recipe with GEMM via BF16 dequant#518
matthiasdiener merged 101 commits intodevfrom
mdiener/nvfp4-gemm

Conversation

@matthiasdiener
Copy link
Copy Markdown
Contributor

@matthiasdiener matthiasdiener commented Apr 2, 2026

Description

Part of https://github.com/ROCm/frameworks-internal/issues/15682

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Micky774 and others added 30 commits March 27, 2026 09:27
Remove TODO regarding userbuffers
Userbuffer Enablement for ROCm
* Update Dockerfile to use ROCm TheRock
* Update wheels building script to work with ROCm TheRock and the latest Manylinux image
* Support default ROCm location /opt/rocm/core
* Fix UB code build on TheRock
* Support comma separated list of target GPU architectures
* Guess ROCm build from HIP_PLATFORM
Comment thread transformer_engine/common/hadamard_transform/wht16.cuh
Comment thread tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
Comment thread transformer_engine/common/gemm/rocm_gemm.cu Outdated
@matthiasdiener matthiasdiener requested a review from ipanfilo April 24, 2026 03:47
Comment thread tests/cpp/operator/test_cast_nvfp4_transpose.cu
Comment on lines +112 to +113
const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f;
const float factor_inv = 1.0f / (6.0f * fp8_max);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above regarding using Numeric_Traits_fp8e4m3 here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a08e8c5

Comment thread tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
Comment on lines +438 to +470
if (is_fp4_dtype(param.Atype)) {
hip_bfloat16* a_bf16 = reinterpret_cast<hip_bfloat16*>(ws_ptr);
ws_ptr += a_bf16_bytes;
const int64_t total_a = static_cast<int64_t>(m) * k;
const auto& a_sinv = (transa == CUBLAS_OP_T) ? inputA.scale_inv
: inputA.columnwise_scale_inv;
const int64_t a_num_cols = (transa == CUBLAS_OP_T)
? inputA.data.shape.back()
: inputA.columnwise_data.shape.back();
const int64_t a_scale_stride = (a_sinv.shape.size() >= 2) ? a_sinv.shape[1] : (a_num_cols / 16);
launch_dequant_fp4_to_bf16(param.A, param.A_scale_inv, a_bf16, total_a,
a_num_cols, a_scale_stride, stream);
param.A = a_bf16;
param.Atype = DType::kBFloat16;
param.A_scale_inv = nullptr;
}

if (is_fp4_dtype(param.Btype)) {
hip_bfloat16* b_bf16 = reinterpret_cast<hip_bfloat16*>(ws_ptr);
ws_ptr += b_bf16_bytes;
const int64_t total_b = static_cast<int64_t>(k) * n;
const auto& b_sinv = (transb == CUBLAS_OP_N) ? inputB.scale_inv
: inputB.columnwise_scale_inv;
const int64_t b_num_cols = (transb == CUBLAS_OP_N)
? inputB.data.shape.back()
: inputB.columnwise_data.shape.back();
const int64_t b_scale_stride = (b_sinv.shape.size() >= 2) ? b_sinv.shape[1] : (b_num_cols / 16);
launch_dequant_fp4_to_bf16(param.B, param.B_scale_inv, b_bf16, total_b,
b_num_cols, b_scale_stride, stream);
param.B = b_bf16;
param.Btype = DType::kBFloat16;
param.B_scale_inv = nullptr;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: would it make sense to factor the repeated FP4→BF16 staging logic for A/B into a small helper? The two blocks look structurally similar, aside from the operand-specific shape/layout details.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I factored this out into a lambda function in fae76d3

Copy link
Copy Markdown
Contributor

@aris134 aris134 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I left one minor non-blocking suggestion, but this looks good to me overall.

Comment thread transformer_engine/common/gemm/rocm_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/rocm_gemm.cu Outdated
Comment thread transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu Outdated
Comment thread transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu Outdated
@matthiasdiener matthiasdiener force-pushed the mdiener/nvfp4-gemm branch 7 times, most recently from 0f240ad to 9ed88ff Compare April 28, 2026 19:13
@matthiasdiener matthiasdiener merged commit 95a65d6 into dev Apr 29, 2026
9 checks passed
@matthiasdiener matthiasdiener deleted the mdiener/nvfp4-gemm branch April 29, 2026 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants