Skip to content

Conversation

@negvet
Copy link
Collaborator

@negvet negvet commented Dec 2, 2025

Description

Post-RHT amax can be estimated from pre-RHT amax.

This PR optimizes out post-RHT amax (RHT+amax) kernel, enabling estimation of post-RHT amax from pre-RHT amax with linear scaling.
amax fusion is required to see perf benefits.
The feature is opt-in via NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION=1.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

negvet and others added 2 commits December 2, 2025 15:24
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet marked this pull request as ready for review December 2, 2025 16:07
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 2, 2025

Greptile Overview

Greptile Summary

This PR introduces an experimental optimization for NVFP4 quantization that estimates post-RHT (Random Hadamard Transform) amax from pre-RHT amax using a configurable linear scale factor, eliminating the need for a separate RHT+amax kernel launch.

Key Changes:

  • Adds amax_estimation_scale configuration parameter throughout the quantization pipeline (C++ structs, Python dataclasses, and CUDA kernels)
  • Modifies the RHT cast fusion kernel to apply the estimation scale when computing global encode scale
  • Updates activation, bias, and normalization extensions to use fused paths when amax estimation is enabled
  • Feature is opt-in via NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION=1 environment variable
  • Default scale factors: 2.0 for forward input activations, 1.0 for backward gradients

Performance Impact:

  • Reduces kernel launch overhead by skipping the RHT+amax kernel when amax fusion is available
  • Trade-off: estimated amax may affect numerical accuracy compared to true post-RHT amax

Confidence Score: 4/5

  • This PR is safe to merge - it's an opt-in experimental feature behind an environment variable that doesn't affect default behavior.
  • The implementation is well-structured with consistent changes across C++, CUDA, and Python layers. The feature is properly gated behind an environment variable. The code follows existing patterns in the codebase. The only concern is ensuring the fallback path (non-fused kernel) correctly handles the amax estimation when inputs don't meet fusion kernel requirements.
  • transformer_engine/pytorch/csrc/quantizer.cpp - verify the fallback path correctly applies amax estimation when the fused RHT kernel cannot be used.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu 5/5 Core kernel changes: passes amax_scale through the RHT+cast fusion kernel to multiply global_amax by the estimation scale factor in epilogue.
transformer_engine/common/recipe/init.py 4/5 Adds amax estimation configuration to NVFP4BlockScaling recipe with env var controls. Docstrings match implementation defaults (2.0 for fwd, 1.0 for bwd).
transformer_engine/pytorch/csrc/quantizer.cpp 4/5 Core quantizer changes: handles amax estimation by computing pre-RHT amax and passing scale to fusion kernel. Adds fallback path that computes pre-RHT amax when estimation is enabled.
transformer_engine/pytorch/tensor/nvfp4_tensor.py 5/5 Adds amax_estimation_scale parameter to NVFP4Quantizer class and propagates it through copy() method.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Recipe as NVFP4BlockScaling Recipe
    participant Quantizer as NVFP4Quantizer
    participant Activation as Activation/Norm Extension
    participant Kernel as CUDA Kernel

    User->>Recipe: Create recipe with use_post_rht_amax_estimation=True
    Recipe->>Recipe: Set amax_estimation_scale (2.0 fwd, 1.0 bwd)
    Recipe->>Quantizer: Pass amax_estimation_scale via QParams
    
    alt Fused Path (amax estimation enabled)
        Activation->>Activation: Select FUSED_ACTIVATION_AMAX_NVFP4 impl
        Activation->>Kernel: Compute activation + pre-RHT amax
        Kernel-->>Quantizer: Return pre-RHT amax
        Quantizer->>Kernel: RHT cast fusion with amax_scale
        Kernel->>Kernel: global_amax_val = pre_rht_amax * amax_scale
        Kernel->>Kernel: Compute FP4 quantization with scaled amax
    else Unfused Path (true post-RHT amax)
        Activation->>Activation: Select UNFUSED impl
        Quantizer->>Kernel: nvte_hadamard_transform_amax
        Kernel-->>Quantizer: Return true post-RHT amax
        Quantizer->>Kernel: Quantize with true amax
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

negvet and others added 2 commits December 3, 2025 15:23
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Evgeny Tsykunov <e.tsykunov@gmail.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@qijiaxing
Copy link

Interesting idea. But how to set these two estimation scales?

Updated the default scale factor for forward input activations in post-RHT amax estimation to 2.0.

Signed-off-by: Evgeny Tsykunov <e.tsykunov@gmail.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@negvet
Copy link
Collaborator Author

negvet commented Dec 5, 2025

how to set these two estimation scales?

One way to go it is to estimate real distribution of the data (i.e. how amax is affected by RHT). From my experiments, I observe that amax(RHT(X)) / amax(X) is up to 2.0 and amax(RHT(G)) / amax(G) is up to 1.0 - so setting scales to 2.0 and 1.0 is an option.

In practice, quite wide range of scales is actually working well, due to wide dynamic range of e4m3 (amax misestimation is getting cancelled out eventually if stay within e4).

@negvet
Copy link
Collaborator Author

negvet commented Dec 9, 2025

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/csrc/extensions/activation.cpp, line 45-54 (link)

    style: Identical logic block duplicated in both forward and backward paths - consider extracting this decision logic into a helper function to avoid code duplication

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

12 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants