Skip to content

Conversation

@vuiseng9
Copy link
Owner

@vuiseng9 vuiseng9 commented Nov 18, 2025

This PR introduces a proof-of-concept implementation of NVFP4 forward + MXFP8 backward training.
The work is intentionally scoped as a local PoC and serves as a foundation for subsequent iterations.

Motivations

  1. Implementation challenge in full NVFP4 training: The initial goal was end-to-end NVFP4 (forward + backward). However, NVFP4 matmuls in cuBLASLt currently support TN-only layouts, which would require an additional transpose kernel for the backward pass. We defer this to subsequent work. (NVIDIA’s official release v2.8 already has full NVFP4 support.)

  2. Use-case 1: More efficient than NVFP4-QAT MXFP8 backward is substantially more efficient compared to NVFP4-QAT pipelines, which still rely on 16-/32-bit backward passes.

  3. Use-case 2: Practicality of full NVFP4 training
    While NVFP4 training has advanced significantly, it still requires several supporting techniques (1) Hadamard transforms, (2) selective higher-precision layers, and (3) switching back to higher precision for the last fraction of training, as also seen in the recent MLPerf v5.1 NVFP4 submission. Therefore, MXFP8 backward can be valuable, either for last-mile convergence or from the get-go.

Quick Summary on Implementation:

  1. Generalize mxfp8 (de)quantization to include nvfp4 that uses block size 16, E4M3 scale, 2xE2M1 data (un)packing.
  2. Plumb NVFP4Quantizer interfacing C++/Python side, all the way up to Pytorch level.
  3. Create recipe NVFP4FwdMXFP8BwdScaling, at recipe level, each linear has two of quantizers per location, each for NVFP4Quantizer and MXFP8Quantizer, entailing two launches instead of a single launch at lower-level, not optimal but good as a first step. Require broader architectural changes for optimal performance.
  4. Enabled for Linear, LayerNorm_Linear, LayerNorm_MLP for now.

Use-case studies here.

* create NVFP4Quantizer at TE cpp side
* modify mxfp8_quantize/cast_mxfp8_2D_kernel for nvfp4 generalization
* temporary hijack mxfp8 torch side to call to nvfp4 quantization, will
  revert
* generalize dequantize_mxfp8_kernel to dequantize_mxnv_kernel
* create nvfp4 extension interface but not fully enabled.
* mxfp8 trainablility restored.
* create NVFP4BlockScaling, NVFP4BlockScalingRecipeState class
* subclassing:
    - NVFP4TensorBase(MXFP8TensorBase)
    - NVFP4Quantizer(MXFP8Quantizer)
    - NVFP4Tensor(MXFP8Tensor)
* forward pass functional, backward raise exception due to only TN layout allowed in cublaslt nvfp4
* motivation: due to current TN-only layout for cublaslt NVFP8 matmul
*   this recipe uses TN NVFP4 forward, and NN/NT MXFP8 backward,
*   avoiding tensor relayout which can be costly to materialize for
*   large models.
* piggyback NVFP4Quantizer for shadow MXFP8Quantizer needed for backward pass
* remove redundant quantization, step elapse improves
@vuiseng9 vuiseng9 changed the title Nvfp4 forward + Mxfp8 backward Recipe NVFP4 forward + MXFP8 backward Recipe Nov 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants