Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,31 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor:
return scaling_factor


def _ensure_weight_quantizer_calibrated(
weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = ""
) -> None:
"""Calibrate weight quantizer if amax is not set.

This is a lazy calibration pattern used during export when weight quantizers
may not have been calibrated during the main calibration phase.

Args:
weight_quantizer: The weight quantizer to calibrate
weight: The weight tensor to use for calibration
module_name: Optional module name for better warning messages
"""
if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None:
warn(
f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. "
f"Computing amax from weights. This may occur if: "
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
)
weight_quantizer.reset_amax()
enable_stats_collection(weight_quantizer)
weight_quantizer(weight)
finish_stats_collection(weight_quantizer)


def get_activation_scaling_factor(
module: nn.Module, input_quantizer_name: str = "input_quantizer"
) -> torch.Tensor:
Expand Down Expand Up @@ -279,6 +304,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)

if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
Expand Down Expand Up @@ -307,13 +336,26 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
if weight_quantizer is None:
return None

if get_quantization_format(module) in [
quantization_format = get_quantization_format(module)

# Calibrate weight quantizer if amax is not set for all NVFP4 variants
if quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
weight = getattr(module, weight_name)
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)

if quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
return weight_quantizer._amax.float() / 448.0
Expand Down