-
Notifications
You must be signed in to change notification settings - Fork 242
Added column-major storage of weights and scales in INT4 quantization for model load time improvement in TRT-RTX #811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…improvement in TRT-RTX Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
📝 WalkthroughWalkthroughThis PR introduces a column-major storage optimization feature for ONNX INT4 quantization targeting the NvTensorRtRtx execution provider. It adds a CLI flag to the quantization script, integrates it through the quantization pipeline, and provides utility functions for applying column-major transformations to GEMM weights and inserting transpose operations in DQ-only quantization modes. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant CLI as quantize.py<br/>(CLI)
participant API as int4.py<br/>(quantize)
participant Transform as qdq_utils.py<br/>(apply_column_major)
participant Graph as Graph<br/>(ONNX)
User->>CLI: --use_column_major flag
CLI->>API: quantize(...,<br/>use_column_major=True)
API->>Transform: apply_column_major_transformation(<br/>weights, scales, ...)
Transform->>Transform: Transpose weights &<br/>scales in-place
Transform->>API: Return DQ attributes<br/>(axis=1)
API->>Graph: Create DQ nodes with<br/>column-major attributes
API->>Transform: add_transpose_nodes_for_column_major(graph)
Transform->>Graph: Insert Transpose nodes<br/>after DQ nodes
Transform->>Graph: Update MatMul/Gemm<br/>inputs
Graph-->>User: Optimized ONNX model
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@modelopt/onnx/quantization/qdq_utils.py`:
- Around line 1083-1091: The current Gemm handling in
apply_column_major_transformation (qdq_utils.py) skips nodes with node.op ==
"Gemm" when node.attrs contains transB=1, which breaks semantics for
column-major weights; instead, when encountering a Gemm with transB set, update
the node.attrs transB to 0 (or remove/normalize it to zero) so the graph expects
B^T (matching the earlier weight transpose) and do not skip inserting the
transpose-back; locate the Gemm handling block (check for node.op == "Gemm" and
the transB logic) and replace the early continue with logic that flips
node.attrs["transB"] to 0 (or deletes the attr) so outputs remain correct while
keeping the transpose-back insertion.
| # For Gemm nodes, check if transB is already set | ||
| if node.op == "Gemm": | ||
| trans_b = False | ||
| if hasattr(node, "attrs") and "transB" in node.attrs: | ||
| trans_b = node.attrs["transB"] > 0 | ||
| if trans_b: | ||
| logger.debug(f"Gemm node {node.name} already has transB=1, skipping") | ||
| continue | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gemm transB=1 skip breaks correctness with column‑major weights.
apply_column_major_transformation already transposes weights. If a Gemm has transB=1, skipping the transpose‑back makes Gemm consume B instead of B^T, changing outputs. Either always insert the transpose‑back or flip transB to 0 so Gemm consumes B^T directly.
🐛 Proposed fix (flip transB to 0 and keep semantics)
- # For Gemm nodes, check if transB is already set
- if node.op == "Gemm":
- trans_b = False
- if hasattr(node, "attrs") and "transB" in node.attrs:
- trans_b = node.attrs["transB"] > 0
- if trans_b:
- logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
- continue
+ # For Gemm nodes with transB=1, flip to 0 since weights are already transposed
+ if node.op == "Gemm":
+ trans_b = bool((node.attrs or {}).get("transB", 0))
+ if trans_b:
+ node.attrs = node.attrs or {}
+ node.attrs["transB"] = 0
+ logger.debug(
+ f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights"
+ )
+ continue📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # For Gemm nodes, check if transB is already set | |
| if node.op == "Gemm": | |
| trans_b = False | |
| if hasattr(node, "attrs") and "transB" in node.attrs: | |
| trans_b = node.attrs["transB"] > 0 | |
| if trans_b: | |
| logger.debug(f"Gemm node {node.name} already has transB=1, skipping") | |
| continue | |
| # For Gemm nodes with transB=1, flip to 0 since weights are already transposed | |
| if node.op == "Gemm": | |
| trans_b = bool((node.attrs or {}).get("transB", 0)) | |
| if trans_b: | |
| node.attrs = node.attrs or {} | |
| node.attrs["transB"] = 0 | |
| logger.debug( | |
| f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights" | |
| ) | |
| continue |
🤖 Prompt for AI Agents
In `@modelopt/onnx/quantization/qdq_utils.py` around lines 1083 - 1091, The
current Gemm handling in apply_column_major_transformation (qdq_utils.py) skips
nodes with node.op == "Gemm" when node.attrs contains transB=1, which breaks
semantics for column-major weights; instead, when encountering a Gemm with
transB set, update the node.attrs transB to 0 (or remove/normalize it to zero)
so the graph expects B^T (matching the earlier weight transpose) and do not skip
inserting the transpose-back; locate the Gemm handling block (check for node.op
== "Gemm" and the transB logic) and replace the early continue with logic that
flips node.attrs["transB"] to 0 (or deletes the attr) so outputs remain correct
while keeping the transpose-back insertion.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #811 +/- ##
==========================================
- Coverage 74.17% 73.99% -0.19%
==========================================
Files 192 192
Lines 19246 19313 +67
==========================================
+ Hits 14276 14290 +14
- Misses 4970 5023 +53 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Type of change: ? New feature
Overview:
TensorRT-RTX requires the weights and scales in the ONNX models to be in column-major format. So whenever the model loads TRT-RTX JIT transposes the weights and scales during load time, causing increased load time.
Proposed feature is after quantization, transpose the weights and scales in DQ node and add a transpose node right after i.e,
A × B = A × ((Bᵀ)ᵀ)
The transformation is post processing step and is disabled by default. It can be enabled by quantizing with --use_column_major
Usage
Testing
Tested a few LLM's and their MMLU scores with and without this transformation. No degradations were observed.
Summary by CodeRabbit
Release Notes
--use_column_majorcommand-line flag to ONNX quantization script for enabling column-major weight storage optimization compatible with NvTensorRtRtx execution provider. This optimization applies to DQ-only quantization modes (rtn_dq, awq_lite, awq_clip).✏️ Tip: You can customize this high-level summary in your review settings.