Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ NVIDIA Model Optimizer Changelog

- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow.
- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory.
- Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to all the experts.
- Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to None (not enabled).
- Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
- Add support for rotating the input before quantization for RHT.
- Add support for advanced weight scale search for NVFP4 quantization and its export path.
Expand Down
6 changes: 3 additions & 3 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,16 +1158,16 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--moe_calib_experts_ratio",
type=float,
default=1.0,
default=None,
help=(
"Fraction of experts to calibrate during forward pass (ratio in (0.0, 1.0]). "
"Only used for MOE models; used to reduce the number of experts calibrated during the forward pass."
"Only used for MOE models; used to reduce the number of experts calibrated during the forward pass. "
"Does not impact non-MOE models."
),
)

args = parser.parse_args()
if not (0.0 < args.moe_calib_experts_ratio <= 1.0):
if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0):
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")
return args

Expand Down
16 changes: 5 additions & 11 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,21 +1055,15 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):

moe_calib_experts_ratio: float | None = ModeloptField(
default=None,
gt=0.0,
le=1.0,
title="% of experts to calibrate during forward pass.",
description=(
"If specified, we force forward tokens to % of experts during the calibration"
" pass. This forward is for calibration purpose only and will not affect the"
" actual inference. Not supported for all MoE architectures; currently works"
" with a few HuggingFace models such as Mixtral, Qwen3Moe, MiniMax."
),
)

moe_count_expert_calib_tokens: bool = ModeloptField(
default=False,
title="Enable expert token counting during MoE calibration.",
description=(
"If True, counts how many tokens are routed to each expert during calibration."
" Not supported for all MoE architectures; currently works with a few HuggingFace"
" actual inference. NOTE: when set, ``layer_sync_moe_local_experts_amax`` is"
" disabled so each expert maintains its own calibration statistics. Not"
" supported for all MoE architectures; currently works with a few HuggingFace"
" models such as Mixtral, Qwen3Moe, MiniMax."
),
)
Expand Down
6 changes: 0 additions & 6 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,6 @@ def wrapped_calib_func(
if hasattr(module, "_moe_calib_experts_ratio"):
module._moe_calib_experts_ratio = moe_calib_experts_ratio

moe_count_expert_calib_tokens = kwargs.pop("moe_count_expert_calib_tokens", False)
if moe_count_expert_calib_tokens:
for module in model.modules():
if hasattr(module, "_moe_count_expert_calib_tokens"):
module._moe_count_expert_calib_tokens = True

if func is not None:
if sequential:
if forward_loop is None:
Expand Down
38 changes: 18 additions & 20 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,16 +444,15 @@ class _QuantSparseMoe(QuantModule):

Supports ``layer_sync_moe_local_experts_amax`` to sync input quantizer amax across experts.

Optionally supports two config-driven features (disabled by default):
Optionally supports config-driven features (disabled by default):
- ``_moe_calib_experts_ratio``: force-forward tokens to more experts during calibration.
- ``_moe_count_expert_calib_tokens``: count tokens routed to each expert during calibration.
When set to a value > 0, also enables token counting per expert.

When both are disabled, forward is a direct pass-through with zero overhead.
When disabled, forward is a direct pass-through with zero overhead.
"""
Comment on lines +447 to 452
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring inaccuracy: token counting is not enabled when ratio == 1.0

Line 449 states "When set to a value > 0, also enables token counting per expert," but line 512 only enables counting when ratio < 1.0. The docstring should clarify this edge case.

📝 Suggested docstring fix
     Optionally supports config-driven features (disabled by default):
     - ``_moe_calib_experts_ratio``: force-forward tokens to more experts during calibration.
-      When set to a value > 0, also enables token counting per expert.
+      When set to a value in (0, 1), also enables token counting per expert.
+      At ratio == 1.0, all experts are calibrated so counting is skipped.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 447 - 452,
Update the docstring to accurately reflect when token counting per expert is
enabled: the config key _moe_calib_experts_ratio does not enable per-expert
token counting for any value > 0 — counting is only activated when
_moe_calib_experts_ratio < 1.0 (the code path that checks ratio < 1.0 enables
counting), so change the sentence to state that token counting is enabled only
for ratios strictly less than 1.0 and clarify that ratio == 1.0 results in a
direct pass-through without counting.


def _setup(self):
self._moe_calib_experts_ratio = None
self._moe_count_expert_calib_tokens = False
self._token_counting_initialized = False

def _init_token_counting(self):
Expand Down Expand Up @@ -501,24 +500,18 @@ def _gate_forward_hook(self, module, input, output):
self.expert_token_count += counts.to(self.expert_token_count.device)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if not self._moe_calib_experts_ratio and not self._moe_count_expert_calib_tokens:
if self._moe_calib_experts_ratio is None:
return super().forward(hidden_states)

if self._moe_count_expert_calib_tokens and not self._token_counting_initialized:
self._init_token_counting()

is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules())
self._count_expert_tokens = is_calib and self._moe_count_expert_calib_tokens

# If any of the experts are in calibration mode, we will forward all tokens to
# self._moe_calib_experts_ratio % of the experts to improve the calibration coverage.
# This is used only for calibration, we need to re-calculate the actual outputs again using
# the original top_k
if is_calib and self._moe_calib_experts_ratio:
self._count_expert_tokens = True
assert 0 < self._moe_calib_experts_ratio <= 1, (
"moe_calib_experts_ratio must be between 0 and 1"
)

# During calibration, forward all tokens to a larger fraction of experts to improve
# calibration coverage, then re-run with the original top_k for actual outputs.
if is_calib:
# Skip counting when all experts are calibrated (ratio == 1.0).
self._count_expert_tokens = self._moe_calib_experts_ratio < 1.0
if self._count_expert_tokens and not self._token_counting_initialized:
self._init_token_counting()
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate") and hasattr(self.gate, "top_k")
original_top_k = self.gate.top_k
Expand Down Expand Up @@ -559,7 +552,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return output

def layer_sync_moe_local_experts_amax(self):
"""Sync input_quantizer amax across experts so all share the same amax per quantizer."""
"""Sync input_quantizer amax across experts so all share the same amax per quantizer.

Skipped when _moe_calib_experts_ratio is set, as each expert is calibrated independently.
"""
if self._moe_calib_experts_ratio is not None:
return
sync_moe_expert_amax(self.experts)


Expand Down
13 changes: 8 additions & 5 deletions tests/unit/torch/quantization/plugins/test_sparse_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def test_setup_config_knobs_default(self):

converted = QuantModuleRegistry.convert(moe_block)
assert converted._moe_calib_experts_ratio is None
assert converted._moe_count_expert_calib_tokens is False
assert not hasattr(converted, "expert_token_count")

def test_forward_default_config_passthrough(self):
Expand Down Expand Up @@ -259,17 +258,22 @@ def test_forward_calib_restores_top_k(self):
assert converted.top_k == original_top_k

def test_token_counting_lazy_init(self):
"""When moe_count_expert_calib_tokens is enabled, token counting infra is lazy-inited."""
"""When moe_calib_experts_ratio > 0, token counting infra is lazy-inited."""
model = get_tiny_qwen3_moe()
moe_block = self._get_moe_block(model)
if QuantModuleRegistry.get(type(moe_block)) is None:
register_sparse_moe_on_the_fly(model)

converted = QuantModuleRegistry.convert(moe_block)
converted._moe_count_expert_calib_tokens = True
converted._moe_calib_experts_ratio = 0.5

assert not hasattr(converted, "expert_token_count")

# Simulate calibration mode so lazy-init triggers during forward
# Set _if_calib on an expert sub-module (not set by default since only the MoE
# block was converted, not the full model).
next(converted.experts.modules())._if_calib = True

x = torch.randn(1, 4, 32)
with torch.no_grad():
converted(x)
Expand Down Expand Up @@ -305,8 +309,7 @@ def test_qwen3_moe_quantize_with_token_forcing_and_counting():
quant_cfg = copy.deepcopy(mtq.INT8_DEFAULT_CFG)
quant_cfg["algorithm"] = {
"method": "max",
"moe_calib_experts_ratio": 1.0,
"moe_count_expert_calib_tokens": True,
"moe_calib_experts_ratio": 0.5,
}

def calib_fn(model):
Expand Down
Loading