Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/windows/onnx_ptq/genai_llm/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def main(args):
layers_8bit=args.layers_8bit,
gather_block_size=args.gather_block_size,
gather_quantize_axis=args.gather_quantize_axis,
use_column_major=args.use_column_major,
)
logging.info(f"\nQuantization process took {time.time() - t} seconds")

Expand Down Expand Up @@ -629,5 +630,14 @@ def main(args):
default="",
help=("Overrides default mixed quant strategy. Example: 'layers.0,lm_head'"),
)
parser.add_argument(
"--use_column_major",
default=False,
action="store_true",
help=(
"Apply column-major storage optimization for NvTensorRtRtx execution provider. "
"Only applicable for DQ-only quantization (e.g., rtn_dq, awq_lite, awq_clip)."
),
)
args = parser.parse_args()
main(args)
60 changes: 57 additions & 3 deletions modelopt/onnx/quantization/int4.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,21 @@ def quantize_rtn(

Always selects the first dimension (0) to block over. This is because we must batch over the Cin
dimension, and in ONNX, weights are always plugged into the RHS (i.e. y = x @ W).

Args:
use_column_major: If True, apply column-major storage optimization for NvTensorRtRtx.
Passed via kwargs.
"""
use_column_major = kwargs.get("use_column_major", False)

# Column-major only makes sense for DQ-only mode
if use_column_major and not dq_only:
logger.warning(
"use_column_major=True has no effect in QDQ mode. "
"Column-major optimization only applies to DQ-only quantization."
)
use_column_major = False

logger.info("Starting RTN quantization")
t_start = time.time()

Expand Down Expand Up @@ -295,8 +309,15 @@ def quantize_rtn(
qw = np.asnumpy(qw)
scales[name] = np.asnumpy(scales[name])
gemm_weights_quantized[name] = numpy.asarray(qw)
# Apply column-major optimization if flag is set
if use_column_major:
dq_node_attributes = qdq.apply_column_major_transformation(
gemm_weights_quantized, scales, graph, block_size
)
else:
dq_node_attributes = {"axis": 0, "block_size": block_size}

scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info)
dq_node_attributes = {"axis": 0, "block_size": block_size}
qdq.insert_dq_nodes(
graph,
scales,
Expand All @@ -305,6 +326,10 @@ def quantize_rtn(
layer_info=layer_info,
)

# Add transpose nodes for column-major if needed
if use_column_major:
qdq.add_transpose_nodes_for_column_major(graph)

if gather_w_map is not None:
gather_dq_node_attributes = {
"axis": gather_quantize_axis,
Expand Down Expand Up @@ -605,7 +630,14 @@ def _quantize_awq_clip(
)

t = time.time()
dq_node_attributes = {"axis": 0, "block_size": block_size}
# Apply column-major optimization if flag is set
use_column_major = kwargs.get("use_column_major", False)
if use_column_major:
dq_node_attributes = qdq.apply_column_major_transformation(
gemm_weights_quantized, scales, graph_gs, block_size
)
else:
dq_node_attributes = {"axis": 0, "block_size": block_size}
scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info)
qdq.insert_dq_nodes(
graph_gs,
Expand All @@ -614,6 +646,9 @@ def _quantize_awq_clip(
attributes=dq_node_attributes,
layer_info=layer_info,
)
# Add transpose nodes for column-major if needed
if use_column_major:
qdq.add_transpose_nodes_for_column_major(graph_gs)
if gather_w_map is not None:
assert gather_s_map is not None, "scale-map not found for quantizable gather nodes"
gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size}
Expand Down Expand Up @@ -1308,7 +1343,14 @@ def _quantize_awq_lite(
)

t = time.time()
dq_node_attributes = {"axis": 0, "block_size": block_size}
# Apply column-major optimization if flag is set
use_column_major = kwargs.get("use_column_major", False)
if use_column_major:
dq_node_attributes = qdq.apply_column_major_transformation(
gemm_weights_quantized, scales, graph_gs, block_size
)
else:
dq_node_attributes = {"axis": 0, "block_size": block_size}
scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info)
qdq.insert_dq_nodes(
graph_gs,
Expand All @@ -1318,6 +1360,9 @@ def _quantize_awq_lite(
zero_points=zero_points if use_zero_point else None,
layer_info=layer_info,
)
# Add transpose nodes for column-major if needed
if use_column_major:
qdq.add_transpose_nodes_for_column_major(graph_gs)
if gather_w_map is not None:
assert gather_s_map is not None, "scale-map not found for quantizable gather nodes"
assert not use_zero_point or gather_zp_map, (
Expand Down Expand Up @@ -1420,10 +1465,19 @@ def quantize(
Default: False.
- **layers_8bit** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
Default: [].
- **use_column_major** (bool): If True, apply column-major storage optimization for
NvTensorRtRtx execution provider. This transposes weights
and adds Transpose nodes around MatMul operations.
Only applies to DQ-only quantization mode.
Default: False.
**Returns**: A quantized ONNX model in ONNX ModelProto format.
"""
configure_logging(level=log_level.upper())
logger.info(f"Starting INT4 quantization with method: {calibration_method}")

# Log if column-major optimization is enabled (works for all methods)
if kwargs.get("use_column_major", False):
logger.info("Column-major storage optimization enabled via use_column_major flag")
t_start = time.time()

if cupy_warning_msg:
Expand Down
148 changes: 148 additions & 0 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,154 @@ def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onn
return onnx_model


# =============================================================================
# Column-major weight storage transformation for NvTensorRtRtx execution provider
# =============================================================================


def _apply_transpose_perm_to_shape(shape, perm):
"""Apply transpose permutation to a shape to get the output shape.

Args:
shape: Input shape as a list/tuple
perm: Permutation indices

Returns:
Transposed shape or None if inputs are None
"""
if shape is None or perm is None:
return None
return [shape[i] for i in perm]


def add_transpose_nodes_for_column_major(graph: gs.Graph):
"""Add a single Transpose node after each DequantizeLinear for column-major weights.

This implements the simple transformation: A @ B = A @ ((B^T)^T)
where B^T is stored in the DequantizeLinear node, and we add a Transpose
node after DQ to recover B before the MatMul.

Graph transformation:
Before: DQ(W) -> MatMul/Gemm
After: DQ(W^T) -> Transpose -> W -> MatMul/Gemm

Args:
graph: ONNX GraphSurgeon graph to modify in-place
"""
nodes_to_add = []
dq_nodes_processed = set()

for node in graph.nodes:
if node.op in ["MatMul", "Gemm"]:
# Check if second input (weight) is from DequantizeLinear
weight_input = node.inputs[1]
if not isinstance(weight_input, gs.Variable):
continue

# Find the producer of the weight input
producer_nodes = [n for n in graph.nodes if weight_input in n.outputs]
if not producer_nodes:
continue

producer_node = producer_nodes[0]
if producer_node.op != DEQUANTIZE_NODE_NAME:
continue

# Skip if we already processed this DQ node
if producer_node.name in dq_nodes_processed:
continue
dq_nodes_processed.add(producer_node.name)

# For Gemm nodes, check if transB is already set
if node.op == "Gemm":
trans_b = False
if hasattr(node, "attrs") and "transB" in node.attrs:
trans_b = node.attrs["transB"] > 0
if trans_b:
logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
continue

Comment on lines +1083 to +1091
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Gemm transB=1 skip breaks correctness with column‑major weights.
apply_column_major_transformation already transposes weights. If a Gemm has transB=1, skipping the transpose‑back makes Gemm consume B instead of B^T, changing outputs. Either always insert the transpose‑back or flip transB to 0 so Gemm consumes B^T directly.

🐛 Proposed fix (flip transB to 0 and keep semantics)
-            # For Gemm nodes, check if transB is already set
-            if node.op == "Gemm":
-                trans_b = False
-                if hasattr(node, "attrs") and "transB" in node.attrs:
-                    trans_b = node.attrs["transB"] > 0
-                if trans_b:
-                    logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
-                    continue
+            # For Gemm nodes with transB=1, flip to 0 since weights are already transposed
+            if node.op == "Gemm":
+                trans_b = bool((node.attrs or {}).get("transB", 0))
+                if trans_b:
+                    node.attrs = node.attrs or {}
+                    node.attrs["transB"] = 0
+                    logger.debug(
+                        f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights"
+                    )
+                    continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# For Gemm nodes, check if transB is already set
if node.op == "Gemm":
trans_b = False
if hasattr(node, "attrs") and "transB" in node.attrs:
trans_b = node.attrs["transB"] > 0
if trans_b:
logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
continue
# For Gemm nodes with transB=1, flip to 0 since weights are already transposed
if node.op == "Gemm":
trans_b = bool((node.attrs or {}).get("transB", 0))
if trans_b:
node.attrs = node.attrs or {}
node.attrs["transB"] = 0
logger.debug(
f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights"
)
continue
🤖 Prompt for AI Agents
In `@modelopt/onnx/quantization/qdq_utils.py` around lines 1083 - 1091, The
current Gemm handling in apply_column_major_transformation (qdq_utils.py) skips
nodes with node.op == "Gemm" when node.attrs contains transB=1, which breaks
semantics for column-major weights; instead, when encountering a Gemm with
transB set, update the node.attrs transB to 0 (or remove/normalize it to zero)
so the graph expects B^T (matching the earlier weight transpose) and do not skip
inserting the transpose-back; locate the Gemm handling block (check for node.op
== "Gemm" and the transB logic) and replace the early continue with logic that
flips node.attrs["transB"] to 0 (or deletes the attr) so outputs remain correct
while keeping the transpose-back insertion.

# Get weight shape and dtype from DQ output
# DQ outputs W^T (transposed), shape is [N, K] instead of [K, N]
weight_shape = weight_input.shape if hasattr(weight_input, "shape") else None
weight_dtype = weight_input.dtype if hasattr(weight_input, "dtype") else None

# Permutation for 2D weights: [1, 0] to transpose back
# The stored weight is B^T (transposed), we need to get B back
# For 2D [N, K] (stored as transposed): perm [1, 0] -> [K, N] (original)
perm = [1, 0]

# Compute the transposed shape (original weight shape)
transposed_weight_shape = _apply_transpose_perm_to_shape(weight_shape, perm)

# Create output variable for the transpose node
transpose_out = gs.Variable(
f"{producer_node.name}_transposed_back",
dtype=weight_dtype,
shape=transposed_weight_shape,
)

# Create transpose node: (B^T)^T = B
transpose_node = gs.Node(
op="Transpose",
name=f"{producer_node.name}_transpose_back",
inputs=[weight_input],
outputs=[transpose_out],
attrs={"perm": perm},
)

# Update MatMul/Gemm to use the transposed weight
node.inputs[1] = transpose_out

# Add transpose node to list
nodes_to_add.append(transpose_node)

# Add all new nodes to graph
if nodes_to_add:
graph.nodes.extend(nodes_to_add)
logger.info(f"Added {len(nodes_to_add)} transpose nodes for column-major optimization")

# Clean up and reorder graph
graph.cleanup().toposort()


def apply_column_major_transformation(
gemm_weights_quantized: dict,
scales: dict,
graph: gs.Graph,
block_size: int,
) -> dict:
"""Apply full column-major transformation to quantized weights and graph.

This is a convenience function that:
1. Transposes the quantized weights and scales
2. Returns the updated DQ node attributes (axis=1 instead of 0)

Note: After calling this function and inserting DQ nodes, you should call
add_transpose_nodes_for_column_major() on the graph.

Args:
gemm_weights_quantized: Dictionary mapping weight names to quantized weight arrays
scales: Dictionary mapping weight names to scale arrays
graph: ONNX GraphSurgeon graph (for reference, not modified here)
block_size: Block size for quantization

Returns:
Dictionary with DQ node attributes (axis=1 for column-major)
"""
logger.info("Applying column-major storage optimization")

# Transpose weights and scales in-place
for name in list(gemm_weights_quantized.keys()):
gemm_weights_quantized[name] = gemm_weights_quantized[name].T

for name in list(scales.keys()):
scales[name] = scales[name].T

# Return updated DQ node attributes with axis=1 (column-major)
return {"axis": 1, "block_size": block_size}


def cast_initializer_to_dtype(
node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto]
):
Expand Down