Skip to content

Commit f220e71

Browse files
john-rockyJohn Rocky
andauthored
Skip argmin/argmax with dim=None in CoreML partitioner (pytorch#19247)
### Summary `argmax(x, dim=None)` / `argmin(x, dim=None)` reduces over the flattened tensor. CoreML does not support this reduction, and the resulting model intermittently crashes the process at runtime (the issue reproducer crashes 100% of the time on M1 Pro when the cell is run twice). Detect the `dim is None` case in `should_override_support` so the op falls back to the portable backend. The ordinary `dim=int` form is unaffected and still gets delegated. Fixes pytorch#11715. ### Test plan Added `test_argmax_argmin_dim_none_is_skipped` covering both branches: - `argmax(x, dim=None) + argmin(x, dim=None)` — neither op is delegated. - `argmax(x, dim=1)` — gets delegated as before. ``` $ python -m unittest -v executorch.backends.apple.coreml.test.test_coreml_partitioner.TestCoreMLPartitioner.test_argmax_argmin_dim_none_is_skipped Ran 1 test in 1.042s OK ``` Authored with Claude. cc @metascroy --------- Co-authored-by: John Rocky <samuraibrothersmail@gmail.com>
1 parent f8cfc73 commit f220e71

2 files changed

Lines changed: 74 additions & 0 deletions

File tree

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,26 @@
6363
)
6464

6565

66+
_ARG_MIN_MAX_TARGETS = (
67+
torch.ops.aten.argmax.default,
68+
torch.ops.aten.argmin.default,
69+
exir_ops.edge.aten.argmax.default,
70+
exir_ops.edge.aten.argmin.default,
71+
)
72+
73+
74+
def _is_arg_min_max_over_flattened_input(node: torch.fx.Node) -> bool:
75+
"""``argmin``/``argmax`` with ``dim=None`` reduces over the flattened input.
76+
77+
CoreML doesn't support that reduction shape and intermittently crashes
78+
the process at runtime — see pytorch/executorch#11715.
79+
"""
80+
if node.target not in _ARG_MIN_MAX_TARGETS:
81+
return False
82+
dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("dim", None)
83+
return dim is None
84+
85+
6686
def _is_view_op(op: torch._ops.OpOverload) -> bool:
6787
schema = op._schema
6888
if len(schema.arguments) == 0:
@@ -132,6 +152,13 @@ def should_override_support(self, node) -> bool:
132152
)
133153
return True
134154

155+
if _is_arg_min_max_over_flattened_input(node):
156+
self.log_once(
157+
"torch.ops.aten.{argmax, argmin}.default with dim=None is "
158+
"not supported by CoreML. Overriding op support."
159+
)
160+
return True
161+
135162
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
136163
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
137164
# # in the placeholders due to partitioning, which CoreML does not support

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,53 @@ def forward(self, x):
386386
self.assertIn("executorch_call_delegate", op_names)
387387
self.assertNotIn("aten.randn.default", op_names)
388388

389+
def test_argmax_argmin_dim_none_is_skipped(self):
390+
"""
391+
Regression test for https://github.com/pytorch/executorch/issues/11715.
392+
393+
argmax/argmin with dim=None reduces over the flattened tensor, which
394+
CoreML does not support; the resulting model intermittently crashes
395+
the process at runtime. The partitioner must reject these so they
396+
fall back to the portable backend, while still delegating the
397+
ordinary dim=int form.
398+
"""
399+
400+
class FlatModel(torch.nn.Module):
401+
def forward(self, x):
402+
return torch.argmax(x, dim=None, keepdim=False) + torch.argmin(
403+
x, dim=None
404+
)
405+
406+
ep = torch.export.export(
407+
FlatModel().eval(), (torch.randn(10, 10),), strict=True
408+
)
409+
edge = executorch.exir.to_edge_transform_and_lower(
410+
ep, partitioner=[CoreMLPartitioner()]
411+
)
412+
op_names = [
413+
n.target.__name__
414+
for n in edge.exported_program().graph.nodes
415+
if n.op == "call_function"
416+
]
417+
self.assertIn("aten.argmax.default", op_names)
418+
self.assertIn("aten.argmin.default", op_names)
419+
420+
class DimModel(torch.nn.Module):
421+
def forward(self, x):
422+
return torch.argmax(x, dim=1)
423+
424+
ep = torch.export.export(DimModel().eval(), (torch.randn(10, 10),), strict=True)
425+
edge = executorch.exir.to_edge_transform_and_lower(
426+
ep, partitioner=[CoreMLPartitioner()]
427+
)
428+
op_names = [
429+
n.target.__name__
430+
for n in edge.exported_program().graph.nodes
431+
if n.op == "call_function"
432+
]
433+
self.assertIn("executorch_call_delegate", op_names)
434+
self.assertNotIn("aten.argmax.default", op_names)
435+
389436
def test_deprecation_warning_for_to_backend_workflow(self):
390437
"""
391438
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.

0 commit comments

Comments
 (0)