Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 112 additions & 10 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,37 @@ def _impl_v13(cls, bb, inputs, attr, params):
return relax.op.matmul(inputs[0], inputs[1])


class MatMulInteger16(OnnxOpConverter):
"""Converts an ONNX MatMulInteger16 node into an equivalent Relax expression."""

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
if len(inputs) != 2:
raise ValueError(f"MatMulInteger16 expects two inputs, but got {len(inputs)}")
a, b = inputs
valid_types = ["int16", "uint16"]
if a.struct_info.dtype not in valid_types:
raise ValueError(
"MatMulInteger16 expects input A to have int16 or uint16 dtype, "
f"but got {a.struct_info.dtype}"
)
if b.struct_info.dtype not in valid_types:
raise ValueError(
"MatMulInteger16 expects input B to have int16 or uint16 dtype, "
f"but got {b.struct_info.dtype}"
)

out_dtype = (
"uint32"
if a.struct_info.dtype == "uint16" and b.struct_info.dtype == "uint16"
else "int32"
)
return relax.op.matmul(
relax.op.astype(a, out_dtype),
relax.op.astype(b, out_dtype),
)


def _to_numpy(x):
if isinstance(x, relax.PrimValue):
x = x.value
Expand All @@ -328,6 +359,19 @@ def _to_numpy(x):
return x.data.numpy()


class _EmptyOptional:
"""Sentinel object that preserves an empty ONNX Optional during import."""

def __init__(self, type_proto: onnx.onnx_ml_pb2.TypeProto):
self.type_proto = type_proto


def _is_empty_optional(value: Any) -> bool:
"""Returns whether the given value represents an empty ONNX Optional."""

return isinstance(value, _EmptyOptional)


class BinaryBase(OnnxOpConverter):
"""Converts an onnx BinaryBase node into an equivalent Relax expression."""

Expand Down Expand Up @@ -3686,6 +3730,50 @@ def _impl_v1(cls, bb, inputs, attr, params):
)


class Optional_(OnnxOpConverter):
"""Converts an ONNX Optional node into an erased or empty Optional representation."""

@classmethod
def _impl_v15(cls, bb, inputs, attr, params):
if len(inputs) > 1:
raise ValueError(f"Optional accepts at most one input, but got {len(inputs)}")
if len(inputs) == 0 or inputs[0] is None:
if "type" not in attr:
raise ValueError("Optional without an input must specify the type attribute.")
return _EmptyOptional(attr["type"])
return inputs[0]

_impl_v18 = _impl_v15


class OptionalHasElement(OnnxOpConverter):
"""Converts an ONNX OptionalHasElement node into a boolean constant."""

@classmethod
def _impl_v15(cls, bb, inputs, attr, params):
if len(inputs) != 1:
raise ValueError(f"OptionalHasElement expects one input, but got {len(inputs)}")
if inputs[0] is None or _is_empty_optional(inputs[0]):
return relax.const(False, dtype="bool")
return relax.const(True, dtype="bool")

_impl_v18 = _impl_v15


class OptionalGetElement(OnnxOpConverter):
"""Converts an ONNX OptionalGetElement node by unwrapping a non-empty Optional."""

@classmethod
def _impl_v15(cls, bb, inputs, attr, params):
if len(inputs) != 1:
raise ValueError(f"OptionalGetElement expects one input, but got {len(inputs)}")
if inputs[0] is None or _is_empty_optional(inputs[0]):
raise ValueError("OptionalGetElement cannot access an empty optional.")
return inputs[0]

_impl_v18 = _impl_v15


class SequenceConstruct(OnnxOpConverter):
"""Operator converter for sequence construction op."""

Expand Down Expand Up @@ -4111,9 +4199,9 @@ def _impl_v10(cls, bb, inputs, attr, params):
def _get_convert_map():
return {
# defs/experimental
# "Optional": Optional_,
# "OptionalHasElement": OptionalHasElement,
# "OptionalGetElement": OptionalGetElement,
"Optional": Optional_,
"OptionalHasElement": OptionalHasElement,
"OptionalGetElement": OptionalGetElement,
# Binary operators
"Add": Add,
"Sub": Sub,
Expand Down Expand Up @@ -4184,7 +4272,7 @@ def _get_convert_map():
"Gemm": Gemm,
"MatMul": MatMul,
"MatMulInteger": MatMulInteger,
# "MatMulInteger16": MatMulInteger16,
"MatMulInteger16": MatMulInteger16,
"Reshape": Reshape,
"Sigmoid": Sigmoid,
"Softmax": Softmax,
Expand Down Expand Up @@ -4343,7 +4431,18 @@ def from_onnx(self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int) -> IRModule:
self._check_for_unsupported_ops(graph)
self._construct_nodes(graph)

outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
# now return the outputs
output_names = [self._parse_value_proto(output) for output in graph.output]
outputs = []
for output_name in output_names:
output_value = self._nodes[output_name]
if _is_empty_optional(output_value):
raise ValueError(
"ONNX graph output "
f"{output_name} is an empty optional. Empty optional graph outputs "
"are not supported by the Relax ONNX frontend."
)
outputs.append(output_value)
outputs = outputs[0] if len(outputs) == 1 else relax.Tuple(outputs)

if has_if:
Expand Down Expand Up @@ -4515,6 +4614,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
"Squeeze",
]
return_tuple_ops = [
"Optional",
"OptionalGetElement",
"SequenceConstruct",
"SequenceEmpty",
"SequenceErase",
Expand All @@ -4533,7 +4634,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
try:
op = self._convert_operator(op_name, inputs, attr, self.opset)
# Create struct information for the new operator.
op = self.bb.normalize(op)
if isinstance(op, relax.Expr):
op = self.bb.normalize(op)
except TVMError as err:
print(f"Error converting operator {op_name}, with inputs: {inputs}")
raise err
Expand Down Expand Up @@ -4585,11 +4687,11 @@ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> dict[str,
if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f))
for f in ["t"]:
if a.HasField(f):
for f in ["t", "tp"]:
if hasattr(a, f) and a.HasField(f):
attrs[a.name] = getattr(a, f)
for f in ["tensors"]:
if list(getattr(a, f)):
for f in ["tensors", "type_protos"]:
if hasattr(a, f) and list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f))
for f in ["graphs"]:
Expand Down
Loading
Loading