Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions modelopt/onnx/quantization/autotune/autotuner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,112 @@ def _is_region_profiled(self, region: Region) -> bool:
for p in self.profiled_patterns
)

def _sample_concat_group_mutation(
self,
selected_points: list,
all_points: list,
region: "Region",
) -> list:
"""Probabilistically add or remove a full Concat input group as an atomic unit.

TRT requires ALL inputs of a Concat to be INT8 for INT8 Concat fusion;
partial quantization has no benefit. This method treats Concat input sets
as atomic groups — randomly choosing to add all inputs of one Concat, or
remove all inputs of one Concat, as a single mutation step.

Called with adaptive probability (min 5%, scaling with budget) during scheme
generation to inject Concat-aware samples into the search space without forcing
all schemes to have full groups.

Args:
selected_points: Currently selected NodeInputInsertionPoint list
all_points: Full set of available NodeInputInsertionPoint list
region: The region being profiled

Returns:
Updated list with one Concat group atomically added or removed
"""
# Identify which local node indices are Concat ops
node_indices = region.get_nodes(sort=True)
concat_local_indices = set()
for local_idx, node_idx in enumerate(node_indices):
node = self.graph.nodes[node_idx]
if node.op == "Concat":
concat_local_indices.add(local_idx)

if not concat_local_indices:
return selected_points

# Build groups: concat_node_index -> all available insertion points for that Concat
concat_groups: dict[int, list] = {}
for p in all_points:
if p.node_index in concat_local_indices:
concat_groups.setdefault(p.node_index, []).append(p)

if not concat_groups:
return selected_points
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# Determine the real arity of each Concat node (number of actual inputs).
# If some inputs were filtered out earlier (e.g. non-float, small tensors),
# the group in all_points is incomplete and can never satisfy TRT's "all
# Concat inputs quantized" requirement — skip such groups entirely.
complete_concat_groups: dict[int, list] = {}
for concat_idx, group_points in concat_groups.items():
global_node_idx = node_indices[concat_idx]
concat_node = self.graph.nodes[global_node_idx]
real_arity = len(concat_node.inputs)
if len(group_points) == real_arity:
complete_concat_groups[concat_idx] = group_points

if not complete_concat_groups:
return selected_points

# Identify fully-present and absent Concat groups in current selection
selected_keys = {(p.node_index, p.input_index) for p in selected_points}
full_groups = [] # Concat groups fully present (can remove)
absent_groups = [] # Concat groups fully absent (can add)

for concat_idx, group_points in complete_concat_groups.items():
group_keys = {(p.node_index, p.input_index) for p in group_points}
present = group_keys & selected_keys
if len(present) == len(group_keys):
full_groups.append(concat_idx)
elif len(present) == 0:
absent_groups.append(concat_idx)
# Partial groups: also eligible for completion (add missing siblings)
elif len(present) < len(group_keys):
absent_groups.append(concat_idx)

# Choose action only from feasible options to avoid no-op mutations
actions = []
if absent_groups:
actions.append("add")
if full_groups:
actions.append("remove")
if not actions:
return selected_points
action = random.choice(actions)

if action == "add":
target = random.choice(absent_groups)
points_to_add = []
for p in complete_concat_groups[target]:
if (p.node_index, p.input_index) not in selected_keys:
points_to_add.append(p)
logger.debug(
f"Concat group mutation: added {len(points_to_add)} points for Concat node {target}"
)
return selected_points + points_to_add

elif action == "remove":
target = random.choice(full_groups)
group_keys = {(p.node_index, p.input_index) for p in complete_concat_groups[target]}
result = [p for p in selected_points if (p.node_index, p.input_index) not in group_keys]
logger.debug(
f"Concat group mutation: removed {len(group_keys)} points for Concat node {target}"
)
return result

def _mutate_insertion_points(
self, base_points, all_points, point_type: str, max_mutations: int
) -> list:
Expand Down Expand Up @@ -1057,6 +1163,17 @@ def _generate_next_insertion_sample(self) -> InsertionScheme:
),
)

# Probabilistically apply Concat-group-aware mutation: atomically add or remove
# all inputs of a Concat as a group. Probability adapts to budget:
# - Large budget (>=100): 5% → at least 5 samples
# - Small budget (<100): min_samples/budget, clamped to [0.05, 0.5]
num_schemes = max(len(pattern_schemes.schemes), 1)
concat_prob = min(max(self.config.concat_group_min_samples / num_schemes, 0.05), 0.5)
if random.random() < concat_prob:
Comment thread
willg-nv marked this conversation as resolved.
scheme.node_inputs = self._sample_concat_group_mutation(
scheme.node_inputs, full_insertion_scheme.node_inputs, region
)

return scheme

def _copy_graph(self) -> gs.Graph:
Expand Down
1 change: 1 addition & 0 deletions modelopt/onnx/quantization/autotune/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ class Config:
minimum_schemes_to_mutate: int = 10
maximum_mutations: int = 3
maximum_generation_attempts: int = 100
concat_group_min_samples: int = 5 # Minimum Concat-group-aware mutations per region
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was 5 chosen here? What are the benefits or a larger or smaller number? @willg-nv

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this change is to mutate insertion spec to insert Q/DQ for all concat input. But this probability should not be too high (5%) to avoid impaction on normal mutation prcoess. But when sample per region is too small (eg: 30), the estimated sample count to trigger "all concat input QDQ" would be ~1, this is too small, sometimes cannot be triggered. so concat_group_min_samples sets the minimum sample count which could trigger "all concat input QDQ". I expect this case could be covered, but with small probability.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is the 5 in concat_group_min_samples: int = 5 a percentage? If so, can we add that in the comment? Thanks!


# Pattern Cache Settings
pattern_cache_minimum_distance: int = 4
Expand Down
66 changes: 64 additions & 2 deletions modelopt/onnx/quantization/autotune/insertion_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,41 @@ def skip_invalid_insertion_points(
producer = node.inputs[0].inputs[0]
if producer.op in ["Conv", "ConvTranspose"]:
return True
# Conv -> [BN ->] Add -> Relu: skip quantizing the main-path Conv
# output feeding Add to preserve TRT Conv+Add+Relu INT8 fusion.
# Guards:
# 1. The Add output has a single consumer and that consumer is Relu
# (otherwise TRT cannot fuse, and skipping removes a legitimate
# quantization point).
# 2. The Conv feeding Add is a "main-path" Conv (its activation input
# has a single consumer), not a downsample/projection Conv (whose
# activation input fans out to multiple consumers).
if node.op == "Add":
# Guard 1: Add must feed exactly one Relu
add_out = node.outputs[0] if node.outputs else None
if add_out is None or len(add_out.outputs) != 1:
pass # Add fans out or has no consumer — skip not applicable
elif add_out.outputs[0].op != "Relu":
pass # Add does not feed Relu — fusion impossible
elif inp.inputs:
producer = inp.inputs[0]
# Unwrap optional BN
conv_node = None
if producer.op in ["Conv", "ConvTranspose"]:
conv_node = producer
elif producer.op == "BatchNormalization":
bn_act = producer.inputs[0] if producer.inputs else None
if (
bn_act
and bn_act.inputs
and bn_act.inputs[0].op in ["Conv", "ConvTranspose"]
):
conv_node = bn_act.inputs[0]
# Guard 2: main-path Conv (single consumer on activation input)
if conv_node is not None and conv_node.inputs:
conv_act_input = conv_node.inputs[0]
if len(conv_act_input.outputs) == 1:
return True
# Filter 1: out boolean operations
if node.op in (
get_bool_ops()
Expand Down Expand Up @@ -472,6 +507,11 @@ def merge_resolved_insertion_points(
to insert Q/DQ once at the tensor level rather than at each individual node input.
This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme.

Additionally, when a tensor has Q/DQ at some consumers and the remaining uncovered
consumers are all Concat nodes, the insertion is promoted to tensor-level. Concat is
a byte-level copy in TRT — quantizing its input has no accuracy cost and enables
INT8 Concat fusion when all Concat inputs are INT8.

Args:
graph: The ONNX graph containing the nodes
resolved_insertion_points: Set of resolved insertion points to optimize
Expand All @@ -486,18 +526,40 @@ def merge_resolved_insertion_points(
for tensor_name in {ip.tensor_name for ip in node_ips}:
all_users = set(tensor_users_map.get(tensor_name, []))
qdq_users = {ip for ip in node_ips if ip.tensor_name == tensor_name}
if all_users == {ip.node_index for ip in qdq_users}:
covered_nodes = {ip.node_index for ip in qdq_users}

if all_users == covered_nodes:
# All consumers have Q/DQ — merge to tensor-level
results.add(
ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None)
)
elif covered_nodes and all_users - covered_nodes:
# Some consumers lack Q/DQ — check if all uncovered ones are Concat
uncovered = all_users - covered_nodes
uncovered_all_concat = all(
node_idx < len(graph.nodes) and graph.nodes[node_idx].op == "Concat"
for node_idx in uncovered
)
if uncovered_all_concat:
# Promote to tensor-level: Concat is byte-copy, safe to quantize
results.add(
ResolvedInsertionPoint(
tensor_name=tensor_name, node_index=None, input_index=None
)
)
else:
results.update(qdq_users)
else:
results.update(qdq_users)
return results


def get_autotuner_skip_ops():
"""Returns set of shape/structural operations that are not quantizable."""
return set(get_copy_ops()) | {
# Concat is excluded: it can pass INT8 data through in TRT (byte-level copy).
# Blocking Concat prevents tensor-level Q/DQ when a quantizable op's output
# fans out to both a compute op (e.g. Conv) and a Concat, breaking INT8 fusion.
return (set(get_copy_ops()) - {"Concat"}) | {
# Additional indexing/scatter/reshape ops
"Compress",
"Scatter",
Expand Down
Loading