Skip to content

Commit 41a38d8

Browse files
Arm backend: Add multiple get_attr folding crash workaround (pytorch#19663)
See description in the added test. The workaround implemented is to create multiple attributes pointing to the same data source. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 6bc1762 commit 41a38d8

5 files changed

Lines changed: 121 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
from .decompose_var_pass import DecomposeVarPass # noqa
9898
from .decompose_where_scalar_other_pass import DecomposeWhereScalarOtherPass # noqa
9999
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
100+
from .deduplicate_get_attr_pass import DeduplicateGetAttrPass # noqa
100101
from .ensure_unique_output_nodes_pass import EnsureUniqueOutputNodesPass # noqa
101102
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
102103
FoldAndAnnotateQParamsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
DecomposeVarPass,
9898
DecomposeWhereScalarOtherPass,
9999
DecorateFp32toInt32CastingPass,
100+
DeduplicateGetAttrPass,
100101
EnsureUniqueOutputNodesPass,
101102
FoldAndAnnotateQParamsPass,
102103
FuseBatchNorm2dPass,
@@ -651,6 +652,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
651652
[
652653
ReplaceInfAndLimitValuesPass(tfa_pass=True),
653654
DecomposeMaskedFillPass(tfa_pass=True),
655+
DeduplicateGetAttrPass(tfa_pass=True),
654656
]
655657
)
656658

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from torch.fx import GraphModule, Node
12+
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
13+
14+
15+
class DeduplicateGetAttrPass(ArmPass):
16+
"""Give duplicate get_attr nodes distinct backing attributes.
17+
18+
Torchao's constant folder can delete a shared backing attribute while
19+
another get_attr node still refers to it. Keep separate graph nodes so PT2E
20+
can attach per-use observers and backend lowering can process constants per
21+
use.
22+
23+
"""
24+
25+
_passes_required_after: Set[Type[ExportPass]] = set()
26+
27+
def _get_attr(self, graph_module: GraphModule, target: str) -> Any:
28+
attr: Any = graph_module
29+
for target_atom in target.split("."):
30+
attr = getattr(attr, target_atom)
31+
return attr
32+
33+
def _copy_attr(self, graph_module: GraphModule, node: Node) -> str:
34+
"""Register a new attribute referring to the same data as the original
35+
one.
36+
"""
37+
38+
assert isinstance(node.target, str)
39+
attr = self._get_attr(graph_module, node.target)
40+
get_new_attr_name = get_new_attr_name_with_prefix(
41+
f"_deduplicated_get_attr_{node.name}_"
42+
)
43+
attr_name = get_new_attr_name(graph_module)
44+
45+
if isinstance(attr, torch.nn.Parameter):
46+
graph_module.register_parameter(attr_name, attr)
47+
elif isinstance(attr, torch.Tensor):
48+
graph_module.register_buffer(attr_name, attr)
49+
else:
50+
setattr(graph_module, attr_name, attr)
51+
52+
return attr_name
53+
54+
def call(self, graph_module: GraphModule) -> PassResult:
55+
seen_targets: set[str] = set()
56+
modified = False
57+
58+
for node in graph_module.graph.find_nodes(op="get_attr"):
59+
60+
if node.target not in seen_targets:
61+
seen_targets.add(node.target)
62+
continue
63+
64+
node.target = self._copy_attr(graph_module, node)
65+
modified = True
66+
67+
if modified:
68+
graph_module.graph.lint()
69+
graph_module.recompile()
70+
71+
return PassResult(graph_module, modified)

backends/arm/test/quantizer/test_selective_quantization.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.backends.arm.test import common
1818
from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline
1919
from executorch.backends.arm.tosa import TosaSpecification
20+
from executorch.backends.cortex_m.test.tester import ramp_tensor
2021
from executorch.backends.test.harness.stages import StageType
2122
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2223
from torchvision import models, transforms # type: ignore[import-untyped]
@@ -229,6 +230,20 @@ def test_composable_global_none_linear_graph_tail_tosa_INT():
229230
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
230231

231232

233+
class SharedBufferEmbeddingLinearConstantFold(torch.nn.Module):
234+
235+
def __init__(self):
236+
super().__init__()
237+
self.shared = torch.nn.Embedding(4, 4)
238+
self.lm_head = torch.nn.Linear(4, 4, bias=False)
239+
self.lm_head.weight = self.shared.weight
240+
241+
def forward(self, ids, x):
242+
y0 = self.shared(ids).sum(dim=1)
243+
z = self.lm_head(x)
244+
return y0 + z
245+
246+
232247
def test_mv3_selective_quant_int16_tosa_INT():
233248
model = mv3
234249
inputs = (normalize(torch.randn(1, 3, 224, 224)),)
@@ -302,3 +317,33 @@ def test_mv3_io_quant_tosa_INT():
302317
)
303318

304319
pipeline.run()
320+
321+
322+
def test_multiple_folded_get_attr():
323+
"""In torchao/quantization/pt2e/constant_fold.py:constant_fold, get_attr
324+
node targets are deleted as soon as there is one get_attr node w/o users
325+
using the target.
326+
327+
If there are multiple get_attr nodes refering the same target such as in
328+
this test, the function crashes if no workaround is present.
329+
330+
"""
331+
332+
model = SharedBufferEmbeddingLinearConstantFold()
333+
example_inputs = (
334+
torch.tensor([[0, 1]], dtype=torch.long),
335+
ramp_tensor(-2, 2, (1, 4)),
336+
)
337+
338+
quantizer = get_quantizer()
339+
quantizer.set_module_type(torch.nn.Embedding, None)
340+
341+
pipeline = QuantizationPipeline(
342+
model,
343+
example_inputs,
344+
quantizer=quantizer,
345+
qspecs=None,
346+
input_qspecs=None,
347+
output_qspecs=None,
348+
)
349+
pipeline.run()

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Optional, Type
99

1010
from executorch.backends.arm._passes import (
11+
DeduplicateGetAttrPass,
1112
FoldAndAnnotateQParamsPass,
1213
ScalarsToAttributePass,
1314
)
@@ -52,6 +53,7 @@ class CortexMPassManager(PassManager):
5253
ReplaceScalarWithTensorArgPass,
5354
ClampHardswishPass,
5455
DecomposeMeanPass,
56+
DeduplicateGetAttrPass,
5557
]
5658

5759
def __init__(

0 commit comments

Comments
 (0)