Skip to content

Commit 5363438

Browse files
Arm backend: Improve permute fusion over elementwise ops (pytorch#19451)
Adds handling of all ops handled by the insert_table_ops - For FP, add all ops to remove_permutes_arount_elementwise_tosa_ops - For INT, ensure that the tosa.TABLE op is treated properly Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent af1f7d4 commit 5363438

3 files changed

Lines changed: 49 additions & 3 deletions

File tree

backends/arm/_passes/insert_table_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
278278
out_quantargs=output_qparams[0],
279279
)
280280
# Register buffer in self.exported_program.state_dict
281+
# b_ prefix is important to be recognized as a constant in RemovePermutesAroundElementwiseOps
281282
const_table_node = create_constant_placeholder(
282283
exp_program=self.exported_program,
283284
graph=node.graph,
284285
kind=InputKind.BUFFER,
285-
name=node.name + "_table_constant",
286+
name="b_" + node.name + "_table_constant",
286287
data=buffer,
287288
persistent_buffer=True,
288289
)

backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from executorch.backends.arm._passes.insert_table_ops import TableOps
67
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
78
RemovePermutesAroundElementwiseOps,
89
)
@@ -12,6 +13,21 @@
1213
class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
1314
permutable_ops = {
1415
*RemovePermutesAroundElementwiseOps.permutable_ops,
16+
*TableOps.unary_table_ops.keys(),
17+
*TableOps.special_table_ops,
1518
exir_ops.backend.tosa.RESCALE.default,
1619
exir_ops.backend.tosa.TABLE.default,
1720
}
21+
22+
def permute_subgraph(self, subgraph):
23+
# Original function will always permute constant nodes which is wrong for table ops
24+
# Remove constant tosa.TABLE edges before running full function
25+
new_constant_edges_in = set()
26+
for const_node, user_node in subgraph.constant_edges_in:
27+
if user_node.target == exir_ops.backend.tosa.TABLE.default:
28+
continue
29+
else:
30+
new_constant_edges_in.add((const_node, user_node))
31+
32+
subgraph.constant_edges_in = new_constant_edges_in
33+
super().permute_subgraph(subgraph)

backends/arm/test/misc/test_transpose_counts.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import torch
1010

1111
from executorch.backends.arm.test import common
12-
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineFP
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
TosaPipelineFP,
14+
TosaPipelineINT,
15+
)
1316

1417

1518
InputT = Tuple[Any, ...]
@@ -330,6 +333,17 @@ def forward(self, x):
330333
return torch.cat((a, b), dim=-1)
331334

332335

336+
class PermuteSiluPermute(torch.nn.Module):
337+
def __init__(self):
338+
super().__init__()
339+
self.silu = torch.nn.SiLU()
340+
341+
def forward(self, x: torch.Tensor):
342+
x = torch.permute(x, [0, 2, 3, 1])
343+
x = self.silu(x)
344+
return torch.permute(x, [0, 3, 1, 2])
345+
346+
333347
cases = {
334348
"conv1d_rank2": TransposeCountCase(Conv1dModule(), (torch.randn(2, 8),), 2),
335349
"conv1d_rank3": TransposeCountCase(Conv1dModule(), (torch.randn(1, 2, 8),), 2),
@@ -458,6 +472,14 @@ def forward(self, x):
458472
),
459473
}
460474

475+
cases_int = {
476+
"permute_silu_permute": TransposeCountCase(
477+
PermuteSiluPermute(),
478+
(torch.randn(1, 2, 3, 4),),
479+
0,
480+
),
481+
}
482+
461483

462484
cases_channels_last = {
463485
"conv2d_rank4_channels_last": TransposeCountCase(
@@ -531,13 +553,20 @@ def forward(self, x):
531553
}
532554

533555

534-
@common.parametrize("case", cases)
556+
@common.parametrize("case", cases | cases_int)
535557
def test_transpose_counts_tosa_FP(case: TransposeCountCase) -> None:
536558
pipeline = TosaPipelineFP[InputT](case.module, case.inputs, aten_op=[])
537559
pipeline.count_tosa_ops({"TRANSPOSE": case.expected_transposes})
538560
pipeline.run()
539561

540562

563+
@common.parametrize("case", cases_int)
564+
def test_transpose_counts_tosa_INT(case: TransposeCountCase) -> None:
565+
pipeline = TosaPipelineINT[InputT](case.module, case.inputs, aten_op=[])
566+
pipeline.count_tosa_ops({"TRANSPOSE": case.expected_transposes})
567+
pipeline.run()
568+
569+
541570
xfails = {
542571
"conv3d_rank5_channels_last": "Numerical error",
543572
"views_channels_last": "Torch.export: View not supported by torch.export in channels last format",

0 commit comments

Comments
 (0)