Skip to content

Support return_indices for max_pool2d in the torch frontend#2717

Open
adityasingh2400 wants to merge 1 commit into
apple:mainfrom
adityasingh2400:max-pool2d-return-indices
Open

Support return_indices for max_pool2d in the torch frontend#2717
adityasingh2400 wants to merge 1 commit into
apple:mainfrom
adityasingh2400:max-pool2d-return-indices

Conversation

@adityasingh2400
Copy link
Copy Markdown

Fixes #2456.

Problem

max_pool2d with return_indices=True could not be converted. The torch frontend bound the indices output to None, so any model that read the indices failed with ValueError: Torch var <n> not found in context:

class Model(torch.nn.Module):
    def forward(self, image):
        _, ixs = F.max_pool2d(image, kernel_size=3, stride=1, padding=1, return_indices=True)
        return ixs

ct.convert(torch.jit.trace(Model().eval(), torch.rand(1, 3, 224, 224)),
           inputs=[ct.TensorType(shape=(1, 3, 224, 224))])

Fix

MIL has no max-pool-with-indices op, so the indices are recovered from the pooling windows. The input is padded so window extraction lines up with the pooled output, the windows are materialized with sliding_windows (batch and channel are folded together to stay within Core ML's rank-5 limit), reduce_argmax finds the selected element inside each window, and the within-window position is mapped back to a flattened input coordinate, which is what PyTorch returns.

A few details match PyTorch exactly:

  • The pooled output's spatial size is used as the source of truth, so the ceil_mode rule that drops a trailing window starting inside the bottom/right padding is honored.
  • reduce_argmax returns the first occurrence on ties, matching PyTorch's tie-breaking.
  • Padded cells are filled with the lowest representable float so they are never selected, matching PyTorch treating padding as -inf.

The values output is unchanged. 1D/3D return_indices and dynamic input shapes raise a clear NotImplementedError instead of silently producing wrong results.

Testing

Added test_max_pool2d_return_indices, parametrized over kernel size, stride, padding, and ceil_mode for both the TorchScript and torch.export frontends. Indices are compared against the torch reference on the fp32 path; a fp16 backend can rank two near-equal window elements differently from torch's fp32 argmax, so that combination is skipped for the index comparison. The existing max_pool tests continue to pass.

max_pool2d_with_indices bound its indices output to None, so any model
that consumed the indices failed conversion with a missing-var error
(issue apple#2456). Recover the indices by replaying the pooling windows with
sliding_windows, taking reduce_argmax inside each window, and mapping the
within-window position back to a flattened input coordinate. The pooled
output's spatial size is used directly so the ceil_mode window-dropping
rule matches PyTorch. reduce_argmax's first-occurrence tie-break matches
PyTorch, and padded cells are filled with the lowest float so they are
never selected.

Adds a parametrized test covering kernel size, stride, padding, and
ceil_mode for both the TorchScript and torch.export frontends.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Error Converting max_pool2d

1 participant