Skip to content

Commit c71823c

Browse files
authored
NXP backend: Add QuantizeFusedConvBnBiasAtenPass call to integration tests pipeline (pytorch#18904)
Adds call to `QuantizeFusedConvBnBiasAtenPass` in conversion pipeline for integration tests. The pass enables QAT support for models containing biasless convolutions in integration tests. ### Test plan Covered by NXP internal tests. cc @robert-kalmar @JakeStevens @digantdesai
1 parent a49171d commit c71823c

5 files changed

Lines changed: 42 additions & 19 deletions

File tree

backends/nxp/quantizer/patterns.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,17 @@ def get_anchors(
543543

544544
# If the following node is a fusable activation, quantize together with activation
545545
output = [(conv_node,)]
546-
if len(
547-
conv_node.users
548-
) == 1 and self.neutron_target_info.is_supported_fused_activation__aten(
549-
activation := next(iter(conv_node.users))
546+
if len(conv_node.users) == 1 and (
547+
self.neutron_target_info.is_supported_fused_activation__aten(
548+
activation := next(iter(conv_node.users))
549+
)
550+
or (
551+
self.is_qat
552+
and _is_batch_norm(activation)
553+
and self.neutron_target_info.is_supported_fused_activation__aten(
554+
activation := next(iter(activation.users))
555+
)
556+
)
550557
):
551558
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
552559
activation.target
@@ -555,6 +562,14 @@ def get_anchors(
555562
output = []
556563
activation.meta["quantization_annotation"].input_qspec_map = {}
557564

565+
if isinstance(bn := next(iter(conv_node.users)), Node) and _is_batch_norm(
566+
bn
567+
):
568+
bn_quantizer = self.neutron_quantizer.op_to_quantizer[bn.target]
569+
bn_quantizer.annotate(gm)
570+
bn.meta["quantization_annotation"].input_qspec_map = {}
571+
bn.meta["quantization_annotation"].output_qspec = None
572+
558573
# In order for QAT to be numerically correct, there should be no quantization between
559574
# convolution node and batch norm node.
560575
if self.is_qat:

backends/nxp/quantizer/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ def calibrate_and_quantize(
219219

220220
m = convert_pt2e(m)
221221

222-
m = QuantizeFusedConvBnBiasAtenPass(default_zero_bias=True)(m).graph_module
222+
if is_qat:
223+
m = QuantizeFusedConvBnBiasAtenPass(
224+
default_zero_bias=False, symmetric_quant=True
225+
)(m).graph_module
223226

224227
return m

backends/nxp/tests/generic_tests/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_conv_fc_softmax__to_executorch_program(use_qat):
2929
delegation_info = get_delegation_info(program.graph_module)
3030
assert delegation_info.num_delegated_subgraphs == 1
3131
assert delegation_info.num_non_delegated_nodes == 11
32-
assert delegation_info.num_delegated_nodes == 14
32+
assert delegation_info.num_delegated_nodes == 13
3333

3434
for node in program.graph.nodes:
3535
# Make sure Convolution and AddMM are delegated

backends/nxp/tests/generic_tests/test_qdq_clustering_conv.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,16 @@ def test_conv2d_partitioner():
1616
lowered_module = edge_program.exported_program().graph_module.lowered_module_0
1717
nodes = list(lowered_module.original_module.graph.nodes)
1818

19-
assert len(nodes) == 13
19+
assert len(nodes) == 9
2020

21-
q_x_node = nodes[6]
22-
dq_w_node = nodes[7]
23-
dq_x_node = nodes[8]
24-
dq_bias_node = nodes[9]
25-
conv_node = nodes[10]
26-
q_y_node = nodes[11]
21+
q_x_node = nodes[3]
22+
dq_x_node = nodes[4]
23+
dq_w_node = nodes[5]
24+
conv_node = nodes[6]
25+
q_y_node = nodes[7]
2726

2827
assert "cluster" not in q_x_node.meta
2928
assert dq_w_node.meta["cluster"] == "aten_convolution_default_cluster"
3029
assert dq_x_node.meta["cluster"] == "aten_convolution_default_cluster"
31-
assert dq_bias_node.meta["cluster"] == "aten_convolution_default_cluster"
3230
assert conv_node.meta["cluster"] == "aten_convolution_default_cluster"
3331
assert q_y_node.meta["cluster"] == "aten_convolution_default_cluster"

backends/transforms/quantize_fused_convbn_bias_pass.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def _quantize_fused_conv_bias(
171171
set_param,
172172
get_weight_scale_tensor,
173173
default_zero_bias=False,
174+
use_symmetric_quantization=False,
174175
):
175176
"""Core logic for quantizing biases introduced by BatchNorm fusion/QAT.
176177
@@ -188,6 +189,7 @@ def _quantize_fused_conv_bias(
188189
set_param: Callable(node_or_name, tensor, insert_before=None) -> Node.
189190
get_weight_scale_tensor: Callable(node) -> Tensor.
190191
default_zero_bias: If True, create zero bias for conv nodes without bias.
192+
use_symmetric_quantization: If True, uses symmetric quantization range.
191193
192194
Returns:
193195
True if any modifications were made.
@@ -236,6 +238,7 @@ def _quantize_fused_conv_bias(
236238
else torch.empty(bias.shape, dtype=torch.float32)
237239
)
238240

241+
quant_min = -(2**31) + 1 if use_symmetric_quantization else -(2**31)
239242
if isinstance(weight_dequant.args[1], torch.fx.node.Node):
240243
weight_scale = get_weight_scale_tensor(weight_dequant.args[1])
241244
bias_scale = input_dequant.args[1] * weight_scale
@@ -246,7 +249,7 @@ def _quantize_fused_conv_bias(
246249
bias_scale,
247250
bias_zp,
248251
0,
249-
-(2**31),
252+
quant_min,
250253
2**31 - 1,
251254
torch.int32,
252255
)
@@ -267,7 +270,7 @@ def _quantize_fused_conv_bias(
267270
scale_node,
268271
zp_node,
269272
0,
270-
-(2**31),
273+
quant_min,
271274
2**31 - 1,
272275
torch.int32,
273276
),
@@ -279,14 +282,14 @@ def _quantize_fused_conv_bias(
279282
bias_scale = input_dequant.args[1] * weight_scale
280283

281284
qbias = torch.ops.quantized_decomposed.quantize_per_tensor.default(
282-
bias, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32
285+
bias, bias_scale, 0, quant_min, 2**31 - 1, torch.int32
283286
)
284287
set_param(bias_node, qbias)
285288

286289
with graph_module.graph.inserting_before(node):
287290
bias_dequant = graph_module.graph.call_function(
288291
dq_per_tensor,
289-
(bias_node, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32),
292+
(bias_node, bias_scale, 0, quant_min, 2**31 - 1, torch.int32),
290293
)
291294
bias_dequant.meta["val"] = dequant_val
292295
node.replace_input_with(bias_node, bias_dequant)
@@ -306,9 +309,12 @@ class QuantizeFusedConvBnBiasAtenPass(PassBase):
306309
exported_program can be omitted.
307310
"""
308311

309-
def __init__(self, exported_program=None, default_zero_bias=False) -> None:
312+
def __init__(
313+
self, exported_program=None, default_zero_bias=False, symmetric_quant=False
314+
) -> None:
310315
self.exported_program = exported_program
311316
self.default_zero_bias = default_zero_bias
317+
self.symmetric_quantization = symmetric_quant
312318

313319
def call(self, graph_module: fx.GraphModule) -> PassResult:
314320
ep = self.exported_program
@@ -351,5 +357,6 @@ def get_scale(node):
351357
set_param=set_param,
352358
get_weight_scale_tensor=get_scale,
353359
default_zero_bias=self.default_zero_bias,
360+
use_symmetric_quantization=self.symmetric_quantization,
354361
)
355362
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)