Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def validate_calibrator(cls, v, info: ValidationInfo):
)


class QuantizeAlgorithmConfig(ModeloptBaseConfig):
class CalibrationConfig(ModeloptBaseConfig):
"""Calibration algorithm config base."""

method: Literal[None] = ModeloptField(
Expand Down Expand Up @@ -1084,8 +1084,39 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
),
)

include_modules: list[str] | None = ModeloptField(
default=None,
title="Patterns of modules to include in calibration.",
description=(
"If provided, only modules whose names match at least one of the fnmatch patterns are "
"calibrated. Modules that do not match any pattern are skipped and retain their "
"pre-existing calibration state. "
"If a module name matches both ``include_modules`` and ``exclude_modules``, "
"exclusion takes precedence and the module is skipped. "
"Note: filtering applies only to quantized linear modules; TensorQuantizers in "
"non-linear modules (e.g. layer norms, embeddings) are unaffected."
),
)

exclude_modules: list[str] | None = ModeloptField(
default=None,
title="Patterns of modules to exclude from calibration.",
description=(
"If provided, modules whose names match at least one of the fnmatch patterns are "
"skipped during calibration and retain their pre-existing calibration state. "
"If a module name matches both ``include_modules`` and ``exclude_modules``, "
"exclusion takes precedence. "
"Note: filtering applies only to quantized linear modules; TensorQuantizers in "
"non-linear modules (e.g. layer norms, embeddings) are unaffected."
),
)


# Backward-compatible alias — deprecated, will be removed in a future release.
QuantizeAlgorithmConfig = CalibrationConfig


class MaxCalibConfig(QuantizeAlgorithmConfig):
class MaxCalibConfig(CalibrationConfig):
"""The config for max calibration algorithm.

Max calibration estimates max values of activations or weights and use this max values
Expand All @@ -1102,7 +1133,7 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
)


class MseCalibConfig(QuantizeAlgorithmConfig):
class MseCalibConfig(CalibrationConfig):
"""Configuration for per-tensor MSE calibration.

Finds a scale s (via amax a, with s = a / q_max) that minimizes the
Expand Down Expand Up @@ -1152,7 +1183,7 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
)


class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
class LocalHessianCalibConfig(CalibrationConfig):
"""Configuration for local Hessian-weighted MSE calibration.

This algorithm uses activation information to optimize per-block scales for weight
Expand Down Expand Up @@ -1219,7 +1250,7 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
)


class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
class SmoothQuantCalibConfig(CalibrationConfig):
"""The config for ``smoothquant`` algorithm (SmoothQuant).

SmoothQuant applies a smoothing factor which balances the scale of outliers in weights and activations.
Expand All @@ -1241,7 +1272,7 @@ class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
)


class AWQLiteCalibConfig(QuantizeAlgorithmConfig):
class AWQLiteCalibConfig(CalibrationConfig):
"""The config for ``awq_lite`` (AWQ lite) algorithm.

AWQ lite applies a channel-wise scaling factor which minimizes the output difference after quantization.
Expand All @@ -1265,7 +1296,7 @@ class AWQLiteCalibConfig(QuantizeAlgorithmConfig):
)


class AWQClipCalibConfig(QuantizeAlgorithmConfig):
class AWQClipCalibConfig(CalibrationConfig):
"""The config for ``awq_clip`` (AWQ clip) algorithm.

AWQ clip searches clipped amax for per-group quantization, This search requires much more compute
Expand Down Expand Up @@ -1331,7 +1362,7 @@ class AWQFullCalibConfig(AWQLiteCalibConfig, AWQClipCalibConfig):
)


class SVDQuantConfig(QuantizeAlgorithmConfig):
class SVDQuantConfig(CalibrationConfig):
"""The config for SVDQuant.

Refer to the `SVDQuant paper <https://arxiv.org/pdf/2411.05007>`_ for more details.
Expand All @@ -1349,7 +1380,7 @@ class SVDQuantConfig(QuantizeAlgorithmConfig):
)


class GPTQLiteConfig(QuantizeAlgorithmConfig):
class GPTQLiteConfig(CalibrationConfig):
"""The config for GPTQ lite.

GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation.
Expand Down Expand Up @@ -1394,7 +1425,7 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig):
| dict[str | Callable, QuantizerAttributeConfig | list[QuantizerAttributeConfig]],
]

_QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None
_QuantizeAlgoCfgType = str | dict | CalibrationConfig | None

QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None

Expand Down
70 changes: 37 additions & 33 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
AWQClipCalibConfig,
AWQFullCalibConfig,
AWQLiteCalibConfig,
CalibrationConfig,
CompressConfig,
GPTQLiteConfig,
LocalHessianCalibConfig,
MaxCalibConfig,
MseCalibConfig,
QuantizeAlgoCfgType,
QuantizeAlgorithmConfig,
QuantizeConfig,
SmoothQuantCalibConfig,
SVDQuantConfig,
Expand All @@ -59,6 +59,7 @@
)
from .model_calib import (
awq,
filter_calib_modules,
gptq_lite,
local_hessian_calibrate,
max_calibrate,
Expand Down Expand Up @@ -210,7 +211,7 @@ def name(self) -> str:

def wrapped_calib_func(
model: ModelLikeModule,
config: QuantizeAlgorithmConfig,
config: CalibrationConfig,
forward_loop: ForwardLoop | None = None,
func: Callable | None = None,
) -> ConvertReturnType:
Expand All @@ -223,6 +224,8 @@ def wrapped_calib_func(
kwargs = config.model_dump()
method = kwargs.pop("method")
sequential = kwargs.pop("use_sequential", False)
include_modules = kwargs.pop("include_modules", None)
exclude_modules = kwargs.pop("exclude_modules", None)
if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method
Expand All @@ -243,22 +246,23 @@ def wrapped_calib_func(
module._moe_count_expert_calib_tokens = True

if func is not None:
if sequential:
if forward_loop is None:
raise ValueError("forward_loop is required for calibration but got None.")
assert method in ["max"], (
f"Sequential calibration currently only supports max calibration, got {method}"
)
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)
with filter_calib_modules(model, include_modules, exclude_modules):
if sequential:
if forward_loop is None:
raise ValueError("forward_loop is required for calibration but got None.")
assert method in ["max"], (
f"Sequential calibration currently only supports max calibration, got {method}"
)
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
Comment on lines +249 to +262
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep module filtering out of sequential activation collection.

Line 249 wraps the entire sequential_calibrate(...) call. sequential_calibrate() first recomputes each layer's inputs from the full model before invoking calib_func(layer, ...), so excluded modules are also disabled during those recomputation passes. That means later included layers can be calibrated against activations that no longer reflect the already-calibrated earlier layers. Please scope the filter to the per-layer calibration call after inputs are captured, not to the whole sequential wrapper.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/mode.py` around lines 249 - 262, The current code
applies filter_calib_modules around the entire sequential_calibrate call, which
disables excluded modules during the recomputation of layer inputs; instead,
remove the outer with filter_calib_modules(...) wrapper and apply filtering only
around the per-layer calibration invocation after inputs are captured.
Concretely, stop wrapping sequential_calibrate(model, ...) with
filter_calib_modules; either update sequential_calibrate to call
calib_func(layer, ...) inside a with filter_calib_modules(model,
include_modules, exclude_modules) block for each layer after inputs are
recomputed, or wrap the calib_func argument (func) with a small wrapper that
enters filter_calib_modules only when calling calib_func for that layer so
include_modules/exclude_modules only affect the calibration call, not the input
recomputation. Ensure symbols referenced are forward_loop, sequential_calibrate,
func/calib_func, filter_calib_modules, include_modules, exclude_modules, and
model.

else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)

# Lets get the latest metadata for the quantizer states
metadata = {}
Expand All @@ -270,7 +274,7 @@ class BaseCalibrateModeDescriptor(ModeDescriptor):
"""Base class for quantization calibration algorithm modes.

All calibration algorithm modes must be derived from this base class.
In addition, the `config_class` for the mode must return a subclass of :class:`QuantizeAlgorithmConfig`.
In addition, the `config_class` for the mode must return a subclass of :class:`CalibrationConfig`.

This base class also provides some convenient wrappers/utilities for calibration algorithms to be
translated into ModelOpt mode.
Expand All @@ -289,8 +293,8 @@ class BaseCalibrateModeDescriptor(ModeDescriptor):

def __init__(self, *args, **kwargs):
"""Initialize Base calibrate mode descriptor."""
assert issubclass(self.config_class, QuantizeAlgorithmConfig), (
f"`config_class` of {self.__class__} must be a subclass of `QuantizeAlgorithmConfig`!, "
assert issubclass(self.config_class, CalibrationConfig), (
f"`config_class` of {self.__class__} must be a subclass of `CalibrationConfig`!, "
f"got {self.config_class}!"
)
super().__init__(*args, **kwargs)
Expand All @@ -311,7 +315,7 @@ def name(self) -> str:

@property
@abstractmethod
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""

@property
Expand Down Expand Up @@ -386,9 +390,9 @@ class NoneCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for no calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return QuantizeAlgorithmConfig
return CalibrationConfig

_calib_func = None

Expand All @@ -398,7 +402,7 @@ class MaxCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for max calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return MaxCalibConfig

Expand All @@ -410,7 +414,7 @@ class MseCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for mse calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return MseCalibConfig

Expand All @@ -426,7 +430,7 @@ class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor):
"""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return LocalHessianCalibConfig

Expand All @@ -438,7 +442,7 @@ class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for smoothquant calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return SmoothQuantCalibConfig

Expand All @@ -450,7 +454,7 @@ class AWQLiteModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for AWQ lite calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return AWQLiteCalibConfig

Expand All @@ -462,7 +466,7 @@ class AWQClipModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for AWQ clip calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return AWQClipCalibConfig

Expand All @@ -474,7 +478,7 @@ class AWQFullModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for AWQ full calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return AWQFullCalibConfig

Expand All @@ -486,7 +490,7 @@ class SVDQuantModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for SVDQuant calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return SVDQuantConfig

Expand All @@ -503,7 +507,7 @@ class GPTQLiteModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for GPTQ calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return GPTQLiteConfig

Expand Down
Loading
Loading