Skip to content

Commit f63324e

Browse files
Arm backend: Add count_tosa_ops test helper (pytorch#18006)
Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 4ad6012 commit f63324e

3 files changed

Lines changed: 83 additions & 0 deletions

File tree

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import pytest
7+
import torch
8+
9+
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineFP
10+
11+
12+
class AddModule(torch.nn.Module):
13+
def forward(self, x, y):
14+
return x + y
15+
16+
17+
def test_count_tosa_ops_add_no_target():
18+
model = AddModule()
19+
test_data = (torch.randn(1, 8, 8, 8), torch.randn(1, 8, 8, 8))
20+
pipeline = TosaPipelineFP[type(test_data)](
21+
model,
22+
test_data,
23+
["torch.ops.aten.add.Tensor"],
24+
run_on_tosa_ref_model=False,
25+
)
26+
pipeline.count_tosa_ops({"ADD": 1, "SUB": 0})
27+
pipeline.run()
28+
29+
30+
def test_count_tosa_ops_2_adds_no_target():
31+
model = AddModule()
32+
test_data = (torch.randn(1, 8, 8, 8), torch.randn(1, 8, 8, 8))
33+
pipeline = TosaPipelineFP[type(test_data)](
34+
model,
35+
test_data,
36+
["torch.ops.aten.add.Tensor"],
37+
run_on_tosa_ref_model=False,
38+
)
39+
pipeline.count_tosa_ops({"ADD": 2})
40+
with pytest.raises(AssertionError, match="Expected 2 occurrences of TOSA op ADD"):
41+
pipeline.run()

backends/arm/test/tester/arm_tester.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,3 +1182,24 @@ def _format_dict(to_print: dict, print_table: bool = True) -> str:
11821182
)
11831183
else:
11841184
return pformat(to_print, compact=True, indent=1)
1185+
1186+
1187+
def count_tosa_ops(graph_module: torch.fx.GraphModule, expected_ops: Dict[str, int]):
1188+
"""Asserts that the number of occurrences of TOSA operators in the graph of
1189+
a partitioned module matches the expected counts.
1190+
"""
1191+
op_counts = dict(_get_tosa_operator_distribution(graph_module))
1192+
for op, expected_count in expected_ops.items():
1193+
actual_count = op_counts.get(op, 0)
1194+
1195+
if expected_count != actual_count:
1196+
if expected_count == 0:
1197+
raise AssertionError(
1198+
f"Expected no occurrences of TOSA op {op} but found {actual_count}."
1199+
)
1200+
elif actual_count == 0:
1201+
raise AssertionError(f"Expected TOSA op {op} but it was not found.")
1202+
else:
1203+
raise AssertionError(
1204+
f"Expected {expected_count} occurrences of TOSA op {op} but found {actual_count}."
1205+
)

backends/arm/test/tester/test_pipeline.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
Union,
2121
)
2222

23+
import executorch.backends.arm.test.tester.arm_tester as arm_tester_module
24+
2325
import torch
2426
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
2527
from executorch.backends.arm.ethosu import EthosUCompileSpec
@@ -301,6 +303,25 @@ def visualize(self, stage_id: str, suffix: str | None = None):
301303
self.add_stage_after(stage_id, self.tester.visualize, suffix=suffix)
302304
return self
303305

306+
def count_tosa_ops(self, expected_ops: Dict[str, int]):
307+
"""Assert the number of TOSA ops in the graph,"""
308+
if not self.has_stage("to_edge_transform_and_lower"):
309+
raise RuntimeError(
310+
"count_tosa_ops requires to_edge_transform_and_lower in the pipeline."
311+
)
312+
313+
def _count_tosa_ops():
314+
stage = self.tester.stages[StageType.TO_EDGE_TRANSFORM_AND_LOWER]
315+
graph_module = stage.graph_module
316+
arm_tester_module.count_tosa_ops(graph_module, expected_ops)
317+
318+
self.add_stage_after(
319+
"to_edge_transform_and_lower",
320+
_count_tosa_ops,
321+
suffix="tosa_ops",
322+
)
323+
return self
324+
304325
def change_args(self, stage_id: str, *args, **kwargs):
305326
"""Updates the args to the given stage id."""
306327
pos = self.find_pos(stage_id)

0 commit comments

Comments
 (0)