Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions docs/source/guides/_quant_cfg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ Each entry in the list is a dictionary with the following fields:
(e.g. ``"nn.Linear"``). If omitted, all modules are targeted regardless of class.
* - ``cfg``
- No
- A dict of quantizer attributes as defined by :class:`QuantizerAttributeConfig
<modelopt.torch.quantization.config.QuantizerAttributeConfig>`, or a list of such dicts
for sequential quantization (see :ref:`sequential-quantizers`).
- A :class:`QuantizerAttributeConfig
<modelopt.torch.quantization.config.QuantizerAttributeConfig>`, or a list of
``QuantizerAttributeConfig`` objects for sequential quantization (see
:ref:`sequential-quantizers`). Equivalent Python dicts, YAML mappings, and lists of
dicts are still accepted for backward compatibility, but those weakly schematized forms
are deprecated.
* - ``enable``
- No
- ``True`` or ``False``. Toggles matched quantizers on or off, independently of ``cfg``.
Expand All @@ -74,6 +77,11 @@ Each entry in the list is a dictionary with the following fields:
a bare ``{"quantizer_name": "*"}`` would silently behave as ``enable=True`` for all
quantizers.

Schema-backed YAML loading parses ``cfg`` mappings into
:class:`QuantizerAttributeConfig <modelopt.torch.quantization.config.QuantizerAttributeConfig>`
values. Plain Python dicts and lists of dicts are accepted only as a backward-compatible,
weakly schematized input format.

----------

Default Quantizer Configuration
Expand Down Expand Up @@ -278,7 +286,9 @@ For entirely custom recipes, compose the list directly:
Sequential Quantization
=======================

When ``cfg`` is a **list** of attribute dicts, the matched
When ``cfg`` is a **list** of
:class:`QuantizerAttributeConfig <modelopt.torch.quantization.config.QuantizerAttributeConfig>`
values, the matched
:class:`TensorQuantizer <modelopt.torch.quantization.nn.modules.tensor_quantizer.TensorQuantizer>`
is replaced with a
:class:`SequentialQuantizer <modelopt.torch.quantization.nn.modules.tensor_quantizer.SequentialQuantizer>`
Expand All @@ -295,6 +305,11 @@ are quantized first in INT4 and then in FP8:
],
}

The list-of-dict spelling shown above remains accepted for existing Python configs and is the
natural YAML spelling, but it is a deprecated weakly schematized input form. After schema-backed
loading or :class:`QuantizeConfig <modelopt.torch.quantization.config.QuantizeConfig>` parsing,
each element is a ``QuantizerAttributeConfig``.

----------

.. _migrating-from-dict-format:
Expand Down
13 changes: 11 additions & 2 deletions examples/diffusers/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Mapping, MutableMapping

import torch.nn as nn
from calib.plugin_calib import PercentileCalibrator

from modelopt.torch.quantization.config import QuantizerAttributeConfig

FP8_DEFAULT_CONFIG = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
Expand Down Expand Up @@ -104,8 +108,13 @@ def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, **
quant_config["algorithm"] = algo_cfg

for entry in quant_config["quant_cfg"]:
p = entry.get("cfg", {})
if isinstance(p, dict) and "num_bits" in p and "trt_high_precision_dtype" not in p:
p = entry.get("cfg", {}) if isinstance(entry, Mapping) else {}
if not isinstance(p, MutableMapping):
continue
keys = p.explicit_keys() if isinstance(p, QuantizerAttributeConfig) else p.keys()
# TODO: Replace this membership-based config patching with a better config API;
# ``in``/``not in`` checks are fragile with schema-backed defaults.
if "num_bits" in keys and "trt_high_precision_dtype" not in keys:
p["trt_high_precision_dtype"] = trt_high_precision_dtype


Expand Down
35 changes: 24 additions & 11 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# limitations under the License.

import argparse
import copy
import logging
import sys
import time as time
from collections.abc import Mapping
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -49,6 +51,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_hf_checkpoint
from modelopt.torch.opt.config import ModeloptBaseConfig


def setup_logging(verbose: bool = False) -> logging.Logger:
Expand Down Expand Up @@ -119,14 +122,6 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
base_cfg = mtq.INT8_SMOOTHQUANT_CFG
else:
base_cfg = INT8_DEFAULT_CONFIG
if self.config.collect_method != CollectMethod.DEFAULT:
reset_set_int8_config(
base_cfg,
self.config.percentile,
n_steps,
collect_method=self.config.collect_method.value,
backbone=backbone,
)
elif self.config.format == QuantFormat.FP8:
base_cfg = FP8_DEFAULT_CONFIG
elif self.config.format == QuantFormat.FP4:
Expand All @@ -138,15 +133,33 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
raise NotImplementedError(f"Unknown format {self.config.format}")

# Build a fresh config dict so we never mutate the global constants.
if isinstance(base_cfg, ModeloptBaseConfig):
base_cfg = base_cfg.model_dump(exclude_unset=True)
base_cfg = copy.deepcopy(base_cfg)

if (
self.config.format == QuantFormat.INT8
and self.config.collect_method != CollectMethod.DEFAULT
):
reset_set_int8_config(
base_cfg,
self.config.percentile,
n_steps,
collect_method=self.config.collect_method.value,
backbone=backbone,
)

quant_cfg_list = list(base_cfg["quant_cfg"])

if self.config.format == QuantFormat.FP4:
for i, entry in enumerate(quant_cfg_list):
if isinstance(entry, dict) and "block_sizes" in entry.get("cfg", {}):
new_block_sizes = {**entry["cfg"]["block_sizes"], -1: self.config.block_size}
cfg = entry.get("cfg", {}) if isinstance(entry, Mapping) else {}
block_sizes = cfg.get("block_sizes") if isinstance(cfg, Mapping) else None
if isinstance(block_sizes, Mapping):
new_block_sizes = {**block_sizes, -1: self.config.block_size}
quant_cfg_list[i] = {
**entry,
"cfg": {**entry["cfg"], "block_sizes": new_block_sizes},
"cfg": {**cfg, "block_sizes": new_block_sizes},
}

if self.config.quantize_mha:
Expand Down
6 changes: 3 additions & 3 deletions examples/llm_autodeploy/run_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from modelopt.torch.utils import create_forward_loop
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader

SUPPORT_QUANT_FORMAT = {
"fp8": mtq.FP8_DEFAULT_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
SUPPORT_QUANT_FORMAT: dict[str, str] = {
"fp8": "FP8_DEFAULT_CFG",
"nvfp4": "NVFP4_DEFAULT_CFG",
}


Expand Down
3 changes: 2 additions & 1 deletion examples/llm_ptq/cast_mxfp4_to_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"""

import json
from collections.abc import Mapping
from contextlib import ExitStack, contextmanager
from pathlib import Path

Expand Down Expand Up @@ -304,7 +305,7 @@ def force_weight_quantizers_static(quant_cfg: list) -> None:
qname = entry.get("quantizer_name", "")
cfg = entry.get("cfg") or {}
bs = cfg.get("block_sizes")
if "weight_quantizer" in qname and isinstance(bs, dict):
if "weight_quantizer" in qname and isinstance(bs, Mapping):
quant_cfg[i] = {**entry, "cfg": {**cfg, "block_sizes": {**bs, "type": "static"}}}


Expand Down
45 changes: 26 additions & 19 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import shutil
import sys
import warnings
from collections.abc import Mapping, MutableMapping
from pathlib import Path
from typing import Any

Expand All @@ -41,6 +42,8 @@
ProcessorMixin,
)

from modelopt.torch.quantization.config import QuantizeConfig, QuantizerCfgEntry

try:
from huggingface_hub import snapshot_download
except ImportError:
Expand Down Expand Up @@ -203,17 +206,17 @@ def calibrate_loop(_model):

def build_quant_cfg(
qformat,
quant_cfg,
quant_cfg: QuantizeConfig | Mapping[str, Any],
awq_block_size,
model_type,
moe_calib_experts_ratio: float | None = None,
) -> dict[str, Any]:
quant_cfg = copy.deepcopy(quant_cfg)
if "awq" in str(quant_cfg.get("algorithm")):
) -> QuantizeConfig:
quant_cfg_obj: QuantizeConfig = QuantizeConfig.model_validate(copy.deepcopy(quant_cfg))
if "awq" in str(quant_cfg_obj.get("algorithm")):
from modelopt.torch.quantization.config import find_quant_cfg_entry_by_path

weight_quantizer_entry = find_quant_cfg_entry_by_path(
quant_cfg["quant_cfg"], "*weight_quantizer"
quant_cfg_obj["quant_cfg"], "*weight_quantizer"
)
weight_quantizer = weight_quantizer_entry.get("cfg") or {}
if isinstance(weight_quantizer, list):
Expand All @@ -224,34 +227,38 @@ def build_quant_cfg(

# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
quant_cfg_obj["algorithm"] = {"method": "awq_lite", "alpha_step": 1}

if moe_calib_experts_ratio:
assert 0 < moe_calib_experts_ratio <= 1, "moe_calib_experts_ratio must be between 0 and 1"
if isinstance(quant_cfg["algorithm"], str):
quant_cfg["algorithm"] = {
"method": quant_cfg["algorithm"],
if isinstance(quant_cfg_obj["algorithm"], str):
quant_cfg_obj["algorithm"] = {
"method": quant_cfg_obj["algorithm"],
"moe_calib_experts_ratio": moe_calib_experts_ratio,
}
elif isinstance(quant_cfg["algorithm"], dict):
quant_cfg["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio
elif isinstance(quant_cfg_obj["algorithm"], MutableMapping):
quant_cfg_obj["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio
else:
warnings.warn(
f"Quantization algorithm: {quant_cfg['algorithm']} does not support setting moe_calib_experts_ratio"
f"Quantization algorithm: {quant_cfg_obj['algorithm']} does not support setting moe_calib_experts_ratio"
)

# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
if model_type == "gemma" and "int8_sq" in qformat:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
quant_cfg_obj["algorithm"] = {"method": "smoothquant", "alpha": 0.5}

if model_type == "phi4mm":
# Only quantize the language model
quant_cfg["quant_cfg"].append({"quantizer_name": "*speech*", "enable": False})
quant_cfg["quant_cfg"].append({"quantizer_name": "*audio*", "enable": False})
quant_cfg["quant_cfg"].append({"quantizer_name": "*image*", "enable": False})
quant_cfg["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False})
quant_cfg_obj["quant_cfg"].extend(
[
QuantizerCfgEntry(quantizer_name="*speech*", enable=False),
QuantizerCfgEntry(quantizer_name="*audio*", enable=False),
QuantizerCfgEntry(quantizer_name="*image*", enable=False),
QuantizerCfgEntry(quantizer_name="*vision*", enable=False),
]
)

return quant_cfg
return quant_cfg_obj


def is_speculative(hf_config):
Expand Down Expand Up @@ -842,7 +849,7 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
"""Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath."""
algorithm = quant_cfg.get("algorithm")
if not isinstance(algorithm, dict):
if not isinstance(algorithm, Mapping):
return False
return algorithm.get("layerwise_checkpoint_dir") is not None

Expand Down
16 changes: 10 additions & 6 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
save_expert_token_count_table,
)
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
from modelopt.torch.quantization.config import (
QuantizeConfig,
_default_disabled_quantizer_cfg,
need_calibration,
)
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
from modelopt.torch.quantization.utils import is_quantized
from modelopt.torch.speculative.eagle.utils import (
Expand All @@ -89,18 +93,18 @@
def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"""Set use_constant_amax on KV cache quantizers.

Creates a new dict for the KV bmm quantizer config to avoid mutating shared references.
Updates the matched KV bmm quantizer entry in place.
"""
for i, entry in enumerate(quant_cfg):
for entry in quant_cfg:
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
continue
cfg = entry.get("cfg") or {}
assert isinstance(cfg, dict)
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
cfg["use_constant_amax"] = True
entry["cfg"] = cfg
Comment on lines 101 to +103
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Avoid falsy-coercing cfg; it can drop typed empty configs.

Line 101 uses entry.get("cfg") or {}. If cfg is an existing empty mapping/model, it is falsy and gets replaced by a plain dict, which can unintentionally discard schema-backed config type.

Suggested fix
-        cfg = entry.get("cfg") or {}
+        cfg = entry.get("cfg")
+        if cfg is None:
+            cfg = {}
         cfg["use_constant_amax"] = True
         entry["cfg"] = cfg
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cfg = entry.get("cfg") or {}
assert isinstance(cfg, dict)
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
cfg["use_constant_amax"] = True
entry["cfg"] = cfg
cfg = entry.get("cfg")
if cfg is None:
cfg = {}
cfg["use_constant_amax"] = True
entry["cfg"] = cfg
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/hf_ptq.py` around lines 101 - 103, The current code uses
entry.get("cfg") or {} which treats any falsy cfg (like a typed empty mapping)
as missing and replaces it; instead, fetch cfg with entry.get("cfg") and only
replace it when it is actually None or not present (e.g., if entry.get("cfg") is
None: set a new dict), then set cfg["use_constant_amax"] = True and write back
entry["cfg"] = cfg so you preserve existing typed/empty config objects;
reference the variables entry and cfg and the key "use_constant_amax" when
making the change.

break


QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
QUANT_CFG_CHOICES: dict[str, QuantizeConfig] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
Expand Down
5 changes: 2 additions & 3 deletions examples/llm_ptq/multinode_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import time
import warnings
from pathlib import Path
from typing import Any

import numpy as np
import torch
Expand All @@ -37,14 +36,14 @@
from modelopt.torch.export import get_model_type
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint
from modelopt.torch.quantization.config import need_calibration
from modelopt.torch.quantization.config import QuantizeConfig, need_calibration
from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets

# Constants
RAND_SEED = 1234

QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
QUANT_CFG_CHOICES: dict[str, QuantizeConfig] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
Expand Down
4 changes: 2 additions & 2 deletions examples/vllm_serve/vllm_ptq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import dataclasses
from collections.abc import Callable
from collections.abc import Callable, Mapping
from typing import Any

import torch
Expand Down Expand Up @@ -122,7 +122,7 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list:
(
e
for e in kv_quant_cfg
if isinstance(e, dict) and e.get("quantizer_name") == "*[kv]_bmm_quantizer"
if isinstance(e, Mapping) and e.get("quantizer_name") == "*[kv]_bmm_quantizer"
),
None,
)
Expand Down
4 changes: 1 addition & 3 deletions modelopt/onnx/llm_export_utils/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"):
else:
raise ValueError(f"Unsupported precision: {precision}")

quant_cfg_list: list = [
e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e
]
quant_cfg_list: list = list(quant_cfg["quant_cfg"])

if lm_head_precision == "fp8":
quant_cfg_list.append(
Expand Down
Loading
Loading