Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,7 +1914,16 @@ def max_pool1d(context, node):
# See: rdar://60633736 (Implement dilation for mil op max_pool)
raise ValueError("@max_pool does not support dilation > 1")

spatial_rank = len(pad) // 2
spatial_rank = kernel_sizes.shape[0]

# PyTorch's MaxPool{1,2,3}d accepts unbatched input (C, *spatial); MIL's
# max_pool expects 2 leading (N, C) dims. Prepend the missing dims and
# squeeze them off after the pool. Same pattern used by _adaptive_pool2d.
required_rank = spatial_rank + 2
expand_axes = list(range(required_rank - x.rank)) if x.rank < required_rank else []
if expand_axes:
x = mb.expand_dims(x=x, axes=expand_axes)

if spatial_rank > 2 and ceil_mode is True and list(strides.val) != [1] * len(strides.val):
# since MIL does not support ceil_mode for 3D pool,
# need to adjust padding values if ceil_mode is True
Expand All @@ -1928,9 +1937,11 @@ def max_pool1d(context, node):
strides=strides,
pad_type=pad_type,
pad=pad,
name=node.name,
name=node.name + "_pool" if expand_axes else node.name,
ceil_mode=ceil_mode if spatial_rank <= 2 else False,
)
if expand_axes:
pool = mb.squeeze(x=pool, axes=expand_axes, name=node.name)

if re.match(r"max_pool[123]d_with_indices", node.kind):
# TODO(rdar://117038432) ([Executorch] Handle/Bind other outputs of `max_pool2d_with_indices` op during lowering)
Expand Down
31 changes: 31 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3475,6 +3475,37 @@ def test_max_pool2d_symbolic_input(self, compute_unit, backend, frontend):
torch_export_dynamic_shapes=torch_export_dynamic_shapes,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, ndim, ceil_mode",
itertools.product(
compute_units,
backends,
frontends,
[1, 2, 3],
[False, True],
),
)
def test_max_pool_unbatched_input(self, compute_unit, backend, frontend, ndim, ceil_mode):
# PyTorch's MaxPool{1,2,3}d accepts both batched (N, C, *spatial) and
# unbatched (C, *spatial) input. Prior to #2148 the unbatched case
# failed conversion with `input_shape (length ...) ... divided by two
# must all be the same length` because the converter passed the
# lower-rank tensor straight to MIL's max_pool, which assumes 2 leading
# (N, C) dims.
input_shape = (3,) + (7,) * ndim
model = {
1: nn.MaxPool1d,
2: nn.MaxPool2d,
3: nn.MaxPool3d,
}[ndim](kernel_size=3, stride=2, padding=1, ceil_mode=ceil_mode)
self.run_compare_torch(
input_shape,
model,
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
)


class TestMaximumMinimum(TorchBaseTest):
@pytest.mark.parametrize(
Expand Down