Skip to content

Commit 01ef73b

Browse files
exir: Prevent implicit symbolic ProxyValue conversion (pytorch#19661)
Stop pass code from silently turning symbolic ProxyValues into Python values through bool, int, float, or index conversion as this can lead to graph breakage. Add regressions for bool, int, float, and index cases using symbolic metadata from an exported dynamic graph. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent c7bbba4 commit 01ef73b

2 files changed

Lines changed: 64 additions & 1 deletion

File tree

exir/pass_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,33 @@ def __iter__(self):
150150
yield from self.data
151151

152152
def __bool__(self) -> bool:
153+
if isinstance(self.data, (torch.SymInt, torch.SymFloat, torch.SymBool)):
154+
raise ExportPassBaseError(
155+
"ProxyValue with symbolic data cannot be used in boolean context."
156+
)
153157
return bool(self.data)
154158

159+
def __int__(self):
160+
if isinstance(self.data, torch.SymInt):
161+
raise ExportPassBaseError(
162+
"ProxyValue with SymInt data cannot be converted to int."
163+
)
164+
return int(self.data)
165+
166+
def __float__(self):
167+
if isinstance(self.data, torch.SymFloat):
168+
raise ExportPassBaseError(
169+
"ProxyValue with SymFloat data cannot be converted to float."
170+
)
171+
return float(self.data)
172+
173+
def __index__(self):
174+
if isinstance(self.data, torch.SymInt):
175+
raise ExportPassBaseError(
176+
"ProxyValue with SymInt data cannot be used in index context."
177+
)
178+
return self.__int__()
179+
155180

156181
class ExportPassBaseError(RuntimeError):
157182
pass

exir/tests/test_pass_infra.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -10,10 +11,11 @@
1011

1112
import torch
1213
from executorch.exir import to_edge
14+
from executorch.exir.pass_base import ExportPassBaseError, ProxyValue
1315
from executorch.exir.pass_manager import PassManager
1416
from executorch.exir.passes import ScalarToTensorPass
1517
from executorch.exir.passes.pass_registry import PassRegistry
16-
from torch.export import export
18+
from torch.export import Dim, export
1719
from torch.fx.passes.infra.pass_base import PassBase
1820

1921

@@ -178,3 +180,39 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
178180
for node in new_gm.graph.nodes:
179181
if node.target != "output":
180182
self.assertIn("val", node.meta)
183+
184+
185+
class TestProxyValueSymbolicCoercions(unittest.TestCase):
186+
@staticmethod
187+
def _symbolic_values() -> tuple[torch.SymInt, torch.SymFloat]:
188+
class ViewModule(torch.nn.Module):
189+
def forward(self, x: torch.Tensor) -> torch.Tensor:
190+
return x.view(x.size(0), -1)
191+
192+
exported = export(
193+
ViewModule(),
194+
(torch.randn(2, 3),),
195+
dynamic_shapes=({0: Dim("batch", min=1, max=8)},),
196+
strict=True,
197+
)
198+
gm = to_edge(exported).exported_program().graph_module
199+
for node in gm.graph.nodes:
200+
value = node.meta.get("val")
201+
if isinstance(value, torch.SymInt):
202+
return value, torch.sym_float(value)
203+
raise AssertionError("Expected a symbolic scalar in exported graph metadata")
204+
205+
def test_rejects_implicit_symbolic_scalar_coercions(self) -> None:
206+
sym_int, sym_float = self._symbolic_values()
207+
208+
with self.assertRaisesRegex(ExportPassBaseError, "boolean context"):
209+
bool(ProxyValue(sym_int, torch.fx.Graph().placeholder("x")))
210+
211+
with self.assertRaisesRegex(ExportPassBaseError, "converted to int"):
212+
int(ProxyValue(sym_int, torch.fx.Graph().placeholder("x")))
213+
214+
with self.assertRaisesRegex(ExportPassBaseError, "used in index context"):
215+
ProxyValue(sym_int, torch.fx.Graph().placeholder("x")).__index__()
216+
217+
with self.assertRaisesRegex(ExportPassBaseError, "converted to float"):
218+
float(ProxyValue(sym_float, torch.fx.Graph().placeholder("x")))

0 commit comments

Comments
 (0)