Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions backends/nxp/aten_passes/neutron_aten_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,32 @@
PassType = type[Callable[[torch.fx.GraphModule], PassResult]]


def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[PassType]:
passes = [
DecomposeSplitToSlicesPass(),
SplitGroupConvolution(),
SplitGRUBasedOnNumLayers(),
RemoveNodesWithKnownOutputs(),
FuseLinearAndAddPass(),
MoveActivationBeforeConcat(neutron_target_spec),
ConvertUnsqueezeToViewPass(),
]

if not qat_mode:
# In QAT mode, the fusing should happen after the training
# to preserve batch norm stats updating mechanism.
passes.append(FuseBatchNormWithConvPass())
passes.append(FuseBatchNormWithLinearPass())

return passes


class NeutronAtenPassManager(PassManager):

def __init__(
self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None
):
passes: list[PassType] = passes or [
DecomposeSplitToSlicesPass(),
FuseBatchNormWithConvPass(),
FuseBatchNormWithLinearPass(),
SplitGroupConvolution(),
SplitGRUBasedOnNumLayers(),
RemoveNodesWithKnownOutputs(),
FuseLinearAndAddPass(),
MoveActivationBeforeConcat(neutron_target_spec),
ConvertUnsqueezeToViewPass(),
]

passes: list[PassType] = passes or _get_default_passes(neutron_target_spec)
super().__init__(passes)

def __call__(self, module: nn.Module) -> PassResult:
Expand Down
12 changes: 10 additions & 2 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 NXP
# Copyright 2024-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
_get_default_passes,
NeutronAtenPassManager,
)

Expand All @@ -17,6 +18,7 @@
AddmmPattern,
AddTensorPattern,
AvgPoolPattern,
BatchNormPattern,
CatPattern,
Conv1dPattern,
Conv2dPattern,
Expand Down Expand Up @@ -245,6 +247,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig),
OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(BatchNormPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
Expand Down Expand Up @@ -293,7 +296,12 @@ def transform_for_annotation(
) -> torch.fx.GraphModule:
model.graph.eliminate_dead_code() # Remove dead code to simplify the graph for the passes.

model = NeutronAtenPassManager(self.neutron_target_spec)(model).graph_module
pass_manager = NeutronAtenPassManager(
self.neutron_target_spec,
_get_default_passes(self.neutron_target_spec, self.is_qat),
)

model = pass_manager(model).graph_module

model.graph.eliminate_dead_code() # Remove dead code again, in case it was created by the passes.

Expand Down
61 changes: 58 additions & 3 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 NXP
# Copyright 2025-2026 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -153,6 +153,27 @@ def get_anchors(
)


class BatchNormPattern(QuantizationPattern):
def __init__(self, is_qat: bool):
super().__init__(is_qat=is_qat)

def partition_types(self) -> list[OpOverload]:
# BatchNorm quantization is needed only when in QAT mode
return [torch.ops.aten.batch_norm.default] if self.is_qat else []

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]

return PartitionAnchors(
inputs=[],
weights=[],
biases=[],
output=[(node,)],
)


def get_anchors_for_fixed_quant_specs(
fused_partition: list[fx.GraphModule],
scale: float,
Expand Down Expand Up @@ -356,6 +377,14 @@ def get_anchors(
)


def _is_batch_norm(node_: Node) -> bool:
return node_.op == "call_function" and node_.target in [
torch.ops.aten.batch_norm.default,
torch.ops.aten.native_batch_norm.default,
torch.ops.aten._native_batch_norm_legit_no_training.default,
]


class ConvPattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
Expand Down Expand Up @@ -398,11 +427,20 @@ def get_anchors(
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]

output_specs = [(conv_node,)]
# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output_specs = []

return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
biases=bias,
output=[(conv_node,)],
output=output_specs,
)


Expand Down Expand Up @@ -479,6 +517,14 @@ def get_anchors(
output = []
activation.meta["quantization_annotation"].input_qspec_map = {}

# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output = []

return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
Expand Down Expand Up @@ -524,11 +570,20 @@ def get_anchors(
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]

output_specs = [(conv_node,)]
# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output_specs = []

return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
biases=bias,
output=[(conv_node,)],
output=output_specs,
)


Expand Down
24 changes: 24 additions & 0 deletions backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,30 @@ def forward(self, x):
return self.pool(x)


class ConvBNModule(torch.nn.Module):
def __init__(self, conv_module, conv_bias, bn_affine):
super().__init__()

if conv_module == "conv1d":
self.conv = torch.nn.Conv1d(3, 64, 3, padding=1, bias=conv_bias)
self.bn = torch.nn.BatchNorm1d(64, affine=bn_affine)
elif conv_module == "conv2d":
self.conv = torch.nn.Conv2d(3, 64, 3, padding=1, bias=conv_bias)
self.bn = torch.nn.BatchNorm2d(64, affine=bn_affine)
elif conv_module == "conv1d_t":
self.conv = torch.nn.ConvTranspose1d(3, 64, 3, padding=1, bias=conv_bias)
self.bn = torch.nn.BatchNorm1d(64, affine=bn_affine)
elif conv_module == "conv2d_t":
self.conv = torch.nn.ConvTranspose2d(3, 64, 3, padding=1, bias=conv_bias)
self.bn = torch.nn.BatchNorm2d(64, affine=bn_affine)
else:
raise ValueError(f"Unknown conv_module: {conv_module}")

def forward(self, x):
x = self.conv(x)
return self.bn(x)


class MulTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
67 changes: 66 additions & 1 deletion backends/nxp/tests/test_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 NXP
# Copyright 2024-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -66,6 +66,14 @@
("sigmoid", False, False),
]

batch_norm_ops = (
exir_ops.edge.aten._native_batch_norm_legit.no_stats,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
torch.ops.aten._native_batch_norm_legit_no_training.default,
torch.ops.aten.batch_norm.default,
torch.ops.aten.native_batch_norm.default,
)


@pytest.fixture(autouse=True)
def reseed_model_per_test_run():
Expand Down Expand Up @@ -636,3 +644,60 @@ def test_qat_produces_same_graph_as_ptq():
qat_quantized_model.graph.nodes, ptq_quantized_model.graph.nodes
)
)


# TODO: conv1d_t is currently unsupported, add when resolved
@pytest.mark.parametrize("conv_module", ["conv1d", "conv2d", "conv2d_t"])
@pytest.mark.parametrize("conv_bias", [True, False])
@pytest.mark.parametrize("bn_affine", [True, False])
def test_torchao_native_conv_bn_qat_fusing(conv_module, conv_bias, bn_affine):
if not conv_bias:
pytest.skip("Conv without bias is not supported.")

if conv_module.startswith("conv1d"):
input_shape = (1, 3, 32)
elif conv_module.startswith("conv2d"):
input_shape = (1, 3, 32, 32)

model = models.ConvBNModule(
conv_module=conv_module,
conv_bias=conv_bias,
bn_affine=bn_affine,
)
model.eval()

exported_model = export(model, (torch.randn(*input_shape),), strict=True)
prepared_model = _prepare_for_quantization(exported_model, is_qat=True)
quantized_model = convert_pt2e(prepared_model)

def is_conv(node):
return node.op == "call_function" and node.target in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose2d.input,
]

graph_nodes = list(quantized_model.graph.nodes)
conv_node = next(n for n in graph_nodes if is_conv(n))
conv_node_args = conv_node.args

if len(conv_node_args) > 3:
conv_node_args = conv_node_args[:3]

assert not any(
n.target in batch_norm_ops for n in graph_nodes if hasattr(n, "target")
)
assert (
len(conv_node.users) == 1
and list(conv_node.users.keys())[0].target
== torch.ops.quantized_decomposed.quantize_per_tensor.default
)
assert all(
arg.target
in (
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
)
for arg in conv_node_args
)
assert len(graph_nodes) == 15
48 changes: 39 additions & 9 deletions examples/nxp/aot_neutron_compile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 NXP
# Copyright 2024-2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -39,9 +39,18 @@
to_edge_transform_and_lower,
)
from executorch.extension.export_util import save_pte_program
from torch.ao.quantization import (
move_exported_model_to_eval,
move_exported_model_to_train,
)
from torch.export import export
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e

from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model
from .experimental.cifar_net.cifar_net import (
CifarNet,
test_cifarnet_model,
train_cifarnet_model,
)
from .models.mobilenet_v2 import MobilenetV2

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down Expand Up @@ -149,6 +158,13 @@ def get_model_and_inputs_from_name(model_name: str):
default=False,
help="Produce a quantized model",
)
parser.add_argument(
"--use_qat",
action="store_true",
required=False,
default=False,
help="Use QAT mode for quantization (performs two QAT training epochs)",
)
parser.add_argument(
"-s",
"--so_library",
Expand Down Expand Up @@ -230,13 +246,27 @@ def get_model_and_inputs_from_name(model_name: str):

# 3. Quantize if required
if args.quantize:
if calibration_inputs is None:
logging.warning(
"No calibration inputs available, using the example inputs instead"
)
calibration_inputs = example_inputs
quantizer = NeutronQuantizer(neutron_target_spec)
module = calibrate_and_quantize(module, calibration_inputs, quantizer)
quantizer = NeutronQuantizer(neutron_target_spec, is_qat=args.use_qat)
if args.use_qat:
match args.model_name:
case "cifar10":
print("Starting two epochs of QAT training with CifarNet model...")
module = prepare_qat_pt2e(module, quantizer)
module = move_exported_model_to_train(module)
module = train_cifarnet_model(module, num_epochs=2)
module = move_exported_model_to_eval(module)
module = convert_pt2e(module)
case _:
raise ValueError(
f"QAT training is not supported for model '{args.model_name}'"
)
else:
if calibration_inputs is None:
logging.warning(
"No calibration inputs available, using the example inputs instead"
)
calibration_inputs = example_inputs
module = calibrate_and_quantize(module, calibration_inputs, quantizer)

if args.so_library is not None:
logging.debug(f"Loading libraries: {args.so_library}")
Expand Down
Loading
Loading