Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import onnx
import onnx.reference.ops
import onnx_ir as ir
import onnx_ir.passes.common as ir_passes_common

import onnxscript.utils.utils as utils
from onnxscript._internal.tape_builder import BuilderBase, TapeBuilder
Expand Down Expand Up @@ -1394,6 +1395,11 @@ def call(self, model: ir.Model) -> FoldConstantsResult:
for function in model.functions.values():
# TODO(rama): Should we specialize functions?
self.visit_function(function)
if self._modified:
# TapeBuilder may create values with names that clash with existing graph
# values when nodes are inserted via replace_nodes_and_values.
# NameFixPass ensures all value names are unique before returning.
ir_passes_common.NameFixPass()(model)
return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map)


Expand Down
56 changes: 56 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,5 +797,61 @@ def test_initializer_as_graph_output_is_not_removed(self):
self.assertIn("z", output_names)


def _all_value_names_unique(model: ir.Model) -> bool:
"""Return True if all named values in the top-level graph have unique names."""
names = []
for v in model.graph.inputs:
if v.name:
names.append(v.name)
for v in model.graph.initializers.values():
if v.name:
names.append(v.name)
for node in model.graph:
for output in node.outputs:
if output.name:
names.append(output.name)
return len(names) == len(set(names))


class NameClashAfterFoldTest(unittest.TestCase):
"""Tests that fold_constants calls NameFixPass to deduplicate value names.

TapeBuilder may assign names that collide with existing graph values when
new nodes are inserted via replace_nodes_and_values. NameFixPass, invoked
by FoldConstantsPass.call when the model was modified, resolves the
duplicates.
"""

def test_fold_constants_deduplicates_names(self):
"""Duplicate value names present alongside a constant-fold are fixed."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
extra = Relu(x)
z = Mul(extra, four)
}
"""
)

# Simulate the name clash that TapeBuilder can introduce: 'extra' (a
# non-folded node that survives) is given the same name as 'four' (the
# folded Add output) because NameAuthority does not check for conflicts
# when registering pre-named values inserted by TapeBuilder.
four_node = next(n for n in model.graph if n.op_type == "Add")
extra_node = next(n for n in model.graph if n.op_type == "Relu")
extra_node.outputs[0].name = four_node.outputs[0].name # inject clash

result = _constant_folding.fold_constants(model)

self.assertTrue(result.modified, "Folding must have modified the model")
self.assertTrue(
_all_value_names_unique(model),
"All value names must be unique after fold_constants",
)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
TypeVar,
)

import onnx_ir.passes.common as ir_passes_common

import onnxscript.optimizer
import onnxscript.rewriter._basics as _basics
import onnxscript.rewriter._context as _context
Expand Down Expand Up @@ -835,6 +837,11 @@ def apply_to_model(
)
if self.remove_unused_nodes:
onnxscript.optimizer.remove_unused_nodes(model)
if count > 0:
# TapeBuilder may create values with names that clash with existing graph
# values when nodes are inserted via replace_nodes_and_values.
# NameFixPass ensures all value names are unique before returning.
ir_passes_common.NameFixPass()(model)
return count

def __iter__(self):
Expand Down
65 changes: 65 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,5 +989,70 @@ def test_pattern_builder_context(self):
self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"])


def _all_value_names_unique(model: ir.Model) -> bool:
"""Return True if all named values in the top-level graph have unique names."""
names = []
for v in model.graph.inputs:
if v.name:
names.append(v.name)
for v in model.graph.initializers.values():
if v.name:
names.append(v.name)
for node in model.graph:
for output in node.outputs:
if output.name:
names.append(output.name)
return len(names) == len(set(names))


class NameClashAfterRewriteTest(unittest.TestCase):
"""Tests that apply_to_model calls NameFixPass to deduplicate value names.

TapeBuilder may assign names that collide with existing graph values when
new nodes are inserted via replace_nodes_and_values. NameFixPass, invoked
by apply_to_model when at least one rewrite fires, resolves the duplicates.
"""

def test_apply_to_model_deduplicates_names(self):
"""Duplicate value names introduced alongside a rewrite are fixed."""
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x, float[N] y, float[N] p) => (float[N] z)
{
c1 = Constant<value_float = 1.0>()
t1 = Div(c1, x)
z1 = Mul(t1, y)
extra = Add(z1, p)
z = Identity(extra)
}
"""
)
model = ir.serde.deserialize_model(model_proto)

# Simulate the name clash that TapeBuilder can introduce: two values that
# survive the rewrite (not part of the matched pattern) end up sharing a
# name because NameAuthority does not check for conflicts when registering
# pre-named values from TapeBuilder.
extra_node = next(n for n in model.graph if n.op_type == "Add")
identity_node = next(n for n in model.graph if n.op_type == "Identity")
identity_node.outputs[0].name = extra_node.outputs[0].name # inject clash

def reciprocal_mul_pattern(op, x, y):
return (1 / x) * y

def div_replacement(op, x, y):
return op.Div(y, x)

rule = pattern.RewriteRule(reciprocal_mul_pattern, div_replacement)
count = rule.apply_to_model(model)

self.assertGreater(count, 0, "Rewrite rule must have fired to exercise the fix")
self.assertTrue(
_all_value_names_unique(model),
"All value names must be unique after apply_to_model",
)


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions onnxscript/version_converter/_version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Callable, Sequence, Union

import onnx_ir.convenience as ir_convenience
import onnx_ir.passes.common as ir_passes_common

import onnxscript.utils.metadata_merger as metadata_merger
from onnxscript import ir
Expand Down Expand Up @@ -239,6 +240,7 @@ def groupnormalization_20_21(node: ir.Node, op):
class _VersionConverter:
def __init__(self, target_version: int):
self._target_version = target_version
self._modified: bool = False
# Default metadata merger: no merging should be needed; keep the first value.
self._default_metadata_merger: metadata_merger.MetadataMerger = (
metadata_merger.MetadataMerger(
Expand Down Expand Up @@ -269,6 +271,7 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function)
ir_convenience.replace_nodes_and_values(
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)
self._modified = True

def visit_attribute(self, attr: ir.Attr) -> None:
if attr.is_ref():
Expand Down Expand Up @@ -341,6 +344,11 @@ def visit_model(self, model: ir.Model) -> None:
self.visit_graph_or_function(function)
_set_onnx_opset_version(function, self._target_version)
_set_onnx_opset_version(model, self._target_version)
if self._modified:
# TapeBuilder may create values with names that clash with existing graph
# values when nodes are inserted via replace_nodes_and_values.
# NameFixPass ensures all value names are unique before returning.
ir_passes_common.NameFixPass()(model)


def convert_version(model: ir.Model, target_version: int) -> None:
Expand Down
60 changes: 60 additions & 0 deletions onnxscript/version_converter/_version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,5 +538,65 @@ def test_version_convert_compatible(self):
version_converter.convert_version(model, target_version=target_version)


def _all_value_names_unique(model: ir.Model) -> bool:
"""Return True if all named values in the top-level graph have unique names."""
names = []
for v in model.graph.inputs:
if v.name:
names.append(v.name)
for v in model.graph.initializers.values():
if v.name:
names.append(v.name)
for node in model.graph:
for output in node.outputs:
if output.name:
names.append(output.name)
return len(names) == len(set(names))


class NameClashAfterConversionTest(unittest.TestCase):
"""Tests that convert_version calls NameFixPass to deduplicate value names.

TapeBuilder may assign names that collide with existing graph values when
new nodes are inserted via replace_nodes_and_values. NameFixPass, invoked
inside _VersionConverter.visit_model when nodes were modified, resolves
the duplicates.
"""

def test_convert_version_deduplicates_names(self):
"""Duplicate value names present after conversion are fixed by NameFixPass."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
{
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512}>()
reshape_x = Reshape (input_x, shape_a)
shape_b = Constant<value: tensor = int64[5] {1, 4, 1024, 1024}>()
reshape_y = Reshape (input_x, shape_b)
gridsample = GridSample <mode = "bilinear"> (reshape_x, reshape_y)
shape_c = Constant<value: tensor = int64[4] {4, 1024, 1024}>()
output = Reshape (gridsample, shape_c)
}
"""
)

# Simulate the name clash that TapeBuilder can introduce: two Constant
# node outputs (not touched by the GridSample adapter) receive the same
# name because NameAuthority does not check for conflicts when registering
# pre-named values inserted by TapeBuilder.
shape_a_output = model.graph.node(0).outputs[0]
shape_c_output = model.graph.node(5).outputs[0]
shape_c_output.name = shape_a_output.name # inject clash

version_converter.convert_version(model, target_version=20)

self.assertEqual(model.opset_imports[""], 20)
self.assertTrue(
_all_value_names_unique(model),
"All value names must be unique after convert_version",
)


if __name__ == "__main__":
unittest.main()
Loading