Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions lmdeploy/lite/quantization/weight/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def fast_log2_ceil_torch(x: torch.Tensor) -> torch.Tensor:
return result.to(torch.int32)


def fast_pow2_torch(x: torch.Tensor) -> torch.Tensor:
def fast_pow2_torch(x: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor:
bits_x = (x + 127) << 23
return bits_x.view(torch.float32)
return bits_x.view(weight_dtype)
Comment on lines +21 to +23
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fast_pow2_torch function performs bit manipulation assuming IEEE 754 float32 representation (adding 127 bias and shifting 23 bits for the mantissa). The bitcast operation on line 23 should always return float32 to maintain correctness, regardless of the input weight dtype. Using weight_dtype for the view operation will produce incorrect results for non-float32 dtypes like bfloat16 or float16, as they have different bit representations. The result should always be float32 since that's what the bit pattern represents.

Copilot uses AI. Check for mistakes.


def fast_round_scale_torch(amax: torch.Tensor, fp8_max: torch.Tensor) -> torch.Tensor:
return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max))
def fast_round_scale_torch(amax: torch.Tensor, fp8_max: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor:
return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max), weight_dtype)


def _get_quant_scaling(weight: torch.Tensor,
Expand All @@ -34,13 +34,22 @@ def _get_quant_scaling(weight: torch.Tensor,
"""Get the scaling factor for FP8 quantization."""
finfo = torch.finfo(fp8_dtype)
fmax = finfo.max
amax = weight.abs().amax(dim, keepdim=True).clamp_min(1e-6).float()
eps = torch.finfo(weight.dtype).eps
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting eps from weight.dtype may produce unexpected behavior when weight.dtype is bfloat16. The epsilon value for bfloat16 is much larger than float32 (approximately 0.0078 vs 1.19e-07), which could significantly affect quantization accuracy. Consider using a consistent epsilon value or explicitly document this dtype-dependent behavior.

Suggested change
eps = torch.finfo(weight.dtype).eps
eps = torch.finfo(torch.float32).eps

Copilot uses AI. Check for mistakes.

amax = weight.abs().amax(dim, keepdim=True)

if scale_fmt == 'ue8m0':
return fast_round_scale_torch(amax, fmax)
scaling = fast_round_scale_torch(amax, fmax, weight.dtype)
else:
# default
scaling = amax / fmax

scaling = torch.where(
scaling == 0,
torch.tensor(eps, dtype=scaling.dtype, device=scaling.device),
scaling,
)

return scaling


Expand All @@ -65,7 +74,6 @@ def quant_blocked_fp8(weight: torch.Tensor,

# reverse pixel shuffle
weight = weight.unflatten(-2, (-1, block_size)).unflatten(-1, (-1, block_size))
weight = weight.to(torch.float32)

# get scaling
scaling = _get_quant_scaling(weight, fp8_dtype, dim=(-3, -1), scale_fmt=scale_fmt)
Expand Down