Skip to content

Commit bf438ce

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for tan core ATen op (pytorch#19301)
### Summary Added support for the core aten op `tan` using a decomposition pass and the identity: ``` tan(x) = sin(x) / cos(x) ``` ### Test plan ``` python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_tan --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_tan --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android ```
1 parent 2874dcb commit bf438ce

7 files changed

Lines changed: 98 additions & 0 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .decompose_remainder import DecomposeRemainder
3333
from .decompose_roll import DecomposeRoll
3434
from .decompose_silu import DecomposeSilu
35+
from .decompose_tan import DecomposeTan
3536
from .decompose_threshold import DecomposeThreshold
3637
from .decompose_triu import DecomposeTriu
3738
from .decompose_trunc import DecomposeTrunc
@@ -88,6 +89,7 @@
8889
DecomposeRemainder,
8990
DecomposeRoll,
9091
DecomposeSilu,
92+
DecomposeTan,
9193
DecomposeThreshold,
9294
DecomposeTriu,
9395
DecomposeTrunc,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
from .utils import copy_meta
13+
14+
15+
class DecomposeTan(ExportPass):
16+
"""
17+
Decompose tan(x) = sin(x) / cos(x)
18+
"""
19+
20+
def __init__(self):
21+
super(DecomposeTan, self).__init__()
22+
self.targets = {
23+
torch.ops.aten.tan.default,
24+
exir_ops.edge.aten.tan.default,
25+
}
26+
27+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
28+
graph = graph_module.graph
29+
30+
for node in list(graph.nodes):
31+
if node.op == "call_function" and node.target in self.targets:
32+
is_edge = isinstance(node.target, EdgeOpOverload)
33+
34+
sin_op = (
35+
exir_ops.edge.aten.sin.default
36+
if is_edge
37+
else torch.ops.aten.sin.default
38+
)
39+
cos_op = (
40+
exir_ops.edge.aten.cos.default
41+
if is_edge
42+
else torch.ops.aten.cos.default
43+
)
44+
div_op = (
45+
exir_ops.edge.aten.div.Tensor
46+
if is_edge
47+
else torch.ops.aten.div.Tensor
48+
)
49+
50+
with graph.inserting_before(node):
51+
sin_node = graph.create_node(
52+
"call_function", sin_op, (node.args[0],)
53+
)
54+
sin_node.meta = copy_meta(node.meta)
55+
56+
cos_node = graph.create_node(
57+
"call_function", cos_op, (node.args[0],)
58+
)
59+
cos_node.meta = copy_meta(node.meta)
60+
61+
div_node = graph.create_node(
62+
"call_function", div_op, (sin_node, cos_node)
63+
)
64+
div_node.meta = copy_meta(node.meta)
65+
66+
for user in node.users.copy():
67+
user.replace_input_with(node, div_node)
68+
69+
graph.eliminate_dead_code()
70+
graph_module.recompile()
71+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
DecomposeRemainder,
3838
DecomposeRoll,
3939
DecomposeSilu,
40+
DecomposeTan,
4041
DecomposeThreshold,
4142
DecomposeTriu,
4243
DecomposeTrunc,
@@ -112,6 +113,7 @@ def get_capture_program_passes():
112113
(DecomposeMinMaxDim, True),
113114
(DecomposePad, True),
114115
(DecomposeRemainder, True),
116+
(DecomposeTan, True),
115117
(DecomposeTrunc, True),
116118
(ExpandBroadcastTensorShape, True),
117119
(FixedLinearKeepDim, True),
@@ -236,6 +238,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
236238
self.add_pass(DecomposeScaledDotProductAttention())
237239
self.add_pass(DecomposeRoll())
238240
self.add_pass(DecomposeSilu())
241+
self.add_pass(DecomposeTan())
239242
self.add_pass(DecomposeThreshold())
240243
self.add_pass(DecomposeTriu())
241244
self.add_pass(DecomposeTrunc())

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def get_passes_dependency_for_capture_program():
7474
DecomposeMaxPool3d,
7575
DecomposePad,
7676
DecomposeRemainder,
77+
DecomposeTan,
7778
DecomposeTrunc,
7879
ExpandBroadcastTensorShape,
7980
FixedLinearKeepDim,
@@ -107,6 +108,7 @@ def get_passes_dependency_for_capture_program():
107108
DecomposeMaxPool3d: [RemoveRedundancy],
108109
DecomposePad: [RemoveRedundancy],
109110
DecomposeRemainder: [RemoveRedundancy],
111+
DecomposeTan: [RemoveRedundancy],
110112
DecomposeTrunc: [RemoveRedundancy],
111113
ExpandBroadcastTensorShape: [FoldQDQ],
112114
FixedLinearKeepDim: [FoldQDQ],

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ The following PyTorch operators are supported through decomposition or annotatio
518518
| `aten.remainder.Scalar`, `aten.remainder.Tensor` | `DecomposeRemainder` |
519519
| `aten.roll` | `DecomposeRoll` |
520520
| `aten.silu` | `DecomposeSilu` |
521+
| `aten.tan` | `DecomposeTan` |
521522
| `aten.threshold` | `DecomposeThreshold` |
522523
| `aten.triu` | `DecomposeTriu` |
523524
| `aten.trunc` | `DecomposeTrunc` |

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,6 +2458,14 @@ def forward(self, x):
24582458
return torch.swapaxes(x, axis0=self.axis0, axis1=self.axis1)
24592459

24602460

2461+
class Tan(torch.nn.Module):
2462+
def __init__(self):
2463+
super().__init__()
2464+
2465+
def forward(self, x):
2466+
return torch.tan(x)
2467+
2468+
24612469
class Tanh(torch.nn.Module):
24622470
def __init__(self):
24632471
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,11 @@ def test_qnn_backend_swapaxes(self):
20522052
sample_input = (torch.randn([1, 2, 3, 4]),)
20532053
self.lower_module_and_test_output(module, sample_input)
20542054

2055+
def test_qnn_backend_tan(self):
2056+
module = Tan() # noqa: F405
2057+
sample_input = (torch.rand(2, 5, 1, 3) * 2 - 1,)
2058+
self.lower_module_and_test_output(module, sample_input)
2059+
20552060
def test_qnn_backend_tanh(self):
20562061
module = Tanh() # noqa: F405
20572062
sample_input = (torch.randn(2, 5, 1, 3),)
@@ -4854,6 +4859,12 @@ def test_qnn_backend_swapaxes(self):
48544859
module = self.get_qdq_module(module, sample_input)
48554860
self.lower_module_and_test_output(module, sample_input)
48564861

4862+
def test_qnn_backend_tan(self):
4863+
module = Tan() # noqa: F405
4864+
sample_input = (torch.rand(2, 5, 1, 3) * 2 - 1,)
4865+
module = self.get_qdq_module(module, sample_input)
4866+
self.lower_module_and_test_output(module, sample_input)
4867+
48574868
def test_qnn_backend_tanh(self):
48584869
module = Tanh() # noqa: F405
48594870
sample_input = (torch.randn(2, 5, 1, 3),)

0 commit comments

Comments
 (0)