Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions onnxscript/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"TracedOnnxFunction",
"GraphBuilder",
"OpBuilder",
"OpBuilderBase",
"BuilderBase",
"TapeBuilder",
"build_function",
"build_graph",
Expand Down Expand Up @@ -69,6 +69,7 @@
"opset_ai_onnx_ml4",
"opset_ai_onnx_ml5",
"DEBUG",
"BuilderFeature",
]

import importlib.metadata
Expand Down Expand Up @@ -135,7 +136,11 @@

from . import ir, nn, optimizer, rewriter, version_converter
from ._internal.builder import GraphBuilder, OpBuilder, build_function, build_graph
from ._internal.tape_builder import OpBuilderBase, TapeBuilder
from ._internal.tape_builder import (
BuilderBase,
BuilderFeature,
TapeBuilder,
)
from ._internal.utils import external_tensor
from ._internal.values import OnnxFunction, TracedOnnxFunction

Expand Down
282 changes: 84 additions & 198 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
"""Graph builder for constructing ONNX IR graphs imperatively.

This module provides imperative builders for constructing ONNX IR graphs with automatic
constant promotion, type casting, and shape inference. The GraphBuilder class enables
programmatic construction of graphs with proper scoping, constant management, and node
creation. The OpBuilder class provides dynamic op dispatching via attribute access.
constant promotion, type casting, and shape inference. The GraphBuilder class inherits
from BuilderBase and enables programmatic construction of graphs with proper scoping,
constant management, and node creation. The OpBuilder class provides dynamic op
dispatching via attribute access.
"""

from __future__ import annotations

from typing import Any, Callable, Mapping, Sequence, Union

import onnx
import onnx_ir as ir

import onnxscript._internal._inference as inference
import onnxscript.optimizer
from onnxscript._internal import _inliner, param_manipulation
import onnxscript
from onnxscript._internal import _inliner
from onnxscript._internal.tape_builder import (
BuilderBase,
BuilderFeature,
)

# A permissible value for an op input, which can be converted to an ir.Value.
VALUE_LIKE = Union[
Expand Down Expand Up @@ -418,10 +421,11 @@
)


class GraphBuilder:
class GraphBuilder(BuilderBase):
"""Imperative builder for constructing ONNX IR graphs with automatic constant promotion, type casting, and shape inference."""

def __init__(self, graph: ir.Graph, *, parent: GraphBuilder | None = None) -> None:
super().__init__(features=BuilderFeature.FULL)
self._graph = graph
self._parent = parent
self._root: GraphBuilder = parent._root if parent is not None else self
Expand Down Expand Up @@ -453,7 +457,7 @@
return OpBuilder(self, domain, version)

@property
def op(self) -> OpBuilder:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method requires 1 positional argument, whereas overridden
BuilderBase.op
requires at least 2.
return self._op_builder

@property
Expand All @@ -474,6 +478,71 @@
def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]:
return self._root._functions

# ------------------------------------------------------------------
# BuilderBase abstract method implementations
# ------------------------------------------------------------------

def _add_node(self, node: ir.Node) -> None:
"""Append a node to the graph."""
self.graph.append(node)

def _add_initializer(self, value: ir.Value) -> None:
"""Register an initializer in the root graph."""
self._root._graph.register_initializer(value)

def _record_opset(self, domain: str, version: int | None) -> None:
# Graph already tracks opset imports; nothing to do.
pass

# ------------------------------------------------------------------
# BuilderBase hook overrides
# ------------------------------------------------------------------

def _promote_constant(self, value: Any, dtype: ir.DataType | None) -> ir.Value:
"""Cache-based constant promotion.

Delegates to the root builder so that all constant initializers
live in the root graph (outer-scope initializers are visible to
subgraphs per the ONNX spec).
"""
return self._get_or_create_constant(value, dtype)

def _generate_node_name(self, op_type: str) -> str:
count = self.graph.num_nodes()
return self._qualify_node_name(f"{op_type}_node_{count}")

def _adapt_outputs(
self, outputs: int | Sequence[str | ir.Value], op_type: str
) -> Sequence[ir.Value]:
"""Pre-create named output ir.Value objects for the graph."""
if isinstance(outputs, int):
count = self.graph.num_nodes()
if outputs < 0:
raise ValueError(f"Number of outputs must be non-negative, got {outputs}")
if outputs == 1:
name = f"{op_type}_{count}" if op_type else f"{count}"
return [ir.Value(name=self._qualify_value_name(name))]
else:
names = [
(f"{op_type}_{count}_{i}" if op_type else f"{count}_{i}")
for i in range(outputs)
]
return [ir.Value(name=self._qualify_value_name(n)) for n in names]
# Delegate to base class for Sequence[str | ir.Value]
result = super()._adapt_outputs(outputs, op_type)
assert result is not None
return result

def _annotate_node(self, node: ir.Node) -> None:
"""Attach scope metadata to the node."""
node.metadata_props["namespace"] = self._build_namespace()
node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes())
node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names())

# ------------------------------------------------------------------
# GraphBuilder-specific public API
# ------------------------------------------------------------------

def initializer(
self, tensor: ir.TensorProtocol, name: str | None = None, *, qualify: bool = True
) -> ir.Value:
Expand Down Expand Up @@ -594,158 +663,15 @@
# TODO(rama): Consider caching for other tensor values.
return self.initializer(ir.tensor(value, dtype=dtype))

def _input_to_ir_value(
self, value: VALUE_LIKE, like_type: ir.Value | None = None
) -> ir.Value | None:
"""Convert a permissible input (for a call to an op) into an ir.Value.

Permissible values include ir.Value as well as python constants that can be converted
into ONNX constant tensors. For constant values, the like_type is used to determine the
target onnx type.
"""
if isinstance(value, ir.Value):
return value
if value is None:
return value
dtype = (
like_type.type.dtype
if like_type is not None and like_type.type is not None
else None
)
needs_dynamic_cast = like_type is not None and dtype is None
ir_value = self._get_or_create_constant(value, dtype)
# If like_type is provided but its type is unknown, insert a dynamic CastLike
# so the constant is cast to match like_type's type at runtime.
# The CastLike node is created in THIS builder's graph (not root),
# so that it lives in the correct scope (subgraph or function body).
if needs_dynamic_cast:
ir_value = self.op.CastLike(ir_value, like_type)
return ir_value

def _adapt_outputs(
self, outputs: int | Sequence[str | ir.Value], op_type: str = ""
) -> Sequence[ir.Value]:
if isinstance(outputs, int):
count = self.graph.num_nodes()
if outputs < 0:
raise ValueError(f"Number of outputs must be non-negative, got {outputs}")
if outputs == 1:
name = f"{op_type}_{count}" if op_type else f"{count}"
return [ir.Value(name=self._qualify_value_name(name))]
else:
names = [
(f"{op_type}_{count}_{i}" if op_type else f"{count}_{i}")
for i in range(outputs)
]
return [ir.Value(name=self._qualify_value_name(n)) for n in names]
adapted_outputs = []
for output in outputs:
if isinstance(output, ir.Value):
if output.name:
output.name = self._qualify_value_name(output.name)
adapted_outputs.append(output)
elif isinstance(output, str):
adapted_outputs.append(ir.Value(name=self._qualify_value_name(output)))
else:
raise TypeError("Output type not supported.")
return adapted_outputs

def _get_schema(
self, op_type: str, domain: str, version: int | None
) -> onnx.defs.OpSchema | None:
if version is not None:
try:
return onnx.defs.get_schema(op_type, version, domain)
except onnx.defs.SchemaError:
pass
return None

def _partition_inputs_attributes(
self,
schema: onnx.defs.OpSchema | None,
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
kwargs: dict[str, Any],
) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]:
if schema is None:
return inputs, kwargs
op_signature = ir.schemas.OpSignature.from_op_schema(schema)
return param_manipulation.separate_input_attributes_from_arguments(
op_signature,
list(inputs),
kwargs,
fill_defaults=False,
allow_extra_args=False,
)
def add_node(self, node: ir.Node) -> None:
"""Append a node to the graph, run constant propagation and shape inference.

def _cast_inputs(
self,
schema: onnx.defs.OpSchema | None,
inputs: Sequence[VALUE_LIKE],
) -> Sequence[ir.Value | None]:
"""Uses schema specification to support a limited form of auto-casting.

* Scalars are promoted to tensors.
* Further. they are cast to the required type when used in ops with other
tensor inputs that are required to be of same type.
Thus, in "A+1" or "Add(A, 1)", the value 1 will be converted to the same
type as A.
This is a backward-compatible public method used by call_inline and
other code that creates nodes manually.
"""
if schema is None:
return [self._input_to_ir_value(i) for i in inputs]

expected_inputs = schema.inputs
# We make two passes. In the first pass, we identify known type-bindings for
# type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}.
# In the second pass, we use these bindings to cast scalar-values to
# tensors of appropriate types. The two passes are needed to handle cases
# like "Add(1, X)" where 1 must be cast to the same type as X.
type_bindings: dict[str, ir.Value] = {}
args_typevars: list[tuple[ir.Value | None, str | None]] = []
for i, x in enumerate(inputs):
if i < len(expected_inputs):
expected = expected_inputs[i]
elif expected_inputs and (
expected_inputs[-1].option == onnx.defs.OpSchema.FormalParameterOption.Variadic
):
expected = expected_inputs[-1]
if not expected.is_homogeneous:
args_typevars.append((x, None))
continue
else:
raise ValueError(
f"Number of actual parameters {len(inputs)} "
f"exceeds number of formal parameters {len(expected_inputs)}."
)
typevar = expected.type_str
if ("(" not in typevar) and (typevar not in type_bindings):
# typevar is an identifier, like "T"
if isinstance(x, ir.Value):
type_bindings[typevar] = x
args_typevars.append((x, typevar))

def adapt(x, typevar: str | None) -> ir.Value | None:
if x is None:
return None
if typevar is None:
return self._input_to_ir_value(x)
type_like = type_bindings.get(typevar)
return self._input_to_ir_value(x, type_like)

return [adapt(x, typevar) for x, typevar in args_typevars]

def _cast_attributes(
self,
schema: onnx.defs.OpSchema | None,
attributes: dict[str, Any],
) -> dict[str, Any]:
del schema # Not implemented yet
return attributes if attributes is not None else {}

def add_node(self, node: ir.Node) -> None:
"""Append a node to the graph, run constant propagation and shape inference."""
self.graph.append(node)
onnxscript.optimizer.basic_constant_propagation([node])
inference.infer_outputs(node)
self._add_node(node)
self._constant_propagation(node)
self._infer_shapes(node)

def subgraph(
self,
Expand Down Expand Up @@ -796,46 +722,6 @@
parent=self,
)

def call_op(
self,
op_type: str,
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
kwargs: dict[str, Any],
/,
domain: str = "",
version: int | None = None,
outputs: int | Sequence[str | ir.Value] = 1,
):
"""Create an ONNX node and add it to the graph, returning its output value(s)."""
count = self.graph.num_nodes()
node_name = self._qualify_node_name(f"{op_type}_node_{count}")

output_values = self._adapt_outputs(outputs, op_type)

schema = self._get_schema(op_type, domain, version)
inputs, attributes = self._partition_inputs_attributes(schema, inputs, kwargs)
inputs = self._cast_inputs(schema, inputs)
attributes = self._cast_attributes(schema, attributes)

node = ir.node(
op_type,
inputs,
attributes=attributes or None,
domain=domain,
outputs=output_values,
version=version,
name=node_name,
)

# Attach scope metadata to the node
node.metadata_props["namespace"] = self._build_namespace()
node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes())
node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names())

self.add_node(node)

return node.outputs if len(node.outputs) > 1 else node.outputs[0]

def call(
self,
function: ir.Function | onnxscript.OnnxFunction,
Expand Down
Loading
Loading