Skip to content

Commit e3d2afc

Browse files
authored
Adds QAT ConvBN fuse pass to utils (pytorch#17599)
Summary: Earlier PR adds support for a pass that quantizes the bias resulting from QAT ConvBN fusion without an initial bias. This PR adds it to the NXP calibrate_and_quantize method. Differential Revision: D93904683 cc @robert-kalmar @digantdesai
1 parent 5123efe commit e3d2afc

8 files changed

Lines changed: 144 additions & 8 deletions

File tree

backends/nxp/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ fbcode_target(_kind = runtime.python_library,
5656
deps = [
5757
":aten_passes",
5858
"//caffe2:torch",
59+
"//executorch/backends/transforms:quantize_fused_convbn_bias_pass",
5960
"//pytorch/ao:torchao", # @manual
6061
],
6162
)

backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,41 @@
2525
_is_reshape = partial(is_op_node, target_op=torch.ops.aten.reshape)
2626
_is_zeros_like = partial(is_op_node, target_op=torch.ops.aten.zeros_like)
2727

28+
_CONV_TARGETS = {
29+
torch.ops.aten.conv1d.default,
30+
torch.ops.aten.conv1d.padding,
31+
torch.ops.aten.conv2d.default,
32+
torch.ops.aten.conv2d.padding,
33+
torch.ops.aten.conv_transpose1d.default,
34+
torch.ops.aten.conv_transpose2d.input,
35+
}
36+
37+
38+
def _feeds_into_linear(node: Node) -> bool:
39+
"""
40+
BFS from node to check if it eventually feeds into a linear op (not conv).
41+
This is required because:
42+
- Linear-BN fusion (added by AddSimulatedLinearBatchNormFusionQATPass, NXP-specific)
43+
- Conv-BN QAT fusion (added by TorchAO's _fuse_conv_bn_qat inside prepare_qat_pt2e)
44+
are structurally identical. Without this check, we would incorrectly remove
45+
Conv-BN scale factor chains, breaking Conv-BN QAT fusion when TorchAO's _fold_conv_bn_qat
46+
is called during convert_pt2e.
47+
"""
48+
visited = set()
49+
queue = list(node.users.keys())
50+
while queue:
51+
n = queue.pop(0)
52+
if n in visited:
53+
continue
54+
visited.add(n)
55+
if n.op == "call_function":
56+
if n.target == torch.ops.aten.linear.default:
57+
return True
58+
if n.target in _CONV_TARGETS:
59+
return False
60+
queue.extend(n.users.keys())
61+
return True
62+
2863

2964
def _is_denorm_pattern(node: Node) -> bool:
3065
if not _is_div(node):
@@ -56,6 +91,10 @@ def _remove_pattern_from_graph(graph_module: GraphModule, pattern: GraphModule):
5691
for match in matches:
5792
last_pattern_node = match.anchors[0]
5893
last_matched_subgraph_node = match.nodes_map[last_pattern_node]
94+
95+
if not _feeds_into_linear(last_matched_subgraph_node):
96+
continue
97+
5998
weight = match.placeholder_nodes[0]
6099

61100
last_matched_subgraph_node.replace_all_uses_with(weight)

backends/nxp/quantizer/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
AddSimulatedLinearBatchNormFusionQATPass,
2121
RemoveSimulatedLinearBatchNormFusionQATPass,
2222
)
23+
from executorch.backends.transforms.quantize_fused_convbn_bias_pass import (
24+
QuantizeFusedConvBnBiasAtenPass,
25+
)
2326
from torch import fx
2427
from torch._ops import OpOverload
2528
from torch.export import ExportedProgram
@@ -205,4 +208,6 @@ def calibrate_and_quantize(
205208

206209
m = convert_pt2e(m)
207210

211+
m = QuantizeFusedConvBnBiasAtenPass(default_zero_bias=True)(m).graph_module
212+
208213
return m

backends/nxp/tests/BUCK

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ fbcode_target(_kind = python_pytest,
5454
]
5555
)
5656

57+
fbcode_target(_kind = runtime.python_library,
58+
name = "use_qat",
59+
srcs = [
60+
"use_qat.py",
61+
],
62+
deps = [
63+
"fbsource//third-party/pypi/pytest:pytest",
64+
],
65+
)
66+
5767
fbcode_target(_kind = python_pytest,
5868
name = "test_batch_norm_fusion",
5969
srcs = [
@@ -68,3 +78,34 @@ fbcode_target(_kind = python_pytest,
6878
"fbsource//third-party/pypi/numpy:numpy",
6979
],
7080
)
81+
82+
fbcode_target(_kind = python_pytest,
83+
name = "test_qdq_clustering_conv",
84+
srcs = [
85+
"test_qdq_clustering_conv.py",
86+
],
87+
deps = [
88+
":executorch_pipeline",
89+
":models",
90+
],
91+
)
92+
93+
fbcode_target(_kind = python_pytest,
94+
name = "test_integration",
95+
srcs = [
96+
"test_integration.py",
97+
],
98+
preload_deps = [
99+
"//executorch/kernels/quantized:custom_ops_generated_lib",
100+
],
101+
deps = [
102+
":executorch_pipeline",
103+
":models",
104+
":use_qat",
105+
"//executorch/devtools/backend_debug:delegation_info",
106+
"//executorch/extension/pybindings:portable_lib",
107+
"//executorch/examples/nxp/experimental/cifar_net:cifar_net",
108+
"//executorch/kernels/quantized:custom_ops_generated_lib",
109+
"//executorch/kernels/quantized:quantized_ops_lib",
110+
],
111+
)

backends/nxp/tests/test_batch_norm_fusion.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
neutron_target_spec,
2323
to_quantized_edge_program,
2424
)
25-
from executorch.backends.nxp.tests.executors import OverrideTargetSupportCheck
25+
from executorch.backends.nxp.tests.executors import (
26+
graph_contains_any_of_ops,
27+
OverrideTargetSupportCheck,
28+
)
29+
30+
from executorch.backends.nxp.tests.models import ConvBNModule
2631
from torch import nn
2732

2833

@@ -229,3 +234,28 @@ def unsupported_target(*_): # Accept all input arguments and return `False`.
229234
node.op == "call_function" and "batch_norm" in node.target.__name__
230235
for node in nodes
231236
)
237+
238+
239+
@pytest.mark.parametrize(
240+
"conv_module",
241+
["conv2d"],
242+
)
243+
def test_biasless_convbn_fusion_qat(
244+
conv_module,
245+
):
246+
if conv_module.startswith("conv1d"):
247+
input_shape = (1, 3, 32)
248+
elif conv_module.startswith("conv2d"):
249+
input_shape = (1, 3, 32, 32)
250+
else: # conv3d
251+
input_shape = (1, 3, 32, 32, 32)
252+
253+
model = ConvBNModule(conv_module, conv_bias=False, bn_affine=True)
254+
255+
edge_program = to_quantized_edge_program(
256+
model, input_shape, use_qat=True, use_neutron_for_format_conversion=False
257+
).exported_program()
258+
259+
assert graph_contains_any_of_ops(
260+
edge_program.graph, [torch.ops.higher_order.executorch_call_delegate]
261+
)

backends/nxp/tests/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_conv_fc_softmax__to_executorch_program(use_qat):
2929
delegation_info = get_delegation_info(program.graph_module)
3030
assert delegation_info.num_delegated_subgraphs == 1
3131
assert delegation_info.num_non_delegated_nodes == 11
32-
assert delegation_info.num_delegated_nodes == 13
32+
assert delegation_info.num_delegated_nodes == 14
3333

3434
for node in program.graph.nodes:
3535
# Make sure Convolution and AddMM are delegated

backends/nxp/tests/test_qdq_clustering_conv.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@ def test_conv2d_partitioner():
1616
lowered_module = edge_program.exported_program().graph_module.lowered_module_0
1717
nodes = list(lowered_module.original_module.graph.nodes)
1818

19-
assert len(nodes) == 9
19+
assert len(nodes) == 13
2020

21-
q_x_node = nodes[3]
22-
dq_w_node = nodes[4]
23-
dq_x_node = nodes[5]
24-
conv_node = nodes[6]
25-
q_y_node = nodes[7]
21+
q_x_node = nodes[6]
22+
dq_w_node = nodes[7]
23+
dq_x_node = nodes[8]
24+
dq_bias_node = nodes[9]
25+
conv_node = nodes[10]
26+
q_y_node = nodes[11]
2627

2728
assert "cluster" not in q_x_node.meta
2829
assert dq_w_node.meta["cluster"] == "aten_convolution_default_cluster"
2930
assert dq_x_node.meta["cluster"] == "aten_convolution_default_cluster"
31+
assert dq_bias_node.meta["cluster"] == "aten_convolution_default_cluster"
3032
assert conv_node.meta["cluster"] == "aten_convolution_default_cluster"
3133
assert q_y_node.meta["cluster"] == "aten_convolution_default_cluster"
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3+
4+
oncall("executorch")
5+
6+
fbcode_target(_kind = runtime.python_library,
7+
name = "cifar_net",
8+
srcs = [
9+
"cifar_net.py",
10+
],
11+
deps = [
12+
"//caffe2:torch",
13+
"//executorch/exir:lib",
14+
"//executorch/examples/models:model_base",
15+
"fbsource//third-party/pypi/numpy:numpy",
16+
"//pytorch/vision:torchvision", # @manual
17+
],
18+
)

0 commit comments

Comments
 (0)