Skip to content

Commit f7e3800

Browse files
Arm backend: Support permute-removal for TOSA ops (pytorch#19277)
Adds a TOSA specific variant of RemovePermutesAroundElementwiseOps that makes sure that elementwise TOSA backend dialect operators also are covered by this pass. As of now this includes TABLE and RESCALE. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent d939b9b commit f7e3800

4 files changed

Lines changed: 81 additions & 4 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@
140140
from .remove_getitem_pass import RemoveGetItemPass # noqa
141141
from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa
142142
from .remove_noop_pass import RemoveNoopPass # noqa
143+
from .remove_permutes_around_elementwise_tosa_ops import ( # noqa
144+
RemovePermutesAroundElementwiseTosaOps,
145+
)
143146
from .replace_scalar_with_tensor_pass import ( # noqa
144147
ReplaceScalarWithTensorByProfilePass,
145148
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
RemoveGetItemPass,
126126
RemoveGraphAssertsPass,
127127
RemoveNoopPass,
128+
RemovePermutesAroundElementwiseTosaOps,
128129
ReplaceInfAndLimitValuesPass,
129130
ReplaceScalarWithTensorByProfilePass,
130131
RewriteAvgPool2dPass,
@@ -164,9 +165,6 @@
164165
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
165166
)
166167

167-
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
168-
RemovePermutesAroundElementwiseOps,
169-
)
170168
from executorch.exir import ExportedProgram
171169
from executorch.exir.pass_base import ExportPass
172170
from executorch.exir.pass_manager import PassManager
@@ -538,7 +536,7 @@ def _tosa_pipeline(
538536
RewriteMatmulPass(),
539537
RewritePadPass(),
540538
FuseViewCopyTransformPass(),
541-
RemovePermutesAroundElementwiseOps(),
539+
RemovePermutesAroundElementwiseTosaOps(),
542540
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
543541
FuseCascadedTransposeOrPermuteOps(),
544542
ConvertPermuteSingletonToViewPass(),
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
7+
RemovePermutesAroundElementwiseOps,
8+
)
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
12+
class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
13+
permutable_ops = {
14+
*RemovePermutesAroundElementwiseOps.permutable_ops,
15+
exir_ops.backend.tosa.RESCALE.default,
16+
exir_ops.backend.tosa.TABLE.default,
17+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.remove_permutes_around_elementwise_tosa_ops import (
8+
RemovePermutesAroundElementwiseTosaOps,
9+
)
10+
from executorch.backends.arm.tosa.specification import (
11+
TosaLoweringContext,
12+
TosaSpecification,
13+
)
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
TOSA_INT_SPEC = TosaSpecification.create_from_string("TOSA-1.0+INT")
17+
PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default
18+
RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default
19+
TABLE_TARGET = exir_ops.backend.tosa.TABLE.default
20+
21+
22+
def _count_nodes(graph_module: torch.fx.GraphModule, target) -> int:
23+
return sum(
24+
1
25+
for node in graph_module.graph.nodes
26+
if node.op == "call_function" and node.target == target
27+
)
28+
29+
30+
def test_remove_permutes_around_rescale_tosa_INT() -> None:
31+
graph = torch.fx.Graph()
32+
x = graph.placeholder("x")
33+
x.meta["val"] = torch.randn(1, 3, 4, 5)
34+
35+
permute_in = graph.create_node(
36+
"call_function",
37+
PERMUTE_TARGET,
38+
args=(x, [0, 2, 3, 1]),
39+
)
40+
rescale = graph.create_node(
41+
"call_function",
42+
RESCALE_TARGET,
43+
args=(permute_in, torch.int8, [1.0], 0, 0),
44+
)
45+
permute_out = graph.create_node(
46+
"call_function",
47+
PERMUTE_TARGET,
48+
args=(rescale, [0, 3, 1, 2]),
49+
)
50+
graph.output(permute_out)
51+
52+
graph_module = torch.fx.GraphModule({}, graph)
53+
54+
with TosaLoweringContext(TOSA_INT_SPEC):
55+
result = RemovePermutesAroundElementwiseTosaOps().call(graph_module)
56+
57+
assert result.modified
58+
assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 0
59+
assert _count_nodes(result.graph_module, RESCALE_TARGET) == 1

0 commit comments

Comments
 (0)