Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from __future__ import annotations

import warnings
from typing import Sequence
from typing import Optional, Sequence

from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_opset import opset19
from onnxscript.onnx_types import FLOAT, INT64

_INT64_MAX = 0x7FFFFFFFFFFFFFFF
Expand Down Expand Up @@ -91,3 +93,34 @@ def torchvision_roi_pool(input, boxes, output_size: Sequence[int], spatial_scale
pooled_shape=(pooled_height, pooled_width),
spatial_scale=spatial_scale,
)


@torch_op("torchvision::deform_conv2d", trace_only=True)
def torchvision_deform_conv2d(
input: TFloat,
offset: TFloat,
weight: TFloat,
bias: Optional[TFloat] = None,
stride: tuple[int, int] = (1, 1),
padding: tuple[int, int] = (0, 0),
dilation: tuple[int, int] = (1, 1),
mask: Optional[TFloat] = None,
):
"""deform_conv2d(input: torch.Tensor, offset: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: tuple[int, int] = (1, 1), padding: tuple[int, int] = (0, 0), dilation: tuple[int, int] = (1, 1), mask: Optional[torch.Tensor] = None) → torch.Tensor"""

kernel_h, kernel_w = weight.shape[-2:]
group = input.shape[1] // weight.shape[1]
offset_group = offset.shape[1] // (2 * kernel_h * kernel_w)
pads = (padding[0], padding[1], padding[0], padding[1])
return opset19.DeformConv(
X=input,
W=weight,
offset=offset,
B=bias,
mask=mask,
dilations=dilation,
strides=stride,
pads=pads,
group=group,
offset_group=offset_group,
)
70 changes: 70 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,69 @@ def make_x():
)


def sample_inputs_deform_conv2d(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

def make_arg(shape):
return torch_testing.make_tensor(
shape, device=device, dtype=dtype, requires_grad=requires_grad
)

# Basic no-bias/no-mask case
input = make_arg((1, 2, 5, 5))
weight = make_arg((3, 2, 3, 3))
out_h = out_w = 3
offset = make_arg((1, 2 * 3 * 3, out_h, out_w))
yield opinfo_core.SampleInput(input, args=(offset, weight))

# With bias
input = make_arg((1, 2, 5, 5))
weight = make_arg((3, 2, 3, 3))
bias = make_arg((3,))
offset = make_arg((1, 2 * 3 * 3, out_h, out_w))
yield opinfo_core.SampleInput(input, args=(offset, weight, bias))

# With mask
input = make_arg((1, 2, 5, 5))
weight = make_arg((3, 2, 3, 3))
offset = make_arg((1, 2 * 3 * 3, out_h, out_w))
mask = make_arg((1, 3 * 3, out_h, out_w))
yield opinfo_core.SampleInput(
input,
args=(offset, weight),
kwargs={"mask": mask},
)

# Nonzero padding
input = make_arg((1, 2, 5, 5))
weight = make_arg((3, 2, 3, 3))
out_h = out_w = 5
offset = make_arg((1, 2 * 3 * 3, out_h, out_w))
yield opinfo_core.SampleInput(
input,
args=(offset, weight),
kwargs={"padding": (1, 1)},
)

# Grouped convolution
input = make_arg((1, 4, 5, 5))
weight = make_arg((4, 2, 3, 3))
offset = make_arg((1, 2 * 3 * 3, 3, 3))
yield opinfo_core.SampleInput(input, args=(offset, weight))

# Multiple offset groups
input = make_arg((1, 4, 5, 5))
weight = make_arg((4, 4, 3, 3))
offset = make_arg((1, 2 * 2 * 3 * 3, 3, 3))
mask = make_arg((1, 2 * 3 * 3, 3, 3))
yield opinfo_core.SampleInput(
input,
args=(offset, weight),
kwargs={"mask": mask},
)


def sample_inputs_roi_pool(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down Expand Up @@ -3101,4 +3164,11 @@ def __init__(self):
sample_inputs_func=sample_inputs_roi_pool,
supports_out=False,
),
opinfo_core.OpInfo(
"torchvision.ops.deform_conv2d",
op=torchvision.ops.deform_conv2d,
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs_deform_conv2d,
supports_out=False,
),
]
42 changes: 41 additions & 1 deletion tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from __future__ import annotations

import functools
import os
import unittest
from typing import Callable, Optional, Sequence, Tuple
Expand Down Expand Up @@ -276,7 +277,12 @@ def setUp(self) -> None:
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
)
@common_device_type.ops( # type: ignore[misc]
[info for info in ops_test_data.OPS_DB if info.name in ops_test_data.TESTED_OPS],
[
info
for info in ops_test_data.OPS_DB
if info.name in ops_test_data.TESTED_OPS
and info.name != "torchvision.ops.deform_conv2d"
],
allowed_dtypes=TESTED_DTYPES,
)
def test_output_match_opinfo_(
Expand Down Expand Up @@ -324,5 +330,39 @@ def test_complex_output_match_opinfo_(
TestOutputConsistencyFullGraph, globals(), only_for=["cpu", "cuda"]
)


class TestOutputConsistencyDeformConvOpset19(unittest.TestCase):
"""Test deform_conv2d output consistency using an opset that supports DeformConv."""

def setUp(self) -> None:
torch.manual_seed(42)
np.random.seed(42)
ort.set_seed(42)

@common_device_type.ops( # type: ignore[misc]
[
info
for info in ops_test_data.OPS_DB
if info.name == "torchvision.ops.deform_conv2d"
],
allowed_dtypes=TESTED_DTYPES,
)
def test_output_match_opinfo_(
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
):
run_test_output_match(
self,
device,
dtype,
op,
functools.partial(ops_test_common.graph_executor, opset_version=19),
ops_test_data.TORCHLIB_OPINFO_MAPPING,
)


common_device_type.instantiate_device_type_tests(
TestOutputConsistencyDeformConvOpset19, globals(), only_for=["cpu", "cuda"]
)

if __name__ == "__main__":
unittest.main()
6 changes: 4 additions & 2 deletions tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: ir.schemas.OpSignatur
def graph_executor(
test_name: str,
outputs: Sequence[Any],
*,
opset_version: int = TEST_OPSET_VERSION,
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
"""Eagerly executes a function."""

Expand All @@ -481,14 +483,14 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
(),
nodes=(),
opset_imports={
"": 18,
"": opset_version,
"pkg.torch.onnx": 1,
"pkg.onnxscript.torch_lib.common": 1,
"pkg.onnxscript.torch_lib": 1,
},
name="main_graph",
)
opset = onnxscript.opset18
opset = getattr(onnxscript, f"opset{opset_version}")
tracer = _building.OpRecorder(opset, {})
ort_inputs = {}
onnxscript_args: list[Any] = []
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,7 @@ def _where_input_wrangler(
TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms),
TorchLibOpInfo("torchvision.ops.roi_align", vision_ops.torchvision_roi_align),
TorchLibOpInfo("torchvision.ops.roi_pool", vision_ops.torchvision_roi_pool),
TorchLibOpInfo("torchvision.ops.deform_conv2d", vision_ops.torchvision_deform_conv2d),
)

ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
Expand Down