Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions modelopt/onnx/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,67 @@ def is_data_dependent_shape_op(op_type: str):
"NonZero",
"RoiAlign",
]


def get_bool_ops():
"""Returns set of bool operations."""
return {
"Not",
"And",
"Or",
"Xor",
}


def get_bitwise_ops():
"""Returns set of bitwise operations."""
return {
"BitwiseAnd",
"BitwiseOr",
"BitwiseXor",
"BitShift",
}


def get_value_check_ops():
"""Returns set of value checking operations."""
return {
"IsNaN",
"IsInf",
"Sign",
"Abs",
}


def get_comparison_ops():
"""Returns set of comparison operations."""
return {
"Equal",
"Greater",
"GreaterOrEqual",
"Less",
"LessOrEqual",
}


def get_conditional_ops():
"""Returns set of conditional operations."""
return {
"Where",
}


def get_aggregation_ops():
"""Returns set of aggregation operations."""
return {
"All",
"Any",
}


def get_set_ops():
"""Returns set of set/search operations."""
return {
"Unique",
"NonZero",
}
325 changes: 325 additions & 0 deletions modelopt/onnx/quantization/autotune/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common data structures and types for the QDQ Autotuner."""

import hashlib
from dataclasses import dataclass, field
from enum import Enum
from typing import Any

import onnx_graphsurgeon as gs

from modelopt.onnx.logging_config import logger
from modelopt.onnx.quantization.autotune.insertion_points import (
ChildRegionInputInsertionPoint,
NodeInputInsertionPoint,
RegionOutputInsertionPoint,
)


class AutotunerError(Exception):
"""Base exception for autotuner-related errors."""


class AutotunerNotInitializedError(AutotunerError):
"""Exception raised when autotuner is used without initialization."""


class InvalidSchemeError(AutotunerError):
"""Exception raised when an invalid scheme is referenced."""


class RegionType(Enum):
"""Region type enumeration for hierarchical graph structure.

- LEAF: Atomic region containing direct nodes with no child regions
- COMPOSITE: Hierarchical region containing child regions (and optionally direct nodes)
- ROOT: Top-level region encompassing the entire computation graph
"""

LEAF = "LEAF"
COMPOSITE = "COMPOSITE"
ROOT = "ROOT"


class Region:
"""A subgraph region in an ONNX graph, used as the unit for Q/DQ insertion.

Regions form a hierarchy: ROOT contains the entire graph, COMPOSITE regions
contain child regions, and LEAF regions contain only nodes. Each region tracks
its direct nodes, input/output tensors, and a pattern signature for matching
regions with identical structure.
"""

def __init__(self, region_id: int, level: int, region_type: RegionType):
"""Initialize a new region.

Args:
region_id: Unique identifier within the region hierarchy
level: Hierarchical level (0 = leaf, higher = more composite)
region_type: Type classification (LEAF, COMPOSITE, or ROOT)
"""
self.id = region_id
self.level = level
self.type = region_type
self.parent: Region | None = None
self.children: list[Region] = []
self.nodes: set[int] = set()
self.inputs: list[str] = []
self.outputs: list[str] = []
self.metadata: dict[str, str] = {}

def get_children(self, *, sort: bool = False) -> list["Region"]:
"""Get all child regions."""
if sort:
return sorted(
self.children, key=lambda r: (-r.level, r.get_size_of_region_and_descendants())
)
return self.children

def remove_child(self, child: "Region") -> bool:
"""Remove a child region from this region's children list."""
if child not in self.children:
return False
self.children.remove(child)
if child.parent and child.parent.id == self.id:
child.parent = None
return True

def add_child(self, child: "Region") -> None:
"""Add a child sub-region."""
if child.id == self.id:
logger.warning(f"Cannot add region {self.id} as its own child")
return

if self.is_descendant_of(child):
logger.warning(
f"Cycle detected: region {self.id} is already a descendant of region {child.id}"
)
return

if child.parent is not None and child.parent.id != self.id:
old_parent_id = child.parent.id
logger.debug(
f"Re-parenting region {child.id}: moving from parent {old_parent_id} to {self.id}"
)
child.parent.remove_child(child)

if any(c.id == child.id for c in self.children):
logger.debug(f"Region {child.id} already child of {self.id}")
return

self.children.append(child)
child.parent = self

def is_descendant_of(self, potential_ancestor: "Region") -> bool:
"""Check if this region is a descendant of potential_ancestor."""
visited = set()
current = self.parent
while current:
if current.id in visited:
return False
visited.add(current.id)
if current.id == potential_ancestor.id:
return True
current = current.parent
return False

def add_node(self, node_index: int) -> None:
"""Add a node index to this region."""
self.nodes.add(node_index)

def add_nodes(self, node_indices: list[int]) -> None:
"""Add multiple node indices to this region."""
self.nodes.update(node_indices)

def get_nodes(self, *, sort: bool = False) -> list[int]:
"""Get direct node indices in this region only."""
if sort:
return sorted(self.nodes)
return list(self.nodes)

def get_region_nodes_and_descendants(self, _visited: set[int] | None = None) -> set[int]:
"""Get all node indices recursively, including descendants."""
if _visited is None:
_visited = set()

# Detect cycles
assert self.id not in _visited, f"Cycle detected in region {self.id} during node traversal"

_visited.add(self.id)
all_nodes = set(self.nodes)
for child in self.children:
all_nodes.update(child.get_region_nodes_and_descendants(_visited))
return all_nodes

def contains_node(self, node_index: int) -> bool:
"""Check if region contains a specific node (direct only)."""
return node_index in self.nodes

def contains_node_within_region_and_descendants(self, node_index: int) -> bool:
"""Check if region contains a node recursively."""
return node_index in self.get_region_nodes_and_descendants()

def add_input(self, tensor_name: str) -> None:
"""Add an input tensor name."""
if tensor_name not in self.inputs:
self.inputs.append(tensor_name)

def add_output(self, tensor_name: str) -> None:
"""Add an output tensor name."""
if tensor_name not in self.outputs:
self.outputs.append(tensor_name)

def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int:
"""Get total node count recursively including all descendants."""
if _visited is None:
_visited = set()

# Detect cycles
assert self.id not in _visited, (
f"Cycle detected in region {self.id} during size calculation"
)

_visited.add(self.id)
total = len(self.nodes)
for child in self.children:
total += child.get_size_of_region_and_descendants(_visited)
return total

def merge(self, other: "Region") -> None:
"""Merge another region into this one."""
if not other:
return
self.nodes.update(other.nodes)
for child in other.children:
self.add_child(child)

def __repr__(self) -> str:
type_str = self.type.value
return (
f"Region[id={self.id}, level={self.level}, type={type_str}, "
f"nodes={len(self.nodes)}, children={len(self.children)}, "
f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]"
)

def compute_structural_signature(self, graph: gs.Graph) -> str:
"""Compute deterministic structural signature for pattern matching.

Creates a signature that uniquely identifies the region's topology,
node operations, and hierarchical structure. Regions with identical
signatures can share Q/DQ insertion schemes.

The signature captures:
- Node operation types and key parameters
- Hierarchical structure (child regions)
- Deterministic ordering (sorted for consistency)

Args:
graph: The ONNX graph containing the region's nodes

Returns:
Signature string (e.g., "Conv->BatchNorm->Relu" or "COMPOSITE(...)")
"""
raise NotImplementedError("Not implemented")


@dataclass
class InsertionScheme:
"""Q/DQ insertion specification applied to all regions matching a pattern."""

node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list)
child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list)
region_outputs: list[RegionOutputInsertionPoint] = field(default_factory=list)
latency_ms: float = float("inf")
error: bool = False
profile_timestamp: str | None = None

@property
def hash(self) -> str:
"""Compute deterministic hash for scheme identity."""
sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs])
sorted_regions = sorted(
[(pt.region_index, pt.input_index) for pt in self.child_region_inputs]
)
sorted_region_outputs = sorted(
[(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs]
)

hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}"

return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32]

@property
def is_empty(self) -> bool:
"""Check if this is a baseline scheme with no Q/DQ insertions."""
return not self.node_inputs and not self.child_region_inputs and not self.region_outputs

@property
def is_profiled(self) -> bool:
"""Check if this scheme has been profiled (measured)."""
return self.error or self.latency_ms != float("inf")

def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"latency_ms": self.latency_ms,
"error": self.error,
"profile_timestamp": self.profile_timestamp,
"nodes_insertion_points": [pt.to_dict() for pt in self.node_inputs],
"child_region_inputs": [pt.to_dict() for pt in self.child_region_inputs],
"region_outputs": [pt.to_dict() for pt in self.region_outputs],
"hash": self.hash,
}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme":
"""Create InsertionScheme from serialized dictionary."""
scheme = cls()
scheme.latency_ms = data.get("latency_ms", float("inf"))
scheme.error = data.get("error", False)
scheme.profile_timestamp = data.get("profile_timestamp")

scheme.node_inputs = [
NodeInputInsertionPoint.from_dict(pt) for pt in data.get("nodes_insertion_points", [])
]
scheme.child_region_inputs = [
ChildRegionInputInsertionPoint.from_dict(pt)
for pt in data.get("child_region_inputs", [])
]
scheme.region_outputs = [
RegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", [])
]

return scheme

def distance(self, other: "InsertionScheme") -> int:
"""Compute edit distance between this scheme and another scheme."""
return (
len(set(self.node_inputs).symmetric_difference(other.node_inputs))
+ len(set(self.child_region_inputs).symmetric_difference(other.child_region_inputs))
+ len(set(self.region_outputs).symmetric_difference(other.region_outputs))
)

def __str__(self) -> str:
"""String representation for debugging."""
error_str = ", error=True" if self.error else ""
return (
f"InsertionScheme(node_insertions={len(self.node_inputs)}, "
f"region_insertions={len(self.child_region_inputs)}, "
f"region_output_insertions={len(self.region_outputs)}, "
f"latency={self.latency_ms:.3f}ms{error_str})"
)
Loading