Skip to content

Commit b1446cc

Browse files
committed
Arm backend: Simplify fake RESIZE validation
Avoid revalidating RESIZE output shape against dimensions computed by the same formula. Validate parameters once, compute the fake output shape, and directly validate the computed output dimensions. Signed-off-by: Per Held <per.held@arm.com> Change-Id: I97bb91f9fc440c980782955692056196038d5de0
1 parent daa7ad2 commit b1446cc

3 files changed

Lines changed: 47 additions & 1 deletion

File tree

backends/arm/test/misc/tosa_dialect/test_tosa_resize.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,30 @@ def test_resize_rejects_scale_numerator_over_tosa_limit():
7272
)
7373

7474

75+
@pytest.mark.parametrize(
76+
"offset,border",
77+
(
78+
([1, 0], [-1, 0]),
79+
([0, 1], [0, -1]),
80+
),
81+
)
82+
def test_resize_rejects_non_positive_output_dimensions(offset, border):
83+
with TosaLoweringContext(
84+
TosaSpecification.create_from_string("TOSA-1.0+INT")
85+
), FakeTensorMode() as mode:
86+
with pytest.raises(
87+
TosaValueError,
88+
match="RESIZE output dimensions must be positive",
89+
):
90+
exir_ops.backend.tosa.RESIZE.default(
91+
mode.from_tensor(torch.randint(0, 10, (1, 1, 1, 1), dtype=torch.int8)),
92+
[1, 1, 1, 1],
93+
offset,
94+
border,
95+
resize_mode="nearest",
96+
)
97+
98+
7599
def test_resize_accepts_symbolic_scale_and_border_values():
76100
shape_env = ShapeEnv()
77101
scale_y_n = _make_symint(shape_env, "scale_y_n", hint=2, min=1, max=8)

backends/arm/tosa/dialect/ops/resize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
1111
from executorch.backends.arm.tosa.resize_utils import (
1212
calculate_tosa_resize_output_hw,
13+
get_tosa_resize_output_hw_validation_error,
1314
get_tosa_resize_validation_error,
1415
)
1516

@@ -92,7 +93,9 @@ def RESIZE(
9293
H, W = input_shape[1], input_shape[2]
9394
_validate_resize_parameters((H, W), None, scale, offset, border, tosa_spec)
9495
output_hw = calculate_tosa_resize_output_hw((H, W), scale, offset, border)
95-
_validate_resize_parameters((H, W), output_hw, scale, offset, border, tosa_spec)
96+
validation_error = get_tosa_resize_output_hw_validation_error(output_hw)
97+
if validation_error is not None:
98+
raise TosaValueError(validation_error, op="RESIZE")
9699
if output_hw is None:
97100
scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale
98101
offset_y, offset_x = offset

backends/arm/tosa/resize_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,25 @@ def _validate_dimensions(
6767
return None
6868

6969

70+
def get_tosa_resize_output_hw_validation_error(
71+
output_hw: Sequence[int | torch.SymInt] | None,
72+
) -> str | None:
73+
if output_hw is None:
74+
return None
75+
76+
output_hw_ints = _as_concrete_ints(output_hw)
77+
if output_hw_ints is None:
78+
return None
79+
80+
invalid_dimension = next(
81+
(dimension for dimension in output_hw_ints if dimension <= 0), None
82+
)
83+
if invalid_dimension is not None:
84+
return f"RESIZE output dimensions must be positive; got {invalid_dimension}"
85+
86+
return _validate_dimensions((), output_hw)
87+
88+
7089
def _validate_scale(
7190
scale: Sequence[int | torch.SymInt],
7291
tosa_spec: TosaSpecification,

0 commit comments

Comments
 (0)