Skip to content

Commit b0cc15a

Browse files
authored
Arm backend: Skip concating single tensors in AdaptiveAvgPool (pytorch#18520)
It was inserting no-op concats. It was discovered in pytorch#18500
1 parent 28b4813 commit b0cc15a

1 file changed

Lines changed: 14 additions & 9 deletions

File tree

backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,20 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
9898
avg_pool2d_op, pool_args, kwargs, meta, True
9999
)
100100
row.append(pooled)
101-
102-
# Concatenate row results along width (dim=3)
103-
row_tensor = super().call_operator(
104-
cat_op, (row, 3), kwargs, meta_with_no_qparams, True
105-
)
101+
# Concatenate row results along width (dim=3) if more than one.
102+
if len(row) > 1:
103+
row_tensor = super().call_operator(
104+
cat_op, (row, 3), kwargs, meta_with_no_qparams, True
105+
)
106+
else:
107+
row_tensor = row[0]
106108
res.append(row_tensor)
107109

108-
# Concatenate all rows along height (dim=2)
109-
out = super().call_operator(
110-
cat_op, (res, 2), kwargs, meta_with_no_qparams, True
111-
)
110+
# Concatenate all rows along height (dim=2) if more than one.
111+
if len(res) > 1:
112+
out = super().call_operator(
113+
cat_op, (res, 2), kwargs, meta_with_no_qparams, True
114+
)
115+
else:
116+
out = res[0]
112117
return out

0 commit comments

Comments
 (0)