Skip to content

Add calib_include/exclude_modules to calibration algorithms#1043

Open
Fridah-nv wants to merge 4 commits intomainfrom
fridah/in_exclude_mod
Open

Add calib_include/exclude_modules to calibration algorithms#1043
Fridah-nv wants to merge 4 commits intomainfrom
fridah/in_exclude_mod

Conversation

@Fridah-nv
Copy link
Contributor

@Fridah-nv Fridah-nv commented Mar 16, 2026

What does this PR do?

Type of change: New feature

Adds calib_include_modules and calib_exclude_modules fields to QuantizeAlgorithmConfig
so users can restrict any calibration algorithm (max, mse, smoothquant, awq, …) to a
subset of the model's layers. Patterns are fnmatch wildcards matched against module names
(e.g. "*lm_head*", "*self_attn*").

Implementation:

  • New filter_calib_modules context manager in model_calib.py temporarily disables
    TensorQuantizer instances in non-matching modules. TensorQuantizer.disable() does not
    clear _amax, so excluded modules retain their pre-existing calibration state.
  • Fields live on the base QuantizeAlgorithmConfig, so filtering applies uniformly to all
    algorithms with no per-algorithm changes.
  • wrapped_calib_func in mode.py pops these fields and wraps every calibration call in
    filter_calib_modules automatically.

Interaction with "enable": false in quant_cfg:
Quantizers disabled via quant_cfg (i.e. _disabled=True) are skipped by filter_calib_modules
— they are never added to the restore list and are never re-enabled. Their disabled state is fully
preserved regardless of calib_exclude/include_modules.

Lower-level API:
When calling calibration functions directly (outside mtq.calibrate()), wrap manually with the
context manager:

from modelopt.torch.quantization.model_calib import filter_calib_modules, mse_calibrate

with filter_calib_modules(model, exclude_modules=["*lm_head*"]):
    mse_calibrate(model, forward_loop)

Usage

    quant_cfg = copy.deepcopy(mtq.INT8_DEFAULT_CFG)
    quant_cfg["algorithm"] = {"method": "max", "calib_exclude_modules": ["*attn*"]}

    mtq.quantize(model, quant_cfg, forward_loop=forward_loop)

Testing

Added 6 new unit tests in tests/unit/torch/quantization/test_calib.py:

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ❌ (pending)

Additional Information

Summary by CodeRabbit

  • New Features

    • Selectively include or exclude model modules during quantization calibration via new module-filtering options.
  • Tests

    • Added comprehensive tests covering include/exclude patterns, wildcard matching, no-op behavior, and integration with existing calibration flows.

Adds calib_include_modules and calib_exclude_modules fields to
QuantizeAlgorithmConfig so users can restrict any calibration algorithm
(max, mse, smoothquant, awq, ...) to a subset of the model's layers.
Filtering is applied via the new filter_calib_modules context manager,
which temporarily disables TensorQuantizer instances in non-matching
modules while preserving their pre-existing _amax values.

Also exposes --calib_include_modules / --calib_exclude_modules CLI args
in the hf_ptq.py example and wires them through build_quant_cfg in
example_utils.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv self-assigned this Mar 16, 2026
@Fridah-nv Fridah-nv requested review from a team as code owners March 16, 2026 03:35
@Fridah-nv Fridah-nv requested a review from Edwardf0t1 March 16, 2026 03:35
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 16, 2026

📝 Walkthrough

Walkthrough

Adds calibration module filtering via two new config fields, a context manager that scopes calibration to fnmatch-matched modules, integrates that context into the calibration wrapper, adds tests for include/exclude/wildcard behaviors, and makes a small deepcopy fix in example quant config construction.

Changes

Cohort / File(s) Summary
Configuration Schema
modelopt/torch/quantization/config.py
Added calib_include_modules and calib_exclude_modules (`list[str]
Calibration Core
modelopt/torch/quantization/model_calib.py
Added filter_calib_modules(model, include_modules=None, exclude_modules=None) context manager that temporarily disables TensorQuantizers for non-matching modules; exported via __all__.
Calibration Integration
modelopt/torch/quantization/mode.py
Wrapped calibration execution to extract include/exclude lists from config and run calibration inside filter_calib_modules context; updated imports.
Examples (minor)
examples/llm_ptq/example_utils.py
Constructs quant_cfg via deepcopy of chosen template to avoid in-place mutation; removed a redundant deepcopy in AWQ branch.
CLI (whitespace only)
examples/llm_ptq/hf_ptq.py
Whitespace/formatting change in parse_args() (no behavioral change).
Tests
tests/unit/torch/quantization/test_calib.py
Added comprehensive tests and helpers covering include/exclude/no-op/wildcard filtering; imports filter_calib_modules and max_calibrate.

Sequence Diagram(s)

sequenceDiagram
    participant Runner as Quantize Runner / Caller
    participant Config as Quantize Config
    participant Mode as calibration wrapper (mode.py)
    participant Filter as filter_calib_modules
    participant Model as Model & TensorQuantizers

    Runner->>Config: provide quant_cfg (with calib_include/exclude)
    Runner->>Mode: start calibration with quant_cfg
    Mode->>Filter: __enter__ (include/exclude patterns)
    Filter->>Model: disable quantizers for non-matching modules
    Mode->>Model: perform calibration (update amax)
    Mode->>Filter: __exit__
    Filter->>Model: restore quantizers to original state
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.75% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and accurately describes the main change: adding calib_include_modules and calib_exclude_modules fields to calibration algorithms in QuantizeAlgorithmConfig.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. No unsafe torch.load/numpy.load calls, eval/exec usage, nosec comments, or hardcoded unsafe settings. fnmatch patterns used safely for module filtering.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fridah/in_exclude_mod
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can suggest fixes for GitHub Check annotations.

Configure the reviews.tools.github-checks setting to adjust the time to wait for GitHub Checks to complete.

@codecov
Copy link

codecov bot commented Mar 16, 2026

Codecov Report

❌ Patch coverage is 91.66667% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.14%. Comparing base (58417e5) to head (5a8c907).
⚠️ Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/mode.py 62.50% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1043      +/-   ##
==========================================
+ Coverage   70.07%   70.14%   +0.06%     
==========================================
  Files         221      221              
  Lines       25499    25572      +73     
==========================================
+ Hits        17869    17938      +69     
- Misses       7630     7634       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/unit/torch/quantization/test_calib.py (1)

469-473: Strengthen the no-op assertion to validate actual no-op behavior.

This currently only checks that amax is present. It should compare against the baseline snapshot to verify no values changed.

💡 Proposed test tightening
     # Amaxes should be consistent with standard max calibration (not None)
     for name in amaxes_before:
         amax_after = _get_weight_amax(model, name)
-        assert amax_after is not None, f"{name} should have a valid amax after calibration"
+        assert torch.allclose(amaxes_before[name], amax_after), (
+            f"{name} changed under no-op filter context"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/test_calib.py` around lines 469 - 473, The test
currently only asserts presence of amax after calibration; strengthen it by
asserting the amax values did not change: for each name in amaxes_before, fetch
amax_after using _get_weight_amax(model, name) and assert equality (or
approximate equality if floats) against amaxes_before[name] instead of just
checking not None; update the loop that references amaxes_before and
_get_weight_amax to perform the comparison (use math.isclose or pytest.approx
for float comparisons).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 252-266: The code currently mutates quant_cfg["algorithm"] in
place (when it's a dict or converted from a string), which can leak changes into
shared presets; update the logic to create a copy of the algorithm dict before
adding calib_exclude_modules/calib_include_modules (e.g., if
isinstance(quant_cfg["algorithm"], str) set quant_cfg["algorithm"] = {"method":
...} as a new dict, and if it's a dict replace it with a shallow or deep copy
like new_alg = dict(quant_cfg["algorithm"]) or copy.deepcopy(...) and assign
quant_cfg["algorithm"] = new_alg) then add the calib keys to that copy,
referencing quant_cfg["algorithm"], calib_exclude_modules, and
calib_include_modules when applying the changes.

In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1249-1258: The parsed module-pattern lists
args.calib_exclude_modules and args.calib_include_modules must drop
empty/whitespace-only entries; update the list comprehensions to filter out
items where p.strip() is empty (e.g., use [p.strip() for p in
args.calib_exclude_modules.split(",") if p.strip()] and similarly for
calib_include_modules) and keep the existing conditional that yields None when
the original arg is falsy.

In `@modelopt/torch/quantization/model_calib.py`:
- Around line 107-113: The current loop only inspects modules where
is_quantized_linear(module) is true, so TensorQuantizer instances in non-linear
modules are not disabled when _should_calibrate(name) is false; change the logic
to iterate all modules from model.named_modules(), and for any module whose name
fails _should_calibrate(name) traverse its children to find TensorQuantizer
instances and call disable() on those not already _disabled, appending them to
disabled (i.e., replace the is_quantized_linear(module) guard with a direct
check of _should_calibrate(name) and disable all TensorQuantizer children
accordingly).

---

Nitpick comments:
In `@tests/unit/torch/quantization/test_calib.py`:
- Around line 469-473: The test currently only asserts presence of amax after
calibration; strengthen it by asserting the amax values did not change: for each
name in amaxes_before, fetch amax_after using _get_weight_amax(model, name) and
assert equality (or approximate equality if floats) against amaxes_before[name]
instead of just checking not None; update the loop that references amaxes_before
and _get_weight_amax to perform the comparison (use math.isclose or
pytest.approx for float comparisons).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 244f9095-d6a1-4b28-b47a-e77c51d699d9

📥 Commits

Reviewing files that changed from the base of the PR and between 1070d89 and 2a0feaf.

📒 Files selected for processing (6)
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/mode.py
  • modelopt/torch/quantization/model_calib.py
  • tests/unit/torch/quantization/test_calib.py

- Fix shared preset mutation in build_quant_cfg by always deep-copying
  the preset dict before modification (previously only awq path did this)
- Document linear-only filtering limitation in filter_calib_modules
  docstring and calib_include/exclude_modules field descriptions
- Filter empty strings from CLI pattern parsing in hf_ptq.py to handle
  trailing commas gracefully
- Strengthen test_filter_no_op_when_none to assert amax value equality
  rather than just non-None presence

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 251-265: The Gemma/smoothquant override replaces
quant_cfg["algorithm"] with a new dict and inadvertently drops
calib_include_modules/calib_exclude_modules; when you set the Gemma override
(the code path that assigns quant_cfg["algorithm"] = {"method": "int8_sq", ...}
or similar), merge or copy any existing
calib_include_modules/calib_exclude_modules from the previous
quant_cfg["algorithm"] (or from the local
calib_include_modules/calib_exclude_modules variables) into the new dict instead
of overwriting them — i.e., build the override dict then set
override_dict["calib_include_modules"]=calib_include_modules (if present) and
override_dict["calib_exclude_modules"]=calib_exclude_modules (if present) before
assigning back to quant_cfg["algorithm"] so the filters are preserved for
Gemma/smoothquant.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a2cf9898-e53a-4981-8c68-5cdd70010bd6

📥 Commits

Reviewing files that changed from the base of the PR and between 2a0feaf and 378d524.

📒 Files selected for processing (5)
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/model_calib.py
  • tests/unit/torch/quantization/test_calib.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/quantization/config.py
  • examples/llm_ptq/hf_ptq.py

Fridah-nv and others added 2 commits March 16, 2026 23:15
Users should set calib_include_modules / calib_exclude_modules directly
in the algorithm dict of their quantization config rather than via
dedicated CLI flags. Remove --calib_exclude_modules / --calib_include_modules
from hf_ptq.py and the corresponding parameters from build_quant_cfg.

Update test_filter_via_config_api to exercise the intended usage path:
embedding both fields in the algorithm dict and calling mtq.quantize,
covering exclude and include variants and asserting that uncalibrated
module _amax buffers are absent.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
calib_include/exclude_modules is a core library feature accessed via
the algorithm config dict; example scripts should not be modified.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
examples/llm_ptq/example_utils.py (1)

250-251: ⚠️ Potential issue | 🟠 Major

Gemma override still drops calibration filters and other algorithm fields.

Replacing quant_cfg["algorithm"] here discards previously set keys (e.g., calib_include_modules, calib_exclude_modules, and moe_calib_experts_ratio), so filtering silently stops working on this path.

Suggested fix
-    if model_type == "gemma" and "int8_sq" in qformat:
-        quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
+    if model_type == "gemma" and "int8_sq" in qformat:
+        if isinstance(quant_cfg.get("algorithm"), dict):
+            quant_cfg["algorithm"]["method"] = "smoothquant"
+            quant_cfg["algorithm"]["alpha"] = 0.5
+        else:
+            quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/example_utils.py` around lines 250 - 251, In the branch
where model_type == "gemma" and "int8_sq" in qformat, don't replace
quant_cfg["algorithm"] outright (which discards existing keys like
calib_include_modules, calib_exclude_modules, and moe_calib_experts_ratio);
instead merge the new algorithm entries into the existing dict (e.g., ensure
quant_cfg.get("algorithm", {}) is updated with {"method": "smoothquant",
"alpha": 0.5"}) so existing calibration/filtering fields are preserved while
setting/overriding only the needed algorithm keys.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 250-251: In the branch where model_type == "gemma" and "int8_sq"
in qformat, don't replace quant_cfg["algorithm"] outright (which discards
existing keys like calib_include_modules, calib_exclude_modules, and
moe_calib_experts_ratio); instead merge the new algorithm entries into the
existing dict (e.g., ensure quant_cfg.get("algorithm", {}) is updated with
{"method": "smoothquant", "alpha": 0.5"}) so existing calibration/filtering
fields are preserved while setting/overriding only the needed algorithm keys.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fd05d10c-16cc-4a19-aaa2-7c37c90c903f

📥 Commits

Reviewing files that changed from the base of the PR and between 378d524 and 22bd94c.

📒 Files selected for processing (3)
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/hf_ptq.py
  • tests/unit/torch/quantization/test_calib.py
💤 Files with no reviewable changes (1)
  • examples/llm_ptq/hf_ptq.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/torch/quantization/test_calib.py

),
)

calib_include_modules: list[str] | None = ModeloptField(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the calib prefix? Is not it obvious that this is for calib include_modules?

Suggested change
calib_include_modules: list[str] | None = ModeloptField(
include_modules: list[str] | None = ModeloptField(

Comment on lines +418 to +421
# net.4 should be untouched
assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), (
"Excluded module net.4 should have unchanged amax"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using the same dataset for mtq.calibrate? How will amax change if we do this?
How about we dont pass in forward_loop in _make_quantized_mlp() and pass forward_loop in mtq.calibrate( -> this way net.4 wont have amax

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Left some comments - please address it.

else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)
with filter_calib_modules(model, calib_include_modules, calib_exclude_modules):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check. If in the future we want to run multiple algorithms in the sequential flow, for example local_hessian followed by gptq will this work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another follow up question. Do we plan on running different calibration algorithms for different layers in the future? Can this context manager be helpful in that case?

@cjluo-nv
Copy link
Collaborator

Questions:

  1. Will there be a case that include and exclude have conflicts (a module listed in both)?
  2. If a module is not in the include list or exclude list, what's the behavior?

Comment on lines +1087 to +1108
calib_include_modules: list[str] | None = ModeloptField(
default=None,
title="Patterns of modules to include in calibration.",
description=(
"If provided, only modules whose names match at least one of the fnmatch patterns are "
"calibrated. Modules that do not match any pattern are skipped and retain their "
"pre-existing calibration state. "
"Note: filtering applies only to quantized linear modules; TensorQuantizers in "
"non-linear modules (e.g. layer norms, embeddings) are unaffected."
),
)

calib_exclude_modules: list[str] | None = ModeloptField(
default=None,
title="Patterns of modules to exclude from calibration.",
description=(
"If provided, modules whose names match at least one of the fnmatch patterns are "
"skipped during calibration and retain their pre-existing calibration state. "
"Note: filtering applies only to quantized linear modules; TensorQuantizers in "
"non-linear modules (e.g. layer norms, embeddings) are unaffected."
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user passes both calib_include_modules and calib_exclude_modules, the behavior is implicitly "include first, then exclude"? Do you think we need to either:

  • Documented explicitly (what happens if a module matches both?), or
  • Validated to raise an error if both are set simultaneously.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the ordering does not actually matter, but we do need to document a clear semantic. These 2 actually are 2 exclude module lists:

  1. exclude those that not in the incllude_modules
  2. exclude those that in the exclude_modeuls

)


class QuantizeAlgorithmConfig(ModeloptBaseConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a side note, I have been feeling like CalibrationConfig is a better name for this class

Comment on lines +1087 to +1108
calib_include_modules: list[str] | None = ModeloptField(
default=None,
title="Patterns of modules to include in calibration.",
description=(
"If provided, only modules whose names match at least one of the fnmatch patterns are "
"calibrated. Modules that do not match any pattern are skipped and retain their "
"pre-existing calibration state. "
"Note: filtering applies only to quantized linear modules; TensorQuantizers in "
"non-linear modules (e.g. layer norms, embeddings) are unaffected."
),
)

calib_exclude_modules: list[str] | None = ModeloptField(
default=None,
title="Patterns of modules to exclude from calibration.",
description=(
"If provided, modules whose names match at least one of the fnmatch patterns are "
"skipped during calibration and retain their pre-existing calibration state. "
"Note: filtering applies only to quantized linear modules; TensorQuantizers in "
"non-linear modules (e.g. layer norms, embeddings) are unaffected."
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the ordering does not actually matter, but we do need to document a clear semantic. These 2 actually are 2 exclude module lists:

  1. exclude those that not in the incllude_modules
  2. exclude those that in the exclude_modeuls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants