Skip to content

Commit 94f9719

Browse files
[exir] Materialize alloc shapes in ToOutVarPass (pytorch#19806)
Fix a dynamic-shape lowering bug in exir. ConstraintBasedSymShapeEvalPass concretizes TensorSpec metadata, but ToOutVarPass was still building memory.alloc nodes from symbolic FakeTensor/tensor_meta shapes. That let symbolic dims leak into the generated ExecuTorch GraphModule and caused runtime failures when the lowered module was executed in Python. Build memory.alloc specs from concrete upper-bounded integer shapes instead. If an alloc shape is still not concretely bounded, raise a clear error. Add an EXIR regression test that exports a dynamic-shape model, runs ConstraintBasedSymShapeEvalPass + ToOutVarPass, and verifies that memory.alloc shapes are concrete integers. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell @rascani --------- Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent c505aa5 commit 94f9719

3 files changed

Lines changed: 67 additions & 14 deletions

File tree

backends/arm/test/models/test_torch_functions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,6 @@ def forward(self, *args):
9797
"test_data",
9898
test_parameters,
9999
xfails={
100-
"nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). "
101-
"Requires dynamic output shape.",
102100
"topk": "NotImplementedError: No registered serialization name for <class 'torch.return_types.topk'> found",
103101
"sort": "NotImplementedError: No registered serialization name for <class 'torch.return_types.sort'> found",
104102
},
@@ -124,8 +122,6 @@ def test_torch_functions_tosa_FP(test_data):
124122
"test_data",
125123
test_parameters,
126124
xfails={
127-
"nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). "
128-
"Requires dynamic output shape.",
129125
"topk": "NotImplementedError: No registered serialization name for <class 'torch.return_types.topk'> found",
130126
"sort": "NotImplementedError: No registered serialization name for <class 'torch.return_types.sort'> found",
131127
},

exir/passes/__init__.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363
from executorch.exir.passes.to_device_pass import ToDevicePass
6464
from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
65+
from executorch.exir.sym_util import eval_shape_upper_bound
6566
from torch import fx
6667
from torch._subclasses import FakeTensor
6768
from torch.fx.passes.infra.pass_base import PassBase, PassResult
@@ -281,31 +282,38 @@ def make_alloc_node(
281282
Note: tensor_metadata is only used in the case of a Tensor subclass, since
282283
fakifying a tensor subclass is not supported right now
283284
"""
285+
286+
def materialize_alloc_spec(
287+
shape: Union[torch.Size, Tuple[int, ...], List[int]],
288+
dtype: torch.dtype,
289+
) -> memory.AllocSpec:
290+
concrete_shape = eval_shape_upper_bound(shape)
291+
if any(not isinstance(dim, int) for dim in concrete_shape):
292+
raise RuntimeError(
293+
"Memory allocator node requires concrete upper-bounded dimensions. "
294+
f"Got shape {shape} and evaluated upper bounds {concrete_shape}."
295+
)
296+
return (tuple(concrete_shape), dtype)
297+
284298
if val is None:
285299
if tensor_meta is not None:
286300
assert isinstance(tensor_meta, TensorMetadata)
287-
alloc_spec = (tensor_meta.shape, tensor_meta.dtype)
301+
alloc_spec = materialize_alloc_spec(tensor_meta.shape, tensor_meta.dtype)
288302
else:
289303
raise InternalError(
290304
"Memory allocator node needs FakeTensor val or TensorMetadata to proceed"
291305
)
292306
elif isinstance(val, FakeTensor):
293-
alloc_spec = (val.shape, val.dtype)
307+
alloc_spec = materialize_alloc_spec(val.shape, val.dtype)
294308
else:
295309
assert isinstance(val, list) or isinstance(val, tuple)
296310
assert isinstance(tensor_meta, list) or isinstance(tensor_meta, tuple)
297311
alloc_spec: List[memory.AllocSpec] = []
298312
for v, t in zip(val, tensor_meta):
299313
if v is not None:
300-
# pyre-fixme[6]: For 1st argument expected
301-
# `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but
302-
# got `Tuple[Size, dtype]`.
303-
alloc_spec.append((v.shape, v.dtype))
314+
alloc_spec.append(materialize_alloc_spec(v.shape, v.dtype))
304315
elif t is not None:
305-
# pyre-fixme[6]: For 1st argument expected
306-
# `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but
307-
# got `Tuple[Size, dtype]`.
308-
alloc_spec.append((t.shape, t.dtype))
316+
alloc_spec.append(materialize_alloc_spec(t.shape, t.dtype))
309317
else:
310318
raise InternalError(
311319
"Memory allocator node needs FakeTensor val or TensorMetadata to proceed"

exir/tests/test_passes.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -74,6 +75,7 @@
7475
)
7576
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
7677
from executorch.exir.passes.spec_prop_pass import SpecPropPass
78+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
7779
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
7880
from executorch.exir.program._program import lift_constant_tensor_pass
7981
from executorch.exir.schema import TensorShapeDynamism
@@ -1036,6 +1038,53 @@ def test_alloc_node_spec(self) -> None:
10361038
for node in alloc_nodes:
10371039
self.assertTrue(isinstance(node.meta.get("spec", None), TensorSpec))
10381040

1041+
def test_to_out_var_dynamic_alloc_uses_concrete_upper_bounds(self) -> None:
1042+
class DynamicRelu(nn.Module):
1043+
def forward(self, x):
1044+
return torch.relu(x)
1045+
1046+
eager_model = DynamicRelu()
1047+
inputs = (torch.randn(2, 4, 8, 3),)
1048+
dynamic_shapes = {
1049+
"x": {
1050+
0: torch.export.Dim("batch", min=0, max=2),
1051+
2: torch.export.Dim("height", min=0, max=8),
1052+
3: torch.export.Dim("width", min=0, max=8),
1053+
}
1054+
}
1055+
prog = to_edge(
1056+
export(
1057+
eager_model,
1058+
inputs,
1059+
dynamic_shapes=dynamic_shapes,
1060+
strict=True,
1061+
),
1062+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1063+
)
1064+
new_prog = prog.transform(
1065+
[
1066+
SpecPropPass(),
1067+
ConstraintBasedSymShapeEvalPass(),
1068+
]
1069+
)
1070+
1071+
new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module)
1072+
self.assertIsNotNone(new_gm_res)
1073+
new_gm = new_gm_res.graph_module
1074+
1075+
alloc_nodes = []
1076+
for node in new_gm.graph.nodes:
1077+
if node.target == memory.alloc:
1078+
alloc_nodes.append(node)
1079+
1080+
self.assertTrue(len(alloc_nodes) > 0)
1081+
for node in alloc_nodes:
1082+
alloc_spec = node.args[0]
1083+
self.assertIsInstance(alloc_spec, tuple)
1084+
shape, _dtype = alloc_spec
1085+
for dim in shape:
1086+
self.assertIsInstance(dim, int)
1087+
10391088
def test_debug_pass_file_log(self) -> None:
10401089
eager_model = Mul()
10411090
inputs = eager_model.get_random_inputs()

0 commit comments

Comments
 (0)