Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions modelopt/onnx/export/fp8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch
from onnx_graphsurgeon.ir.tensor import LazyValues

from modelopt.onnx.utils import is_fp8_constant

from .base_exporter import ONNXQuantExporter


Expand Down Expand Up @@ -61,37 +63,46 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
graph.cleanup().toposort().fold_constants().cleanup()

for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
# Should not remove input QDQ
if not isinstance(node.inputs[0], gs.Constant):
continue

weights = node.inputs[0]
scale = node.inputs[1]
torch_weights = torch.from_numpy(weights.values)
torch_scale = torch.from_numpy(scale.values)
quantizer_name = scale.name.rsplit("/", 1)[0]
dq_op = node.outputs[0].outputs[0]
assert dq_op.op == "TRT_FP8DequantizeLinear", (
f"QDQ does not occur in pairs. You reached {dq_op.op}"
)

# Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
numpy_weights = (
(torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
)
tensor = onnx.TensorProto()
tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
tensor.dims.extend(numpy_weights.shape)
tensor.raw_data = numpy_weights.tobytes()
values = LazyValues(tensor)
onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values)

node.outputs.clear()
# DQ Op is separated out
dq_op.inputs[0] = onnx_weights_fp8
dq_op.op = "DequantizeLinear"
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype
is_trt_fp8_q = node.op == "TRT_FP8QuantizeLinear"
is_std_fp8_q = (
node.op == "QuantizeLinear"
and len(node.inputs) >= 3
and isinstance(node.inputs[2], gs.Constant)
and is_fp8_constant(node.inputs[2])
)
if not (is_trt_fp8_q or is_std_fp8_q):
continue

# Should not remove input QDQ
if not isinstance(node.inputs[0], gs.Constant):
continue

weights = node.inputs[0]
scale = node.inputs[1]
torch_weights = torch.from_numpy(weights.values)
torch_scale = torch.from_numpy(scale.values)
quantizer_name = scale.name.rsplit("/", 1)[0]
dq_op = node.outputs[0].outputs[0]
assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), (
f"QDQ does not occur in pairs. You reached {dq_op.op}"
)

# Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
numpy_weights = (
(torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
)
tensor = onnx.TensorProto()
tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
tensor.dims.extend(numpy_weights.shape)
tensor.raw_data = numpy_weights.tobytes()
values = LazyValues(tensor)
onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values)

node.outputs.clear()
# DQ Op is separated out
dq_op.inputs[0] = onnx_weights_fp8
dq_op.op = "DequantizeLinear"
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype

graph.cleanup().toposort()
end_time = time.time()
Expand Down
73 changes: 42 additions & 31 deletions modelopt/onnx/llm_export_utils/surgeon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import torch
from onnx_graphsurgeon.ir.tensor import LazyValues

from modelopt.onnx.utils import is_fp8_constant


def clear_inputs(node: gs.Node | gs.Tensor):
"""Clear all inputs for a node or tensor in ONNX."""
Expand Down Expand Up @@ -81,37 +83,46 @@ def fold_fp8_qdq_to_dq(graph: gs.Graph):
graph.cleanup().toposort().fold_constants().cleanup()

for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
# Should not remove input QDQ
if not isinstance(node.inputs[0], gs.Constant):
continue

weights = node.inputs[0]
scale = node.inputs[1]
torch_weights = torch.from_numpy(weights.values)
torch_scale = torch.from_numpy(scale.values)
quantizer_name = scale.name.rsplit("/", 1)[0]
dq_op = node.outputs[0].outputs[0]
assert dq_op.op == "TRT_FP8DequantizeLinear", (
f"QDQ does not occur in pairs. You reached {dq_op.op}"
)

# Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
numpy_weights = (
(torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
)
tensor = onnx.TensorProto()
tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
tensor.dims.extend(numpy_weights.shape)
tensor.raw_data = numpy_weights.tobytes()
values = LazyValues(tensor)
onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values)

node.outputs.clear()
# DQ Op is separated out
dq_op.inputs[0] = onnx_weights_fp8
dq_op.op = "DequantizeLinear"
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype
is_trt_fp8_q = node.op == "TRT_FP8QuantizeLinear"
is_std_fp8_q = (
node.op == "QuantizeLinear"
and len(node.inputs) >= 3
and isinstance(node.inputs[2], gs.Constant)
and is_fp8_constant(node.inputs[2])
)
if not (is_trt_fp8_q or is_std_fp8_q):
continue

# Should not remove input QDQ
if not isinstance(node.inputs[0], gs.Constant):
continue

weights = node.inputs[0]
scale = node.inputs[1]
torch_weights = torch.from_numpy(weights.values)
torch_scale = torch.from_numpy(scale.values)
quantizer_name = scale.name.rsplit("/", 1)[0]
dq_op = node.outputs[0].outputs[0]
assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), (
f"QDQ does not occur in pairs. You reached {dq_op.op}"
)
Comment on lines +105 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard Q output-chain indexing before dereference.

At Line 105, node.outputs[0].outputs[0] can raise IndexError on non-canonical graphs before your pair check at Line 106 runs.

🔧 Proposed fix
-        dq_op = node.outputs[0].outputs[0]
-        assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), (
-            f"QDQ does not occur in pairs. You reached {dq_op.op}"
-        )
+        if not node.outputs or not node.outputs[0].outputs:
+            continue
+        dq_op = node.outputs[0].outputs[0]
+        if dq_op.op not in ("TRT_FP8DequantizeLinear", "DequantizeLinear"):
+            continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/llm_export_utils/surgeon_utils.py` around lines 105 - 108, The
code dereferences node.outputs[0].outputs[0] into dq_op before checking its
existence, which can raise IndexError; modify the logic in surgeon_utils.py so
you first verify node.outputs is non-empty and node.outputs[0].outputs is
non-empty (e.g., if not node.outputs or not node.outputs[0].outputs: raise a
descriptive error or return) before assigning dq_op, then perform the existing
assert on dq_op.op; reference the variables node and dq_op to locate the fix.


# Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
numpy_weights = (
(torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
)
Comment on lines +110 to +113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 1) Locate FP8 QuantizeLinear construction and zero_point setup in exporter paths.
rg -nP --type=py -C3 '\bQuantizeLinear\b|TRT_FP8QuantizeLinear|FLOAT8E4M3FN|zero_point|float8_e4m3fn'

# 2) Locate folding assumptions and any explicit zero_point validation.
rg -nP --type=py -C3 'fold_fp8_qdq_to_dq|is_std_fp8_q|is_fp8_constant|zero_point|axis'

# 3) Check whether tests cover non-zero zero_point / axis cases for FP8 fold/export.
rg -nP --type=py -C3 'test_.*fp8.*(zero_point|axis)|fp8_onnx_export_shape|fold_fp8_qdq_to_dq'

Repository: NVIDIA/Model-Optimizer

Length of output: 50379


🏁 Script executed:

# First, let's read the specific file and lines mentioned in the review
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '100,125p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1413


🏁 Script executed:

# Search for the function containing this code and zero_point handling
rg -A20 -B5 "Replace it with Dequantize with FP8 weights" modelopt/onnx/llm_export_utils/surgeon_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1146


🏁 Script executed:

# Look for fold_fp8_qdq or similar folding functions
fd surgeon_utils.py -x cat -n {} | head -200

Repository: NVIDIA/Model-Optimizer

Length of output: 5481


🏁 Script executed:

# Search for zero_point handling in the surgeon_utils file and related functions
rg -B10 -A10 "zero_point" modelopt/onnx/llm_export_utils/surgeon_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check ONNX standard or comments about FP8 zero_point semantics
rg -C5 "TRT_FP8QuantizeLinear|is_std_fp8_q" modelopt/onnx/ --type=py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 3014


🏁 Script executed:

# Look for FP8 quantization logic and zero_point usage
rg -C5 "float8_e4m3fn.*zero|zero.*float8_e4m3fn" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check how FP8 QuantizeLinear nodes are created - what zero_point values are used
rg -B5 -A5 "QuantizeLinear.*fp8|float8.*QuantizeLinear" modelopt/onnx/ --type=py | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for zero_point creation/assignment in FP8 contexts
rg "zero_point.*=.*0|zero_point.*fp8|fp8.*zero" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 145


🏁 Script executed:

# Check what the ONNX standard and comments say about FP8 and zero_point
rg -B3 -A3 "FP8|float8_e4m3fn" modelopt/onnx/quantization/fp8.py | head -120

Repository: NVIDIA/Model-Optimizer

Length of output: 2904


🏁 Script executed:

# Look for how zero_point is handled when creating FP8 QuantizeLinear nodes
rg -B10 -A10 "create.*qdq|QuantizeLinear.*zero" modelopt/onnx/quantization/autotune/export_utils.py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 2035


🏁 Script executed:

# Check how int8_to_fp8 handles zero_point conversion
rg -B5 -A15 "def int8_to_fp8" modelopt/onnx/quantization/fp8.py

Repository: NVIDIA/Model-Optimizer

Length of output: 797


🏁 Script executed:

# Search for any comments about FP8 and zero_point semantics in the codebase
rg -i "fp8.*zero_point|zero_point.*fp8" modelopt/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for test files that test FP8 folding with zero_point
fd "test.*fp8|fp8.*test" -t f modelopt/ | xargs rg -l "zero_point|fold_fp8" 2>/dev/null | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 325


🏁 Script executed:

# Get the full int8_to_fp8 function to see how it handles zero_point
cat -n modelopt/onnx/quantization/fp8.py | sed -n '40,130p'

Repository: NVIDIA/Model-Optimizer

Length of output: 4531


🏁 Script executed:

# Check if there are any tests that verify folding with non-zero zero_point
rg -l "fold_fp8" tests/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the actual ONNX spec documentation in comments or code
rg -B5 -A5 "QuantizeLinear.*zero_point|zero_point.*QuantizeLinear" modelopt/onnx/quantization/fp8.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at what zero_point values are used when creating FP8 nodes
rg -B10 -A10 "get_zero_point_for_quant_type" modelopt/onnx/quantization/autotune/export_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1393


🏁 Script executed:

# Check if FP8 quantization ever uses non-zero zero_point
rg -B5 -A5 "quant_type.*fp8|fp8.*quant" modelopt/onnx/quantization/autotune/export_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 619


🏁 Script executed:

# Look at the full fold_fp8_qdq_to_dq function and see how zero_point is accessed
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '66,130p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3096


🏁 Script executed:

# Check if DequantizeLinear with FP8 uses zero_point in its operation
rg -B3 -A3 "DequantizeLinear.*zero_point|zero_point.*DequantizeLinear" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if FP8 is ever created with non-zero zero_point
rg -B5 -A10 "default_q_zero_point.*fp8|fp8.*default.*zero" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at tests to see if they verify zero_point handling for FP8
fd "test.*fp8" tests/ --type=py -x grep -l "zero_point" {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 235


🏁 Script executed:

# Check the ONNX standard comment in create_qdq_nodes for how zero_point is used
cat -n modelopt/onnx/quantization/autotune/export_utils.py | sed -n '149,200p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2278


🏁 Script executed:

# Verify: does FP8 always use zero_point=0 in the export pipeline?
rg -B10 -A10 "export_qdq_onnx.*fp8|fp8.*export_qdq|config.default_quant_type.*fp8" modelopt/onnx/quantization/autotune/ --type=py | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 2026


🏁 Script executed:

# Check what zero_point is used when creating FP8 QuantizeLinear nodes
cat -n modelopt/onnx/quantization/autotune/export_utils.py | sed -n '215,230p'

Repository: NVIDIA/Model-Optimizer

Length of output: 945


🏁 Script executed:

# Verify how int8_to_fp8 handles zero_point: does it convert INT8 zero_point to FP8?
cat -n modelopt/onnx/quantization/fp8.py | sed -n '98,115p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1146


🏁 Script executed:

# Search for where fold_fp8_qdq_to_dq is called - is there a constraint on zero_point?
rg -B5 -A5 "fold_fp8_qdq_to_dq" modelopt/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 1099


🏁 Script executed:

# Check ONNX QuantizeLinear semantics: does DequantizeLinear with FP8 use zero_point?
rg -B10 -A10 "TRT_FP8Dequantize|FLOAT8E4M3FN.*dequant" modelopt/onnx/ --type=py | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 3763


🏁 Script executed:

# Check the ONNX DequantizeLinear spec: for FP8, is zero_point used in dequantization?
# Look at what happens with zero_point in DequantizeLinear operations
rg -B5 -A10 "def.*dequantize|DequantizeLinear.*zero" modelopt/onnx/ --type=py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 4724


🏁 Script executed:

# Critical: check if FP8 nodes are ALWAYS created with zero_point=0
rg -B3 -A3 "default_quant_type.*fp8" modelopt/onnx/quantization/autotune/common.py

Repository: NVIDIA/Model-Optimizer

Length of output: 656


🏁 Script executed:

# Check if there's any validation that enforces zero_point=0 for FP8
rg -i "fp8.*zero_point|zero_point.*fp8|fp8.*zero" modelopt/onnx/quantization/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 186


🏁 Script executed:

# Final verification: Check ONNX DequantizeLinear spec - does it use zero_point?
# Look at how DequantizeLinear inputs are documented
rg -B5 -A10 "inputs.*scale.*zero_point" modelopt/onnx/ --type=py | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 2534


🏁 Script executed:

# Check if there's any assertion or validation that zero_point must be 0 for FP8
rg "zero_point.*==.*0|assert.*zero" modelopt/onnx/llm_export_utils/surgeon_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Verify: can FP8 nodes have non-zero zero_point? Check if fold function handles this
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '85,105p'

Repository: NVIDIA/Model-Optimizer

Length of output: 972


🏁 Script executed:

# Check ONNX QuantizeLinear/DequantizeLinear formula in documentation
# Look for how zero_point affects the quantization/dequantization
rg -B3 -A3 "saturate.*round|round.*zero_point" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Key question: if zero_point != 0, what happens when folding?
# Let's verify what the actual mathematical operation should be
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '66,130p' > /tmp/surgeon.txt && cat /tmp/surgeon.txt

Repository: NVIDIA/Model-Optimizer

Length of output: 3096


🏁 Script executed:

# Check if there's ever a case where FP8 nodes are created with non-zero zero_point
rg -B10 "int8_to_fp8" modelopt/onnx/quantization/fp8.py | tail -20

Repository: NVIDIA/Model-Optimizer

Length of output: 984


🏁 Script executed:

# Final check: are there any tests for fold_fp8_qdq_to_dq with non-zero zero_point?
find tests/ -name "*.py" -exec grep -l "fold_fp8_qdq_to_dq\|surgeon_utils" {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Verify: What does is_fp8_constant check? Does it ensure zero_point=0?
rg -B5 -A10 "def is_fp8_constant" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 619


🏁 Script executed:

# Check if DequantizeLinear semantics actually use zero_point
# Search for ONNX spec or comments about the dequantization formula
rg -B5 -A5 "y = \(x - zero|saturate.*round" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Verify the actual behavior: if zero_point is FP8, can it be non-zero?
python3 << 'EOF'
import numpy as np
# FP8 E4M3FN format can represent values - check if zero_point=0 is always used
# Check the range and representation
print("FP8 E4M3FN can represent:")
print("- Zero values: 0.0")
print("- Positive values: up to ~480")
print("- Negative values: down to ~-480")
print("\nFor quantization semantics:")
print("- INT8: zero_point typically 0 (signed) or 128 (unsigned)")
print("- FP8: zero_point would be 0.0 (no unsigned representation)")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 320


🏁 Script executed:

# Final: Check if any code path can create FP8 with non-zero zero_point
rg -B15 "default_q_zero_point" modelopt/onnx/quantization/autotune/common.py | grep -A15 "default_quant_type"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Verify: ONNX QuantizeLinear formula - what does zero_point do?
# Also check if code ever sets non-zero zero_point for FP8
rg "default_q_zero_point.*0|zero_point.*=.*0" modelopt/onnx/quantization/autotune/export_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Critical: Check the int8_to_fp8 conversion - does it preserve zero_point values?
cat -n modelopt/onnx/quantization/fp8.py | sed -n '98,112p'

Repository: NVIDIA/Model-Optimizer

Length of output: 973


🏁 Script executed:

# Verify the formula: if zero_point is in QuantizeLinear, does the fold code handle it?
python3 << 'EOF'
# ONNX standard formulas:
# QuantizeLinear: y = saturate(round(x / scale) + zero_point)
# DequantizeLinear: y = (x - zero_point) * scale

# Current fold operation:
# folded_weights = (original_weights / scale)
# Later dequantize: (folded_weights - zero_point) * scale
#                 = ((original_weights / scale) - zero_point) * scale
#                 = original_weights - (zero_point * scale)  ← WRONG if zero_point != 0

# Correct fold should be:
# folded_weights = (original_weights / scale) + zero_point
# Then dequantize: (folded_weights - zero_point) * scale = original_weights ✓

print("Issue confirmed: if zero_point != 0, folding is mathematically incorrect")
print("Current code ignores zero_point in the folding calculation")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 197


FP8 QuantizeLinear folding ignores zero_point in dequantization formula.

At Line 112, the code folds weights / scale without accounting for zero_point. ONNX QuantizeLinear applies the formula saturate(round(x / scale) + zero_point), and the paired DequantizeLinear applies (x - zero_point) * scale. When folding, the stored FP8 weights should encode (x / scale) + zero_point so that subsequent dequantization yields correct results. If zero_point != 0, the current implementation produces numerically incorrect folded weights.

Recommend either:

  1. Enforce zero_point == 0 before folding (add validation), or
  2. Include zero_point in the weight conversion: (torch_weights / torch_scale + torch_zero_point).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/llm_export_utils/surgeon_utils.py` around lines 110 - 113, The
FP8 folding currently computes numpy_weights from (torch_weights / torch_scale)
but ignores QuantizeLinear's zero_point; update the folding in surgeon_utils.py
to either validate that torch_zero_point == 0 and raise/log if not, or
incorporate the zero point into the conversion by computing (torch_weights /
torch_scale + torch_zero_point) before casting to FP8; modify the numpy_weights
creation (referencing numpy_weights, torch_weights, torch_scale,
torch_zero_point) to include this change and ensure correct rounding/typing for
the subsequent .to(torch.float8_e4m3fn).

tensor = onnx.TensorProto()
tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
tensor.dims.extend(numpy_weights.shape)
tensor.raw_data = numpy_weights.tobytes()
values = LazyValues(tensor)
onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values)

node.outputs.clear()
# DQ Op is separated out
dq_op.inputs[0] = onnx_weights_fp8
dq_op.op = "DequantizeLinear"
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype

graph.cleanup().toposort()
end_time = time.time()
Expand Down
14 changes: 14 additions & 0 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,27 @@
import onnx_graphsurgeon as gs
from onnx.helper import get_attribute_value
from onnx_graphsurgeon import Constant, Node, Variable
from onnx_graphsurgeon.ir.tensor import LazyValues

from modelopt.onnx.logging_config import logger

# Base minimum opset for quantization (opset 19 is the first to support fp16 scales)
BASE_MIN_OPSET = 19


def is_fp8_constant(const: Constant) -> bool:
"""Return True if a gs.Constant holds a FLOAT8E4M3FN tensor.

Uses getattr to guard against future changes to the LazyValues internal API.
"""
if not isinstance(const.values, LazyValues):
return False
tensor_proto = getattr(const.values, "_tensor", None)
if tensor_proto is None:
return False
return tensor_proto.data_type == onnx.TensorProto.FLOAT8E4M3FN


def get_input_names_from_bytes(model_bytes: bytes, external_inputs_only: bool = True) -> list[str]:
"""This function returns the inputs names of the given onnx model in bytes.

Expand Down
36 changes: 22 additions & 14 deletions modelopt/torch/quantization/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,12 @@
}
mha_valid_precisions = {"Half", "BFloat16"}

torch_dtype_map = {"Float": torch.float32, "Half": torch.float16, "BFloat16": torch.bfloat16}
torch_dtype_map = {
"Float": torch.float32,
"Half": torch.float16,
"BFloat16": torch.bfloat16,
"Float8": torch.float8_e4m3fn,
}


def export_int8(
Expand Down Expand Up @@ -221,8 +226,7 @@ def _fp8_quantize(
"""Helper Function for Quantization."""
output_shape = sym_help._get_tensor_sizes(inputs)

# TRT StronglyType only supports FP16 QDQs
# custom ops, so cast the input if needed.
# Cast the input to the high-precision dtype if needed.
input_type = inputs.type().scalarType()
assert trt_high_precision_dtype in (input_type, "Float"), (
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
Expand All @@ -234,9 +238,12 @@ def _fp8_quantize(
"Constant",
value_t=torch.tensor(scale_inv).to(torch_dtype_map[trt_high_precision_dtype]),
)
q_op = g.op("trt::TRT_FP8QuantizeLinear", inputs, scale).setType(
inputs.type().with_dtype(torch.uint8).with_sizes(output_shape)
)
# Use standard ONNX QuantizeLinear with FLOAT8E4M3FN zero_point (opset 19).
# The zero_point dtype determines the output dtype per the ONNX spec.
zero_point = g.op("Constant", value_t=torch.tensor(0.0))
zero_point = g.op("Cast", zero_point, to_i=onnx_dtype_map["Float8"])
q_op = g.op("QuantizeLinear", inputs, scale, zero_point, saturate_i=1)
q_op.setType(inputs.type().with_dtype(torch.float8_e4m3fn).with_sizes(output_shape))
return q_op


Expand All @@ -249,21 +256,22 @@ def _fp8_dequantize(
):
"""Helper Function for Dequantization."""
output_shape = sym_help._get_tensor_sizes(inputs)
assert trt_high_precision_dtype in (otype, "Float"), (
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
)
scale = g.op(
"Constant",
value_t=torch.tensor(scale_inv, dtype=torch_dtype_map[otype]), # type: ignore[index]
)
out = g.op("trt::TRT_FP8DequantizeLinear", inputs, scale).setType(
# Use standard ONNX DequantizeLinear with FLOAT8E4M3FN zero_point (opset 19).
# Per the ONNX spec, DequantizeLinear with FLOAT8E4M3FN input outputs float32.
zero_point = g.op("Constant", value_t=torch.tensor(0.0))
zero_point = g.op("Cast", zero_point, to_i=onnx_dtype_map["Float8"])
out = g.op("DequantizeLinear", inputs, scale, zero_point)
out.setType(
inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape)
)

# DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the output if needed.
if trt_high_precision_dtype != otype:
out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index]
# DequantizeLinear outputs float32 in opset 19; cast back to original type if needed.
if otype in torch_dtype_map and otype != "Float":
out = g.op("Cast", out, to_i=onnx_dtype_map[otype])
return out


Expand Down
Loading
Loading