Skip to content

Commit fa4c6a0

Browse files
authored
Arm backend: Create more general GetDecompositionPass (pytorch#18011)
* Add GetDecompositiobPass for passes that use get_decomposition() Change-Id: Ifd1ec40dda3d225a84dc37c0c60bac7051d7e5bf cc @digantdesai @SS-JIA @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell --------- Signed-off-by: Tom Allsop <tom.allsop@arm.com>
1 parent c1c9a5e commit fa4c6a0

2 files changed

Lines changed: 138 additions & 85 deletions

File tree

backends/arm/_passes/decompose_matmul.py

Lines changed: 12 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
from typing import Set, Type
99

1010
import torch
11-
from executorch.backends.arm._passes import ArmPass
12-
from executorch.exir.pass_base import ExportPass, PassResult
13-
from torch._decomp import get_decompositions
14-
from torch.fx.experimental.proxy_tensor import make_fx
11+
from executorch.backends.arm._passes.get_decomposition_pass import GetDecompositionPass
12+
from executorch.exir.pass_base import ExportPass
1513

1614

17-
class DecomposeMatmulPass(ArmPass):
15+
class DecomposeMatmulPass(GetDecompositionPass):
1816
"""Decompose aten.matmul into more primitive ops using PyTorch decomposition table.
1917
By defualt, to_edge will decompose torch.matmul into 1d (dot), 2d (mm), 3d (bmm), ops
2018
along with required reshape operators and expands/repeats. For a quanitzed matmul this
@@ -37,84 +35,13 @@ class DecomposeMatmulPass(ArmPass):
3735
torch.ops.aten.matmul.default,
3836
]
3937

40-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
41-
modified = False
42-
for node in graph_module.graph.nodes:
43-
if (
44-
node.op != "call_function"
45-
or node.target not in self.targeted_ops
46-
or not self.allowed_to_transform(node.meta)
47-
):
48-
continue
38+
def _skip_pass(self, input_tensors: list) -> bool:
39+
# TODO Add support for multiplication with vectors
40+
if len(input_tensors) > 1:
41+
if input_tensors[1].dim() == 1:
42+
return True
43+
else:
44+
if input_tensors[0].dim() == 1:
45+
return True
4946

50-
input_tensors = [arg.meta["val"] for arg in node.args]
51-
# TODO Add support for mutliplication with vectors
52-
if len(input_tensors) > 1:
53-
if input_tensors[1].dim() == 1:
54-
continue
55-
else:
56-
if input_tensors[0].dim() == 1:
57-
continue
58-
59-
# refer to pytorch/test/test_decomp.py
60-
decomposed_module = make_fx(
61-
node.target,
62-
decomposition_table=get_decompositions(self.targeted_ops), # type: ignore[arg-type]
63-
tracing_mode="fake",
64-
_allow_non_fake_inputs=False,
65-
)(*input_tensors)
66-
with graph_module.graph.inserting_before(node):
67-
name_to_input_tensor_map = {}
68-
for i, arg in enumerate(node.args):
69-
name_to_input_tensor_map[f"arg{i}_1"] = arg
70-
71-
decomposed_node_to_subgraph_node = {}
72-
last_decomposed_node = None
73-
# Create a mapping from input nodes in decomposed module to original nodes.
74-
# In decomposed module, there are only input tensors for placeholder op.
75-
for decomposed_node in decomposed_module.graph.nodes:
76-
if decomposed_node.op == "placeholder":
77-
decomposed_node_to_subgraph_node[decomposed_node] = (
78-
name_to_input_tensor_map[decomposed_node.name]
79-
)
80-
81-
if decomposed_node.op == "output":
82-
last_decomposed_node = decomposed_node.args[0]
83-
84-
# Copy node from decompose graph module
85-
for decomposed_node in decomposed_module.graph.nodes:
86-
decomposed_node.meta["nn_module_stack"] = node.meta.get(
87-
"nn_module_stack"
88-
)
89-
if decomposed_node.op == "placeholder":
90-
continue
91-
92-
if (
93-
decomposed_node.op == "output"
94-
and last_decomposed_node is not None
95-
):
96-
for user in node.users.copy():
97-
user.replace_input_with(
98-
node,
99-
decomposed_node_to_subgraph_node[last_decomposed_node],
100-
)
101-
continue
102-
103-
subgraph_node = graph_module.graph.node_copy(
104-
decomposed_node,
105-
arg_transform=lambda x: decomposed_node_to_subgraph_node[ # noqa: B023
106-
x
107-
],
108-
)
109-
subgraph_node.meta["source_fn_stack"] = [
110-
(subgraph_node, subgraph_node.target)
111-
]
112-
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node
113-
114-
graph_module.graph.erase_node(node)
115-
116-
modified = True
117-
if modified:
118-
graph_module.graph.eliminate_dead_code()
119-
graph_module.recompile()
120-
return PassResult(graph_module, modified)
47+
return False
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
from typing import Set, Type
9+
10+
import torch
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch._decomp import get_decompositions
14+
from torch._ops import OpOverload
15+
from torch.fx.experimental.proxy_tensor import make_fx
16+
17+
18+
class GetDecompositionPass(ArmPass):
19+
20+
_passes_required_after: Set[Type[ExportPass]] = set()
21+
22+
targeted_ops: list[OpOverload] = []
23+
24+
def __init__(self, tfa_pass=False, *args, **kwargs):
25+
super().__init__(tfa_pass, *args, **kwargs)
26+
27+
self.__decomposition = None
28+
29+
if type(self) is GetDecompositionPass:
30+
raise TypeError(
31+
"Base class GetDecompositionPass cannot be instantiated directly."
32+
)
33+
34+
def _skip_pass(self, input_tensors: list) -> bool:
35+
return False
36+
37+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
38+
modified = False
39+
for node in graph_module.graph.nodes:
40+
if (
41+
node.op != "call_function"
42+
or node.target not in self.targeted_ops
43+
or not self.allowed_to_transform(node.meta)
44+
):
45+
continue
46+
47+
input_tensors = []
48+
for arg in node.args:
49+
if hasattr(arg, "meta"):
50+
input_tensors.append(arg.meta["val"])
51+
52+
elif isinstance(arg, int):
53+
input_tensors.append(arg)
54+
55+
if self._skip_pass(input_tensors):
56+
continue
57+
58+
decomposition = (
59+
self.__decomposition
60+
if self.__decomposition is not None
61+
else get_decompositions(self.targeted_ops)
62+
)
63+
64+
# refer to pytorch/test/test_decomp.py
65+
decomposed_module = make_fx(
66+
node.target,
67+
decomposition_table=decomposition, # type: ignore[arg-type]
68+
tracing_mode="fake",
69+
_allow_non_fake_inputs=False,
70+
)(*input_tensors)
71+
72+
with graph_module.graph.inserting_before(node):
73+
name_to_input_tensor_map = {}
74+
for i, arg in enumerate(node.args):
75+
name_to_input_tensor_map[f"arg{i}_1"] = arg
76+
77+
decomposed_node_to_subgraph_node = {}
78+
last_decomposed_node = None
79+
# Create a mapping from input nodes in decomposed module to original nodes.
80+
# In decomposed module, there are only input tensors for placeholder op.
81+
for decomposed_node in decomposed_module.graph.nodes:
82+
if decomposed_node.op == "placeholder":
83+
decomposed_node_to_subgraph_node[decomposed_node] = (
84+
name_to_input_tensor_map[decomposed_node.name]
85+
)
86+
87+
if decomposed_node.op == "output":
88+
last_decomposed_node = decomposed_node.args[0]
89+
90+
# Copy node from decompose graph module
91+
for decomposed_node in decomposed_module.graph.nodes:
92+
decomposed_node.meta["nn_module_stack"] = node.meta.get(
93+
"nn_module_stack"
94+
)
95+
if decomposed_node.op == "placeholder":
96+
continue
97+
98+
if (
99+
decomposed_node.op == "output"
100+
and last_decomposed_node is not None
101+
):
102+
for user in node.users.copy():
103+
user.replace_input_with(
104+
node,
105+
decomposed_node_to_subgraph_node[last_decomposed_node],
106+
)
107+
continue
108+
109+
subgraph_node = graph_module.graph.node_copy(
110+
decomposed_node,
111+
arg_transform=lambda x: decomposed_node_to_subgraph_node[ # noqa: B023
112+
x
113+
],
114+
)
115+
subgraph_node.meta["source_fn_stack"] = [
116+
(subgraph_node, subgraph_node.target)
117+
]
118+
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node
119+
120+
graph_module.graph.erase_node(node)
121+
122+
modified = True
123+
if modified:
124+
graph_module.graph.eliminate_dead_code()
125+
graph_module.recompile()
126+
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)