Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ NVIDIA Model Optimizer Changelog
- Enable PTQ workflow for Qwen3.5 MoE models.
- Enable PTQ workflow for the Kimi-K2.5 model.
- Add ``nvfp4_omlp_only`` quantization format for NVFP4 quantization. This is similar to ``nvfp4_mlp_only`` but also quantizes the output projection layer in attention.
- Add ``nvfp4_experts_only`` quantization config that targets only MoE routed expert layers (excluding shared) with NVFP4 quantization.
- ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy.
- Add :meth:`compute_quantization_mse <modelopt.torch.quantization.model_quant.compute_quantization_mse>` API to measure per-quantizer mean-squared quantization error, with flexible wildcard and callable filtering.
- **Autotune**: New tool for automated Q/DQ (Quantize/Dequantize) placement optimization for ONNX models. Uses TensorRT latency measurements to choose insertion schemes that minimize inference time. Discovers regions automatically, groups them by structural pattern, and tests multiple Q/DQ schemes per pattern. Supports INT8 and FP8 quantization, pattern cache for warm-start on similar models, checkpoint/resume, and importing patterns from an existing QDQ baseline. CLI: ``python -m modelopt.onnx.quantization.autotune``. See the Autotune guide in the documentation.
Expand Down
10 changes: 5 additions & 5 deletions examples/llm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def forward_loop(model):
model = mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop)
```

> *For higher NVFP4 PTQ accuracy, we recommend using `mtq.NVFP4_MLP_ONLY_CFG` or `mtq.NVFP4_OMLP_ONLY_CFG` instead of `mtq.NVFP4_DEFAULT_CFG`. `NVFP4_MLP_ONLY_CFG` applies NVFP4 quantization to MLP (and MoE) layers, leaving attention layers unquantized. `NVFP4_OMLP_ONLY_CFG` additionally quantizes the `o_proj` layer. Both preserve accuracy in the sensitive attention QKV projections while still providing significant compression.*
> *For higher NVFP4 PTQ accuracy, we recommend using `mtq.NVFP4_MLP_ONLY_CFG`, `mtq.NVFP4_EXPERTS_ONLY_CFG`, or `mtq.NVFP4_OMLP_ONLY_CFG` instead of `mtq.NVFP4_DEFAULT_CFG`. `NVFP4_MLP_ONLY_CFG` applies NVFP4 quantization to MLP (and MoE) layers, leaving attention layers unquantized. `NVFP4_EXPERTS_ONLY_CFG` quantizes only expert layers (`*mlp.experts*` and `*block_sparse_moe*`), useful for MoE models where dense MLP and attention stay in higher precision. `NVFP4_OMLP_ONLY_CFG` additionally quantizes the `o_proj` layer. All preserve accuracy in the sensitive attention QKV projections while still providing significant compression.*

### 2. Export Quantized Model

Expand Down Expand Up @@ -129,7 +129,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)* \
> *<sup>8.</sup>GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.*

> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead. For NVFP4 quantization specifically, we recommend `nvfp4_mlp_only` or `nvfp4_omlp_only` to achieve higher accuracy by restricting quantization to the MLP layers (and optionally the `o_proj` layer) while keeping the attention QKV projections unquantized.*
> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead. For NVFP4 quantization specifically, we recommend `nvfp4_mlp_only`, `nvfp4_experts_only`, or `nvfp4_omlp_only` to achieve higher accuracy by restricting quantization to the MLP/expert layers (and optionally the `o_proj` layer) while keeping the attention QKV projections unquantized.*

> You can also create your own custom config using [this](https://nvidia.github.io/Model-Optimizer/guides/_pytorch_quantization.html#custom-calibration-algorithm) guide.

Expand All @@ -147,7 +147,7 @@ For LLM models like [Llama-3](https://huggingface.co/meta-llama):
# Install model specific pip dependencies if needed

export HF_PATH=<the downloaded LLaMA checkpoint from the Hugging Face hub, or simply the model card>
scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|nvfp4_mlp_only|nvfp4_omlp_only|int8_sq|int4_awq|w4a8_awq] --tp [1|2|4|8]
scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|nvfp4_mlp_only|nvfp4_experts_only|nvfp4_omlp_only|int8_sq|int4_awq|w4a8_awq] --tp [1|2|4|8]
```

> *By default `trust_remote_code` is set to false. Please turn it on if model calibration and eval requires it using `--trust_remote_code`.*
Expand Down Expand Up @@ -298,7 +298,7 @@ accelerate launch --config_file fsdp2.yaml \
--fsdp_transformer_layer_cls_to_wrap=<decoder_layer_name>
multinode_ptq.py \
--pyt_ckpt_path <path_to_model> \
--qformat <fp8/nvfp4/nvfp4_mlp_only/nvfp4_omlp_only/nvfp4_awq/int8> \
--qformat <fp8/nvfp4/nvfp4_mlp_only/nvfp4_experts_only/nvfp4_omlp_only/nvfp4_awq/int8> \
--kv_cache_qformat <fp8/nvfp4/nvfp4_affine/none> \
--batch_size <calib_batch_size> \
--calib_size <num_calib_samples> \
Expand Down Expand Up @@ -463,4 +463,4 @@ There are many quantization schemes supported in the example scripts:

1. The W4A8 AWQ is an extension of the INT4 AWQ quantization that it also uses FP8 for activation for more speed up and acceleration.

1. The [NVFP4](https://blogs.nvidia.com/blog/generative-ai-studio-ces-geforce-rtx-50-series/) is one of the new FP4 formats supported by NVIDIA Blackwell GPU and demonstrates good accuracy compared with other 4-bit alternatives. NVFP4 can be applied to both model weights as well as activations, providing the potential for both a significant increase in math throughput and reductions in memory footprint and memory bandwidth usage compared to the FP8 data format on Blackwell. For higher accuracy with NVFP4 PTQ, we recommend `nvfp4_mlp_only` or `nvfp4_omlp_only`. `nvfp4_mlp_only` restricts NVFP4 quantization to MLP (and MoE) layers only, leaving attention layers in higher precision. `nvfp4_omlp_only` extends this by also quantizing the `o_proj` layer, providing a middle ground between full NVFP4 and MLP-only quantization.
1. The [NVFP4](https://blogs.nvidia.com/blog/generative-ai-studio-ces-geforce-rtx-50-series/) is one of the new FP4 formats supported by NVIDIA Blackwell GPU and demonstrates good accuracy compared with other 4-bit alternatives. NVFP4 can be applied to both model weights as well as activations, providing the potential for both a significant increase in math throughput and reductions in memory footprint and memory bandwidth usage compared to the FP8 data format on Blackwell. For higher accuracy with NVFP4 PTQ, we recommend `nvfp4_mlp_only`, `nvfp4_experts_only`, or `nvfp4_omlp_only`. `nvfp4_mlp_only` restricts NVFP4 quantization to MLP (and MoE) layers only, leaving attention layers in higher precision. `nvfp4_experts_only` quantizes only expert layers (`*mlp.experts*` and `*block_sparse_moe*`), ideal for MoE models. `nvfp4_omlp_only` extends MLP-only by also quantizing the `o_proj` layer, providing a middle ground between full NVFP4 and MLP-only quantization.
2 changes: 2 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _set_kv_cache_constant_amax(quant_cfg: dict) -> None:
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG,
"nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
Expand Down Expand Up @@ -275,6 +276,7 @@ def auto_quantize(
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"mxfp8",
]
Expand Down
2 changes: 2 additions & 0 deletions examples/llm_ptq/multinode_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG,
"nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG,
}

KV_QUANT_CFG_CHOICES = {
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_experts_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_experts_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8]" >&2
exit 1
;;
esac
Expand Down
124 changes: 39 additions & 85 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,32 @@
"enable": True,
}

NVFP4_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": _nvfp4_quantizer,
"*input_quantizer": _nvfp4_quantizer,
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
_nvfp4_quantizer_bs32 = {
"num_bits": (2, 1),
"block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)},
"enable": True,
}


def _nvfp4_selective_quant_cfg(
layer_patterns: list[str],
*,
quantizer: dict = _nvfp4_quantizer,
weight_only: bool = False,
algorithm: str | dict = "max",
) -> dict:
"""Build an NVFP4 config that quantizes only the specified layer patterns."""
quant_cfg: dict[str, object] = {}
for pattern in layer_patterns:
quant_cfg[f"{pattern}weight_quantizer"] = quantizer
if not weight_only:
quant_cfg[f"{pattern}input_quantizer"] = quantizer
quant_cfg.update(_default_disabled_quantizer_cfg)
return {"quant_cfg": quant_cfg, "algorithm": algorithm}


NVFP4_DEFAULT_CFG = _nvfp4_selective_quant_cfg(["*"])

NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = {
"quant_cfg": {
"*weight_quantizer": {
Expand Down Expand Up @@ -481,32 +498,13 @@
}


NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": _nvfp4_quantizer,
"*input_quantizer": _nvfp4_quantizer,
**_default_disabled_quantizer_cfg,
},
"algorithm": "awq_lite",
}
NVFP4_AWQ_LITE_CFG = _nvfp4_selective_quant_cfg(["*"], algorithm="awq_lite")

NVFP4_AWQ_CLIP_CFG = {
"quant_cfg": {
"*weight_quantizer": _nvfp4_quantizer,
"*input_quantizer": _nvfp4_quantizer,
**_default_disabled_quantizer_cfg,
},
"algorithm": {"method": "awq_clip"},
}
NVFP4_AWQ_CLIP_CFG = _nvfp4_selective_quant_cfg(["*"], algorithm={"method": "awq_clip"})

NVFP4_AWQ_FULL_CFG = {
"quant_cfg": {
"*weight_quantizer": _nvfp4_quantizer,
"*input_quantizer": _nvfp4_quantizer,
**_default_disabled_quantizer_cfg,
},
"algorithm": {"method": "awq_full", "alpha_step": 0.1},
}
NVFP4_AWQ_FULL_CFG = _nvfp4_selective_quant_cfg(
["*"], algorithm={"method": "awq_full", "alpha_step": 0.1}
)


NVFP4_AFFINE_KV_CFG = {
Expand Down Expand Up @@ -569,14 +567,9 @@
"algorithm": "max",
}

NVFP4_SVDQUANT_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": _nvfp4_quantizer,
"*input_quantizer": _nvfp4_quantizer,
**_default_disabled_quantizer_cfg,
},
"algorithm": {"method": "svdquant", "lowrank": 32},
}
NVFP4_SVDQUANT_DEFAULT_CFG = _nvfp4_selective_quant_cfg(
["*"], algorithm={"method": "svdquant", "lowrank": 32}
)

W4A8_NVFP4_FP8_CFG = {
"quant_cfg": {
Expand Down Expand Up @@ -611,52 +604,12 @@
"algorithm": None,
}

NVFP4_MLP_WEIGHT_ONLY_CFG = {
"quant_cfg": {
"*mlp*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {
-1: 32,
"type": "dynamic",
"scale_bits": (4, 3),
}, # Note: block_size is 32 here
"enable": True,
},
"*block_sparse_moe*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {
-1: 32,
"type": "dynamic",
"scale_bits": (4, 3),
}, # Note: block_size is 32 here
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}

_nvfp4_mlp_only_quant_cfg = {
"*mlp*weight_quantizer": _nvfp4_quantizer,
"*mlp*input_quantizer": _nvfp4_quantizer,
"*block_sparse_moe*weight_quantizer": _nvfp4_quantizer,
"*block_sparse_moe*input_quantizer": _nvfp4_quantizer,
**_default_disabled_quantizer_cfg,
}

NVFP4_MLP_ONLY_CFG = {
"quant_cfg": _nvfp4_mlp_only_quant_cfg,
"algorithm": "max",
}

NVFP4_OMLP_ONLY_CFG = {
"quant_cfg": {
"*o_proj*weight_quantizer": _nvfp4_quantizer,
"*o_proj*input_quantizer": _nvfp4_quantizer,
**_nvfp4_mlp_only_quant_cfg,
},
"algorithm": "max",
}
NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg(
["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_quantizer_bs32, weight_only=True
)
NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp.experts*", "*block_sparse_moe*"])
NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*"])
NVFP4_OMLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*o_proj*", "*mlp*", "*block_sparse_moe*"])

# DO NOT ADD NEW CONFIGS HERE. If you want to add a new general recipe, add it to
# modelopt_recipes/general/ptq/ as a yaml file
Expand Down Expand Up @@ -689,6 +642,7 @@
"NVFP4_MLP_WEIGHT_ONLY_CFG",
"MXFP4_MLP_WEIGHT_ONLY_CFG",
"NVFP4_MLP_ONLY_CFG",
"NVFP4_EXPERTS_ONLY_CFG",
"NVFP4_OMLP_ONLY_CFG",
"MAMBA_MOE_NVFP4_CONSERVATIVE_CFG",
"MAMBA_MOE_NVFP4_AGGRESSIVE_CFG",
Expand Down
86 changes: 86 additions & 0 deletions modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

metadata:
recipe_type: ptq
description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max calibration.
ptq_cfg:
algorithm: max
quant_cfg:
'*mlp.experts*weight_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
'*mlp.experts*input_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
'*block_sparse_moe*weight_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
'*block_sparse_moe*input_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
default:
enable: false
'*block_sparse_moe.gate*':
enable: false
'*linear_attn.conv1d*':
enable: false
'*lm_head*':
enable: false
'*mixer.conv1d*':
enable: false
'*mlp.gate.*':
enable: false
'*mlp.shared_expert_gate.*':
enable: false
'*output_layer*':
enable: false
'*proj_out.*':
enable: false
'*router*':
enable: false
output.*:
enable: false
nn.BatchNorm1d:
'*':
enable: false
nn.BatchNorm2d:
'*':
enable: false
nn.BatchNorm3d:
'*':
enable: false
nn.LeakyReLU:
'*':
enable: false
'*[kv]_bmm_quantizer':
num_bits: e4m3
enable: true
Loading