Skip to content

Commit 554aecf

Browse files
authored
Arm backend: Decompose integral float pow exponents (pytorch#19693)
Treat positive integral float scalar exponents like integer exponents in DecomposeIntPowPass. This avoids lowering pow(x, 2.0) to TOSA POW, whose reference model rejects negative bases even when the exponent is mathematically integral. Test Plan: - lintrunner on changed files - pytest test_decompose_int_pow_pass.py - TinySwin2SR FP/INT TOSA reference smoke cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Usamah Zaheer <usamah.zaheer@arm.com>
1 parent 161376d commit 554aecf

3 files changed

Lines changed: 55 additions & 16 deletions

File tree

backends/arm/_passes/decompose_int_pow_pass.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from typing import Set, Type
7+
from typing import Optional, Set, Type
88

99
from executorch.backends.arm._passes import ArmPass
1010
from executorch.exir.dialects._ops import ops as exir_ops
@@ -21,6 +21,18 @@ class DecomposeIntPowPass(ArmPass):
2121

2222
_passes_required_after: Set[Type[ExportPass]] = set()
2323

24+
@staticmethod
25+
def _get_decomposable_integer_exponent(exp) -> Optional[int]:
26+
if isinstance(exp, int):
27+
return exp
28+
# Exported models can represent positive integer-valued exponents as
29+
# floats, for example pow(x, 2.0). Only exact values are decomposed:
30+
# rounding near-integer floats would change fractional pow semantics,
31+
# especially for negative bases.
32+
if isinstance(exp, float) and exp > 0 and exp.is_integer():
33+
return int(exp)
34+
return None
35+
2436
def call_operator(self, op, args, kwargs, meta):
2537
if op != exir_ops.edge.aten.pow.Tensor_Scalar:
2638
return super().call_operator(op, args, kwargs, meta)
@@ -43,9 +55,18 @@ def call_operator(self, op, args, kwargs, meta):
4355
exir_ops.edge.aten.add.Tensor, (zeros, ones), {}, meta, True
4456
)
4557

46-
if not isinstance(exp, int):
58+
exp = self._get_decomposable_integer_exponent(exp)
59+
if exp is None:
4760
return super().call_operator(op, args, kwargs, meta)
4861

62+
if exp == 1:
63+
ones = super().call_operator(
64+
exir_ops.edge.aten.full_like.default, (x, 1), {}, meta, True
65+
)
66+
return super().call_operator(
67+
exir_ops.edge.aten.mul.Tensor, (x, ones), {}, meta, True
68+
)
69+
4970
# Handle negative exponent
5071
if exp < 0:
5172
x = super().call_operator(

backends/arm/test/ops/test_pow.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,11 @@ def test_pow_tensor_tensor_vgf_no_quant(test_data: Pow_TensorTensor.input_t):
141141
pipeline.run()
142142

143143

144-
x_fail = {
145-
"exp_two": "TOSA constraints: If x <0 .",
146-
}
147-
148-
x_fail_FP = {
149-
"exp_two": "TOSA constraints: If x <0 .",
150-
}
151-
152-
153144
@common.parametrize(
154145
"test_data",
155146
Pow_TensorScalar.test_data
156147
| Pow_TensorScalar.test_data_fp16
157148
| Pow_TensorScalar.test_data_bf16,
158-
xfails=x_fail_FP,
159-
strict=False,
160149
)
161150
def test_pow_tensor_scalar_tosa_FP(test_data: Pow_TensorScalar.input_t):
162151
base, exp = test_data()
@@ -211,8 +200,6 @@ def test_pow_tensor_scalar_u85_INT(test_data: Pow_TensorScalar.input_t):
211200
@common.parametrize(
212201
"test_data",
213202
Pow_TensorScalar.test_data | Pow_TensorScalar.test_data_fp16,
214-
x_fail_FP,
215-
strict=False,
216203
)
217204
@common.SkipIfNoModelConverter
218205
def test_pow_tensor_scalar_vgf_no_quant(test_data: Pow_TensorScalar.input_t):

backends/arm/test/passes/test_decompose_int_pow_pass.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_inputs(self) -> input_t:
3535
class Pow(torch.nn.Module):
3636
"""Basic squaring."""
3737

38-
def __init__(self, exponent: int) -> None:
38+
def __init__(self, exponent: int | float) -> None:
3939
super().__init__()
4040
self.exponent = exponent
4141

@@ -48,12 +48,20 @@ def get_inputs(self) -> input_t:
4848

4949
test_data: Dict[str, TestParam] = {
5050
"square": (Square(), 1),
51+
"pow_1": (Pow(1), 1),
52+
"pow_1_float": (Pow(1.0), 1),
5153
"pow_2": (Pow(2), 1),
54+
"pow_2_float": (Pow(2.0), 1),
5255
"pow_3": (Pow(3), 2),
5356
"pow_0": (Pow(0), 0),
5457
"pow_neg_2": (Pow(-2), 1),
5558
}
5659

60+
non_integer_float_test_data: Dict[str, ModuleWithInputs] = {
61+
"pow_1_999999999": Pow(1.999999999),
62+
"pow_2_000000001": Pow(2.000000001),
63+
}
64+
5765

5866
@common.parametrize("data", test_data)
5967
def test_decompose_int_pow_tosa_FP(data: TestParam) -> None:
@@ -74,3 +82,26 @@ def test_decompose_int_pow_tosa_FP(data: TestParam) -> None:
7482
pass_list=[DecomposeIntPowPass],
7583
)
7684
pipeline.run()
85+
86+
87+
@common.parametrize("module_with_inputs", non_integer_float_test_data)
88+
def test_decompose_int_pow_tosa_FP_non_integer_float(
89+
module_with_inputs: ModuleWithInputs,
90+
) -> None:
91+
module = cast(torch.nn.Module, module_with_inputs)
92+
pow_op = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"
93+
pipeline = PassPipeline[input_t](
94+
module,
95+
module_with_inputs.get_inputs(),
96+
quantize=False,
97+
ops_before_pass={
98+
pow_op: 1,
99+
},
100+
ops_not_before_pass=[],
101+
ops_after_pass={
102+
pow_op: 1,
103+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 0,
104+
},
105+
pass_list=[DecomposeIntPowPass],
106+
)
107+
pipeline.run()

0 commit comments

Comments
 (0)