-
Notifications
You must be signed in to change notification settings - Fork 307
Fix FP8 ONNX export to use standard QuantizeLinear/DequantizeLinear ops #1037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,8 @@ | |
| import torch | ||
| from onnx_graphsurgeon.ir.tensor import LazyValues | ||
|
|
||
| from modelopt.onnx.utils import is_fp8_constant | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def clear_inputs(node: gs.Node | gs.Tensor): | ||
| """Clear all inputs for a node or tensor in ONNX.""" | ||
|
|
@@ -81,37 +83,46 @@ def fold_fp8_qdq_to_dq(graph: gs.Graph): | |
| graph.cleanup().toposort().fold_constants().cleanup() | ||
|
|
||
| for node in graph.nodes: | ||
| if node.op == "TRT_FP8QuantizeLinear": | ||
| # Should not remove input QDQ | ||
| if not isinstance(node.inputs[0], gs.Constant): | ||
| continue | ||
|
|
||
| weights = node.inputs[0] | ||
| scale = node.inputs[1] | ||
| torch_weights = torch.from_numpy(weights.values) | ||
| torch_scale = torch.from_numpy(scale.values) | ||
| quantizer_name = scale.name.rsplit("/", 1)[0] | ||
| dq_op = node.outputs[0].outputs[0] | ||
| assert dq_op.op == "TRT_FP8DequantizeLinear", ( | ||
| f"QDQ does not occur in pairs. You reached {dq_op.op}" | ||
| ) | ||
|
|
||
| # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. | ||
| numpy_weights = ( | ||
| (torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy() | ||
| ) | ||
| tensor = onnx.TensorProto() | ||
| tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN | ||
| tensor.dims.extend(numpy_weights.shape) | ||
| tensor.raw_data = numpy_weights.tobytes() | ||
| values = LazyValues(tensor) | ||
| onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) | ||
|
|
||
| node.outputs.clear() | ||
| # DQ Op is separated out | ||
| dq_op.inputs[0] = onnx_weights_fp8 | ||
| dq_op.op = "DequantizeLinear" | ||
| dq_op.outputs[0].dtype = dq_op.inputs[1].dtype | ||
| is_trt_fp8_q = node.op == "TRT_FP8QuantizeLinear" | ||
| is_std_fp8_q = ( | ||
| node.op == "QuantizeLinear" | ||
| and len(node.inputs) >= 3 | ||
| and isinstance(node.inputs[2], gs.Constant) | ||
| and is_fp8_constant(node.inputs[2]) | ||
| ) | ||
| if not (is_trt_fp8_q or is_std_fp8_q): | ||
| continue | ||
|
|
||
| # Should not remove input QDQ | ||
| if not isinstance(node.inputs[0], gs.Constant): | ||
| continue | ||
|
|
||
| weights = node.inputs[0] | ||
| scale = node.inputs[1] | ||
| torch_weights = torch.from_numpy(weights.values) | ||
| torch_scale = torch.from_numpy(scale.values) | ||
| quantizer_name = scale.name.rsplit("/", 1)[0] | ||
| dq_op = node.outputs[0].outputs[0] | ||
| assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), ( | ||
| f"QDQ does not occur in pairs. You reached {dq_op.op}" | ||
| ) | ||
|
Comment on lines
+105
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard Q output-chain indexing before dereference. At Line 105, 🔧 Proposed fix- dq_op = node.outputs[0].outputs[0]
- assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), (
- f"QDQ does not occur in pairs. You reached {dq_op.op}"
- )
+ if not node.outputs or not node.outputs[0].outputs:
+ continue
+ dq_op = node.outputs[0].outputs[0]
+ if dq_op.op not in ("TRT_FP8DequantizeLinear", "DequantizeLinear"):
+ continue🤖 Prompt for AI Agents |
||
|
|
||
| # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. | ||
| numpy_weights = ( | ||
| (torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy() | ||
| ) | ||
|
Comment on lines
+110
to
+113
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
set -euo pipefail
# 1) Locate FP8 QuantizeLinear construction and zero_point setup in exporter paths.
rg -nP --type=py -C3 '\bQuantizeLinear\b|TRT_FP8QuantizeLinear|FLOAT8E4M3FN|zero_point|float8_e4m3fn'
# 2) Locate folding assumptions and any explicit zero_point validation.
rg -nP --type=py -C3 'fold_fp8_qdq_to_dq|is_std_fp8_q|is_fp8_constant|zero_point|axis'
# 3) Check whether tests cover non-zero zero_point / axis cases for FP8 fold/export.
rg -nP --type=py -C3 'test_.*fp8.*(zero_point|axis)|fp8_onnx_export_shape|fold_fp8_qdq_to_dq'Repository: NVIDIA/Model-Optimizer Length of output: 50379 🏁 Script executed: # First, let's read the specific file and lines mentioned in the review
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '100,125p'Repository: NVIDIA/Model-Optimizer Length of output: 1413 🏁 Script executed: # Search for the function containing this code and zero_point handling
rg -A20 -B5 "Replace it with Dequantize with FP8 weights" modelopt/onnx/llm_export_utils/surgeon_utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 1146 🏁 Script executed: # Look for fold_fp8_qdq or similar folding functions
fd surgeon_utils.py -x cat -n {} | head -200Repository: NVIDIA/Model-Optimizer Length of output: 5481 🏁 Script executed: # Search for zero_point handling in the surgeon_utils file and related functions
rg -B10 -A10 "zero_point" modelopt/onnx/llm_export_utils/surgeon_utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Check ONNX standard or comments about FP8 zero_point semantics
rg -C5 "TRT_FP8QuantizeLinear|is_std_fp8_q" modelopt/onnx/ --type=py | head -100Repository: NVIDIA/Model-Optimizer Length of output: 3014 🏁 Script executed: # Look for FP8 quantization logic and zero_point usage
rg -C5 "float8_e4m3fn.*zero|zero.*float8_e4m3fn" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Check how FP8 QuantizeLinear nodes are created - what zero_point values are used
rg -B5 -A5 "QuantizeLinear.*fp8|float8.*QuantizeLinear" modelopt/onnx/ --type=py | head -80Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look for zero_point creation/assignment in FP8 contexts
rg "zero_point.*=.*0|zero_point.*fp8|fp8.*zero" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 145 🏁 Script executed: # Check what the ONNX standard and comments say about FP8 and zero_point
rg -B3 -A3 "FP8|float8_e4m3fn" modelopt/onnx/quantization/fp8.py | head -120Repository: NVIDIA/Model-Optimizer Length of output: 2904 🏁 Script executed: # Look for how zero_point is handled when creating FP8 QuantizeLinear nodes
rg -B10 -A10 "create.*qdq|QuantizeLinear.*zero" modelopt/onnx/quantization/autotune/export_utils.py | head -100Repository: NVIDIA/Model-Optimizer Length of output: 2035 🏁 Script executed: # Check how int8_to_fp8 handles zero_point conversion
rg -B5 -A15 "def int8_to_fp8" modelopt/onnx/quantization/fp8.pyRepository: NVIDIA/Model-Optimizer Length of output: 797 🏁 Script executed: # Search for any comments about FP8 and zero_point semantics in the codebase
rg -i "fp8.*zero_point|zero_point.*fp8" modelopt/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look for test files that test FP8 folding with zero_point
fd "test.*fp8|fp8.*test" -t f modelopt/ | xargs rg -l "zero_point|fold_fp8" 2>/dev/null | head -5Repository: NVIDIA/Model-Optimizer Length of output: 325 🏁 Script executed: # Get the full int8_to_fp8 function to see how it handles zero_point
cat -n modelopt/onnx/quantization/fp8.py | sed -n '40,130p'Repository: NVIDIA/Model-Optimizer Length of output: 4531 🏁 Script executed: # Check if there are any tests that verify folding with non-zero zero_point
rg -l "fold_fp8" tests/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look at the actual ONNX spec documentation in comments or code
rg -B5 -A5 "QuantizeLinear.*zero_point|zero_point.*QuantizeLinear" modelopt/onnx/quantization/fp8.pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look at what zero_point values are used when creating FP8 nodes
rg -B10 -A10 "get_zero_point_for_quant_type" modelopt/onnx/quantization/autotune/export_utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 1393 🏁 Script executed: # Check if FP8 quantization ever uses non-zero zero_point
rg -B5 -A5 "quant_type.*fp8|fp8.*quant" modelopt/onnx/quantization/autotune/export_utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 619 🏁 Script executed: # Look at the full fold_fp8_qdq_to_dq function and see how zero_point is accessed
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '66,130p'Repository: NVIDIA/Model-Optimizer Length of output: 3096 🏁 Script executed: # Check if DequantizeLinear with FP8 uses zero_point in its operation
rg -B3 -A3 "DequantizeLinear.*zero_point|zero_point.*DequantizeLinear" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Check if FP8 is ever created with non-zero zero_point
rg -B5 -A10 "default_q_zero_point.*fp8|fp8.*default.*zero" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look at tests to see if they verify zero_point handling for FP8
fd "test.*fp8" tests/ --type=py -x grep -l "zero_point" {} \;Repository: NVIDIA/Model-Optimizer Length of output: 235 🏁 Script executed: # Check the ONNX standard comment in create_qdq_nodes for how zero_point is used
cat -n modelopt/onnx/quantization/autotune/export_utils.py | sed -n '149,200p'Repository: NVIDIA/Model-Optimizer Length of output: 2278 🏁 Script executed: # Verify: does FP8 always use zero_point=0 in the export pipeline?
rg -B10 -A10 "export_qdq_onnx.*fp8|fp8.*export_qdq|config.default_quant_type.*fp8" modelopt/onnx/quantization/autotune/ --type=py | head -60Repository: NVIDIA/Model-Optimizer Length of output: 2026 🏁 Script executed: # Check what zero_point is used when creating FP8 QuantizeLinear nodes
cat -n modelopt/onnx/quantization/autotune/export_utils.py | sed -n '215,230p'Repository: NVIDIA/Model-Optimizer Length of output: 945 🏁 Script executed: # Verify how int8_to_fp8 handles zero_point: does it convert INT8 zero_point to FP8?
cat -n modelopt/onnx/quantization/fp8.py | sed -n '98,115p'Repository: NVIDIA/Model-Optimizer Length of output: 1146 🏁 Script executed: # Search for where fold_fp8_qdq_to_dq is called - is there a constraint on zero_point?
rg -B5 -A5 "fold_fp8_qdq_to_dq" modelopt/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 1099 🏁 Script executed: # Check ONNX QuantizeLinear semantics: does DequantizeLinear with FP8 use zero_point?
rg -B10 -A10 "TRT_FP8Dequantize|FLOAT8E4M3FN.*dequant" modelopt/onnx/ --type=py | head -80Repository: NVIDIA/Model-Optimizer Length of output: 3763 🏁 Script executed: # Check the ONNX DequantizeLinear spec: for FP8, is zero_point used in dequantization?
# Look at what happens with zero_point in DequantizeLinear operations
rg -B5 -A10 "def.*dequantize|DequantizeLinear.*zero" modelopt/onnx/ --type=py | head -100Repository: NVIDIA/Model-Optimizer Length of output: 4724 🏁 Script executed: # Critical: check if FP8 nodes are ALWAYS created with zero_point=0
rg -B3 -A3 "default_quant_type.*fp8" modelopt/onnx/quantization/autotune/common.pyRepository: NVIDIA/Model-Optimizer Length of output: 656 🏁 Script executed: # Check if there's any validation that enforces zero_point=0 for FP8
rg -i "fp8.*zero_point|zero_point.*fp8|fp8.*zero" modelopt/onnx/quantization/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 186 🏁 Script executed: # Final verification: Check ONNX DequantizeLinear spec - does it use zero_point?
# Look at how DequantizeLinear inputs are documented
rg -B5 -A10 "inputs.*scale.*zero_point" modelopt/onnx/ --type=py | head -60Repository: NVIDIA/Model-Optimizer Length of output: 2534 🏁 Script executed: # Check if there's any assertion or validation that zero_point must be 0 for FP8
rg "zero_point.*==.*0|assert.*zero" modelopt/onnx/llm_export_utils/surgeon_utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Verify: can FP8 nodes have non-zero zero_point? Check if fold function handles this
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '85,105p'Repository: NVIDIA/Model-Optimizer Length of output: 972 🏁 Script executed: # Check ONNX QuantizeLinear/DequantizeLinear formula in documentation
# Look for how zero_point affects the quantization/dequantization
rg -B3 -A3 "saturate.*round|round.*zero_point" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Key question: if zero_point != 0, what happens when folding?
# Let's verify what the actual mathematical operation should be
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '66,130p' > /tmp/surgeon.txt && cat /tmp/surgeon.txtRepository: NVIDIA/Model-Optimizer Length of output: 3096 🏁 Script executed: # Check if there's ever a case where FP8 nodes are created with non-zero zero_point
rg -B10 "int8_to_fp8" modelopt/onnx/quantization/fp8.py | tail -20Repository: NVIDIA/Model-Optimizer Length of output: 984 🏁 Script executed: # Final check: are there any tests for fold_fp8_qdq_to_dq with non-zero zero_point?
find tests/ -name "*.py" -exec grep -l "fold_fp8_qdq_to_dq\|surgeon_utils" {} \;Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Verify: What does is_fp8_constant check? Does it ensure zero_point=0?
rg -B5 -A10 "def is_fp8_constant" modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 619 🏁 Script executed: # Check if DequantizeLinear semantics actually use zero_point
# Search for ONNX spec or comments about the dequantization formula
rg -B5 -A5 "y = \(x - zero|saturate.*round" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Verify the actual behavior: if zero_point is FP8, can it be non-zero?
python3 << 'EOF'
import numpy as np
# FP8 E4M3FN format can represent values - check if zero_point=0 is always used
# Check the range and representation
print("FP8 E4M3FN can represent:")
print("- Zero values: 0.0")
print("- Positive values: up to ~480")
print("- Negative values: down to ~-480")
print("\nFor quantization semantics:")
print("- INT8: zero_point typically 0 (signed) or 128 (unsigned)")
print("- FP8: zero_point would be 0.0 (no unsigned representation)")
EOFRepository: NVIDIA/Model-Optimizer Length of output: 320 🏁 Script executed: # Final: Check if any code path can create FP8 with non-zero zero_point
rg -B15 "default_q_zero_point" modelopt/onnx/quantization/autotune/common.py | grep -A15 "default_quant_type"Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Verify: ONNX QuantizeLinear formula - what does zero_point do?
# Also check if code ever sets non-zero zero_point for FP8
rg "default_q_zero_point.*0|zero_point.*=.*0" modelopt/onnx/quantization/autotune/export_utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Critical: Check the int8_to_fp8 conversion - does it preserve zero_point values?
cat -n modelopt/onnx/quantization/fp8.py | sed -n '98,112p'Repository: NVIDIA/Model-Optimizer Length of output: 973 🏁 Script executed: # Verify the formula: if zero_point is in QuantizeLinear, does the fold code handle it?
python3 << 'EOF'
# ONNX standard formulas:
# QuantizeLinear: y = saturate(round(x / scale) + zero_point)
# DequantizeLinear: y = (x - zero_point) * scale
# Current fold operation:
# folded_weights = (original_weights / scale)
# Later dequantize: (folded_weights - zero_point) * scale
# = ((original_weights / scale) - zero_point) * scale
# = original_weights - (zero_point * scale) ← WRONG if zero_point != 0
# Correct fold should be:
# folded_weights = (original_weights / scale) + zero_point
# Then dequantize: (folded_weights - zero_point) * scale = original_weights ✓
print("Issue confirmed: if zero_point != 0, folding is mathematically incorrect")
print("Current code ignores zero_point in the folding calculation")
EOFRepository: NVIDIA/Model-Optimizer Length of output: 197 FP8 QuantizeLinear folding ignores At Line 112, the code folds Recommend either:
🤖 Prompt for AI Agents |
||
| tensor = onnx.TensorProto() | ||
| tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN | ||
| tensor.dims.extend(numpy_weights.shape) | ||
| tensor.raw_data = numpy_weights.tobytes() | ||
| values = LazyValues(tensor) | ||
| onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) | ||
|
|
||
| node.outputs.clear() | ||
| # DQ Op is separated out | ||
| dq_op.inputs[0] = onnx_weights_fp8 | ||
| dq_op.op = "DequantizeLinear" | ||
| dq_op.outputs[0].dtype = dq_op.inputs[1].dtype | ||
|
|
||
| graph.cleanup().toposort() | ||
| end_time = time.time() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.