Skip to content

Commit 463fbe4

Browse files
Add general Aten lowering pass (pytorch#19837)
Adds a simple pass for replacing single Aten ops with corresponding dialect ops to be reused across multiple backends. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent acce7cd commit 463fbe4

3 files changed

Lines changed: 402 additions & 0 deletions

File tree

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import traceback
8+
from collections.abc import Callable
9+
from dataclasses import dataclass
10+
from typing import ClassVar, TypeAlias
11+
12+
import torch
13+
14+
from executorch.backends.xnnpack._passes.xnnpack_pass import ExportPass
15+
16+
from executorch.exir import ExportedProgram
17+
from torch.fx.node import Target
18+
from torch.fx.passes.infra.pass_manager import PassResult
19+
20+
21+
# Expected type to be returned by substitution functions.
22+
@dataclass
23+
class DialectNodeSpec:
24+
op: Target
25+
args: tuple
26+
kwargs: dict = None
27+
28+
29+
# Expected type to be used for substitution functions
30+
SubstitutionFn: TypeAlias = Callable[
31+
[torch.fx.Node, torch.export.ExportedProgram], DialectNodeSpec | None
32+
]
33+
34+
35+
class AtenToDialectPass(ExportPass):
36+
"""
37+
General pass to convert ops 1-1 from ATen to a specific dialect.
38+
39+
Usage:
40+
1. Subclass the pass for a specific dialect
41+
2. For each ATen target to be substituted, implement a function returning a DialectNodeSpec defining the
42+
corresponding dialect op, or None if the substitution does not apply.
43+
3. Register each substitution function for the subclass using the decorator register_dialect_substitution
44+
45+
Only one substitution function can be registered for a given target.
46+
47+
The pass must be initialized with an exported_program to allow substitution functions to modify placeholders,
48+
e.g. if the dialect ops require additional scratch buffers.
49+
"""
50+
51+
_DIALECT_SUBSTITUTIONS: ClassVar[dict[Target, SubstitutionFn]] = {}
52+
53+
def __init__(self, exported_program: ExportedProgram):
54+
super().__init__()
55+
self.exported_program: ExportedProgram = exported_program
56+
57+
# Ensure each subclass has its own substitution registry.
58+
def __init_subclass__(cls, **kwargs):
59+
super().__init_subclass__(**kwargs)
60+
cls._DIALECT_SUBSTITUTIONS = {}
61+
62+
@classmethod
63+
def register_dialect_substitution(
64+
cls, target: Target
65+
) -> Callable[[SubstitutionFn], SubstitutionFn]:
66+
67+
def decorator(func: SubstitutionFn) -> SubstitutionFn:
68+
if target in cls._DIALECT_SUBSTITUTIONS:
69+
raise RuntimeError(
70+
f"Multiple substitutions registered for the same target in {cls.__name__} are not allowed."
71+
)
72+
else:
73+
cls._DIALECT_SUBSTITUTIONS[target] = func
74+
return func
75+
76+
return decorator
77+
78+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
79+
modified = False
80+
81+
for node in graph_module.graph.nodes:
82+
if node.op != "call_function":
83+
continue
84+
85+
substitution_func = self._DIALECT_SUBSTITUTIONS.get(node.target, None)
86+
if substitution_func is None:
87+
continue
88+
89+
dialect_node_spec = substitution_func(node, self.exported_program)
90+
if dialect_node_spec is None:
91+
continue
92+
93+
modified = True
94+
with graph_module.graph.inserting_before(node):
95+
dialect_node = graph_module.graph.create_node(
96+
"call_function",
97+
target=dialect_node_spec.op,
98+
args=dialect_node_spec.args,
99+
kwargs=dialect_node_spec.kwargs or {},
100+
)
101+
102+
node.replace_all_uses_with(dialect_node)
103+
104+
# Keep same meta dict for new node and append new trace
105+
dialect_node.meta = node.meta
106+
old_stack_trace = dialect_node.meta.get("stack_trace", "")
107+
dialect_node.meta["stack_trace"] = (
108+
f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
109+
)
110+
111+
graph_module.graph.erase_node(node)
112+
113+
if modified:
114+
graph_module.graph.eliminate_dead_code()
115+
graph_module.recompile()
116+
graph_module = super().call(graph_module).graph_module
117+
118+
return PassResult(graph_module, modified)
119+
120+
def requires(self, graph_module):
121+
self.ops_before = sum(
122+
1 for node in graph_module.graph.nodes if node.op == "call_function"
123+
)
124+
return super().requires(graph_module)
125+
126+
def ensures(self, graph_module: torch.fx.GraphModule) -> bool:
127+
"""Ensure that there has only been 1-1 substitution of call_function nodes, i.e. that the number of call_function nodes is preserved after the pass."""
128+
129+
self.ops_after = sum(
130+
1 for node in graph_module.graph.nodes if node.op == "call_function"
131+
)
132+
if self.ops_after != self.ops_before:
133+
raise RuntimeError(
134+
f"{self.__class__.__name__} did not preserve the number of call_function nodes: "
135+
f"before={self.ops_before}, after={self.ops_after}"
136+
)
137+
138+
return super().ensures(graph_module)

backends/transforms/targets.bzl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,21 @@ def define_common_targets():
176176
],
177177
)
178178

179+
runtime.python_library(
180+
name = "aten_to_dialect_pass",
181+
srcs = [
182+
"aten_to_dialect_pass.py",
183+
],
184+
visibility = [
185+
"//executorch/backends/...",
186+
],
187+
deps = [
188+
"//caffe2:torch",
189+
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
190+
"//executorch/exir:lib",
191+
],
192+
)
193+
179194
runtime.python_library(
180195
name = "rank_0_to_rank_1",
181196
srcs = [
@@ -243,6 +258,16 @@ def define_common_targets():
243258
],
244259
)
245260

261+
runtime.python_test(
262+
name = "test_aten_to_dialect_pass",
263+
srcs = [
264+
"test/test_aten_to_dialect_pass.py",
265+
],
266+
deps = [
267+
"//caffe2:torch",
268+
":aten_to_dialect_pass",
269+
],
270+
)
246271

247272
runtime.python_test(
248273
name = "test_rank_0_to_rank_1",

0 commit comments

Comments
 (0)