Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,27 @@ def _translate_torch_args(pad: Var, mode: Var, value: Var) -> Tuple[Var]:
mode = mode.val
assert mode in ("circular", "constant", "reflect", "replicate")

# MIL `mb.pad` only supports `reflect` / `replicate` when at most the
# final two dimensions carry non-zero padding. Forwarding a larger
# number of non-constant-padded dims surfaces later as a Core ML
# Framework model-compile error inside `predict()`; raise the same
# message at conversion time instead. (Covers torch.export-decomposed
# ReflectionPad3d / ReplicationPad3d -> aten.pad with
# mode="reflect" / "replicate".)
if (
mode in ("reflect", "replicate")
and isinstance(pad, list)
and all(isinstance(p, (int, float)) for p in pad)
):
padded_dims = sum(
1 for i in range(len(pad) // 2)
if pad[2 * i] != 0 or pad[2 * i + 1] != 0
)
if padded_dims > 2:
raise NotImplementedError(
"Padding for more than two dimensions only supports constant mode"
)

if value is None:
value = 0.0
elif isinstance(value, Var):
Expand Down
31 changes: 31 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11183,6 +11183,37 @@ def test_pad_reflect_replicate(self, compute_unit, backend, frontend, rank: int,
input_shape, model, backend=backend, compute_unit=compute_unit, frontend=frontend
)

@pytest.mark.parametrize(
"mode, torch_module",
[
("reflect", torch.nn.ReflectionPad3d),
("replicate", torch.nn.ReplicationPad3d),
],
)
def test_pad_reflect_replicate_3d_raises(self, mode, torch_module):
# Regression test for issues #2576 and #2571: padding more than two
# dimensions only supports constant mode. Previously a `reflect` /
# `replicate` pad on >2 dims surfaced as a Core ML Framework
# model-compile error inside `predict()`; now it must raise
# NotImplementedError at conversion time with the same message.
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.pad = torch_module(padding=2)

def forward(self, x):
return self.pad(x)

model = Model().eval()
inputs = (torch.randn(1, 6, 6, 6, 6),)
exported = torch.export.export(model, inputs).run_decompositions({})

with pytest.raises(
NotImplementedError,
match="Padding for more than two dimensions only supports constant mode",
):
ct.convert(exported)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, rank",
itertools.product(compute_units, backends, frontends, range(1, 6)),
Expand Down