Skip to content

Commit a76d9cd

Browse files
NXP backend: Add support for leaky_relu with the new Neutron flow. (pytorch#19667)
### Summary This PR adds support for the `aten.leaky_relu` operator with the new Neutron MLIR flow. ### Test plan Unit tests provided. cc @robert-kalmar @JakeStevens @digantdesai @rascani
1 parent 6c74cdc commit a76d9cd

3 files changed

Lines changed: 96 additions & 7 deletions

File tree

backends/nxp/backend/ir/converter/node_converters/ops_converters/leaky_relu_converter.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import torch
7+
68
from executorch.backends.nxp.backend.ir.converter.node_converter import (
79
CustomDelegationOptions,
810
NodeConverter,
911
)
1012
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.leaky_relu_options import (
1113
LeakyRelu,
1214
)
15+
16+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1317
from torch.fx import Node
1418
from torch.nn import Parameter
1519

@@ -24,6 +28,29 @@ def _is_supported_in_IR(
2428
) -> bool:
2529
return True
2630

31+
@staticmethod
32+
def _is_supported_on_target(
33+
node: Node,
34+
neutron_target_spec: NeutronTargetSpec,
35+
parameters_mapping: dict[str, Parameter],
36+
custom_delegation_options: CustomDelegationOptions,
37+
) -> bool:
38+
if custom_delegation_options.use_new_flow_neutron_c:
39+
# Requirements specified by the new Neutron flow documentation.
40+
41+
if not NodeConverter.uses_quantization_type_for_io(
42+
node,
43+
supported_types=[torch.int8, torch.uint8],
44+
input_indices=[0],
45+
output_indices=[0],
46+
):
47+
return False
48+
49+
return True
50+
else:
51+
52+
return True
53+
2754
def convert(self, node: Node):
2855
"""Convert the `aten.leaky_relu.default` operator to Neutron IR `LeakyRelu`.
2956
The schema is:

backends/nxp/tests/ir/converter/node_converter/test_leaky_relu_converter.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,24 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import numpy as np
7+
8+
# noinspection PyUnusedImports
79
import pytest
810
import torch
911

1012
from executorch.backends.nxp.backend.edge_program_converter import (
1113
EdgeProgramToIRConverter,
1214
)
15+
from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator
1316
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
1417
from executorch.backends.nxp.tests.executors import (
1518
convert_run_compare,
1619
graph_contains_any_of_ops,
1720
)
18-
from executorch.exir.dialects._ops import ops as exir_ops
21+
from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier
22+
from executorch.backends.nxp.tests.nsys_testing import lower_run_compare
23+
from executorch.backends.nxp.tests.ops_aliases import ExecutorchDelegateCall, LeakyRelu
24+
from executorch.backends.nxp.tests.use_qat import * # noqa F403
1925

2026

2127
@pytest.fixture(autouse=True)
@@ -24,17 +30,13 @@ def reseed_model_per_test_run():
2430
np.random.seed(23)
2531

2632

27-
ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate
28-
LeakyRelu2D = exir_ops.edge.aten.leaky_relu.default
29-
30-
3133
def _assert_successful_delegation(model, input_shape, mocker, atol=0):
3234
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
3335
delegated_ep = to_quantized_edge_program(model, input_shape).exported_program()
3436

3537
# Make sure the `leaky_relu` was delegated.
3638
assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall])
37-
assert not graph_contains_any_of_ops(delegated_ep.graph, [LeakyRelu2D])
39+
assert not graph_contains_any_of_ops(delegated_ep.graph, [LeakyRelu])
3840

3941
# Verify correct behavior of the converted NeutronIR model.
4042
intermediate_ep = converter_spy.call_args.args[1]
@@ -45,7 +47,7 @@ def _assert_successful_delegation(model, input_shape, mocker, atol=0):
4547
).astype(np.int8)
4648

4749
# Make sure the tested program contains the `leaky_relu`.
48-
assert graph_contains_any_of_ops(intermediate_ep.graph, [LeakyRelu2D])
50+
assert graph_contains_any_of_ops(intermediate_ep.graph, [LeakyRelu])
4951

5052
convert_run_compare(
5153
intermediate_ep, tfl_model=neutron_ir_model, input_data=input_data, atol=atol
@@ -121,3 +123,62 @@ def test_convert_leaky_relu__ranks(mocker, input_shape: tuple[int, ...]):
121123
mocker,
122124
atol=1, # Common quantization rounding error.
123125
)
126+
127+
128+
class TestLeakyReluNewNeutronFlow:
129+
# noinspection PyMethodMayBeStatic
130+
def assert_delegated(self, model, input_shape, mocker, use_qat=False):
131+
graph_verifier = DetailedGraphVerifier(
132+
mocker,
133+
expected_delegated_ops={LeakyRelu: 1},
134+
expected_non_delegated_ops={},
135+
)
136+
137+
# Create a RandomDatasetCreator that covers also negative numbers to properly test the operator.
138+
dataset_creator = RandomDatasetCreator(low=-2, high=2)
139+
140+
lower_run_compare(
141+
model,
142+
input_shape,
143+
graph_verifier,
144+
dataset_creator,
145+
use_qat=use_qat,
146+
use_new_flow_neutron_c=True, # Use the new flow.
147+
)
148+
149+
@pytest.mark.parametrize(
150+
"input_shape",
151+
[
152+
(2,),
153+
(2, 3),
154+
(2, 3, 4),
155+
(2, 3, 4, 5),
156+
(2, 3, 4, 5, 6),
157+
],
158+
ids=lambda shape: f"{len(shape)}D",
159+
)
160+
def test__default_alpha__input_shapes(self, mocker, input_shape):
161+
model = LeakyReluModule()
162+
self.assert_delegated(model, input_shape, mocker)
163+
164+
def test__default_alpha__qat(self, mocker, use_qat):
165+
model = LeakyReluModule()
166+
input_shape = (23,)
167+
self.assert_delegated(model, input_shape, mocker, use_qat)
168+
169+
@pytest.mark.parametrize(
170+
"alpha",
171+
[0.01, 3.14159, 0, 1, float("inf")],
172+
ids=lambda alpha: f"alpha = {alpha}",
173+
)
174+
def test__specific_alpha(self, mocker, alpha):
175+
model = LeakyReluModule(negative_slope=alpha)
176+
self.assert_delegated(model, (23,), mocker)
177+
178+
def test__inplace(self, mocker):
179+
model = LeakyReluModule(inplace=True)
180+
self.assert_delegated(
181+
model,
182+
(23,),
183+
mocker,
184+
)

backends/nxp/tests/ops_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
GetItem = operator.getitem
2323
HardTanh = exir_ops.edge.aten.hardtanh.default
2424
HardTanh_ = exir_ops.edge.aten.hardtanh_.default
25+
LeakyRelu = exir_ops.edge.aten.leaky_relu.default
2526
MaxPool2DWithIndices = exir_ops.edge.aten.max_pool2d_with_indices.default
2627
MulTensor = exir_ops.edge.aten.mul.Tensor
2728
QuantizePerChannel = exir_ops.edge.quantized_decomposed.quantize_per_channel.default

0 commit comments

Comments
 (0)