-
Notifications
You must be signed in to change notification settings - Fork 307
Add calib_include/exclude_modules to calibration algorithms #1043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2a0feaf
378d524
22bd94c
5a8c907
c1ad3b2
68a4cee
8e185df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,13 +36,13 @@ | |
| AWQClipCalibConfig, | ||
| AWQFullCalibConfig, | ||
| AWQLiteCalibConfig, | ||
| CalibrationConfig, | ||
| CompressConfig, | ||
| GPTQLiteConfig, | ||
| LocalHessianCalibConfig, | ||
| MaxCalibConfig, | ||
| MseCalibConfig, | ||
| QuantizeAlgoCfgType, | ||
| QuantizeAlgorithmConfig, | ||
| QuantizeConfig, | ||
| SmoothQuantCalibConfig, | ||
| SVDQuantConfig, | ||
|
|
@@ -59,6 +59,7 @@ | |
| ) | ||
| from .model_calib import ( | ||
| awq, | ||
| filter_calib_modules, | ||
| gptq_lite, | ||
| local_hessian_calibrate, | ||
| max_calibrate, | ||
|
|
@@ -210,7 +211,7 @@ def name(self) -> str: | |
|
|
||
| def wrapped_calib_func( | ||
| model: ModelLikeModule, | ||
| config: QuantizeAlgorithmConfig, | ||
| config: CalibrationConfig, | ||
| forward_loop: ForwardLoop | None = None, | ||
| func: Callable | None = None, | ||
| ) -> ConvertReturnType: | ||
|
|
@@ -223,6 +224,8 @@ def wrapped_calib_func( | |
| kwargs = config.model_dump() | ||
| method = kwargs.pop("method") | ||
| sequential = kwargs.pop("use_sequential", False) | ||
| include_modules = kwargs.pop("include_modules", None) | ||
| exclude_modules = kwargs.pop("exclude_modules", None) | ||
| if method is not None and "awq" in method: | ||
| # For backward compatibility | ||
| kwargs["algorithm"] = method | ||
|
|
@@ -243,22 +246,23 @@ def wrapped_calib_func( | |
| module._moe_count_expert_calib_tokens = True | ||
|
|
||
| if func is not None: | ||
| if sequential: | ||
| if forward_loop is None: | ||
| raise ValueError("forward_loop is required for calibration but got None.") | ||
| assert method in ["max"], ( | ||
| f"Sequential calibration currently only supports max calibration, got {method}" | ||
| ) | ||
| # Wrap with sequential processing | ||
| sequential_calibrate( | ||
| model, | ||
| forward_loop=forward_loop, | ||
| calib_func=func, | ||
| **kwargs, | ||
| ) | ||
| else: | ||
| # Direct calibration (existing behavior) | ||
| func(model, forward_loop=forward_loop, **kwargs) | ||
| with filter_calib_modules(model, include_modules, exclude_modules): | ||
| if sequential: | ||
| if forward_loop is None: | ||
| raise ValueError("forward_loop is required for calibration but got None.") | ||
| assert method in ["max"], ( | ||
| f"Sequential calibration currently only supports max calibration, got {method}" | ||
| ) | ||
| # Wrap with sequential processing | ||
| sequential_calibrate( | ||
| model, | ||
| forward_loop=forward_loop, | ||
| calib_func=func, | ||
| **kwargs, | ||
| ) | ||
|
Comment on lines
+249
to
+262
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep module filtering out of sequential activation collection. Line 249 wraps the entire 🤖 Prompt for AI Agents |
||
| else: | ||
| # Direct calibration (existing behavior) | ||
| func(model, forward_loop=forward_loop, **kwargs) | ||
|
|
||
| # Lets get the latest metadata for the quantizer states | ||
| metadata = {} | ||
|
|
@@ -270,7 +274,7 @@ class BaseCalibrateModeDescriptor(ModeDescriptor): | |
| """Base class for quantization calibration algorithm modes. | ||
|
|
||
| All calibration algorithm modes must be derived from this base class. | ||
| In addition, the `config_class` for the mode must return a subclass of :class:`QuantizeAlgorithmConfig`. | ||
| In addition, the `config_class` for the mode must return a subclass of :class:`CalibrationConfig`. | ||
|
|
||
| This base class also provides some convenient wrappers/utilities for calibration algorithms to be | ||
| translated into ModelOpt mode. | ||
|
|
@@ -289,8 +293,8 @@ class BaseCalibrateModeDescriptor(ModeDescriptor): | |
|
|
||
| def __init__(self, *args, **kwargs): | ||
| """Initialize Base calibrate mode descriptor.""" | ||
| assert issubclass(self.config_class, QuantizeAlgorithmConfig), ( | ||
| f"`config_class` of {self.__class__} must be a subclass of `QuantizeAlgorithmConfig`!, " | ||
| assert issubclass(self.config_class, CalibrationConfig), ( | ||
| f"`config_class` of {self.__class__} must be a subclass of `CalibrationConfig`!, " | ||
| f"got {self.config_class}!" | ||
| ) | ||
| super().__init__(*args, **kwargs) | ||
|
|
@@ -311,7 +315,7 @@ def name(self) -> str: | |
|
|
||
| @property | ||
| @abstractmethod | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
|
|
||
| @property | ||
|
|
@@ -386,9 +390,9 @@ class NoneCalibrateModeDescriptor(BaseCalibrateModeDescriptor): | |
| """Mode for no calibration algorithm.""" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return QuantizeAlgorithmConfig | ||
| return CalibrationConfig | ||
|
|
||
| _calib_func = None | ||
|
|
||
|
|
@@ -398,7 +402,7 @@ class MaxCalibrateModeDescriptor(BaseCalibrateModeDescriptor): | |
| """Mode for max calibration algorithm.""" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return MaxCalibConfig | ||
|
|
||
|
|
@@ -410,7 +414,7 @@ class MseCalibrateModeDescriptor(BaseCalibrateModeDescriptor): | |
| """Mode for mse calibration algorithm.""" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return MseCalibConfig | ||
|
|
||
|
|
@@ -426,7 +430,7 @@ class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor): | |
| """ | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return LocalHessianCalibConfig | ||
|
|
||
|
|
@@ -438,7 +442,7 @@ class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor): | |
| """Mode for smoothquant calibration algorithm.""" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return SmoothQuantCalibConfig | ||
|
|
||
|
|
@@ -450,7 +454,7 @@ class AWQLiteModeDescriptor(BaseCalibrateModeDescriptor): | |
| """Mode for AWQ lite calibration algorithm.""" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return AWQLiteCalibConfig | ||
|
|
||
|
|
@@ -462,7 +466,7 @@ class AWQClipModeDescriptor(BaseCalibrateModeDescriptor): | |
| """Mode for AWQ clip calibration algorithm.""" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return AWQClipCalibConfig | ||
|
|
||
|
|
@@ -474,7 +478,7 @@ class AWQFullModeDescriptor(BaseCalibrateModeDescriptor): | |
| """Mode for AWQ full calibration algorithm.""" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return AWQFullCalibConfig | ||
|
|
||
|
|
@@ -486,7 +490,7 @@ class SVDQuantModeDescriptor(BaseCalibrateModeDescriptor): | |
| """Mode for SVDQuant calibration algorithm.""" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return SVDQuantConfig | ||
|
|
||
|
|
@@ -503,7 +507,7 @@ class GPTQLiteModeDescriptor(BaseCalibrateModeDescriptor): | |
| """Mode for GPTQ calibration algorithm.""" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[QuantizeAlgorithmConfig]: | ||
| def config_class(self) -> type[CalibrationConfig]: | ||
| """Specifies the config class for the mode.""" | ||
| return GPTQLiteConfig | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.