Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 123 additions & 4 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,6 +1886,106 @@ def _calculate_pool_output_size(in_dim, kernel, stride, pad_l, pad_r, ceil_mode)
return new_pad


def _max_pool_indices(x, pool, kernel_sizes, strides, pad, ceil_mode, name):
"""Compute the second output of ``max_pool*d_with_indices``.

PyTorch returns, for every pooled element, the index of the selected element
flattened over the input spatial dimensions (per channel, ignoring the batch
and channel offsets). MIL has no max-pool-with-indices op, so we recover the
indices by replaying the pooling windows: slide over the (padded) input,
take ``reduce_argmax`` inside each window, then map the within-window position
back to a flattened input coordinate.

``reduce_argmax`` returns the first occurrence on ties, which matches
PyTorch's tie-breaking. Padded positions are filled with the lowest float so
they are never selected (consistent with PyTorch treating padding as -inf).
"""
spatial_rank = len(kernel_sizes)
if spatial_rank != 2:
# 1D/3D index recovery can be added later; flag clearly until then.
raise NotImplementedError(
"Conversion of max_pool with return_indices=True is only supported "
f"for 2D pooling, but got a {spatial_rank}D pooling op."
)

in_shape = x.shape
if any(is_symbolic(d) for d in in_shape):
raise NotImplementedError(
"Conversion of max_pool with return_indices=True requires a static "
"input shape."
)

N, C, H, W = [int(d) for d in in_shape]
KH, KW = kernel_sizes
SH, SW = strides
PH, PW = pad[0], pad[2]
neg = float(np.finfo(nptype_from_builtin(x.dtype)).min)

# Pad the spatial dimensions so window extraction matches the pooled output.
x_padded = x
if PH or PW:
x_padded = mb.pad(
x=x_padded, pad=[0, 0, 0, 0, PH, PH, PW, PW], mode="constant", constant_val=neg
)
H_pad, W_pad = H + 2 * PH, W + 2 * PW

# Use the pooled output's spatial size as the source of truth. It already
# reflects PyTorch's ceil_mode rule, which drops a trailing window when it
# would start entirely inside the bottom/right padding.
OH, OW = int(pool.shape[2]), int(pool.shape[3])

# ceil_mode can require windows that extend past the padded grid; pad the
# extra cells with -inf so they are never picked.
extra_H = (OH - 1) * SH + KH - H_pad
extra_W = (OW - 1) * SW + KW - W_pad
extra_H = extra_H if extra_H > 0 else 0
extra_W = extra_W if extra_W > 0 else 0
if extra_H or extra_W:
x_padded = mb.pad(
x=x_padded,
pad=[0, 0, 0, 0, 0, extra_H, 0, extra_W],
mode="constant",
constant_val=neg,
)
H_pad += extra_H
W_pad += extra_W

# sliding_windows would push the rank past Core ML's limit of 5 if batch and
# channel were kept separate, so fold them together first.
x_flat = mb.reshape(x=x_padded, shape=[N * C, H_pad, W_pad])
windows_h = mb.sliding_windows(x=x_flat, axis=1, size=KH, stride=SH)
windows = mb.sliding_windows(x=windows_h, axis=3, size=KW, stride=SW)
# (N*C, OH', KH, OW', KW) -> (N*C, OH', OW', KH, KW) -> (N*C, OH', OW', KH * KW)
windows = mb.transpose(x=windows, perm=[0, 1, 3, 2, 4])
# sliding_windows can yield more windows than the pooled output keeps, so
# trim to the pooled spatial size.
if int(windows.shape[1]) != OH or int(windows.shape[2]) != OW:
windows = mb.slice_by_index(
x=windows,
begin=[0, 0, 0, 0, 0],
end=[N * C, OH, OW, KH, KW],
end_mask=[True, False, False, True, True],
)
windows = mb.reshape(x=windows, shape=[N * C, OH, OW, KH * KW])

argmax = mb.reduce_argmax(x=windows, axis=-1, keep_dims=False)
win_row = mb.floor_div(x=argmax, y=KW)
win_col = mb.mod(x=argmax, y=KW)

# Map the window-local position to a flattened input coordinate:
# ((out_h * stride_h + win_row) - pad_h) * W + ((out_w * stride_w + win_col) - pad_w)
out_h_offset = (np.arange(OH, dtype=np.int32) * SH).reshape(1, OH, 1)
out_w_offset = (np.arange(OW, dtype=np.int32) * SW).reshape(1, 1, OW)
row = mb.add(x=win_row, y=mb.const(val=out_h_offset))
col = mb.add(x=win_col, y=mb.const(val=out_w_offset))
row = mb.sub(x=row, y=PH)
col = mb.sub(x=col, y=PW)
flat = mb.add(x=mb.mul(x=row, y=W), y=col)

indices = mb.reshape(x=flat, shape=[N, C, OH, OW])
return mb.cast(x=indices, dtype="int32", name=name)


@register_torch_op(
torch_alias=[
"max_pool2d",
Expand Down Expand Up @@ -1922,6 +2022,8 @@ def max_pool1d(context, node):
x_spatial_dimensions = x.shape[-spatial_rank:]
pad = _adjust_pad_for_ceil_mode(x_spatial_dimensions, kernel_sizes.val, strides.val, pad)

with_indices = bool(re.match(r"max_pool[123]d_with_indices", node.kind))

pool = mb.max_pool(
x=x,
kernel_sizes=kernel_sizes,
Expand All @@ -1932,11 +2034,28 @@ def max_pool1d(context, node):
ceil_mode=ceil_mode if spatial_rank <= 2 else False,
)

if re.match(r"max_pool[123]d_with_indices", node.kind):
# TODO(rdar://117038432) ([Executorch] Handle/Bind other outputs of `max_pool2d_with_indices` op during lowering)
context.add((pool, None), torch_name=node.name)
else:
if not with_indices:
context.add(pool)
return

indices = _max_pool_indices(
x,
pool,
kernel_sizes=[int(k) for k in kernel_sizes.val],
strides=[int(s) for s in strides.val],
pad=[int(p) for p in pad],
ceil_mode=bool(ceil_mode),
name=node.name + "_indices",
)
if len(node.outputs) > 1:
# TorchScript exposes the values and indices as two separate outputs,
# so bind each one by its own name.
context.add(pool, torch_name=node.outputs[0])
context.add(indices, torch_name=node.outputs[1])
else:
# The torch.export based frontends expose a single output that the
# downstream getitem ops index into, so bind it as a tuple.
context.add((pool, indices), torch_name=node.name)


@register_torch_op(torch_alias=["min.other"])
Expand Down
52 changes: 52 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3430,6 +3430,58 @@ def test_max_pool3d(
input_shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, input_shape, kernel_size, stride, padding, ceil_mode",
itertools.product(
compute_units,
backends,
frontends,
[(1, 3, 15, 15), (1, 1, 8, 8)],
[2, 3],
[1, 2],
[0, 1],
[True, False],
),
)
def test_max_pool2d_return_indices(
self,
compute_unit,
backend,
frontend,
input_shape,
kernel_size,
stride,
padding,
ceil_mode,
):
if padding > kernel_size / 2:
return
# The indices are recovered from an argmax over the pooled windows, so a
# fp16 backend can rank two near-equal window elements differently from
# the fp32 torch reference. Only the fp32 path is bit-exact for indices.
if backend[1] == "fp16":
pytest.skip("indices require fp32 to match torch's argmax tie-breaking")

class Model(nn.Module):
def forward(self, x):
return nn.functional.max_pool2d(
x,
kernel_size=kernel_size,
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
return_indices=True,
)

self.run_compare_torch(
input_shape,
Model(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
minimum_deployment_target=ct.target.iOS17,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(compute_units, backends, frontends),
Expand Down