Skip to content

Commit 0fc4d6d

Browse files
Arm backend: Fuse consecutive CONCAT_SHAPES (pytorch#18519)
Make sure that the arguments to CONCAT_SHAPE has rank==1 to make dim_order permutation work. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent e7c9c9e commit 0fc4d6d

4 files changed

Lines changed: 176 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
QuantizeClampArgumentsPass,
103103
)
104104
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
105+
from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa
105106
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
106107
from .fuse_constant_ops_pass import ( # noqa
107108
ComputeConstantOpsAOTPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
DecorateFp32toInt32CastingPass,
9999
FoldAndAnnotateQParamsPass,
100100
FuseBatchNorm2dPass,
101+
FuseConsecutiveConcatShapesPass,
101102
FuseConsecutiveRescalesPass,
102103
FuseConstantArgsPass,
103104
FuseDuplicateUsersPass,
@@ -503,6 +504,7 @@ def _tosa_pipeline(
503504
[
504505
CastInt64BuffersToInt32Pass(exported_program),
505506
FuseEqualPlaceholdersPass(exported_program),
507+
FuseConsecutiveConcatShapesPass(),
506508
ToTosaMemoryFormatPass(exported_program),
507509
RemoveNoopPass(),
508510
InsertRescalePass(),
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import NodeMetadata, ProxyValue
12+
13+
14+
class FuseConsecutiveConcatShapesPass(ArmPass):
15+
"""This pass fuses consecutive tosa.CONCAT_SHAPE operations into a single
16+
tosa.CONCAT_SHAPE operation with a flattened list of input shapes. E.g.
17+
tosa.CONCAT_SHAPE([shape1, tosa.CONCAT_SHAPE([shape2, shape3]), shape4])
18+
becomes tosa.CONCAT_SHAPE([shape1, shape2, shape3, shape4])
19+
20+
This is necessary in order for dim-order propagation to work correctly. E.g.
21+
in the case of dim-order==(0, 2, 3, 1) we would need to permute input shapes
22+
accordingly. This is much easier if the inputs are flattened.
23+
24+
"""
25+
26+
_passes_required_after = set()
27+
28+
def _to_proxy_value(
29+
self, arg: ProxyValue | torch.fx.Node | Any
30+
) -> ProxyValue | Any:
31+
if isinstance(arg, ProxyValue):
32+
return arg
33+
if isinstance(arg, torch.fx.Node):
34+
return ProxyValue(arg.meta["val"], self.tracer.proxy(arg))
35+
return arg
36+
37+
def call_operator(
38+
self,
39+
op: Any,
40+
args: tuple[Any, ...],
41+
kwargs: dict[str, Any],
42+
meta: NodeMetadata,
43+
updated: bool | None = False,
44+
) -> ProxyValue:
45+
if op != exir_ops.backend.tosa.CONCAT_SHAPE.default:
46+
return super().call_operator(op, args, kwargs, meta)
47+
arg_list = args[0]
48+
new_arg_list: list[Any] = []
49+
modified = False
50+
for arg in arg_list:
51+
if (
52+
hasattr(arg, "node")
53+
and arg.node.target == exir_ops.backend.tosa.CONCAT_SHAPE.default
54+
):
55+
new_arg_list.extend(
56+
self._to_proxy_value(nested_arg) for nested_arg in arg.node.args[0]
57+
)
58+
modified = True
59+
else:
60+
new_arg_list.append(arg)
61+
return super().call_operator(
62+
op, (new_arg_list,), kwargs, meta, updated=modified
63+
)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import executorch.backends.arm.tosa.dialect # noqa: F401
6+
from executorch.backends.arm._passes.fuse_consecutive_concat_shapes import (
7+
FuseConsecutiveConcatShapesPass,
8+
)
9+
from executorch.backends.arm.tosa.specification import (
10+
TosaLoweringContext,
11+
TosaSpecification,
12+
)
13+
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass
16+
from torch.fx import GraphModule
17+
from torch.fx.passes.infra.pass_base import PassResult
18+
19+
20+
def _graph_module_with_nested_concat():
21+
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
22+
builder = GraphBuilder()
23+
const_0 = builder.call_operator(
24+
exir_ops.backend.tosa.CONST_SHAPE.default, ([0],)
25+
)
26+
const_1 = builder.call_operator(
27+
exir_ops.backend.tosa.CONST_SHAPE.default, ([1],)
28+
)
29+
const_2 = builder.call_operator(
30+
exir_ops.backend.tosa.CONST_SHAPE.default, ([2],)
31+
)
32+
const_3 = builder.call_operator(
33+
exir_ops.backend.tosa.CONST_SHAPE.default, ([3],)
34+
)
35+
inner = builder.call_operator(
36+
exir_ops.backend.tosa.CONCAT_SHAPE.default, ([const_1, const_2],)
37+
)
38+
outer = builder.call_operator(
39+
exir_ops.backend.tosa.CONCAT_SHAPE.default, ([const_0, inner, const_3],)
40+
)
41+
builder.output([outer])
42+
return ExportPass().call(builder.get_graph_module()).graph_module
43+
44+
45+
def _graph_module_with_flat_concat():
46+
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
47+
builder = GraphBuilder()
48+
const_0 = builder.call_operator(
49+
exir_ops.backend.tosa.CONST_SHAPE.default, ([4],)
50+
)
51+
const_1 = builder.call_operator(
52+
exir_ops.backend.tosa.CONST_SHAPE.default, ([5],)
53+
)
54+
const_2 = builder.call_operator(
55+
exir_ops.backend.tosa.CONST_SHAPE.default, ([6],)
56+
)
57+
outer = builder.call_operator(
58+
exir_ops.backend.tosa.CONCAT_SHAPE.default, ([const_0, const_1, const_2],)
59+
)
60+
builder.output([outer])
61+
return ExportPass().call(builder.get_graph_module()).graph_module
62+
63+
64+
def _concat_shape_nodes(graph_module):
65+
return [
66+
node
67+
for node in graph_module.graph.nodes
68+
if node.op == "call_function"
69+
and node.target == exir_ops.backend.tosa.CONCAT_SHAPE.default
70+
]
71+
72+
73+
def _const_shape_values(shape_list_nodes):
74+
return [node.args[0][0] for node in shape_list_nodes]
75+
76+
77+
def _run_fuse_pass(graph_module: GraphModule):
78+
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
79+
result = FuseConsecutiveConcatShapesPass()(graph_module)
80+
if isinstance(result, PassResult):
81+
graph_module = result.graph_module
82+
graph_module.graph.eliminate_dead_code()
83+
return graph_module
84+
85+
86+
def test_fuse_consecutive_concat_shapes_no_target_flattens_nested_concat_inputs():
87+
graph_module = _graph_module_with_nested_concat()
88+
graph_module = _run_fuse_pass(graph_module)
89+
90+
concat_nodes = _concat_shape_nodes(graph_module)
91+
outer_concat = concat_nodes[-1]
92+
outer_inputs = outer_concat.args[0]
93+
94+
assert len(concat_nodes) == 1
95+
assert _const_shape_values(outer_inputs) == [0, 1, 2, 3]
96+
assert all(
97+
node.target == exir_ops.backend.tosa.CONST_SHAPE.default
98+
for node in outer_inputs
99+
)
100+
101+
102+
def test_fuse_consecutive_concat_shapes_no_target_leaves_flat_concat_unchanged():
103+
graph_module = _graph_module_with_flat_concat()
104+
graph_module = _run_fuse_pass(graph_module)
105+
106+
concat_nodes = _concat_shape_nodes(graph_module)
107+
outer_inputs = concat_nodes[-1].args[0]
108+
109+
assert len(concat_nodes) == 1
110+
assert _const_shape_values(outer_inputs) == [4, 5, 6]

0 commit comments

Comments
 (0)