Skip to content

Commit 36a1952

Browse files
authored
Arm backend: fix the quantized scalar remainder issue (pytorch#18401)
- Keep scalar remainder opaque through quantization and lower it through the LUT/table path, to avoid the inaccurate div/floor/mul/sub decomposition in INT mode. - Keep tensor-tensor INT remainder xfail cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell Signed-off-by: Xingguo Li <xingguo.li@arm.com>
1 parent 410dd83 commit 36a1952

6 files changed

Lines changed: 32 additions & 10 deletions

File tree

backends/arm/_passes/decompose_remainder_pass.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -59,6 +59,15 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
5959
)
6060
if op not in supported_ops:
6161
return super().call_operator(op, args, kwargs, meta, updated)
62+
# Keep scalar remainder opaque during transform-for-annotation so the
63+
# quantizer can wrap the original op directly. In the backend pipeline,
64+
# also preserve quantized scalar remainder so InsertTableOpsPass can
65+
# lower it as a lookup table instead of expanding to div/floor/mul/sub.
66+
if op in (
67+
exir_ops.edge.aten.remainder.Scalar,
68+
torch.ops.aten.remainder.Scalar,
69+
) and (self.is_tfa_pass or self._is_quantized_meta(meta)):
70+
return super().call_operator(op, args, kwargs, meta, updated)
6271

6372
div_op, mul_op, sub_op = _decomposition_ops[op]
6473
x, y = args[0], args[1]

backends/arm/_passes/insert_table_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class TableOps:
6464
exir_ops.edge.aten.pow.Tensor_Scalar,
6565
exir_ops.edge.aten.gelu.default,
6666
exir_ops.edge.aten.elu.default,
67+
exir_ops.edge.aten.remainder.Scalar,
6768
}
6869

6970
def __init__(self, exported_program: ExportedProgram):
@@ -102,6 +103,9 @@ def __getitem__(self, node: Node):
102103
return lambda x: torch.nn.functional.elu(
103104
x, alpha=input_alpha
104105
).flatten()
106+
case exir_ops.edge.aten.remainder.Scalar:
107+
divisor = cast(float | int, node.args[1])
108+
return lambda x: torch.remainder(x, divisor).flatten()
105109
case _:
106110
# Op must be handled if it's inside self.special_ops
107111
raise AssertionError("Unhandled table operation")

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -79,6 +79,11 @@
7979
_fp_profile_ops | _int_profile_ops
8080
)
8181

82+
_preserve_in_tfa = {
83+
torch.ops.aten.remainder.Scalar,
84+
exir_ops.edge.aten.remainder.Scalar,
85+
}
86+
8287

8388
class ReplaceScalarWithTensorByProfilePass(ArmPass, ReplaceScalarWithTensorArgPass):
8489
"""Profile-aware scalar-to-tensor replacement pass for binary ops."""
@@ -94,6 +99,9 @@ def __init__(self, tfa_pass=False, *args, **kwargs):
9499
super().__init__(tfa_pass, _all_ops, *args, **kwargs)
95100

96101
def call_operator(self, op, args, kwargs, meta):
102+
if self.is_tfa_pass and op in _preserve_in_tfa:
103+
return ExportPass.call_operator(self, op, args, kwargs, meta)
104+
97105
tosa_spec = get_context_spec()
98106

99107
included_ops = {}
@@ -108,7 +116,7 @@ def call_operator(self, op, args, kwargs, meta):
108116
if op in TableOps.included_ops():
109117
# Do not handle quantized table ops; forward unchanged.
110118
input_qparams = meta.data.get("input_qparams", {})
111-
output_qparams = meta.data.get("input_qparams", {})
119+
output_qparams = meta.data.get("output_qparams", {})
112120
if len(input_qparams) > 0 and len(output_qparams) > 0:
113121
# Do not handle; forward unchanged.
114122
return ExportPass.call_operator(self, op, args, kwargs, meta)

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
exir_ops.edge.aten.repeat.default,
8181
exir_ops.edge.aten.reciprocal.default,
8282
exir_ops.edge.aten.relu.default,
83+
exir_ops.edge.aten.remainder.Scalar,
8384
exir_ops.edge.aten.remainder.Tensor,
8485
exir_ops.edge.aten.rsqrt.default,
8586
exir_ops.edge.aten.select_copy.int,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ def _match_pattern(
507507
torch.ops.aten.asinh.default,
508508
torch.ops.aten.cosh.default,
509509
torch.ops.aten.cumsum.default,
510+
torch.ops.aten.remainder.Scalar,
510511
torch.ops.aten.tan.default,
511512
}
512513

backends/arm/test/ops/test_remainder.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class Remainder(torch.nn.Module):
2929
exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_remainder_Tensor"
3030
aten_op_scalar = "torch.ops.aten.remainder.Scalar"
3131
exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_remainder_Scalar"
32+
lowered_exir_ops = [exir_op_scalar, exir_op_tensor]
3233

3334
test_cases_tensor = {
3435
"rank2_tensors": lambda: (
@@ -97,18 +98,14 @@ def test_remainder_tensor_tosa_INT(test_data):
9798
pipeline.run()
9899

99100

100-
@common.parametrize(
101-
"test_data",
102-
Remainder.test_cases_scalar,
103-
xfails={
104-
"scalar_pos": "MLETORCH-1832 - Quantized remainder with scalar divisor produces incorrect results for certain inputs"
105-
},
106-
)
101+
@common.parametrize("test_data", Remainder.test_cases_scalar)
107102
def test_remainder_scalar_tosa_INT(test_data):
108103
pipeline = TosaPipelineINT[Remainder.input_t](
109104
Remainder(),
110105
test_data(),
111106
[],
107+
Remainder.lowered_exir_ops,
108+
frobenius_threshold=0.4,
112109
)
113110
pipeline.run()
114111

@@ -131,6 +128,7 @@ def test_remainder_scalar_u55_INT(test_data):
131128
Remainder(),
132129
test_data(),
133130
[],
131+
Remainder.lowered_exir_ops,
134132
)
135133
pipeline.run()
136134

@@ -153,6 +151,7 @@ def test_remainder_scalar_u85_INT(test_data):
153151
Remainder(),
154152
test_data(),
155153
[],
154+
Remainder.lowered_exir_ops,
156155
)
157156
pipeline.run()
158157

0 commit comments

Comments
 (0)