Skip to content

Commit 6c74cdc

Browse files
NXP backend: Add support for sigmoid with the new Neutron flow. (pytorch#19687)
### Summary This PR adds support for the aten.sigmoid operator with the new Neutron MLIR flow. ### Test plan Unit tests provided. cc @robert-kalmar @JakeStevens @digantdesai @rascani
1 parent fdd4f43 commit 6c74cdc

4 files changed

Lines changed: 108 additions & 3 deletions

File tree

backends/nxp/backend/ir/converter/builder/model_builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,10 @@ def _validate_new_tensor_name(self, name: str) -> str:
742742
return new_name
743743

744744
def op_code_index_for_op_type(
745-
self, op_type: BuiltinOperator, version: int = 1, custom_code: str = None
745+
self,
746+
op_type: BuiltinOperator | int,
747+
version: int = 1,
748+
custom_code: str | None = None,
746749
):
747750
"""
748751
Return the index to the 'operator_codes' vector in the TFLite model for the operator

backends/nxp/backend/ir/converter/node_converters/ops_converters/sigmoid_converter.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import torch
7+
68
from executorch.backends.nxp.backend.ir.converter.node_converter import (
79
CustomDelegationOptions,
810
NodeConverter,
911
)
1012
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import (
1113
BuiltinOperator,
1214
)
15+
16+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1317
from torch.fx import Node
1418
from torch.nn import Parameter
1519

@@ -24,7 +28,37 @@ def _is_supported_in_IR(
2428
) -> bool:
2529
return True
2630

31+
@staticmethod
32+
def _is_supported_on_target(
33+
node: Node,
34+
neutron_target_spec: NeutronTargetSpec,
35+
parameters_mapping: dict[str, Parameter],
36+
custom_delegation_options: CustomDelegationOptions,
37+
) -> bool:
38+
if custom_delegation_options.use_new_flow_neutron_c:
39+
# Requirements specified by the new Neutron flow documentation.
40+
41+
if not NodeConverter.uses_quantization_type_for_io(
42+
node,
43+
supported_types=[torch.int8, torch.uint8],
44+
input_indices=[0],
45+
output_indices=[0],
46+
):
47+
return False
48+
49+
return True
50+
51+
else:
52+
# Requirements of the old Neutron flow.
53+
return True
54+
2755
def convert(self, node: Node):
56+
"""Convert the `aten.sigmoid.default` node to NeutronIR `Logistic` operator.
57+
The ExecuTorch schema is:
58+
sigmoid(
59+
Tensor self
60+
) -> Tensor
61+
"""
2862
self.assert_convertible(node)
2963

3064
t_op = self._create_tflite_op_with_io_tensors(node)

backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66

77
import numpy as np
8+
9+
# noinspection PyUnusedImports
810
import pytest
911
import torch
1012

1113
from executorch.backends.nxp.backend.edge_program_converter import (
1214
EdgeProgramToIRConverter,
1315
)
16+
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
17+
from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator
1418
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
1519
from executorch.backends.nxp.tests.executors import (
1620
convert_run_compare,
1721
ToNCHWPreprocess,
1822
ToNHWCPreprocess,
1923
)
24+
from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier
25+
from executorch.backends.nxp.tests.model_output_comparator import (
26+
AllCloseOutputComparator,
27+
)
2028
from executorch.backends.nxp.tests.models import ConvWithSigmoid
29+
from executorch.backends.nxp.tests.nsys_testing import lower_run_compare
30+
from executorch.backends.nxp.tests.ops_aliases import DequantizePerTensor, Sigmoid
2131
from torch import nn
2232
from torch.export import ExportedProgram
2333
from executorch.backends.nxp.tests.use_qat import * # noqa F403
@@ -76,3 +86,60 @@ def test_sigmoid_only(mocker, use_qat, input_shape):
7686
convert_run_compare(
7787
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
7888
)
89+
90+
91+
class TestSigmoidNewNeutronFlow:
92+
# noinspection PyMethodMayBeStatic
93+
def assert_delegated(self, model, input_shape, mocker, use_qat=False, atol=None):
94+
graph_verifier = DetailedGraphVerifier(
95+
mocker,
96+
expected_delegated_ops={Sigmoid: 1},
97+
expected_non_delegated_ops={},
98+
)
99+
100+
# Create a RandomDatasetCreator that covers also negative numbers to properly test the operator.
101+
dataset_creator = RandomDatasetCreator(low=-2, high=2)
102+
103+
kwargs = {"atol": atol} if atol is not None else {}
104+
output_comparator = AllCloseOutputComparator(**kwargs)
105+
106+
lower_run_compare(
107+
model,
108+
input_shape,
109+
graph_verifier,
110+
dataset_creator,
111+
output_comparator,
112+
use_qat=use_qat,
113+
use_new_flow_neutron_c=True, # Use the new flow.
114+
)
115+
116+
def test__basic_nsys_inference__qat(self, mocker, use_qat):
117+
input_shape = (23,)
118+
model = nn.Sigmoid()
119+
self.assert_delegated(model, input_shape, mocker, use_qat=use_qat)
120+
121+
@pytest.mark.parametrize(
122+
"input_shape",
123+
[
124+
(2,),
125+
(2, 3),
126+
(2, 3, 4),
127+
(2, 3, 4, 5),
128+
(2, 3, 4, 5, 6),
129+
],
130+
ids=lambda shape: f"{len(shape)}D",
131+
)
132+
def test__input_shapes(self, mocker, input_shape):
133+
model = nn.Sigmoid()
134+
135+
output_scale = 1.0 / 256.0
136+
lowering_spy = mocker.spy(NeutronPartitioner, "partition")
137+
self.assert_delegated(
138+
model, input_shape, mocker, atol=output_scale
139+
) # Allow single bit error.
140+
141+
# Verify that the `atol` is indeed equal to the output scale.
142+
# In the near future, we would like to add support for testing with int8 IO, where this check will be trivial.
143+
nodes = list(lowering_spy.spy_return.tagged_exported_program.graph.nodes)
144+
assert nodes[-2].target == DequantizePerTensor
145+
assert nodes[-2].args[1] == output_scale

backends/nxp/tests/ops_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
QuantizePerChannel = exir_ops.edge.quantized_decomposed.quantize_per_channel.default
2828
QuantizePerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
2929
Relu = exir_ops.edge.aten.relu.default
30+
Sigmoid = exir_ops.edge.aten.sigmoid.default
3031
Slice = exir_ops.edge.aten.slice.Tensor
3132
SliceCopy = exir_ops.edge.aten.slice_copy.Tensor
3233
Softmax = exir_ops.edge.aten._softmax.default

0 commit comments

Comments
 (0)