-
Notifications
You must be signed in to change notification settings - Fork 638
[Fix] fix a quant scale caculation bug in quant_utils.py #4233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
There are linting errors. May resolve them as follows: Make sure that the python version is 3.10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR aims to fix a quantization scale calculation bug that caused bf16 to fp8 weight conversion to differ from official fp8 model weights. The changes modify the scale computation logic and dtype handling in the quantization utilities.
Key changes:
- Modified
fast_pow2_torchandfast_round_scale_torchto accept a dtype parameter - Updated
_get_quant_scalingto use dtype-specific epsilon values and prevent division by zero - Removed explicit float32 conversion in
quant_blocked_fp8to preserve original weight dtype
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| scaling = torch.where(scaling == 0, | ||
| torch.tensor(eps, dtype=scaling.dtype, device=scaling.device), | ||
| scaling, | ||
| ) | ||
|
|
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch.tensor creation inside torch.where can be inefficient as it creates a new tensor on each call. Consider using scaling.clamp_min(eps) instead, which would be more efficient and cleaner. This would achieve the same result of ensuring scaling values are at least eps.
| scaling = torch.where(scaling == 0, | |
| torch.tensor(eps, dtype=scaling.dtype, device=scaling.device), | |
| scaling, | |
| ) | |
| scaling = scaling.clamp_min(eps) |
| finfo = torch.finfo(fp8_dtype) | ||
| fmax = finfo.max | ||
| amax = weight.abs().amax(dim, keepdim=True).clamp_min(1e-6).float() | ||
| eps = torch.finfo(weight.dtype).eps |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting eps from weight.dtype may produce unexpected behavior when weight.dtype is bfloat16. The epsilon value for bfloat16 is much larger than float32 (approximately 0.0078 vs 1.19e-07), which could significantly affect quantization accuracy. Consider using a consistent epsilon value or explicitly document this dtype-dependent behavior.
| eps = torch.finfo(weight.dtype).eps | |
| eps = torch.finfo(torch.float32).eps |
| def fast_pow2_torch(x: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: | ||
| bits_x = (x + 127) << 23 | ||
| return bits_x.view(torch.float32) | ||
| return bits_x.view(weight_dtype) |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fast_pow2_torch function performs bit manipulation assuming IEEE 754 float32 representation (adding 127 bias and shifting 23 bits for the mantissa). The bitcast operation on line 23 should always return float32 to maintain correctness, regardless of the input weight dtype. Using weight_dtype for the view operation will produce incorrect results for non-float32 dtypes like bfloat16 or float16, as they have different bit representations. The result should always be float32 since that's what the bit pattern represents.
| def fast_round_scale_torch(amax: torch.Tensor, fp8_max: torch.Tensor) -> torch.Tensor: | ||
| return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max)) | ||
| def fast_round_scale_torch(amax: torch.Tensor, fp8_max: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: | ||
| return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max), weight_dtype) |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the trailing whitespace at the end of this line.
| return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max), weight_dtype) | |
| return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max), weight_dtype) |
| def fast_round_scale_torch(amax: torch.Tensor, fp8_max: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: | ||
| return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max), weight_dtype) |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter weight_dtype is misleading because this function doesn't operate on weights but on scaling factors. Additionally, the bit manipulation in this function assumes IEEE 754 float32 representation, so the result type should always be float32. This parameter should be removed and the function should always return float32.
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Due to a bug in the scale calculation, the weights of the model converted from bf16 to fp8 in the code differ slightly from the weights of the official fp8 model.
Modification
The quant_utils.py file under the path lmdeploy/lmdeploy/lite/quantization/weight has been modified. The code to convert weight to fp32 has been removed in the def quant_blocked_fp8 function. In the def _get_quant_scaling function, the eps value under different dtypes is given to limit the amplitude to prevent division by zero.
Checklist