Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead.
- Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint.
- Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details on its usage.
- Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model.

0.41 (2026-01-19)
^^^^^^^^^^^^^^^^^
Expand Down
38 changes: 36 additions & 2 deletions modelopt/onnx/autocast/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from modelopt.onnx.autocast.nodeclassifier import NodeClassifier, NodeRuleBase
from modelopt.onnx.autocast.precisionconverter import PrecisionConverter
from modelopt.onnx.autocast.referencerunner import ReferenceRunner
from modelopt.onnx.utils import get_min_opset_for_precisions, get_qdq_precisions

"""
FP16 accuracy decreases in accordance with the data's magnitude.
Expand Down Expand Up @@ -202,6 +203,7 @@ def convert_to_f16(
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
trt_plugins: list[str] | None = [],
use_standalone_type_inference: bool = False,
opset: int | None = None,
) -> onnx.ModelProto:
"""Convert model to mixed precision, using PrecisionConverter.

Expand All @@ -217,13 +219,45 @@ def convert_to_f16(
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
infer_shapes. This is a workaround (WAR) when only type inference is
needed without shape inference. Default: False.
opset: Target ONNX opset version. If None, uses default minimum opset based on precision type
(22 for bf16, 13 for fp16) and Q/DQ node requirements. The opset may be automatically
increased if Q/DQ nodes in the model require a higher version (e.g., FP8 requires 19,
INT4 requires 21, NVFP4 requires 23).
"""
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"

# Opset 21 is needed for NVFP4 quantization support (DQ with 'block_size' attribute)
# Check Q/DQ precision types in the model and determine required opset
qdq_precisions = get_qdq_precisions(model)
qdq_min_opset = get_min_opset_for_precisions(qdq_precisions)

# Base minimum opset for FP16/BF16 conversion
# Opset 19 is the first to support fp16 scales in Q/DQ nodes
base_min_opset = 22 if low_precision_type == "bf16" else 19
Comment on lines +222 to +235
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring mismatch: FP16 minimum opset is now 19, not 13.
The implementation enforces base_min_opset = 19 for fp16, but the docstring still says 13. Please align the docstring.

✏️ Suggested doc fix
-        opset: Target ONNX opset version. If None, uses default minimum opset based on precision type
-               (22 for bf16, 13 for fp16) and Q/DQ node requirements. The opset may be automatically
+        opset: Target ONNX opset version. If None, uses default minimum opset based on precision type
+               (22 for bf16, 19 for fp16) and Q/DQ node requirements. The opset may be automatically
                increased if Q/DQ nodes in the model require a higher version (e.g., FP8 requires 19,
                INT4 requires 21, NVFP4 requires 23).
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
opset: Target ONNX opset version. If None, uses default minimum opset based on precision type
(22 for bf16, 13 for fp16) and Q/DQ node requirements. The opset may be automatically
increased if Q/DQ nodes in the model require a higher version (e.g., FP8 requires 19,
INT4 requires 21, NVFP4 requires 23).
"""
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"
# Opset 21 is needed for NVFP4 quantization support (DQ with 'block_size' attribute)
# Check Q/DQ precision types in the model and determine required opset
qdq_precisions = get_qdq_precisions(model)
qdq_min_opset = get_min_opset_for_precisions(qdq_precisions)
# Base minimum opset for FP16/BF16 conversion
# Opset 19 is the first to support fp16 scales in Q/DQ nodes
base_min_opset = 22 if low_precision_type == "bf16" else 19
opset: Target ONNX opset version. If None, uses default minimum opset based on precision type
(22 for bf16, 19 for fp16) and Q/DQ node requirements. The opset may be automatically
increased if Q/DQ nodes in the model require a higher version (e.g., FP8 requires 19,
INT4 requires 21, NVFP4 requires 23).
"""
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"
# Check Q/DQ precision types in the model and determine required opset
qdq_precisions = get_qdq_precisions(model)
qdq_min_opset = get_min_opset_for_precisions(qdq_precisions)
# Base minimum opset for FP16/BF16 conversion
# Opset 19 is the first to support fp16 scales in Q/DQ nodes
base_min_opset = 22 if low_precision_type == "bf16" else 19
🤖 Prompt for AI Agents
In `@modelopt/onnx/autocast/convert.py` around lines 222 - 235, The docstring for
the conversion routine is out of sync with the implementation:
low_precision_type uses a base_min_opset of 19 for "fp16" (see base_min_opset
and low_precision_type in convert.py) but the docstring still claims 13; update
the docstring to state that the default minimum opset for fp16 is 19 (and bf16
is 22) and keep the note that Q/DQ nodes may require increasing the opset (e.g.,
FP8/INT4/NVFP4) so the documentation matches the logic in get_qdq_precisions,
get_min_opset_for_precisions and the base_min_opset assignment.


# Determine target opset version
if opset is not None:
min_opset = opset
# Check if Q/DQ nodes require a higher opset
if qdq_precisions and qdq_min_opset > min_opset:
logger.warning(
f"Model contains Q/DQ nodes with precisions {qdq_precisions} that require "
f"opset >= {qdq_min_opset}. Upgrading from specified opset {opset} to {qdq_min_opset}."
)
min_opset = qdq_min_opset
# Also ensure we meet base minimum for precision type
if min_opset < base_min_opset:
logger.warning(
f"Opset {min_opset} is below minimum opset {base_min_opset} for {low_precision_type}. "
f"Upgrading to opset {base_min_opset}."
)
min_opset = base_min_opset
else:
# Use the highest required opset between base and Q/DQ requirements
min_opset = max(base_min_opset, qdq_min_opset)

sanitizer = GraphSanitizer(
model,
min_opset=21,
min_opset=min_opset,
trt_plugins=trt_plugins,
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
)
Expand Down
10 changes: 10 additions & 0 deletions modelopt/onnx/quantization/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ def get_parser() -> argparse.ArgumentParser:
"The currently supported precisions are {fp16, int8, fp8}."
),
)
argparser.add_argument(
"--opset",
type=int,
help=(
"Target ONNX opset version for the quantized model. If not specified, uses default minimum opset "
"(19 for fp16 scales support, 21 for int4, 23 for nvfp4). The opset may be automatically increased "
"if certain operations require a higher version."
),
)
Comment on lines +289 to +297
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clarify BF16 minimum opset in --opset help text.
--high_precision_dtype bf16 can require opset 22; the help message lists fp16/int4/nvfp4 only. Adding bf16 avoids confusion.

✏️ Suggested doc tweak
-            "Target ONNX opset version for the quantized model. If not specified, uses default minimum opset "
-            "(19 for fp16 scales support, 21 for int4, 23 for nvfp4). The opset may be automatically increased "
+            "Target ONNX opset version for the quantized model. If not specified, uses default minimum opset "
+            "(19 for fp16 scales support, 22 for bf16, 21 for int4, 23 for nvfp4). The opset may be automatically increased "
             "if certain operations require a higher version."

Also applies to: 364-364

🤖 Prompt for AI Agents
In `@modelopt/onnx/quantization/__main__.py` around lines 289 - 297, Update the
help text for the "--opset" argument added via argparser.add_argument to include
bf16 minimum opset info: mention that --high_precision_dtype bf16 may require
opset 22 (in addition to the existing note about 19 for fp16, 21 for int4, 23
for nvfp4), and make the same change to the other duplicate "--opset" help
string found elsewhere in this module; ensure the message clearly states that
opset may be automatically increased if required by operations.

return argparser


Expand Down Expand Up @@ -352,6 +361,7 @@ def main():
simplify=args.simplify,
calibrate_per_node=args.calibrate_per_node,
direct_io_types=args.direct_io_types,
opset=args.opset,
)


Expand Down
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def quantize(
calibrate_per_node: bool = False,
custom_ops_to_quantize: list[str] = [],
direct_io_types: bool = False,
opset: int | None = None,
**kwargs,
) -> onnx.ModelProto:
"""Applies FP8 GEMM only quantization to an ONNX file.
Expand Down Expand Up @@ -328,6 +329,7 @@ def quantize(
tensor_block_dict=custom_ops_to_cast_fp32 or {},
low_precision_type=high_precision_dtype,
trt_plugins=trt_extra_plugin_lib_paths,
opset=opset,
)

current_opsets = {opset.domain: opset.version for opset in onnx_model.opset_import}
Expand Down
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def quantize(
calibrate_per_node: bool = False,
custom_ops_to_quantize: list[str] = [],
direct_io_types: bool = False,
opset: int | None = None,
**kwargs,
) -> onnx.ModelProto:
"""Applies INT8 quantization to an ONNX file using the compiler friendly heuristics.
Expand Down Expand Up @@ -289,6 +290,7 @@ def quantize(
tensor_block_dict=custom_ops_to_cast_fp32 or {},
low_precision_type=high_precision_dtype,
trt_plugins=trt_extra_plugin_lib_paths,
opset=opset,
)

if nodes_to_quantize:
Expand Down
53 changes: 45 additions & 8 deletions modelopt/onnx/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
)
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model
from modelopt.onnx.utils import (
BASE_MIN_OPSET,
QDQ_PRECISION_MIN_OPSET,
duplicate_shared_constants,
get_opset_version,
name_onnx_nodes,
Expand All @@ -88,6 +90,7 @@ def _preprocess_onnx(
override_shapes: str,
simplify: bool = False,
quantize_mode: str = "int8",
opset: int | None = None,
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]:
logger.info(f"Preprocessing the model {onnx_path}")
intermediate_generated_files = []
Expand Down Expand Up @@ -118,16 +121,43 @@ def _preprocess_onnx(
" '--trt_plugins' flag (requires TRT 10+)."
)

# Per-Channel support with QDQ format requires onnx opset version 13 or above
opset_version = get_opset_version(onnx_model)
# Opset 19 is the minimum required for fp16 scales in Q/DQ nodes
# Higher opsets required for specific quantization modes (int4: 21, nvfp4: 23)
original_opset_version = get_opset_version(onnx_model)

required_opset_version = 13
if opset_version < required_opset_version and opset_version != 1:
opset_version = required_opset_version
onnx_model = onnx.version_converter.convert_version(onnx_model, opset_version)
onnx_path = os.path.join(output_dir, f"{model_name}_opset{opset_version}.onnx")
# Determine minimum required opset based on quantization mode
mode_min_opset = QDQ_PRECISION_MIN_OPSET.get(quantize_mode, BASE_MIN_OPSET)

Comment on lines +124 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l modelopt/onnx/quantization/quantize.py

Repository: NVIDIA/Model-Optimizer

Length of output: 108


🏁 Script executed:

cat -n modelopt/onnx/quantization/quantize.py | sed -n '110,150p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2526


🏁 Script executed:

rg -n "QDQ_PRECISION_MIN_OPSET|BASE_MIN_OPSET" modelopt/onnx/quantization/ -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 719


🏁 Script executed:

rg -n '"int4"|"nvfp4"|"fp4"' modelopt/onnx/quantization/ -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 798


🏁 Script executed:

rg -n 'quantize_mode.*"int4"|"int4".*in.*quantize_mode' modelopt/onnx/ -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 415


🏁 Script executed:

rg -n "QDQ_PRECISION_MIN_OPSET\s*=" modelopt/onnx/ -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 657


🏁 Script executed:

rg -n "int4_awq|int4_gptq|int4_woq|nvfp4" modelopt/onnx/

Repository: NVIDIA/Model-Optimizer

Length of output: 1301


🏁 Script executed:

rg -n "quantize_mode" modelopt/onnx/quantization/__init__.py modelopt/onnx/quantization/__main__.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1216


🏁 Script executed:

rg -n "def quantize\(" modelopt/onnx/quantization/quantize.py -A 15

Repository: NVIDIA/Model-Optimizer

Length of output: 816


🏁 Script executed:

rg -n "nvfp4_exporter|nvfp4" modelopt/onnx/ -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 4595


🏁 Script executed:

rg -n "int4_awq" modelopt/onnx/quantization/quantize.py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cat -n modelopt/onnx/quantization/quantize.py | sed -n '500,560p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3232


Normalize quantize_mode before min‑opset lookup.

The code at line 523 uses substring matching (elif "int4" in quantize_mode), implying variant strings are supported. However, the min‑opset lookup at line 129 uses exact key matching, so variant values like int4_awq will incorrectly fall back to BASE_MIN_OPSET (19) instead of the required 21. Additionally, the documentation references opset 23 for nvfp4, but that key doesn't exist in QDQ_PRECISION_MIN_OPSET; only float4_e2m1fn is present. This creates a mismatch where int4 variants could proceed with insufficient opset. Normalize quantize_mode before lookup to handle variants consistently.

🔧 Suggested fix
-    mode_min_opset = QDQ_PRECISION_MIN_OPSET.get(quantize_mode, BASE_MIN_OPSET)
+    if "int4" in quantize_mode:
+        mode_min_opset = QDQ_PRECISION_MIN_OPSET["int4"]
+    elif quantize_mode in QDQ_PRECISION_MIN_OPSET:
+        mode_min_opset = QDQ_PRECISION_MIN_OPSET[quantize_mode]
+    else:
+        mode_min_opset = BASE_MIN_OPSET
🤖 Prompt for AI Agents
In `@modelopt/onnx/quantization/quantize.py` around lines 124 - 130, The min-opset
lookup uses exact keys so variants like "int4_awq" fall back to BASE_MIN_OPSET;
normalize quantize_mode before querying QDQ_PRECISION_MIN_OPSET: compute a
normalized_mode (e.g., if "int4" in quantize_mode -> "int4", if "nvfp4" in
quantize_mode or "float4" variant -> "float4_e2m1fn", etc.), then use
QDQ_PRECISION_MIN_OPSET.get(normalized_mode, BASE_MIN_OPSET) when setting
mode_min_opset; update references in quantize.py that use quantize_mode
(including the substring checks and the get_opset_version flow) to use the
normalized value so variants resolve to the correct minimum opset.

# Determine target opset version
if opset is not None:
target_opset = opset
# Warn if user-specified opset is below mode minimum (but still respect it)
if opset < mode_min_opset:
logger.warning(
f"Opset {opset} is below the minimum opset {mode_min_opset} required for "
f"{quantize_mode} quantization. Upgrading to opset {mode_min_opset}."
)
target_opset = mode_min_opset
# Warn if user-specified opset is lower than original
if opset < original_opset_version:
logger.warning(
f"Specified opset {opset} is lower than the original model's opset {original_opset_version}. "
f"Using original model's opset {original_opset_version}."
)
target_opset = max(target_opset, original_opset_version)
else:
# Use model's opset if it's >= mode_min_opset, otherwise upgrade to mode_min_opset
target_opset = (
max(original_opset_version, mode_min_opset)
if original_opset_version != 1
else mode_min_opset
)

if original_opset_version < target_opset and original_opset_version != 1:
onnx_model = onnx.version_converter.convert_version(onnx_model, target_opset)
onnx_path = os.path.join(output_dir, f"{model_name}_opset{target_opset}.onnx")
save_onnx(onnx_model, onnx_path, use_external_data_format)
logger.info(f"Model is cloned to {onnx_path} with opset_version {opset_version}")
logger.info(f"Model is cloned to {onnx_path} with opset_version {target_opset}")
intermediate_generated_files.append(onnx_path)

# Simplify model if requested
Expand Down Expand Up @@ -231,6 +261,7 @@ def quantize(
calibrate_per_node: bool = False,
input_shapes_profile: Sequence[dict[str, str]] | None = None,
direct_io_types: bool = False,
opset: int | None = None,
**kwargs: Any,
) -> None:
"""Quantizes the provided ONNX model.
Expand Down Expand Up @@ -350,6 +381,10 @@ def quantize(
direct_io_types:
If True, modify the I/O types in the quantized ONNX model to be lower precision whenever possible.
If False, keep the I/O types in the quantized ONNX model the same as in the given ONNX model.
opset:
Target ONNX opset version for the quantized model. If None, uses required minimum opset
(19 for int8/fp8, 21 for int4, 23 for nvfp4). If the specified opset is lower than the required minimum,
a warning will be issued and the opset will be upgraded to the required minimum.
kwargs:
Additional keyword arguments for int4 quantization, including:
- awqlite_alpha_step (float): Alpha step for lite, range [0, 1].
Expand Down Expand Up @@ -420,6 +455,7 @@ def quantize(
override_shapes, # type: ignore[arg-type]
simplify,
quantize_mode,
opset,
)
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]

Expand Down Expand Up @@ -481,6 +517,7 @@ def quantize(
calibrate_per_node=calibrate_per_node,
custom_ops_to_quantize=list(custom_ops_to_quantize.keys()),
direct_io_types=direct_io_types,
opset=opset,
**kwargs,
)
elif "int4" in quantize_mode:
Expand Down
64 changes: 64 additions & 0 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

from modelopt.onnx.logging_config import logger

# Base minimum opset for quantization (opset 19 is the first to support fp16 scales)
BASE_MIN_OPSET = 19


def get_input_names_from_bytes(model_bytes: bytes, external_inputs_only: bool = True) -> list[str]:
"""This function returns the inputs names of the given onnx model in bytes.
Expand Down Expand Up @@ -696,6 +699,67 @@ def get_opset_version(model: onnx.ModelProto) -> int:
return ai_onnx_domain[0].version


def get_qdq_precisions(model: onnx.ModelProto) -> set:
"""Gets the Q/DQ precision types present in the model.

Args:
model: Loaded in-memory onnx ModelProto.

Returns:
set: Set of Q/DQ precision types present in the model (e.g., 'float8_e4m3fn', 'int8',
'int4', 'float4_e2m1fn').
"""
graph = gs.import_onnx(model)
precisions = set()

# Check for custom 'NVFP4' nodes
custom_fp4_q_nodes = [node for node in graph.nodes if node.op == "TRT_FP4DynamicQuantize"]
if custom_fp4_q_nodes:
precisions.add("float4_e2m1fn")

# Check for precision in DQ nodes
dq_nodes = [node for node in graph.nodes if node.op == "DequantizeLinear"]
for dq_node in dq_nodes:
if len(dq_node.inputs) >= 3 and dq_node.inputs[2] is not None:
# If zero-point is set, return that as the quantization mode
if isinstance(dq_node.inputs[2], Constant) and dq_node.inputs[2].values is not None:
precisions.add(dq_node.inputs[2].values.dtype.name)
elif isinstance(dq_node.inputs[0], Constant) and dq_node.inputs[0].values is not None:
# Else, return the node's input precision (ex: 'NVFP4' weight quantization)
precisions.add(dq_node.inputs[0].values.dtype.name)

return precisions
Comment on lines +702 to +731
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Locate call sites and current precision inference usage.
rg -n "get_qdq_precisions" -S
rg -n "QuantizeLinear" modelopt/onnx -S

Repository: NVIDIA/Model-Optimizer

Length of output: 5270


🏁 Script executed:

#!/bin/bash
# Check how get_qdq_precisions result is used
sed -n '220,250p' modelopt/onnx/autocast/convert.py

# Check QuantizeLinear output_dtype handling
sed -n '915,935p' modelopt/onnx/utils.py

# Search for function that uses precisions from get_qdq_precisions
rg -n "get_min_opset_for_precisions" -A 10 modelopt/onnx

Repository: NVIDIA/Model-Optimizer

Length of output: 5310


🏁 Script executed:

#!/bin/bash
# Read the complete get_qdq_precisions function
sed -n '702,745p' modelopt/onnx/utils.py

# Check if there are any tests for this function
find . -name "*.py" -type f | xargs grep -l "test.*get_qdq_precisions\|get_qdq_precisions.*test" 2>/dev/null | head -5

# Search for QuantizeLinear node processing with output_dtype
rg -n "output_dtype" modelopt/onnx -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 4593


🏁 Script executed:

#!/bin/bash
# Check if there are tests showing Q/DQ precision detection
find . -name "*test*.py" -type f | xargs grep -l "get_qdq_precisions\|QuantizeLinear.*output_dtype" 2>/dev/null

# Search for any models or test data with QuantizeLinear nodes
find . -name "*.onnx" -o -name "*quantize*test*" 2>/dev/null | head -10

# Check if Variable nodes are used in Q/DQ paths
rg -n "Variable.*QuantizeLinear|QuantizeLinear.*Variable" modelopt/onnx -B 2 -A 2

# Look for examples of Q/DQ with output_dtype attribute
rg -n "output_dtype.*int4|output_dtype.*float8|int4.*output_dtype|float8.*output_dtype" modelopt/onnx -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

#!/bin/bash
# Search more carefully for QuantizeLinear node handling
rg -n "for.*node.*graph" modelopt/onnx/utils.py -A 5 | grep -A 5 "QuantizeLinear"

# Check what types of inputs QuantizeLinear nodes typically have
rg -n "node.op.*QuantizeLinear\|node.op_type.*QuantizeLinear" modelopt/onnx -B 3 -A 3 | head -40

# Look at the quantization module to understand Q/DQ structure
cat modelopt/onnx/quantization/fp8.py | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 6264


🏁 Script executed:

#!/bin/bash
# Verify the actual impact by checking if output_dtype is ever set on QuantizeLinear
rg -n "TRT_INT4QuantizeLinear\|output_dtype.*=\|output_dtype.*:" modelopt/onnx/quantization -B 2 -A 2 | head -50

# Check if there are any activation quantization examples
rg -n "activation.*quantize\|activation.*Q/DQ" modelopt/onnx -i | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 48


get_qdq_precisions misses QuantizeLinear nodes and non-constant Q/DQ paths.

The function only checks DequantizeLinear nodes and only when inputs/zero-points are Constant types. It never processes QuantizeLinear nodes or handles Variable-typed inputs, which are typical for activations. This causes activation Q/DQ and weight QuantizeLinear with non-constant parameters to be under-reported. Since this set is used for opset version selection (e.g., int4 requires opset 21, float4_e2m1fn requires opset 23), insufficient precision detection could result in incompatible opset versions being selected. Add checks for QuantizeLinear.output_dtype attribute and consider inspecting Variables or value_info types to cover non-constant inputs.

🤖 Prompt for AI Agents
In `@modelopt/onnx/utils.py` around lines 702 - 731, get_qdq_precisions currently
only inspects DequantizeLinear nodes with Constant inputs and misses
QuantizeLinear nodes and non-constant/Variable paths (activations), causing
under-reporting of precisions; update get_qdq_precisions to also iterate
QuantizeLinear nodes and extract precision from their output_dtype attribute
where present, and for both QuantizeLinear and DequantizeLinear handle Variable
inputs by resolving the tensor type via the graph/model value_info or
node.output type (e.g., check graph.value_info / model.graph.value_info /
model.graph.input/output types for the corresponding tensor and use its
elem_type/name), while still keeping the existing Constant-path logic
(Constant.values.dtype.name) and preserving detection of custom nodes like
TRT_FP4DynamicQuantize.



# Minimum opset requirements by quantization mode/precision
# Base minimum is 19 (first opset that allows fp16 scales in Q/DQ nodes)
# Supports both quantize modes (e.g., "fp8") and dtype prefixes (e.g., "float8" for "float8_e4m3fn")
QDQ_PRECISION_MIN_OPSET = {
"int8": BASE_MIN_OPSET,
"float8_e4m3fn": BASE_MIN_OPSET,
"int4": 21,
"uint4": 21,
"float4_e2m1fn": 23,
}


def get_min_opset_for_precisions(precisions: set) -> int:
"""Gets the minimum required opset version for a set of Q/DQ precision types.

Args:
precisions: Set of precision type strings (e.g., 'float8_e4m3fn', 'int4').

Returns:
int: Minimum required opset version for the given precisions.
"""
min_opset = BASE_MIN_OPSET # Base minimum for fp16 scales support
for precision in precisions:
# Direct lookup first
if precision in QDQ_PRECISION_MIN_OPSET:
min_opset = max(min_opset, QDQ_PRECISION_MIN_OPSET[precision])
return min_opset


def bfloat16_to_float32(bf16_array):
"""Converts a bfloat16 array (as raw data) to a float32 array."""
uint32_array = bf16_array.astype(np.uint32) << 16
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
"onnx-graphsurgeon",
"onnx~=1.19.0",
"onnxconverter-common~=1.16.0",
"onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'",
"onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin'",
"onnxruntime~=1.23.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'",
"onnxruntime-gpu~=1.23.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin'",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave windows version unchanged as they saw some regressions. cc @hthadicherla

Suggested change
"onnxruntime-gpu~=1.23.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin'",
"onnxruntime-gpu~=1.23.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501
"onnxruntime-gpu==1.22.0; platform_system == 'Windows'",

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the regression was initially observed by @ajrasane . He saw that quantizing with latest ort was causing accuracy degradations in some vision models in Linux . When I tested these models later , I found the exact same regressions in windows.

Better to leave it 1.22 in setup.py. In LLM quantization examples, we reinstall the latest ort version, by having ort==1.23 in requirements.txt.

"onnxscript", # For autocast opset conversion and test_onnx_dynamo_export unit test
"onnxslim>=0.1.76",
"polygraphy>=0.49.22",
Expand Down
Loading
Loading