-
Notifications
You must be signed in to change notification settings - Fork 575
[PyTorch] Enable post-RHT amax estimation #2442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR introduces an experimental optimization for NVFP4 quantization that estimates post-RHT (Random Hadamard Transform) amax from pre-RHT amax using a configurable linear scale factor, eliminating the need for a separate RHT+amax kernel launch. Key Changes:
Performance Impact:
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as User Code
participant Recipe as NVFP4BlockScaling Recipe
participant Quantizer as NVFP4Quantizer
participant Activation as Activation/Norm Extension
participant Kernel as CUDA Kernel
User->>Recipe: Create recipe with use_post_rht_amax_estimation=True
Recipe->>Recipe: Set amax_estimation_scale (2.0 fwd, 1.0 bwd)
Recipe->>Quantizer: Pass amax_estimation_scale via QParams
alt Fused Path (amax estimation enabled)
Activation->>Activation: Select FUSED_ACTIVATION_AMAX_NVFP4 impl
Activation->>Kernel: Compute activation + pre-RHT amax
Kernel-->>Quantizer: Return pre-RHT amax
Quantizer->>Kernel: RHT cast fusion with amax_scale
Kernel->>Kernel: global_amax_val = pre_rht_amax * amax_scale
Kernel->>Kernel: Compute FP4 quantization with scaled amax
else Unfused Path (true post-RHT amax)
Activation->>Activation: Select UNFUSED impl
Quantizer->>Kernel: nvte_hadamard_transform_amax
Kernel-->>Quantizer: Return true post-RHT amax
Quantizer->>Kernel: Quantize with true amax
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 1 comment
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Evgeny Tsykunov <e.tsykunov@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, no comments
|
Interesting idea. But how to set these two estimation scales? |
Updated the default scale factor for forward input activations in post-RHT amax estimation to 2.0. Signed-off-by: Evgeny Tsykunov <e.tsykunov@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, no comments
One way to go it is to estimate real distribution of the data (i.e. how amax is affected by RHT). From my experiments, I observe that In practice, quite wide range of scales is actually working well, due to wide dynamic range of e4m3 (amax misestimation is getting cancelled out eventually if stay within e4). |
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/csrc/extensions/activation.cpp, line 45-54 (link)style: Identical logic block duplicated in both forward and backward paths - consider extracting this decision logic into a helper function to avoid code duplication
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
12 files reviewed, 1 comment
Signed-off-by: Evgeny <etsykunov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, no comments
Description
Post-RHT amax can be estimated from pre-RHT amax.
This PR optimizes out post-RHT amax (RHT+amax) kernel, enabling estimation of post-RHT amax from pre-RHT amax with linear scaling.
amax fusion is required to see perf benefits.
The feature is opt-in via
NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION=1.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: