Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion coremltools/converters/mil/backend/mil/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from . import (adjust_io_to_supported_types, fuse_activation_silu,
insert_image_preprocessing_op, sanitize_name_strings)
insert_image_preprocessing_op, sanitize_name_strings,
split_non_constant_pads)
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2025, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import numpy as np

from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass


@register_pass(namespace="mil_backend")
class split_non_constant_pads(AbstractGraphPass):
"""
Split ``pad`` ops that use non-constant modes (``reflect``, ``replicate``)
and pad more than two dimensions, because the CoreML ML Program runtime
rejects such ops with the error:
"Padding for more than two dimensions only supports `constant` mode".

Each split step pads at most two dimensions, which CoreML supports for all
padding modes.

**Limitation**: ``pad`` ops whose padding values are computed at runtime
(i.e. ``op.inputs["pad"].val is None``) are skipped. Such models will still
fail at CoreML runtime with the same error.

.. code-block::

Input:
x(1, 3, 4, 4, 4) -> pad([0,0, 0,0, 2,2, 2,2, 2,2], mode="replicate") -> (1, 3, 8, 8, 8)

Output:
x(1, 3, 4, 4, 4) -> pad([0,0, 0,0, 2,2, 2,2, 0,0], mode="replicate") -> (1, 3, 8, 8, 4)
-> pad([0,0, 0,0, 0,0, 0,0, 2,2], mode="replicate") -> (1, 3, 8, 8, 8)
"""

def apply(self, prog):
for f in prog.functions.values():
self._split_pads_block(f)

@block_context_manager
def _split_pads_block(self, block):
for op in list(block.operations):
for b in op.blocks:
self._split_pads_block(b)

if op.op_type != "pad":
continue

mode = op.inputs["mode"].val
if mode == "constant":
continue

pad = op.inputs["pad"].val
if pad is None:
continue

# Find dimensions with non-zero padding
pad_pairs = pad.reshape(-1, 2)
nonzero_dims = [
i for i, (before, after) in enumerate(pad_pairs) if before != 0 or after != 0
]

if len(nonzero_dims) <= 2:
continue

# Split into sequential pads, each covering at most 2 dimensions
x = op.inputs["x"]
constant_val = op.inputs["constant_val"].val
num_chunks = (len(nonzero_dims) + 1) // 2
result = x
for chunk_idx, chunk_start in enumerate(range(0, len(nonzero_dims), 2)):
chunk_dims = nonzero_dims[chunk_start : chunk_start + 2]
chunk_pad = np.zeros_like(pad)
for dim in chunk_dims:
chunk_pad[2 * dim] = pad_pairs[dim][0]
chunk_pad[2 * dim + 1] = pad_pairs[dim][1]

is_last = chunk_idx == num_chunks - 1
step_name = op.name if is_last else f"{op.name}_split_{chunk_idx}"
result = mb.pad(
x=result,
pad=chunk_pad,
mode=mode,
constant_val=constant_val,
before_op=op,
name=step_name,
)

op.enclosing_block.replace_uses_of_var_after_op(
anchor_op=op, old_var=op.outputs[0], new_var=result
)
op.enclosing_block.remove_ops([op])
94 changes: 92 additions & 2 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11156,7 +11156,7 @@ class TestPad(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, rank, mode",
itertools.product(
compute_units, backends, frontends, range(3, 5), ["reflect", "replicate"]
compute_units, backends, frontends, range(3, 6), ["reflect", "replicate"]
),
)
def test_pad_reflect_replicate(self, compute_unit, backend, frontend, rank: int, mode: str):
Expand All @@ -11172,9 +11172,22 @@ def test_pad_reflect_replicate(self, compute_unit, backend, frontend, rank: int,
elif rank == 4:
pad_len = 4
input_shape = (10, 5, 5, 10)
elif rank == 5:
if backend[0] == "neuralnetwork":
# The NN backend's add_padding supports reflect/replicate only when
# all padding is confined to the last 2 dimensions (pad[:-4] must be
# all zeros). Rank-5 with pad_len=6 pads 3 spatial dims, so the D
# dimension's padding falls outside that window. Ranks 3 and 4 only
# pad 1 or 2 spatial dims respectively and do pass the check.
pytest.skip(
"NeuralNetwork backend only supports reflect/replicate padding on the "
"last 2 dimensions; rank-5 with 3 spatial dims padded is not supported"
)
pad_len = 6
input_shape = (2, 3, 5, 5, 10)
else:
raise NotImplementedError(
"Only 3D, 4D padding with non-constant padding are supported for now"
"Only 3D, 4D, 5D padding with non-constant padding are supported for now"
)
max_pad = min(input_shape[-1], input_shape[-2])
pad = list(np.random.randint(low=0, high=max_pad, size=pad_len))
Expand Down Expand Up @@ -11257,6 +11270,83 @@ def test_constant_pad_3d(self, compute_unit, backend, frontend):
input_shape, model, backend=backend, compute_unit=compute_unit, frontend=frontend
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(
compute_units,
backends,
frontends,
),
)
def test_replication_pad3d(self, compute_unit, backend, frontend):
"""Regression test for https://github.com/apple/coremltools/issues/2571.
ReplicationPad3d should work correctly rather than producing a model that
fails at CoreML runtime with 'Padding for more than two dimensions only
supports constant mode'."""
if backend[0] == "neuralnetwork":
pytest.skip("NeuralNetwork backend does not support replicate padding for >2 spatial dims")
input_shape = (1, 3, 4, 4, 4)
model = torch.nn.ReplicationPad3d(2).eval()
self.run_compare_torch(
input_shape, model, backend=backend, compute_unit=compute_unit, frontend=frontend
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(
compute_units,
backends,
frontends,
),
)
def test_replication_pad3d_asymmetric(self, compute_unit, backend, frontend):
"""Test ReplicationPad3d with asymmetric padding (different values per side)."""
if backend[0] == "neuralnetwork":
pytest.skip("NeuralNetwork backend does not support replicate padding for >2 spatial dims")
input_shape = (1, 3, 4, 5, 6)
model = torch.nn.ReplicationPad3d((1, 2, 3, 1, 2, 3)).eval()
self.run_compare_torch(
input_shape, model, backend=backend, compute_unit=compute_unit, frontend=frontend
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(
compute_units,
backends,
frontends,
),
)
def test_reflection_pad3d(self, compute_unit, backend, frontend):
"""Regression test for reflect mode: ReflectionPad3d should work correctly
rather than failing at CoreML runtime with 'Padding for more than two dimensions
only supports constant mode'."""
if backend[0] == "neuralnetwork":
pytest.skip("NeuralNetwork backend does not support reflect padding for >2 spatial dims")
input_shape = (1, 3, 5, 5, 5)
model = torch.nn.ReflectionPad3d(2).eval()
self.run_compare_torch(
input_shape, model, backend=backend, compute_unit=compute_unit, frontend=frontend
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(
compute_units,
backends,
frontends,
),
)
def test_reflection_pad3d_asymmetric(self, compute_unit, backend, frontend):
"""Test ReflectionPad3d with asymmetric padding (different values per side)."""
if backend[0] == "neuralnetwork":
pytest.skip("NeuralNetwork backend does not support reflect padding for >2 spatial dims")
input_shape = (1, 3, 5, 6, 7)
model = torch.nn.ReflectionPad3d((1, 2, 1, 2, 1, 2)).eval()
self.run_compare_torch(
input_shape, model, backend=backend, compute_unit=compute_unit, frontend=frontend
)


class TestMaskedFill(TorchBaseTest):
@pytest.mark.parametrize(
Expand Down
1 change: 1 addition & 0 deletions coremltools/converters/mil/mil/passes/pass_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@
"common::const_deduplication", # after all consts have been settled
"common::cast_optimization",
"common::dead_code_elimination",
"mil_backend::split_non_constant_pads", # must come before sanitize_name_strings
"mil_backend::sanitize_name_strings",
"common::dedup_op_and_var_names",
"nn_backend::handle_unused_inputs", # must come after dce.
Expand Down