-
Notifications
You must be signed in to change notification settings - Fork 640
[Fix] fix a quant scale caculation bug in quant_utils.py #4233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -18,13 +18,13 @@ def fast_log2_ceil_torch(x: torch.Tensor) -> torch.Tensor: | |||||
| return result.to(torch.int32) | ||||||
|
|
||||||
|
|
||||||
| def fast_pow2_torch(x: torch.Tensor) -> torch.Tensor: | ||||||
| def fast_pow2_torch(x: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: | ||||||
| bits_x = (x + 127) << 23 | ||||||
| return bits_x.view(torch.float32) | ||||||
| return bits_x.view(weight_dtype) | ||||||
|
|
||||||
|
|
||||||
| def fast_round_scale_torch(amax: torch.Tensor, fp8_max: torch.Tensor) -> torch.Tensor: | ||||||
| return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max)) | ||||||
| def fast_round_scale_torch(amax: torch.Tensor, fp8_max: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: | ||||||
| return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max), weight_dtype) | ||||||
|
|
||||||
|
|
||||||
| def _get_quant_scaling(weight: torch.Tensor, | ||||||
|
|
@@ -34,13 +34,22 @@ def _get_quant_scaling(weight: torch.Tensor, | |||||
| """Get the scaling factor for FP8 quantization.""" | ||||||
| finfo = torch.finfo(fp8_dtype) | ||||||
| fmax = finfo.max | ||||||
| amax = weight.abs().amax(dim, keepdim=True).clamp_min(1e-6).float() | ||||||
| eps = torch.finfo(weight.dtype).eps | ||||||
|
||||||
| eps = torch.finfo(weight.dtype).eps | |
| eps = torch.finfo(torch.float32).eps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
fast_pow2_torchfunction performs bit manipulation assuming IEEE 754 float32 representation (adding 127 bias and shifting 23 bits for the mantissa). The bitcast operation on line 23 should always return float32 to maintain correctness, regardless of the input weight dtype. Usingweight_dtypefor the view operation will produce incorrect results for non-float32 dtypes like bfloat16 or float16, as they have different bit representations. The result should always be float32 since that's what the bit pattern represents.