Skip to content

Fix max_pool conversion for unbatched torch input#2723

Open
devin-lai wants to merge 1 commit into
apple:mainfrom
devin-lai:fix-maxpool-unbatched-input-2148
Open

Fix max_pool conversion for unbatched torch input#2723
devin-lai wants to merge 1 commit into
apple:mainfrom
devin-lai:fix-maxpool-unbatched-input-2148

Conversation

@devin-lai
Copy link
Copy Markdown

Summary

  • Support unbatched (C, *spatial) inputs when converting torch MaxPool1d/2d/3d ops.
  • Temporarily prepend missing leading dims before MIL max_pool, then squeeze them back off after pooling.
  • Derive the pool spatial rank from kernel_sizes, and leave the existing batched path unchanged.

Details

PyTorch MaxPool1d, MaxPool2d, and MaxPool3d accept both batched (N, C, *spatial) inputs and unbatched (C, *spatial) inputs. MIL max_pool expects two leading (N, C) dimensions, so an unbatched MaxPool2d input such as (C, H, W) reaches MIL as a rank-3 tensor and its type inference only sees one spatial dimension. That causes the spatial-shape consistency check to fail before conversion completes.

This change mirrors the existing adaptive-pooling handling: when the input rank is lower than spatial_rank + 2, the converter inserts the missing leading dims with expand_dims, applies max_pool, and removes those temporary dims with squeeze. Batched inputs still go through the original max_pool path and keep the original node name.

Tests

  • python -m pytest -q --log-cli-level=CRITICAL --log-level=CRITICAL coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestMaxPool::test_max_pool_unbatched_input

Fixes #2148.

PyTorch's MaxPool{1,2,3}d accepts unbatched (C, *spatial) input as
well as the batched (N, C, *spatial) form. MIL's max_pool requires
two leading (N, C) dims, so the converter has to bridge. Wrap the
input in expand_dims when its rank is below spatial_rank + 2, run
the pool, then squeeze the prepended dims back off, mirroring the
existing _adaptive_pool2d handling in the same file. The batched
path is unchanged.

Fixes apple#2148
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Experiencing "divided by two must all be the same length" when trying to convert a PyTorch MaxPool2d

1 participant