Skip to content

Commit d80a82c

Browse files
committed
Arm backend: Validate TOSA resize parameters
Share TOSA RESIZE parameter validation between upsample support checks and fake RESIZE lowering so invalid nearest and bilinear resize parameters are rejected before delegation. Signed-off-by: Per Held <per.held@arm.com> Change-Id: I57c267aca96d733879ae90329267e44adce399c6
1 parent 45fe55c commit d80a82c

5 files changed

Lines changed: 376 additions & 64 deletions

File tree

backends/arm/operator_support/upsample_support.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,65 @@
1313
SupportedTOSAOperatorCheck,
1414
)
1515
from executorch.backends.arm.tosa import TosaSpecification
16+
from executorch.backends.arm.tosa.resize_utils import get_tosa_resize_validation_error
1617
from executorch.exir.dialects._ops import ops as exir_ops
1718

1819

20+
def _is_upsample_node_tosa_supported(
21+
support_check: SupportedTOSAOperatorCheck,
22+
node: fx.Node,
23+
tosa_spec: TosaSpecification,
24+
*,
25+
align_corners: bool,
26+
) -> bool:
27+
input_node = ensure_type(fx.Node, node.args[0])
28+
input_size_yx = get_first_fake_tensor(input_node).shape[2:]
29+
output_size_yx = get_first_fake_tensor(node).shape[2:]
30+
31+
try:
32+
scale_y_n, scale_y_d, offset_y, border_y = (
33+
RewriteUpsamplePass.get_resize_parameters_1d(
34+
input_size_yx[0], output_size_yx[0], align_corners
35+
)
36+
)
37+
scale_x_n, scale_x_d, offset_x, border_x = (
38+
RewriteUpsamplePass.get_resize_parameters_1d(
39+
input_size_yx[1], output_size_yx[1], align_corners
40+
)
41+
)
42+
except RuntimeError as err:
43+
support_check.reporter.report_reject(node, str(err))
44+
return False
45+
46+
# Validate the exact TOSA RESIZE parameters that RewriteUpsamplePass will
47+
# emit so support checks and fake-op validation reject the same cases.
48+
validation_error = get_tosa_resize_validation_error(
49+
input_hw=input_size_yx,
50+
output_hw=output_size_yx,
51+
scale=[scale_y_n, scale_y_d, scale_x_n, scale_x_d],
52+
offset=[offset_y, offset_x],
53+
border=[border_y, border_x],
54+
tosa_spec=tosa_spec,
55+
)
56+
if validation_error is not None:
57+
support_check.reporter.report_reject(node, validation_error)
58+
return False
59+
60+
return True
61+
62+
1963
@register_tosa_support_check
2064
class UpsampleNearest2dSupported(SupportedTOSAOperatorCheck):
2165
"""Provide the explicit TOSA support gate for nearest upsample."""
2266

2367
targets = [exir_ops.edge.aten.upsample_nearest2d.vec]
2468

2569
def is_node_tosa_supported(
26-
self, _node: fx.Node, _tosa_spec: TosaSpecification
70+
self, node: fx.Node, tosa_spec: TosaSpecification
2771
) -> bool: # type: ignore[override, misc]
28-
return True
72+
return _is_upsample_node_tosa_supported(
73+
self, node, tosa_spec, align_corners=False
74+
)
2975

3076

3177
@register_tosa_support_check
@@ -37,33 +83,9 @@ class UpsampleBilinear2dSupported(SupportedTOSAOperatorCheck):
3783
targets = [exir_ops.edge.aten.upsample_bilinear2d.vec]
3884

3985
def is_node_tosa_supported(
40-
self, node: fx.Node, _tosa_spec: TosaSpecification
86+
self, node: fx.Node, tosa_spec: TosaSpecification
4187
) -> bool: # type: ignore[override, misc]
42-
input_node = ensure_type(fx.Node, node.args[0])
4388
align_corners = ensure_type(bool, node.args[2])
44-
input_size_yx = get_first_fake_tensor(input_node).shape[2:]
45-
output_size_yx = get_first_fake_tensor(node).shape[2:]
46-
47-
try:
48-
scale_y_n, scale_y_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d(
49-
input_size_yx[0], output_size_yx[0], align_corners
50-
)
51-
scale_x_n, scale_x_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d(
52-
input_size_yx[1], output_size_yx[1], align_corners
53-
)
54-
except RuntimeError as err:
55-
self.reporter.report_reject(node, str(err))
56-
return False
57-
58-
# get_resize_parameters_1d() returns the TOSA RESIZE scale fraction for
59-
# each spatial dimension. For align_corners=False, this is the effective
60-
# output_size / input_size ratio, so the 1/16 boundary is checked
61-
# directly in the same representation that RESIZE lowering will use.
62-
if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n:
63-
self.reporter.report_reject(
64-
node,
65-
"Bilinear RESIZE downscale must be strictly greater than 1/16",
66-
)
67-
return False
68-
69-
return True
89+
return _is_upsample_node_tosa_supported(
90+
self, node, tosa_spec, align_corners=align_corners
91+
)

backends/arm/test/misc/tosa_dialect/test_tosa_resize.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ def _expr(sym: torch.SymInt) -> sympy.Expr:
3333
return sympy.sympify(getattr(sym.node, "expr", sym.node._expr))
3434

3535

36-
def test_bilinear_resize_rejects_exact_one_sixteenth_downscale():
36+
@pytest.mark.parametrize("resize_mode", ("nearest", "bilinear"))
37+
def test_resize_rejects_exact_one_sixteenth_downscale(resize_mode: str):
3738
with TosaLoweringContext(
3839
TosaSpecification.create_from_string("TOSA-1.0+INT")
3940
), FakeTensorMode() as mode:
4041
with pytest.raises(
4142
TosaValueError,
42-
match="Bilinear RESIZE downscale must be strictly greater than 1/16",
43+
match="RESIZE downscale must be strictly greater than 1/16",
4344
):
4445
exir_ops.backend.tosa.RESIZE.default(
4546
mode.from_tensor(
@@ -48,7 +49,26 @@ def test_bilinear_resize_rejects_exact_one_sixteenth_downscale():
4849
[2, 32, 2, 32],
4950
[15, 15],
5051
[-15, -15],
51-
resize_mode="bilinear",
52+
resize_mode=resize_mode,
53+
)
54+
55+
56+
def test_resize_rejects_scale_numerator_over_tosa_limit():
57+
with TosaLoweringContext(
58+
TosaSpecification.create_from_string("TOSA-1.0+INT")
59+
), FakeTensorMode() as mode:
60+
with pytest.raises(
61+
TosaValueError,
62+
match="RESIZE scale numerator must be <= 2048",
63+
):
64+
exir_ops.backend.tosa.RESIZE.default(
65+
mode.from_tensor(torch.randint(0, 10, (1, 3, 4, 2), dtype=torch.int8)),
66+
# 2049 violates scale_n <= 1 << 11, while 2049/2 still stays
67+
# within MAX_SCALE so this test isolates the numerator rule.
68+
[2049, 2, 4, 2],
69+
[0, 0],
70+
[0, 0],
71+
resize_mode="nearest",
5272
)
5373

5474

backends/arm/test/ops/test_upsample_nearest2d.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,17 @@ def test_upsample_nearest2d_vec_tosa_FP_interpolate(test_data: torch.Tensor):
198198
pipeline.run()
199199

200200

201+
def test_upsample_nearest2d_vec_tosa_does_not_delegate_exact_one_sixteenth_downscale():
202+
pipeline = OpNotSupportedPipeline[input_t1](
203+
Interpolate(size=None, scale_factor=1.0 / 16.0),
204+
(torch.randn(1, 3, 256, 448),),
205+
{exir_op: 1},
206+
n_expected_delegates=0,
207+
)
208+
209+
pipeline.run()
210+
211+
201212
@common.parametrize("test_data", test_data_suite)
202213
def test_upsample_nearest2d_vec_tosa_INT(test_data: torch.Tensor):
203214
test_data, size, scale_factor, compare_outputs = test_data()

backends/arm/tosa/dialect/ops/resize.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
import torch
99
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
1010
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
11+
from executorch.backends.arm.tosa.resize_utils import (
12+
calculate_tosa_resize_output_hw,
13+
get_tosa_resize_validation_error,
14+
)
1115

1216
from executorch.backends.arm.tosa.specification import (
1317
get_context_spec,
@@ -50,23 +54,17 @@ def _get_output_dtype(
5054
return output_dtype
5155

5256

53-
def _validate_resize_parameters(scale, border, resize_mode):
54-
def in_int16_range(values):
55-
return all(
56-
(x >= -(2**15)) and (x <= 2**15 - 1) for x in values if isinstance(x, int)
57-
)
58-
59-
if not in_int16_range(scale):
60-
raise TosaValueError("scale is out of the int16 range", op="RESIZE")
61-
if not in_int16_range(border):
62-
raise TosaValueError("border is out of the int16 range", op="RESIZE")
63-
if resize_mode == "bilinear":
64-
scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale
65-
if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n:
66-
raise TosaValueError(
67-
"Bilinear RESIZE downscale must be strictly greater than 1/16",
68-
op="RESIZE",
69-
)
57+
def _validate_resize_parameters(input_hw, output_hw, scale, offset, border, tosa_spec):
58+
validation_error = get_tosa_resize_validation_error(
59+
input_hw=input_hw,
60+
output_hw=output_hw,
61+
scale=scale,
62+
offset=offset,
63+
border=border,
64+
tosa_spec=tosa_spec,
65+
)
66+
if validation_error is not None:
67+
raise TosaValueError(validation_error, op="RESIZE")
7068

7169

7270
@register_fake_tosa_op(
@@ -88,24 +86,26 @@ def RESIZE(
8886
f"Input tensor must be 4D, but got {x.dim()}D", op="RESIZE"
8987
)
9088
_validate_resize_mode(resize_mode)
91-
_validate_resize_parameters(scale, border, resize_mode)
9289
output_dtype = _get_output_dtype(x.dtype, tosa_spec, resize_mode)
9390

9491
input_shape = x.shape
95-
scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale
96-
offset_y, offset_x = offset
97-
border_y, border_x = border
9892
H, W = input_shape[1], input_shape[2]
99-
# RESIZE first upscales the input by an integer value, to "upscale space".
100-
H_upscaled = (H - 1) * scale_y_n
101-
# offset and border are provided in this scale, therefore adjust for these while in this space.
102-
H_shifted = H_upscaled - offset_y + border_y
103-
# Then, complete the RESIZE by downscaling with another integer value, approximating multplication with a fraction.
104-
OH = (H_shifted // scale_y_d) + 1
105-
# Mirror the same computation horizontally for the output width.
106-
W_upscaled = (W - 1) * scale_x_n
107-
W_shifted = W_upscaled - offset_x + border_x
108-
OW = (W_shifted // scale_x_d) + 1
93+
_validate_resize_parameters((H, W), None, scale, offset, border, tosa_spec)
94+
output_hw = calculate_tosa_resize_output_hw((H, W), scale, offset, border)
95+
_validate_resize_parameters((H, W), output_hw, scale, offset, border, tosa_spec)
96+
if output_hw is None:
97+
scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale
98+
offset_y, offset_x = offset
99+
border_y, border_x = border
100+
# RESIZE first upscales the input by an integer value to "upscale
101+
# space". Offset and border are encoded in that space, then RESIZE
102+
# completes by downscaling with another integer value, approximating
103+
# multiplication by a fraction.
104+
OH = ((H - 1) * scale_y_n - offset_y + border_y) // scale_y_d + 1
105+
OW = ((W - 1) * scale_x_n - offset_x + border_x) // scale_x_d + 1
106+
else:
107+
OH, OW = output_hw
108+
109109
fake_aten_tensor = torch.empty(
110110
size=(input_shape[0], OH, OW, input_shape[3]), dtype=output_dtype
111111
)

0 commit comments

Comments
 (0)