Skip to content

Commit 410dd83

Browse files
authored
Arm backend: Refactor scalar handling (pytorch#18402)
## Arm backend: use scalars instead of fulls in TFA Scalars are then converted to buffers by the ScalarsToAttribute pass. This both simplifies the code, and allows affected ops to be moved to device with model.to(device=...). Note that this does not solve all issues with device kwargs after TFA, only specifically for scalar cases. ## Arm backend: Clean up some pass inefficiencies. - The ScalarToAttribute pass went through all submodules for each node, it only needs to do it once. - Some exir passes used full_like for scalars. This creates very buffers of the same size as the input, when a single value is enough. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell --------- Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent c652736 commit 410dd83

13 files changed

Lines changed: 126 additions & 167 deletions

backends/arm/_passes/arm_pass.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY
1313
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
14+
from executorch.exir.dialects._ops import ops as exir_ops
1415
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
1516
from torch.fx import GraphModule
1617
from torch.fx.passes.infra.pass_base import PassResult
@@ -124,3 +125,31 @@ def call_shape_operator(
124125
shape_meta.data[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
125126
# Call the super (ArmPass) call operator with updated meta
126127
return self.call_operator(op, args, kwargs, shape_meta, updated)
128+
129+
def call_scalar(self, value: int | float, meta: NodeMetadata | dict[str, Any]):
130+
"""Return a scalar value for the current pass stage.
131+
132+
In transform-for-annotation passes this returns the Python scalar
133+
directly. In later passes it materializes a `(1,)` `aten.full` node
134+
using the output dtype/device from `meta["val"]` when available.
135+
136+
"""
137+
138+
if self.is_tfa_pass:
139+
return value
140+
141+
kwargs = {}
142+
if "val" in meta:
143+
val = meta["val"]
144+
if isinstance(val, tuple):
145+
val = val[0]
146+
kwargs = {"device": val.device, "dtype": val.dtype}
147+
148+
return ArmPass.call_operator(
149+
self,
150+
op=exir_ops.edge.aten.full.default,
151+
args=((1,), value),
152+
kwargs=kwargs,
153+
meta=meta,
154+
updated=True,
155+
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,12 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
557557
DecomposeDivTensorModePass(tfa_pass=True),
558558
DecomposeWhereScalarOtherPass(tfa_pass=True),
559559
RewriteInplaceArithmeticPass(tfa_pass=True),
560+
DecomposeAddSubAlphaPass(tfa_pass=True),
561+
DecomposeLeakyReLUPass(tfa_pass=True),
562+
DecomposeGroupNormPass(tfa_pass=True),
563+
DecomposeLayerNormPass(tfa_pass=True),
564+
DecomposeVarPass(tfa_pass=True),
565+
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
560566
]
561567
)
562568

@@ -573,16 +579,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
573579
self.add_passes(
574580
[
575581
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
576-
DecomposeAddSubAlphaPass(tfa_pass=True),
577-
DecomposeGroupNormPass(tfa_pass=True),
578-
DecomposeLayerNormPass(tfa_pass=True),
579-
DecomposeVarPass(tfa_pass=True),
580-
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
581582
DecomposeNotEqualPass(tfa_pass=True),
582583
DecomposeCosineSimilarityPass(tfa_pass=True),
583584
DecomposeGluPass(tfa_pass=True),
584585
DecomposeDivPass(tfa_pass=True),
585-
DecomposeLeakyReLUPass(tfa_pass=True),
586586
DecomposeLinalgVectorNormPass(tfa_pass=True),
587587
DecomposeSqrtPass(tfa_pass=True),
588588
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),

backends/arm/_passes/arm_pass_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,38 @@ def meta_without_qparams(meta: NodeMetadata) -> NodeMetadata:
236236
return NodeMetadata(plain_meta_dict)
237237

238238

239+
def insert_scalar(
240+
graph: torch.fx.Graph,
241+
value: int | float,
242+
meta: NodeMetadata | dict,
243+
from_node: torch.fx.Node,
244+
is_tfa_pass: bool = False,
245+
) -> torch.fx.Node | int | float:
246+
"""Insert an `aten.full` scalar node for direct graph-rewrite passes."""
247+
248+
if is_tfa_pass:
249+
return value
250+
251+
kwargs = {}
252+
val = None
253+
if "val" in meta:
254+
val = meta["val"]
255+
if isinstance(val, tuple):
256+
val = val[0]
257+
kwargs = {"device": val.device, "dtype": val.dtype}
258+
259+
scalar = create_node(
260+
graph=graph,
261+
op_target=exir_ops.edge.aten.full.default,
262+
args=((1,), value),
263+
kwargs=kwargs,
264+
from_node=from_node,
265+
)
266+
if val is not None:
267+
scalar.meta["val"] = torch.full((1,), value, **kwargs)
268+
return scalar
269+
270+
239271
def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
240272
"""Returns a FakeTensor from the meta field of 'node'.
241273

backends/arm/_passes/decompose_add_sub_alpha_pass.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,20 @@ def _get_ops(op):
3030
if op is exir_ops.edge.aten.add.Tensor:
3131
return (
3232
exir_ops.edge.aten.mul.Tensor,
33-
exir_ops.edge.aten.full.default,
3433
exir_ops.edge.aten.add.Tensor,
3534
)
3635
return (
3736
torch.ops.aten.mul.Tensor,
38-
torch.ops.aten.full.default,
3937
torch.ops.aten.add.Tensor,
4038
)
4139
if op in _SUB_OPS:
4240
if op is exir_ops.edge.aten.sub.Tensor:
4341
return (
4442
exir_ops.edge.aten.mul.Tensor,
45-
exir_ops.edge.aten.full.default,
4643
exir_ops.edge.aten.sub.Tensor,
4744
)
4845
return (
4946
torch.ops.aten.mul.Tensor,
50-
torch.ops.aten.full.default,
5147
torch.ops.aten.sub.Tensor,
5248
)
5349
raise RuntimeError(f"Unsupported operator {op}")
@@ -72,19 +68,12 @@ def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
7268
if not _should_decompose(alpha):
7369
return super().call_operator(op, args, kwargs, meta, updated)
7470

75-
mul_op, full_op, binary_op = _get_ops(op)
71+
mul_op, binary_op = _get_ops(op)
7672
lhs, rhs = args
7773

78-
alpha_full = super().call_operator(
79-
full_op,
80-
((1,), float(alpha)),
81-
{"device": meta["val"].device, "dtype": meta["val"].dtype},
82-
meta,
83-
updated=True,
84-
)
8574
scaled_rhs = super().call_operator(
8675
mul_op,
87-
(rhs, alpha_full),
76+
(rhs, super().call_scalar(alpha, meta)),
8877
{},
8978
meta,
9079
updated=True,

backends/arm/_passes/decompose_asin_and_acos_pass.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def get_decomposition(op) -> tuple:
4242
exir_ops.edge.aten.gt.Scalar,
4343
exir_ops.edge.aten.lt.Scalar,
4444
exir_ops.edge.aten.sub.Tensor,
45-
exir_ops.edge.aten.full_like.default,
4645
exir_ops.edge.aten.neg.default,
4746
)
4847

@@ -79,15 +78,12 @@ def _build_polynomial(
7978
"""Helper function to build polynomial from coefficients and
8079
variable.
8180
"""
82-
full_like_op, add_op, mul_op_scalar, mul_op = (
83-
exir_ops.edge.aten.full_like.default,
81+
add_op, mul_op_scalar, mul_op = (
8482
exir_ops.edge.aten.add.Tensor,
8583
exir_ops.edge.aten.mul.Scalar,
8684
exir_ops.edge.aten.mul.Tensor,
8785
)
88-
result = super().call_operator(
89-
full_like_op, (variable, coefficients[0]), {}, meta, True
90-
)
86+
result = super().call_scalar(coefficients[0], meta)
9187
for coeff in coefficients[1:]:
9288
result = super().call_operator(
9389
add_op,
@@ -150,7 +146,6 @@ def call_operator(self, op, args, kwargs, meta):
150146
gt_op,
151147
lt_op,
152148
sub_op,
153-
full_like_op,
154149
neg_op,
155150
) = get_decomposition(op)
156151

@@ -179,7 +174,7 @@ def call_operator(self, op, args, kwargs, meta):
179174

180175
# Step 2: Compute the transformed approximation for large values
181176
# Calculate z = -0.5 * (|x| - 1)
182-
tmp_ones = super().call_operator(full_like_op, (x_abs, one), {}, meta, True)
177+
tmp_ones = super().call_scalar(one, meta)
183178
tmp = super().call_operator(sub_op, (x_abs, tmp_ones), {}, meta, True)
184179
z = super().call_operator(mul_op_scalar, (tmp, neg_half), {}, meta, True)
185180

@@ -201,9 +196,7 @@ def call_operator(self, op, args, kwargs, meta):
201196
t2 = super().call_operator(mul_op_scalar, (t1, two), {}, meta, True)
202197

203198
diff = super().call_operator(sub_op_scalar, (t2, pi_over_2), {}, meta, True)
204-
tmp_neg_ones = super().call_operator(
205-
full_like_op, (diff, neg_one), {}, meta, True
206-
)
199+
tmp_neg_ones = super().call_scalar(neg_one, meta)
207200
asin_large = super().call_operator(mul_op, (diff, tmp_neg_ones), {}, meta, True)
208201

209202
asin_unsigned = self._combine_branches(
@@ -218,9 +211,7 @@ def call_operator(self, op, args, kwargs, meta):
218211

219212
if op in edge_acos_op:
220213
# If x <= 0.5: acos(x) = pi/2 - asin(x)
221-
const_tensor = super().call_operator(
222-
full_like_op, (x, pi_over_2), {}, meta, True
223-
)
214+
const_tensor = super().call_scalar(pi_over_2, meta)
224215
acos_small = super().call_operator(
225216
sub_op, (const_tensor, asin), {}, meta, True
226217
)

backends/arm/_passes/decompose_erfinv_pass.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def get_erfinv_decomposition(op) -> tuple:
2626
if op in edge_erfinv_ops:
2727
# Ordered by first use in call_operator below.
2828
return (
29-
exir_ops.edge.aten.full_like.default,
3029
exir_ops.edge.aten.lt.Tensor,
3130
exir_ops.edge.aten.where.self,
3231
exir_ops.edge.aten.abs.default,
@@ -140,7 +139,6 @@ def call_operator(self, op, args, kwargs, meta):
140139
x = args[0]
141140

142141
(
143-
op_full_like,
144142
op_lt_t,
145143
op_where,
146144
op_abs,
@@ -179,12 +177,10 @@ def call_operator(self, op, args, kwargs, meta):
179177
CORR_MAX = 0.5
180178
TWO_OVER_SQRT_PI = 1.1283791670955126
181179

182-
# ---- zeros / ones (tensor-shaped) ----
183-
zeros = super().call_operator(op_full_like, (x, 0.0), {}, meta, updated=True)
184-
ones = super().call_operator(op_full_like, (x, 1.0), {}, meta, updated=True)
185-
neg_ones = super().call_operator(
186-
op_full_like, (x, -1.0), {}, meta, updated=True
187-
)
180+
# ---- zeros / ones constants ----
181+
zeros = super().call_scalar(0.0, meta)
182+
ones = super().call_scalar(1.0, meta)
183+
neg_ones = super().call_scalar(-1.0, meta)
188184

189185
# ---- s = sign(x): -1 for x<0 else +1 ----
190186
x_lt0 = super().call_operator(op_lt_t, (x, zeros), {}, meta, updated=True)

backends/arm/_passes/decompose_gelu_pass.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@ def _get_gelu_ops(op) -> tuple:
2727

2828
if op in edge_gelu:
2929
return (
30-
exir_ops.edge.aten.full.default,
3130
exir_ops.edge.aten.add.Tensor,
3231
exir_ops.edge.aten.mul.Tensor,
3332
exir_ops.edge.aten.tanh.default,
3433
exir_ops.edge.aten.erf.default,
3534
)
3635
if op in torch_gelu:
3736
return (
38-
torch.ops.aten.full.default,
3937
torch.ops.aten.add.Tensor,
4038
torch.ops.aten.mul.Tensor,
4139
torch.ops.aten.tanh.default,
@@ -98,30 +96,18 @@ def call_operator(self, op, args, kwargs, meta):
9896
# If quantized, node should be replace by table op
9997
return super().call_operator(op, args, kwargs, meta)
10098

101-
full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op)
99+
add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op)
102100

103101
input = get_node_arg(args, 0)
104102
# If approximate is default (none) it does not appear in kwargs
105103
approximate = get_node_arg(kwargs, "approximate", "none")
106104

107-
shape = meta["val"].size()
108-
dtype = meta["val"].dtype
109-
110-
FULL_0_5 = super().call_operator(
111-
full_op, ([1] * len(shape), 0.5), {"dtype": dtype}, meta
112-
)
113-
FULL_1 = super().call_operator(
114-
full_op, ([1] * len(shape), 1), {"dtype": dtype}, meta
115-
)
105+
FULL_0_5 = super().call_scalar(0.5, meta)
106+
FULL_1 = super().call_scalar(1, meta)
116107

117108
if approximate == "none":
118109
# Constant mirrors ExecuTorch implementation for parity.
119-
FULL_SQRT1_2 = super().call_operator(
120-
full_op,
121-
([1] * len(shape), 0.70710678118654752440),
122-
{"dtype": dtype},
123-
meta,
124-
)
110+
FULL_SQRT1_2 = super().call_scalar(0.70710678118654752440, meta)
125111

126112
op1 = super().call_operator(mul_op, (input, FULL_SQRT1_2), {}, meta)
127113
op2 = super().call_operator(erf_op, (op1,), {}, meta)
@@ -131,21 +117,9 @@ def call_operator(self, op, args, kwargs, meta):
131117

132118
elif approximate == "tanh":
133119
# Constants mirror ExecuTorch implementation for parity.
134-
FULL_SQRT2 = super().call_operator(
135-
full_op,
136-
([1] * len(shape), 1.41421356237309504880),
137-
{"dtype": dtype},
138-
meta,
139-
)
140-
FULL_2_SQRTPI = super().call_operator(
141-
full_op,
142-
([1] * len(shape), 1.12837916709551257390),
143-
{"dtype": dtype},
144-
meta,
145-
)
146-
FULL_CUBE_COEFF = super().call_operator(
147-
full_op, ([1] * len(shape), 0.044715), {"dtype": dtype}, meta
148-
)
120+
FULL_SQRT2 = super().call_scalar(1.41421356237309504880, meta)
121+
FULL_2_SQRTPI = super().call_scalar(1.12837916709551257390, meta)
122+
FULL_CUBE_COEFF = super().call_scalar(0.044715, meta)
149123

150124
# Mirrors ExecuTorch implementations for calculating this value
151125
SQRT_MUL = super().call_operator(

0 commit comments

Comments
 (0)