Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d4b7d60
Replace op_schema with op_signature
justinchuby Jan 15, 2026
78a25d8
wip
justinchuby Jan 16, 2026
f4e47bb
Clean up _to_model_proto
justinchuby Jan 16, 2026
be172c1
wip
justinchuby Jan 16, 2026
38ecac7
update
justinchuby Jan 16, 2026
26151b2
Fixes
justinchuby Jan 16, 2026
0fe2f6d
Fix converter
justinchuby Jan 16, 2026
0e067b6
homogeneous
justinchuby Jan 16, 2026
1109f26
fix call functions
justinchuby Jan 16, 2026
8bd5f52
copilot
justinchuby Jan 16, 2026
b5bba7a
Merge branch 'main' into justinchu/replace-opschema
justinchuby Jan 21, 2026
2e96a0b
update opset import
justinchuby Jan 21, 2026
e1e17ee
Support user functions
justinchuby Jan 21, 2026
367c942
Fix an error where the function return value is not the same as the e…
justinchuby Jan 21, 2026
7bd47e7
Reverse the scope lists for efficiency
justinchuby Jan 21, 2026
59a9eaa
Bump ir version
justinchuby Jan 21, 2026
9c89b31
Merge branch 'main' into justinchu/replace-opschema
justinchuby Jan 21, 2026
dd2396d
wip
justinchuby Jan 21, 2026
7a3581b
Merge branch 'main' into justinchu/replace-opschema
justinchuby Jan 22, 2026
9d933ae
Fix cond
justinchuby Jan 22, 2026
bd09ed3
Get opset version
justinchuby Jan 22, 2026
4d92d81
update
justinchuby Jan 22, 2026
1fd2623
Merge branch 'main' into justinchu/replace-opschema
justinchuby Jan 22, 2026
f2f747a
Handle graph attributes
justinchuby Jan 23, 2026
bcaec1b
Clean ta
justinchuby Jan 23, 2026
078db34
Fix graph attributes
justinchuby Jan 23, 2026
d3b7fd8
Fix
justinchuby Jan 23, 2026
1eefa02
Fix attr conversion
justinchuby Jan 23, 2026
e7b353d
update
justinchuby Jan 23, 2026
eef0934
update
justinchuby Jan 23, 2026
4b58dac
Update onnxscript/_internal/converter.py
justinchuby Jan 23, 2026
838d3c3
lint
justinchuby Jan 23, 2026
8ef9c0b
update
justinchuby Jan 23, 2026
e4fb629
fail text
justinchuby Jan 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"packaging",
"protobuf",
)
ONNX_IR = "onnx_ir==0.1.13"
ONNX_IR = "onnx_ir==0.1.15"
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"


Expand Down
67 changes: 27 additions & 40 deletions onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

import numpy as np
import onnx
import onnx.helper # noqa: TID251
from onnx.defs import OpSchema

from onnxscript import ir, tensor
from onnxscript.ir import _schemas

if TYPE_CHECKING:
from onnxscript._internal import converter
Expand All @@ -20,23 +19,15 @@
# python values into ONNX TensorProto, while the runtime converts python values into
# ONNXScript runtime's value-representation (based on Tensor).


# Utilities to convert a python value to TensorProto (for use by the script converter)


def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue):
return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name))


_REPEATED_ATTRIBUTE_TYPES = frozenset(
{
onnx.AttributeProto.FLOATS,
onnx.AttributeProto.INTS,
onnx.AttributeProto.STRINGS,
onnx.AttributeProto.TENSORS,
onnx.AttributeProto.GRAPHS,
onnx.AttributeProto.SPARSE_TENSORS,
onnx.AttributeProto.TYPE_PROTOS,
ir.AttributeType.FLOATS,
ir.AttributeType.INTS,
ir.AttributeType.STRINGS,
ir.AttributeType.TENSORS,
ir.AttributeType.GRAPHS,
ir.AttributeType.SPARSE_TENSORS,
ir.AttributeType.TYPE_PROTOS,
}
)

Expand All @@ -45,33 +36,28 @@ def pyvalue_to_onnx_attribute(
key: str,
value: Any,
name_generator: Callable[[], str],
attr_type: onnx.AttributeProto.AttributeType | None = None,
) -> onnx.AttributeProto:
attr_type: ir.AttributeType | None = None,
) -> ir.Attr:
"""Helper function to create an ONNX AttributeProto.

This is a refinement of onnx.helper.make_attribute that works with ONNX Script
conventions for allowed types for attribute-values. In particular, it allows
* Empty lists as attribute values, provided the attribute type is specified
* Empty lists can be attribute values, provided the attribute type is specified
and is a list type.
* Scalar-values like 1.0 as well as lists like [1, -1] to be specified
when the attribute type is TensorProto by automatically converting the value
into a 0-D or 1-D tensor respectively.
"""
# TODO(justinchuby): Remove this function and use onnx-ir directly.
if isinstance(value, list) and not value:
# Empty list value:
if attr_type is None:
raise ValueError("Attribute type must be specified for empty list value.")
if attr_type not in _REPEATED_ATTRIBUTE_TYPES:
raise ValueError("Empty list value is only allowed for repeated attribute types.")
return onnx.AttributeProto(name=key, type=attr_type)
elif attr_type == onnx.AttributeProto.TENSOR and not isinstance(value, onnx.TensorProto):
return onnx.AttributeProto(
name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value)
)
return ir.Attr(name=key, type=attr_type, value=[])
elif attr_type == ir.AttributeType.TENSOR and not isinstance(value, onnx.TensorProto):
return ir.AttrTensor(name=key, value=ir.tensor(value, name=name_generator()))
else:
# When the value is a subgraph, ONNX IR will complain that some values are
# not found from the scope.
return onnx.helper.make_attribute(key, value) # noqa: TID251
return ir.convenience.convert_attribute(key, value, attr_type=attr_type)


# Utilities to convert python values into onnxscript tensors.
Expand Down Expand Up @@ -126,7 +112,7 @@ def cast_pyvalue_to_os_tensor(pyvalue, dtype=None):
def cast_inputs(
get_type_info: Callable[[Any], Any],
cast: Callable[[Any, Any], Any],
op_schema: OpSchema | None,
op_signature: _schemas.OpSignature | None,
args,
) -> tuple[Any, ...]:
"""Uses schema specification to support a limited form of auto-casting.
Expand All @@ -140,12 +126,13 @@ def cast_inputs(
This is used by the converter in a static-mode, as well as by the eager-mode
execution in a dynamic-mode.
"""
if op_schema is None:
if op_signature is None:
# Either an error or a custom op.
# No checks/casts in this case.
return tuple(cast(x, None) for x in args)

expected_inputs = op_schema.inputs
# Filter to get only input parameters (not AttributeParameters)
expected_inputs = op_signature.inputs
# We make two passes. In the first pass, we identify known type-bindings for
# type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}.
# In the second pass, we use these bindings to cast scalar-values to
Expand All @@ -156,17 +143,17 @@ def cast_inputs(
for i, x in enumerate(args):
if i < len(expected_inputs):
expected = expected_inputs[i]
elif expected_inputs[-1].option == OpSchema.FormalParameterOption.Variadic:
elif expected_inputs[-1].variadic:
expected = expected_inputs[-1]
if not expected.is_homogeneous:
if not expected.homogeneous:
args_typevars.append((x, None))
continue
else:
raise ValueError(
f"Number of actual parameters {len(args)} "
f"exceeds number of formal parameters {len(expected_inputs)}."
)
typevar = expected.type_str
typevar = expected.type_constraint.name
if "(" not in typevar:
# typevar is an identifier, like "T"
typeinfo = get_type_info(x)
Expand All @@ -177,18 +164,18 @@ def cast_inputs(
return tuple(cast_args)


def dynamic_cast_inputs(op_schema: OpSchema, args):
def dynamic_cast_inputs(op_signature: _schemas.OpSignature, args):
"""Used for autocast during eager-mode execution."""

def get_type_info(x):
return x.dtype if isinstance(x, tensor.Tensor) else None

return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_schema, args)
return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_signature, args)


def static_cast_inputs(
converter_: converter.Converter,
op_schema: Optional[OpSchema],
op_signature: Optional[_schemas.OpSignature],
args: Sequence[Optional[ir.Value]],
) -> tuple[str, ...]:
"""Used for autocast during script-translation.
Expand All @@ -212,4 +199,4 @@ def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]:
return converter_.emit1([x_cast], "CastLike", [x, y])
return x

return cast_inputs(get_type_info, cast_like, op_schema, args)
return cast_inputs(get_type_info, cast_like, op_signature, args)
Loading
Loading