Skip to content
27 changes: 27 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def _relax_dtype_is_floating_point(dtype: str) -> bool:
)


def _const_integer_expr_has_zero(expr: relax.Expr) -> bool | None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think we need to create a new three-line function if it is only called once from another function.

"""Return whether a constant integer expression contains a zero value.

Returns None when expression is not statically inspectable.
"""

if isinstance(expr, relax.Constant):
return bool(_np.any(expr.data.numpy() == 0))
Comment on lines +81 to +82
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function _const_integer_expr_has_zero should also handle relax.PrimValue expressions, as they can also represent constant integer divisors in Relax. This ensures that scalar constant zero divisors are also caught during import validation.

Suggested change
if isinstance(expr, relax.Constant):
return bool(_np.any(expr.data.numpy() == 0))
if isinstance(expr, relax.Constant):
return bool(_np.any(expr.data.numpy() == 0))
if isinstance(expr, relax.PrimValue) and isinstance(expr.value, tirx.IntImm):
return int(expr.value.value) == 0


return None


def get_type(elem_type: str | int) -> str:
"""Converts onnx integer datatype to numpy datatype"""
# If a string was passed instead of a tensor type, it does not need
Expand Down Expand Up @@ -526,6 +538,21 @@ class Div(BinaryBase):

@classmethod
def _impl_v7(cls, bb, inputs, attr, params):
try:
lhs_code = DataType(inputs[0].struct_info.dtype).type_code
rhs_code = DataType(inputs[1].struct_info.dtype).type_code
except (AttributeError, ValueError, TypeError, TVMError):
return cls.base_impl(bb, inputs, attr, params)

lhs_is_integer = lhs_code == DataTypeCode.INT or lhs_code == DataTypeCode.UINT
rhs_is_integer = rhs_code == DataTypeCode.INT or rhs_code == DataTypeCode.UINT
if not (lhs_is_integer and rhs_is_integer):
return cls.base_impl(bb, inputs, attr, params)

rhs_has_zero = _const_integer_expr_has_zero(inputs[1])
if rhs_has_zero:
raise ValueError("ONNX Div with integer inputs encountered divisor value 0.")

return cls.base_impl(bb, inputs, attr, params)


Expand Down
45 changes: 45 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,51 @@ def test_binary(op_name: str):
verify_binary_scalar(op_name)


def test_div_integer_constant_zero_divisor_raises_valueerror():
b_init = numpy_helper.from_array(np.array([3, 0, -2, 1], dtype=np.int32), name="b")
node = helper.make_node("Div", ["a", "b"], ["y"])
graph = helper.make_graph(
[node],
"div_const_zero",
[helper.make_tensor_value_info("a", TensorProto.INT32, [4])],
[helper.make_tensor_value_info("y", TensorProto.INT32, [4])],
initializer=[b_init],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
model.ir_version = 9

with pytest.raises(
ValueError, match="ONNX Div with integer inputs encountered divisor value 0"
):
from_onnx(model, opset=18, keep_params_in_input=False)


def test_div_integer_dynamic_nonzero_matches_ort():
Copy link
Copy Markdown
Contributor

@mshr-h mshr-h May 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to add this test? It looks like it is already covered by test_binary.

node = helper.make_node("Div", ["a", "b"], ["y"])
graph = helper.make_graph(
[node],
"div_dynamic_nonzero",
[
helper.make_tensor_value_info("a", TensorProto.INT32, [4]),
helper.make_tensor_value_info("b", TensorProto.INT32, [4]),
],
[helper.make_tensor_value_info("y", TensorProto.INT32, [4])],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
model.ir_version = 9

check_correctness(
model,
inputs={
"a": np.array([42, 99, -50, 7], dtype=np.int32),
"b": np.array([3, -2, 5, 1], dtype=np.int32),
},
ir_version=9,
opset=18,
check_dtypes=True,
)


@pytest.mark.parametrize("int_mode", [True, False])
def test_mod(int_mode: bool):
if int_mode:
Expand Down
97 changes: 69 additions & 28 deletions tests/python/relax/test_frontend_onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,10 @@ def run(self, inputs, **kwargs):
self._vm.invoke_stateful("main")
output = self._vm.get_outputs("main")

if isinstance(output, (tvm.runtime.Tensor, np.ndarray)):
if isinstance(output, tvm.runtime.Tensor | np.ndarray):
return (output.numpy() if hasattr(output, "numpy") else output,)
if isinstance(output, (tuple, list)):
return tuple(
o.numpy() if hasattr(o, "numpy") else np.array(o) for o in output
)
if isinstance(output, tuple | list):
return tuple(o.numpy() if hasattr(o, "numpy") else np.array(o) for o in output)
return (np.array(output),)


Expand Down Expand Up @@ -110,9 +108,7 @@ def prepare(cls, model, device="CPU", **kwargs):
func_param_names = [p.name_hint for p in func.params]
graph_input_names = [inp.name for inp in model.graph.input]

return TVMRelaxBackendRep(
tvm_model, params, func_param_names, graph_input_names
)
return TVMRelaxBackendRep(tvm_model, params, func_param_names, graph_input_names)

@classmethod
def supports_device(cls, device: str) -> bool:
Expand All @@ -133,32 +129,77 @@ def supports_device(cls, device: str) -> bool:
# validated against the ONNX Backend Test Suite. They can be added
# incrementally as the importer improves.
_INCLUDE_OPS = [
"abs", "acos", "acosh", "add", "and", "argmax", "argmin",
"averagepool", "bitshift",
"bitwise_and", "bitwise_not", "bitwise_or", "bitwise_xor",
"ceil", "clip", "compress", "concat",
"conv", "cos", "cosh",
"depthtospace", "div",
"einsum", "erf", "exp",
"flatten", "floor",
"gathernd", "gemm",
"globalaveragepool", "globalmaxpool", "greater", "greater_equal",
"hardmax", "hardswish",
"abs",
"acos",
"acosh",
"add",
"and",
"argmax",
"argmin",
"averagepool",
"bitshift",
"bitwise_and",
"bitwise_not",
"bitwise_or",
"bitwise_xor",
"ceil",
"clip",
"compress",
"concat",
"conv",
"cos",
"cosh",
"depthtospace",
"div",
"einsum",
"erf",
"exp",
"flatten",
"floor",
"gathernd",
"gemm",
"globalaveragepool",
"globalmaxpool",
"greater",
"greater_equal",
"hardmax",
"hardswish",
"isnan",
"less", "less_equal", "lrn",
"matmul", "matmulinteger", "mean", "min", "mod", "mul", "neg",
"nonzero", "not",
"less",
"less_equal",
"lrn",
"matmul",
"matmulinteger",
"mean",
"min",
"mod",
"mul",
"neg",
"nonzero",
"not",
"or",
"reciprocal",
"round",
"scatternd",
"sigmoid", "sign",
"sin", "sinh", "size", "slice",
"sigmoid",
"sign",
"sin",
"sinh",
"size",
"slice",
"spacetodepth",
"sqrt", "squeeze", "sub", "sum",
"tan", "tanh", "tile", "transpose",
"unique", "unsqueeze",
"where", "xor",
"sqrt",
"squeeze",
"sub",
"sum",
"tan",
"tanh",
"tile",
"transpose",
"unique",
"unsqueeze",
"where",
"xor",
]

for _op in _INCLUDE_OPS:
Expand Down
Loading