Skip to content

Commit 17cb87c

Browse files
authored
Arm backend: Correct per channel axis for transpose conv (pytorch#17842)
Adjust per-channel weight axis for conv_transpose2d based on group(s). The corrected axis is propagated to: - QuantizationSpec - QAT fake-quant constructors wrapped in PartialWrapper - Non-QAT observer/fake-quant constructors (via with_args) cc @digantdesai @SS-JIA @freddan80 @per @zingo @oscarandersson8218 @Sebastian-Larsson @robell Signed-off-by: Måns Nilsson <mans.nilsson@arm.com>
1 parent e248a99 commit 17cb87c

2 files changed

Lines changed: 88 additions & 13 deletions

File tree

backends/arm/quantizer/quantization_annotator.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
1010
"""
1111

12+
import functools
1213
import logging
1314
import operator
1415
from dataclasses import dataclass, replace
@@ -22,6 +23,7 @@
2223
from torch._subclasses import FakeTensor
2324

2425
from torch.fx import Node
26+
from torchao.quantization.pt2e import PartialWrapper
2527
from torchao.quantization.pt2e.quantizer import (
2628
annotate_input_qspec_map,
2729
annotate_output_qspec,
@@ -85,21 +87,51 @@ def _as_list(x):
8587

8688
def _adjust_weight_qspec_for_conv_transpose(node: Node, weight_qspec):
8789
if (
88-
node.target == torch.ops.aten.conv_transpose2d.input
89-
and isinstance(weight_qspec, QuantizationSpec)
90-
and weight_qspec.qscheme == torch.per_channel_symmetric
91-
and weight_qspec.ch_axis != 1
90+
node.target != torch.ops.aten.conv_transpose2d.input
91+
or not isinstance(weight_qspec, QuantizationSpec)
92+
or weight_qspec.qscheme != torch.per_channel_symmetric
9293
):
93-
return QuantizationSpec(
94-
dtype=weight_qspec.dtype,
95-
observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr,
96-
quant_min=weight_qspec.quant_min,
97-
quant_max=weight_qspec.quant_max,
98-
qscheme=weight_qspec.qscheme,
99-
ch_axis=1,
100-
is_dynamic=weight_qspec.is_dynamic,
94+
return weight_qspec
95+
96+
# For now skip axis adjustment for a8w4 per-channel configs (int4 weights).
97+
if weight_qspec.quant_min == -7 and weight_qspec.quant_max == 7:
98+
return weight_qspec
99+
100+
groups = 1
101+
if len(node.args) > 6 and isinstance(node.args[6], int):
102+
groups = node.args[6]
103+
expected_axis = 0 if groups != 1 else 1
104+
if weight_qspec.ch_axis == expected_axis:
105+
return weight_qspec
106+
107+
observer_or_fake_quant_ctr = weight_qspec.observer_or_fake_quant_ctr
108+
# TorchAO PT2e QAT commonly represents the ctor as PartialWrapper(partial(...)).
109+
# Rebuild it to update ch_axis while preserving callable_args.
110+
if isinstance(observer_or_fake_quant_ctr, PartialWrapper):
111+
original_callable_args = dict(observer_or_fake_quant_ctr.callable_args)
112+
base_partial = observer_or_fake_quant_ctr.p
113+
if isinstance(base_partial, functools.partial):
114+
base_keywords = dict(base_partial.keywords or {})
115+
base_keywords["ch_axis"] = expected_axis
116+
observer_or_fake_quant_ctr = PartialWrapper(
117+
functools.partial(base_partial.func, **base_keywords)
118+
)
119+
observer_or_fake_quant_ctr.callable_args = original_callable_args
120+
# Non-QAT observer/fake-quant constructors can be updated via with_args.
121+
elif hasattr(observer_or_fake_quant_ctr, "with_args"):
122+
observer_or_fake_quant_ctr = observer_or_fake_quant_ctr.with_args(
123+
ch_axis=expected_axis
101124
)
102-
return weight_qspec
125+
126+
return QuantizationSpec(
127+
dtype=weight_qspec.dtype,
128+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
129+
quant_min=weight_qspec.quant_min,
130+
quant_max=weight_qspec.quant_max,
131+
qscheme=weight_qspec.qscheme,
132+
ch_axis=expected_axis,
133+
is_dynamic=weight_qspec.is_dynamic,
134+
)
103135

104136

105137
def _is_ok_for_quantization(

backends/arm/test/ops/test_transpose_conv2d.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,21 @@
1111
from executorch.backends.arm.quantizer.arm_quantizer import (
1212
get_symmetric_a16w8_quantization_config,
1313
get_symmetric_a8w4_quantization_config,
14+
get_symmetric_quantization_config,
15+
TOSAQuantizer,
1416
)
1517
from executorch.backends.arm.test import common
1618
from executorch.backends.arm.test.tester.test_pipeline import (
1719
EthosU55PipelineINT,
1820
EthosU85PipelineINT,
1921
OpNotSupportedPipeline,
22+
QuantizationPipeline,
2023
TosaPipelineFP,
2124
TosaPipelineINT,
2225
VgfPipeline,
2326
)
27+
from executorch.backends.arm.tosa.specification import TosaSpecification
28+
from executorch.backends.test.harness.stages.quantize import Quantize
2429

2530
aten_op = "torch.ops.aten.conv_transpose2d.input"
2631
exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" # No edge transpoe conv
@@ -94,6 +99,21 @@ def forward(self, x):
9499
for q in [True, False]
95100
}
96101

102+
test_data_QAT = {
103+
"qat_basic": lambda: (
104+
TransposeConv2d(
105+
in_channels=16,
106+
out_channels=4,
107+
kernel_size=4,
108+
stride=2,
109+
padding=1,
110+
groups=1,
111+
),
112+
True,
113+
True,
114+
),
115+
}
116+
97117
u55_supported_test_data_INT = {
98118
k: v
99119
for k, v in test_data_INT.items()
@@ -150,6 +170,29 @@ def test_conv_transpose2d_tosa_INT(test_data):
150170
pipeline.run()
151171

152172

173+
@common.parametrize("test_data", test_data_QAT)
174+
def test_conv_transpose2d_tosa_INT_qat_per_channel_quantization_pipeline(test_data):
175+
model, is_per_channel, is_qat = test_data()
176+
inputs = model.get_inputs()
177+
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
178+
quantizer.set_global(
179+
get_symmetric_quantization_config(
180+
is_per_channel=is_per_channel,
181+
is_qat=is_qat,
182+
)
183+
)
184+
pipeline = QuantizationPipeline[input_t](model, inputs, quantizer)
185+
pipeline.change_args(
186+
"quantize",
187+
Quantize(
188+
quantizer,
189+
quantization_config=quantizer.global_config,
190+
is_qat=is_qat,
191+
),
192+
)
193+
pipeline.run()
194+
195+
153196
_a8w4_transpose_conv_xfails = {
154197
k: "per-channel int4 weight quantization is not supported for transpose conv yet."
155198
for k in test_data_INT

0 commit comments

Comments
 (0)