Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7323,7 +7323,14 @@ def floor(context, node):
@register_torch_op
def reciprocal(context, node):
inputs = _get_inputs(context, node, expected=1)
context.add(mb.inverse(x=inputs[0], name=node.name))
x = inputs[0]
# PyTorch's reciprocal promotes int inputs to float; mb.inverse only
# accepts fp16/fp32. Without this cast, common patterns like
# `1 / x.shape[0]` (which TorchScript traces as
# reciprocal(prim::NumToTensor(int))) fail to convert.
if types.is_int(x.dtype):
x = mb.cast(x=x, dtype="fp32")
context.add(mb.inverse(x=x, name=node.name))


@register_torch_op
Expand Down
29 changes: 29 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6676,6 +6676,35 @@ def test_div(self, compute_unit, backend, frontend, rounding_mode, x2_type):
)


class TestReciprocal(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(compute_units, backends, frontends),
)
def test_reciprocal_int_shape(self, compute_unit, backend, frontend):
# Regression test for #2579: TorchScript traces `16 / x.shape[0]`
# as reciprocal(int) -> mul(16), and reciprocal previously rejected
# the int input because mb.inverse only accepts fp16/fp32.
if frontend in TORCH_EXPORT_BASED_FRONTENDS:
pytest.skip("torch.export folds shape-derived constants")

class TestModel(nn.Module):
def forward(self, x):
return 16 / x.shape[0] * x

# mb.inverse uses hardware reciprocal with limited precision; loosen
# tolerance to accommodate fp16 backends.
self.run_compare_torch(
(2, 16, 11),
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
atol=1e-2,
rtol=1e-2,
)


class TestElementWiseUnary(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, shape, op_string",
Expand Down