Skip to content

Commit 5dbbec3

Browse files
committed
Allow symints to be created for arguments
Add test for creating args of SymInt type to be able to use them in view_copy nodes in the Arm TOSA backend together with the fix to make the pass work. Signed-off-by: Per Åstrand <per.astrand@arm.com> Change-Id: Ia947b8426af1b473df415a17e10f3db1582b84fd
1 parent 176c162 commit 5dbbec3

File tree

2 files changed

+112
-2
lines changed

2 files changed

+112
-2
lines changed

exir/pass_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3-
# Copyright 2025 Arm Limited and/or its affiliates.
3+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -191,6 +191,11 @@ def create_arg(self, a: Argument) -> torch.fx.Node:
191191
if not hasattr(a, "constant") or a.constant is None:
192192
raise ExportPassBaseError(f"Cannot add {a} to graph.")
193193
a = a.constant
194+
elif isinstance(a, torch.SymInt):
195+
if a.node.constant is not None:
196+
return a.node.constant
197+
else:
198+
return a
194199
node = super().create_arg(a)
195200
if (
196201
isinstance(a, torch.Tensor)

exir/tests/test_dynamic_shape_propagation.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
33
#
4+
# Copyright 2026 Arm Limited and/or its affiliates.
5+
#
46
# This source code is licensed under the BSD-style license found in the
57
# LICENSE file in the root directory of this source tree.
68

79
# pyre-unsafe
810

911
from unittest import TestCase
1012

13+
import torch
14+
1115
from executorch import exir
1216
from executorch.exir import to_edge
13-
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
17+
from executorch.exir.passes import (
18+
DebugPass,
19+
ExportPass,
20+
HintBasedSymShapeEvalPass,
21+
SpecPropPass,
22+
)
23+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
1424
from executorch.exir.tests.models import Repeat, TensorItem
1525
from torch.export import export
1626

@@ -67,3 +77,98 @@ def test_unbacked_symint(self):
6777
self.assertEqual(
6878
speclist[0].shape, [100, 100]
6979
) # upper bound of TensorItem model
80+
81+
82+
class TestSymIntViewArgs(TestCase):
83+
class Conv1dToConv2d(torch.nn.Module):
84+
def __init__(self) -> None:
85+
super().__init__()
86+
87+
def forward(self, input: torch.Tensor) -> torch.Tensor:
88+
# Use view to make sure edge view handle symint shapes correctly.
89+
# input = input.view(input.size(0), input.size(1), input.size(2), 1) # (N, C, H, W)
90+
# weight = torch.randn(1, 16, 3, 1) # (out_channels, in_channels, kH, kW)
91+
# return torch.nn.functional.conv2d(input, weight)
92+
93+
return torch.nn.functional.conv1d(
94+
input, torch.randn(1, 16, 3)
95+
) # (out_channels, in_channels, kW)
96+
97+
def get_random_inputs(self) -> tuple[torch.Tensor]:
98+
return (torch.randn(1, 16, 50),) # (batch_size, channels, width)
99+
100+
def get_dynamic_shape(self) -> tuple[dict[int, torch.export.Dim]]:
101+
dim = torch.export.Dim("width", min=10, max=100)
102+
return ({2: dim},)
103+
104+
def test_symint_viewargs(self):
105+
eager_model = TestSymIntViewArgs.Conv1dToConv2d()
106+
inputs = eager_model.get_random_inputs()
107+
108+
class TestViewCopyPass(ExportPass):
109+
def call_operator(self, op, args, kwargs, meta):
110+
from executorch.exir.dialects._ops import ops as exir_ops
111+
112+
if op != exir_ops.edge.aten.convolution.default:
113+
return super().call_operator(op, args, kwargs, meta)
114+
115+
x = args[0]
116+
x = super().call_operator(
117+
exir_ops.edge.aten.view_copy.default,
118+
(x, list(x.data.shape) + [1]),
119+
{},
120+
meta,
121+
)
122+
123+
w = args[1]
124+
w = super().call_operator(
125+
exir_ops.edge.aten.view_copy.default,
126+
(w, list(w.data.shape) + [1]),
127+
{},
128+
meta,
129+
)
130+
131+
new_args = (
132+
x,
133+
w,
134+
args[2],
135+
args[3] + [1], # stride
136+
args[4] + [0], # padding
137+
args[5] + [1], # dilation
138+
args[6],
139+
args[7] + [0],
140+
args[8],
141+
)
142+
x = super().call_operator(
143+
exir_ops.edge.aten.convolution.default, new_args, kwargs, meta
144+
)
145+
x = super().call_operator(
146+
exir_ops.edge.aten.view_copy.default,
147+
(x, list(x.data.shape)[:-1]),
148+
{},
149+
meta,
150+
)
151+
152+
return x
153+
154+
prog = to_edge(
155+
export(
156+
eager_model,
157+
inputs,
158+
dynamic_shapes=eager_model.get_dynamic_shape(),
159+
strict=True,
160+
),
161+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
162+
)
163+
new_prog = prog.transform(
164+
[SpecPropPass(), ConstraintBasedSymShapeEvalPass(), TestViewCopyPass()]
165+
)
166+
gm = new_prog.exported_program().graph_module
167+
DebugPass(show_spec=True)(gm)
168+
*_, return_node = gm.graph.nodes
169+
speclist = return_node.meta["spec"]
170+
171+
self.assertEqual(len(speclist), 1)
172+
out_spec = speclist[0]
173+
self.assertTrue(out_spec.is_upper_bound_tensor)
174+
self.assertEqual(out_spec.shape, [1, 1, 98])

0 commit comments

Comments
 (0)