Skip to content

Commit fda2451

Browse files
Arm backend: Up tolerance when running on Aarch64 (pytorch#19110)
Increase the atol by 10% when running on Aarch64. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell Signed-off-by: Christoffer J.L <christoffer.johanssonlundqvist@arm.com>
1 parent 69c3728 commit fda2451

1 file changed

Lines changed: 18 additions & 0 deletions

File tree

backends/arm/test/tester/arm_tester.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import inspect
99

1010
import logging
11+
import platform
1112

1213
from collections import Counter, defaultdict
1314
from pprint import pformat
@@ -105,6 +106,21 @@
105106
logger = logging.getLogger(__name__)
106107

107108

109+
# TODO(MLETORCH-2048: Remove if possible or rework this to match minimal tolerance diff between architectures when TOSA is updated, or investigate/update atol in the failing tests)
110+
def _adjust_tosa_aarch64_atol(compile_spec: ArmCompileSpec, atol: float) -> float:
111+
"""Increase tolerance for aarch64 when running on TOSA.
112+
113+
This is due to the TOSA ref model being experimental on Aarch64.
114+
115+
"""
116+
if isinstance(compile_spec, TosaCompileSpec) and platform.machine().lower() in (
117+
"aarch64",
118+
"arm64",
119+
):
120+
return atol * 1.1
121+
return atol
122+
123+
108124
def _dump_lowered_modules_artifact(
109125
path_to_dump: Optional[str],
110126
artifact: Union[EdgeProgramManager, ExecutorchProgramManager],
@@ -573,6 +589,8 @@ def run_method_and_compare_outputs(
573589
574590
"""
575591

592+
atol = _adjust_tosa_aarch64_atol(self.compile_spec, atol)
593+
576594
# backward-compatible ordering (accept inputs as the first positional argument)
577595
inputs, reference_stage, test_stage = self._get_input_and_stages(
578596
inputs, stage, reference_stage_type, run_eager_mode

0 commit comments

Comments
 (0)