Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5d9c272
quant_cfg as a list
shengliangxu Mar 17, 2026
d99e4ae
Make quant_cfg a list of tuples, dict is too much
shengliangxu Mar 18, 2026
b5bea21
yaml config format update
shengliangxu Mar 18, 2026
1b8c4bf
fix some extra quant_cfg
shengliangxu Mar 18, 2026
ab4daec
fix tests
shengliangxu Mar 19, 2026
4ffd2fa
rename from format to cfg
shengliangxu Mar 19, 2026
d599103
pattern to path
shengliangxu Mar 19, 2026
fc53877
flatten the inner configs
shengliangxu Mar 19, 2026
a19335f
get rid of the special 'default'
shengliangxu Mar 19, 2026
04014ec
remove default
shengliangxu Mar 19, 2026
22134ef
match yaml file format
shengliangxu Mar 20, 2026
f52d213
fix tests
shengliangxu Mar 20, 2026
8f59142
fix guide
shengliangxu Mar 20, 2026
3cda60f
default to disable
shengliangxu Mar 20, 2026
43f9a1a
tuple format is not needed, remove all of them
shengliangxu Mar 20, 2026
4549001
final remove tuple format
shengliangxu Mar 20, 2026
30bb041
add atomicity to doc
shengliangxu Mar 20, 2026
ff9fdd9
fix more quant_cfg args
shengliangxu Mar 20, 2026
a164f13
distinguish set_quantizer_attributes_full and set_quantizer_attribute…
shengliangxu Mar 21, 2026
dc915f5
new partial set quantizer cfg for internal merging logic
shengliangxu Mar 22, 2026
10c4cdd
enable semantic documentation
shengliangxu Mar 22, 2026
a03d975
revert accidental test change
shengliangxu Mar 22, 2026
fb3bb07
fix mypy
shengliangxu Mar 22, 2026
aecf832
new tests and fix existing tests
shengliangxu Mar 23, 2026
5115452
python < 3.12
shengliangxu Mar 23, 2026
a481bd1
more fix dict to list
shengliangxu Mar 23, 2026
fe2d2f3
KV config has only quant_cfg meaningful
shengliangxu Mar 23, 2026
3a3b112
Merge branch 'main' into shengliangx/quant_cfg-list
shengliangxu Mar 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/guides/1_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Below, you can find the documentation for the quantization toolkit in ModelOpt:
./_basic_quantization.rst
./_choosing_quant_methods.rst
./_pytorch_quantization.rst
./_quant_cfg.rst
./_customized_model_quantization.rst
./_compress_quantized_models.rst
./_onnx_quantization.rst
Expand Down
31 changes: 19 additions & 12 deletions docs/source/guides/_pytorch_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,16 @@ For debugging purposes or simple customizations, you can modify an existing conf

.. code-block:: python

# Create a copy of the default INT8 configuration
config = mtq.INT8_DEFAULT_CFG.copy()
import copy

# Disable input quantizers for all layers
config["quant_cfg"]["*input_quantizer"]["enable"] = False
# Create a deep copy of the default INT8 configuration
config = copy.deepcopy(mtq.INT8_DEFAULT_CFG)

# Disable input quantizers for all layers (appended last, so it takes precedence)
config["quant_cfg"].append({"quantizer_path": "*input_quantizer", "enable": False})

# Disable all quantizers for layers matching the pattern "layer1.*"
config["quant_cfg"]["*layer1.*"] = {"enable": False}
config["quant_cfg"].append({"quantizer_path": "*layer1.*", "enable": False})

Advanced Configuration Creation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -255,16 +257,19 @@ For exploring new quantization recipes, you can compose a completely new configu

# Custom configuration for INT4 block-wise weights and INT8 dynamic activations
MY_CUSTOM_CONFIG = {
"quant_cfg": {
"quant_cfg": [
# Disable all quantizers by default, then enable selectively
{"quantizer_path": "*", "enable": False},

# Configure weight quantizers with 4-bit precision and 128-element blocks
"*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True},
{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}, "enable": True},

# Configure input quantizers with 8-bit dynamic quantization
"*input_quantizer": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}},
{"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}},

# Include default disabled quantizer configurations
**_default_disabled_quantizer_cfg,
},
*_default_disabled_quantizer_cfg,
],
Comment on lines +260 to +272
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Define or qualify _default_disabled_quantizer_cfg in this example.

This snippet now references _default_disabled_quantizer_cfg without importing or qualifying it, so it fails when copied verbatim. Please either add the import or qualify it as mtq.config._default_disabled_quantizer_cfg.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/source/guides/_pytorch_quantization.rst` around lines 260 - 272, The
example references _default_disabled_quantizer_cfg but doesn't define or import
it; fix by either adding an import for it (e.g., from mtq.config import
_default_disabled_quantizer_cfg) at the top of the snippet or qualify the
reference inline as mtq.config._default_disabled_quantizer_cfg so the example is
self-contained and runnable; update the code block containing "quant_cfg" to use
the qualified name or add the import near the example.

"algorithm": "max",
}

Expand Down Expand Up @@ -394,8 +399,10 @@ You can specify ``custom_calib`` as ``algorithm`` in ``quant_cfg`` to use it. He

# create quantization configuration with "custom_calib" method
quant_cfg = {
'quant_cfg': {'*weight_quantizer': ..},
'algorithm': {"method": 'custom_calib'},
'quant_cfg': [
{"quantizer_path": "*weight_quantizer", "cfg": {...}},
],
'algorithm': {"method": 'custom_calib'},
}


Expand Down
307 changes: 307 additions & 0 deletions docs/source/guides/_quant_cfg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
.. _quant-cfg:

======================================
Quantization Configuration (quant_cfg)
======================================

The ``quant_cfg`` field is the primary mechanism for controlling which quantizers are active in a
model and how they are configured. This guide explains the format, ordering semantics, and common
patterns for composing quantization configurations.

.. tip::

For the list of built-in configs and supported formats, see :any:`quantization-formats`.
For how to apply a config to a model, see :any:`_pytorch_quantization`.

----------

Overview
========

A quantization config is a Python dictionary with two top-level keys:

.. code-block:: python
config = {
"quant_cfg": [...], # ordered list of QuantizerCfgEntry dicts
"algorithm": "max", # calibration algorithm
}
The ``quant_cfg`` value is an **ordered list** of :class:`QuantizerCfgEntry
<modelopt.torch.quantization.config.QuantizerCfgEntry>` dicts. Each entry targets a set of
quantizer modules in the model and specifies their configuration.

----------

Entry Format
============

Each entry in the list is a dictionary with the following fields:

.. list-table::
:header-rows: 1
:widths: 20 15 65

* - Field
- Required
- Description
* - ``quantizer_path``
- Yes
- Wildcard string matched against quantizer module names (e.g. ``"*weight_quantizer"``).
Uses :func:`fnmatch` rules.
* - ``parent_class``
- No
- Restricts matching to quantizers whose immediate parent module is of this PyTorch class
(e.g. ``"nn.Linear"``). If omitted, all modules are targeted regardless of class.
* - ``cfg``
- No
- A dict of quantizer attributes as defined by :class:`QuantizerAttributeConfig
<modelopt.torch.quantization.config.QuantizerAttributeConfig>`, or a list of such dicts
for sequential quantization (see :ref:`sequential-quantizers`).
* - ``enable``
- No
- ``True`` or ``False``. Toggles matched quantizers on or off, independently of ``cfg``.
When ``cfg`` is absent, **only** the enabled/disabled state is changed — all other
attributes remain untouched. When ``cfg`` is present, ``enable`` sets the enabled state
of the newly-configured quantizer. When ``cfg`` is present and ``enable`` is omitted,
the quantizer is implicitly enabled (``True``).

.. note::

Every entry must specify at least one of ``cfg`` or ``enable`` in addition to
``quantizer_path``. An entry with only ``quantizer_path`` and no other keys is **invalid**
and will raise a ``ValueError`` at config-processing time. This prevents subtle bugs where
a bare ``{"quantizer_path": "*"}`` would silently behave as ``enable=True`` for all
quantizers.

----------

Default Quantizer Configuration
================================

When a quantizer is enabled but has never been touched by a ``cfg`` entry — either because no
entry in the list matched it, or because it was only reached by enable-only entries — it operates
with the default attributes of
:class:`QuantizerAttributeConfig <modelopt.torch.quantization.config.QuantizerAttributeConfig>`:

.. code-block:: python
{
"num_bits": 8, # 8-bit integer quantization
"axis": None, # per-tensor scale (no per-channel axis)
"fake_quant": True, # simulate quantization in forward pass (PTQ / QAT)
"unsigned": False, # signed integer range, e.g. [-128, 127] for INT8
"narrow_range": False, # full range; True would restrict to [-127, 127] for INT8
"type": "static", # static calibration (not dynamic per-inference)
"block_sizes": None, # no block quantization; set for NF4 / MXFP formats
"bias": None, # no affine bias correction
"calibrator": "max", # use max-abs calibration to determine amax
"rotate": False, # no Hadamard rotation (QuaRot / SpinQuant)
"pass_through_bwd": True, # straight-through estimator for QAT gradients
"trt_high_precision_dtype": "Float", # cast QDQ nodes to fp32 for TRT StronglyType export
"backend": None, # use the built-in quantization backend
"backend_extra_args": None, # no extra args for custom backends
"use_constant_amax": False, # calibrate amax; True hard-codes FP8 E4M3 max (448.0)
}
In practice this means an un-configured but enabled quantizer performs **INT8 per-tensor static
fake-quantization** with a max-calibrated scale. This is rarely the intended behavior — every
quantizer you want active should be explicitly configured with a ``cfg`` entry.

----------

Ordering and Precedence
=======================

Entries are applied **in list order**. Later entries override earlier ones for any quantizer they
match. This gives a clear, composable precedence model:

- Put broad rules (e.g. deny-all) **first**.
- Put format-specific enable rules **after**.
- Put fine-grained exclusions (specific layers, classes) **last**.

The recommended pattern used by all built-in configs is:

.. code-block:: python
"quant_cfg": [
# 1. Deny all quantizers by default
{"quantizer_path": "*", "enable": False},
# 2. Enable and configure the target quantizers
{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}},
{"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}},
Comment on lines +128 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# 1. Deny all quantizers by default
{"quantizer_path": "*", "enable": False},
# 2. Enable and configure the target quantizers
{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}},
{"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}},
# 1. Deny all quantizers by default
{"*", {"enable": False}},
# 2. Enable and configure the target quantizers
{"*weight_quantizer": {"num_bits": 8, "axis": 0}},
{"*input_quantizer": {"num_bits": 8, "axis": None}},

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the following fields "quantizer_path", "cfg"

# 3. Apply standard exclusions last (BatchNorm, LM head, MoE routers, etc.)
*mtq.config._default_disabled_quantizer_cfg,
]
.. note::

The deny-all entry ``{"quantizer_path": "*", "enable": False}`` is available as
:data:`modelopt.torch.quantization.config._base_disable_all` and is prepended to every
built-in config. This ensures quantizers not explicitly targeted remain disabled.
Comment on lines +123 to +143
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don’t describe _base_disable_all as universal across built-ins.

Several specialized recipes in modelopt/torch/quantization/config.py intentionally omit the deny-all prelude (for example the KV-only configs). Saying it is prepended to every built-in config will send users to the wrong pattern for those formats.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/source/guides/_quant_cfg.rst` around lines 123 - 143, The docs currently
state that _base_disable_all (a.k.a. the deny-all entry) is prepended to every
built-in config; update the text to accurately say that the deny-all entry
(referenced as _base_disable_all or mtq.config._default_disabled_quantizer_cfg)
is included in most built-in configs but not all, and call out that some
specialized recipes (e.g., the KV-only configs in the quantization config set)
intentionally omit the deny-all prelude; adjust the example wording and the note
to avoid the absolute "every built-in config" claim and add a short
parenthetical pointer to the specialized recipes that omit it (referencing their
config names such as the KV-only configs).


----------

Entry Atomicity
===============

Each ``cfg``-bearing entry in ``quant_cfg`` is a **complete, self-contained configuration unit**.
When an entry with ``cfg`` matches a quantizer, it **completely replaces** that quantizer's
configuration — it does not merge with or incrementally update settings left by earlier entries.

Concretely, if an entry specifies only a subset of quantizer attributes (e.g. only ``num_bits``),
all unspecified attributes are filled in with their default values from
:class:`QuantizerAttributeConfig <modelopt.torch.quantization.config.QuantizerAttributeConfig>`.
The resulting *complete* config is then written to the quantizer, discarding whatever any prior
matching entry had set.

This means:

- **Last cfg-entry wins, fully.** If two entries both match ``*weight_quantizer`` and both carry
a ``cfg``, the second entry does not inherit the first entry's settings — it replaces them entirely.
- **No hidden state accumulation.** The final configuration of a quantizer depends only on the
*last* ``cfg``-bearing entry in the list that matched it, making behavior easy to reason about.
- **Changing one field requires a full spec.** Because each ``cfg`` entry is a complete replacement,
to change only one attribute of a quantizer that was already configured, you must reproduce the
full desired config in the new entry. Any attribute omitted from the entry will revert to its
default, not to the value set by an earlier entry.

**Enable-only entries are the exception.** An entry with no ``cfg`` (only ``enable``) is *not* a
full replacement — it solely flips the on/off state of matched quantizers, leaving all other
attributes unchanged:

- ``{"quantizer_path": "*", "enable": False}`` disables all quantizers without touching their
configured attributes. Use this as the first step in a deny-all-then-configure pattern.
- ``{"quantizer_path": "*weight_quantizer", "enable": True}`` (no ``cfg``) re-enables weight
quantizers using whatever attributes they currently carry (or their defaults if they were never
configured by a ``cfg`` entry).

For example, given the following two entries both matching ``*weight_quantizer``:

.. code-block:: python
# Entry 1 — sets FP8 per-channel
{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": 0}},
# Entry 2 — sets INT4 blockwise (axis is NOT inherited from Entry 1)
{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}},
After Entry 2 is applied, the quantizer has ``num_bits=4``, ``block_sizes={-1: 128}``, and
``axis=None`` (the default). The ``axis=0`` set by Entry 1 is gone.

.. note::

The deny-all-then-configure pattern is safe and predictable precisely because
``{"quantizer_path": "*", "enable": False}`` **only** disables quantizers without resetting
their attributes. Subsequent ``cfg`` entries then configure targets from a known default state.

----------

Common Patterns
===============

Skipping Specific Layers
------------------------

Append a disable entry after the existing config to exclude layers matched by a path pattern.
Because it is appended last, it takes precedence over all earlier entries:

.. code-block:: python
import copy
import modelopt.torch.quantization as mtq
config = copy.deepcopy(mtq.FP8_DEFAULT_CFG)
# Skip the final projection layer
config["quant_cfg"].append({"quantizer_path": "*lm_head*", "enable": False})
model = mtq.quantize(model, config, forward_loop)
Skipping Layers by Module Class
--------------------------------

Use ``parent_class`` to target quantizers only within a specific type of layer, leaving the
same quantizer path in other layer types unaffected:

.. code-block:: python
config["quant_cfg"].append({
"quantizer_path": "*input_quantizer",
"parent_class": "nn.LayerNorm",
"enable": False,
})
Overriding Quantizer Precision for Specific Layers
---------------------------------------------------

A later entry with a matching ``quantizer_path`` replaces the configuration set by an earlier
entry. This allows per-layer precision overrides without restructuring the entire config:

.. code-block:: python
config = copy.deepcopy(mtq.FP8_DEFAULT_CFG)
# Quantize attention output projections in higher-precision INT8 instead of FP8
config["quant_cfg"].append({
"quantizer_path": "*o_proj*weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0},
})
Building a Config from Scratch
-------------------------------

For entirely custom recipes, compose the list directly:

.. code-block:: python
from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg
MY_CUSTOM_CFG = {
"quant_cfg": [
*_base_disable_all,
{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}},
{"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}},
*_default_disabled_quantizer_cfg,
],
"algorithm": "max",
}
model = mtq.quantize(model, MY_CUSTOM_CFG, forward_loop)
----------

.. _sequential-quantizers:

Sequential Quantization
=======================

When ``cfg`` is a **list** of attribute dicts, the matched
:class:`TensorQuantizer <modelopt.torch.quantization.nn.modules.tensor_quantizer.TensorQuantizer>`
is replaced with a
:class:`SequentialQuantizer <modelopt.torch.quantization.nn.modules.tensor_quantizer.SequentialQuantizer>`
that applies each format in sequence. This is used, for example, in W4A8 quantization where weights
are quantized first in INT4 and then in FP8:

.. code-block:: python
{
"quantizer_path": "*weight_quantizer",
"cfg": [
{"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}},
{"num_bits": (4, 3)}, # FP8
],
"enable": True,
}
----------

Reference
=========

- :class:`QuantizerCfgEntry <modelopt.torch.quantization.config.QuantizerCfgEntry>`
- :class:`QuantizerAttributeConfig <modelopt.torch.quantization.config.QuantizerAttributeConfig>`
- :class:`QuantizeConfig <modelopt.torch.quantization.config.QuantizeConfig>`
- :func:`set_quantizer_by_cfg <modelopt.torch.quantization.conversion.set_quantizer_by_cfg>`
Loading
Loading