Skip to content

Commit 495eec7

Browse files
Arm backend: Annotate maximum/minimum ops w. independent observers (pytorch#18009)
Previous shared observers meant that inputs and outputs were all quantized within the same range. In cases where the output is heavily truncated, this left a lot of the output range unused, leading to unnecessarily poor accuracy. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 08c3a72 commit 495eec7

3 files changed

Lines changed: 49 additions & 27 deletions

File tree

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,23 @@ def _get_output_qparams(
200200

201201
if target in [
202202
exir_ops.edge.aten.abs.default,
203-
exir_ops.edge.aten.maximum.default,
204-
exir_ops.edge.aten.minimum.default,
205203
exir_ops.edge.aten.sum.dim_IntList,
206204
exir_ops.edge.aten.add.Tensor,
207205
exir_ops.edge.aten.sub.Tensor,
208206
]:
209207
# The op has not altered the scale; the output scale is equal to
210208
# the operands' scales.
211209
return self._int32_qargs(inputs_qparams[0].get_scale_per_tensor())
210+
elif target in [
211+
exir_ops.edge.aten.maximum.default,
212+
exir_ops.edge.aten.minimum.default,
213+
]:
214+
# Min/Max use a shared INT32 accumulator scale for inputs, then
215+
# rescale to the original output activation scale.
216+
min_scale = min(
217+
[qp.get_scale_per_tensor() for qp in inputs_qparams.values()]
218+
)
219+
return self._int32_qargs(min_scale)
212220
elif target in [
213221
exir_ops.edge.aten.eq.Tensor,
214222
exir_ops.edge.aten.ge.Tensor,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -666,16 +666,11 @@ def any_or_hardtanh_min_zero(n: Node):
666666
torch.ops.aten.minimum.default,
667667
torch.ops.aten.maximum.default,
668668
):
669-
lhs_node = ensure_type(Node, node.args[0])
670-
shared_qspec = SharedQuantizationSpec((lhs_node, node))
671669
quant_properties.quant_inputs = [
672670
_QuantProperty(0, input_act_qspec),
673-
_QuantProperty(
674-
1,
675-
input_act_qspec if node.args[0] == node.args[1] else shared_qspec,
676-
),
671+
_QuantProperty(1, input_act_qspec),
677672
]
678-
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
673+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
679674
elif node.target in (torch.ops.aten.where.self,):
680675
true_node = ensure_type(Node, node.args[1])
681676
input_qspec = (

backends/arm/test/misc/test_shared_qspecs.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,12 @@ class SharedQspecInputForkNonShared(torch.nn.Module):
126126
outputs_qspecs = {None: 1}
127127
quant_params = {
128128
"quantized_decomposed.dequantize_per_tensor.default": {
129-
(0.01959827, -26, -128, 127, torch.int8): 4,
129+
(0.015678614, -64, -128, 127, torch.int8): 3,
130+
(0.015678614, 0, -128, 127, torch.int8): 1,
130131
},
131132
"quantized_decomposed.quantize_per_tensor.default": {
132-
(0.01959827, -26, -128, 127, torch.int8): 4,
133+
(0.015678614, -64, -128, 127, torch.int8): 3,
134+
(0.015678614, 0, -128, 127, torch.int8): 1,
133135
},
134136
}
135137

@@ -151,10 +153,12 @@ class SharedQspecInputForkShared(torch.nn.Module):
151153
outputs_qspecs = {None: 1}
152154
quant_params = {
153155
"quantized_decomposed.dequantize_per_tensor.default": {
154-
(0.01959827, -26, -128, 127, torch.int8): 5,
156+
(0.015678614, -64, -128, 127, torch.int8): 2,
157+
(0.015678614, 0, -128, 127, torch.int8): 3,
155158
},
156159
"quantized_decomposed.quantize_per_tensor.default": {
157-
(0.01959827, -26, -128, 127, torch.int8): 5,
160+
(0.015678614, -64, -128, 127, torch.int8): 2,
161+
(0.015678614, 0, -128, 127, torch.int8): 3,
158162
},
159163
}
160164

@@ -178,10 +182,12 @@ class SharedQspecInputForkXShared(torch.nn.Module):
178182
outputs_qspecs = {None: 1}
179183
quant_params = {
180184
"quantized_decomposed.dequantize_per_tensor.default": {
181-
(0.01959827, -26, -128, 127, torch.int8): 4,
185+
(0.015678614, -64, -128, 127, torch.int8): 2,
186+
(0.015678614, 0, -128, 127, torch.int8): 2,
182187
},
183188
"quantized_decomposed.quantize_per_tensor.default": {
184-
(0.01959827, -26, -128, 127, torch.int8): 4,
189+
(0.015678614, -64, -128, 127, torch.int8): 2,
190+
(0.015678614, 0, -128, 127, torch.int8): 2,
185191
},
186192
}
187193

@@ -204,10 +210,12 @@ class SharedQspecInputForkYShared(torch.nn.Module):
204210
outputs_qspecs = {None: 1}
205211
quant_params = {
206212
"quantized_decomposed.dequantize_per_tensor.default": {
207-
(0.01959827, -26, -128, 127, torch.int8): 5,
213+
(0.015678614, -64, -128, 127, torch.int8): 2,
214+
(0.015678614, 0, -128, 127, torch.int8): 3,
208215
},
209216
"quantized_decomposed.quantize_per_tensor.default": {
210-
(0.01959827, -26, -128, 127, torch.int8): 5,
217+
(0.015678614, -64, -128, 127, torch.int8): 2,
218+
(0.015678614, 0, -128, 127, torch.int8): 3,
211219
},
212220
}
213221

@@ -230,10 +238,11 @@ class SharedQspecInputForkXConstant(torch.nn.Module):
230238
outputs_qspecs = {None: 1}
231239
quant_params = {
232240
"quantized_decomposed.dequantize_per_tensor.default": {
233-
(0.027437577, -55, -128, 127, torch.int8): 3,
241+
(0.015678614, 0, -128, 127, torch.int8): 2,
242+
(0.019607844, -128, -128, 127, torch.int8): 1,
234243
},
235244
"quantized_decomposed.quantize_per_tensor.default": {
236-
(0.027437577, -55, -128, 127, torch.int8): 2,
245+
(0.015678614, 0, -128, 127, torch.int8): 2,
237246
},
238247
}
239248
constant = torch.tensor(5.0)
@@ -255,10 +264,12 @@ class SharedQspecInputForkYConstant(torch.nn.Module):
255264
outputs_qspecs = {None: 1}
256265
quant_params = {
257266
"quantized_decomposed.dequantize_per_tensor.default": {
258-
(0.027437577, -55, -128, 127, torch.int8): 3,
267+
(0.015678614, 0, -128, 127, torch.int8): 1,
268+
(0.019607844, -128, -128, 127, torch.int8): 2,
259269
},
260270
"quantized_decomposed.quantize_per_tensor.default": {
261-
(0.027437577, -55, -128, 127, torch.int8): 2,
271+
(0.015678614, 0, -128, 127, torch.int8): 1,
272+
(0.019607844, -128, -128, 127, torch.int8): 1,
262273
},
263274
}
264275

@@ -365,10 +376,14 @@ class SharedQspecSurroundedQuantizedOp(torch.nn.Module):
365376
outputs_qspecs = {None: 1}
366377
quant_params = {
367378
"quantized_decomposed.dequantize_per_tensor.default": {
368-
(1.019109964, 123, -128, 127, torch.int8): 5,
379+
(0.509554982, 123, -128, 127, torch.int8): 3,
380+
(0.517394304, 119, -128, 127, torch.int8): 1,
381+
(1.019109964, 123, -128, 127, torch.int8): 1,
369382
},
370383
"quantized_decomposed.quantize_per_tensor.default": {
371-
(1.019109964, 123, -128, 127, torch.int8): 4,
384+
(0.509554982, 123, -128, 127, torch.int8): 2,
385+
(0.517394304, 119, -128, 127, torch.int8): 1,
386+
(1.019109964, 123, -128, 127, torch.int8): 1,
372387
},
373388
}
374389

@@ -393,11 +408,13 @@ class SharedQspecSurroundedQuantizedOpConstant(torch.nn.Module):
393408
quant_params = {
394409
"quantized_decomposed.dequantize_per_tensor.default": {
395410
(0.003921569, -128, -128, 127, torch.int8): 1,
396-
(0.01959827, -26, -128, 127, torch.int8): 5,
411+
(0.015678614, -64, -128, 127, torch.int8): 2,
412+
(0.015678614, 0, -128, 127, torch.int8): 3,
397413
},
398414
"quantized_decomposed.quantize_per_tensor.default": {
399415
(0.003921569, -128, -128, 127, torch.int8): 1,
400-
(0.01959827, -26, -128, 127, torch.int8): 4,
416+
(0.015678614, -64, -128, 127, torch.int8): 2,
417+
(0.015678614, 0, -128, 127, torch.int8): 2,
401418
},
402419
}
403420

@@ -532,11 +549,13 @@ class MixedMaximumInt8Int16(torch.nn.Module):
532549
output_qspecs = {None: 1}
533550
quant_params = {
534551
"quantized_decomposed.quantize_per_tensor.default": {
535-
(0.015678614, 0, -128, 127, torch.int8): 4,
552+
(0.007839307, -128, -128, 127, torch.int8): 2,
553+
(0.015678614, 0, -128, 127, torch.int8): 2,
536554
(0.000244141, 0, -32767, 32767, torch.int16): 2,
537555
},
538556
"quantized_decomposed.dequantize_per_tensor.default": {
539-
(0.015678614, 0, -128, 127, torch.int8): 4,
557+
(0.007839307, -128, -128, 127, torch.int8): 2,
558+
(0.015678614, 0, -128, 127, torch.int8): 2,
540559
(0.000244141, 0, -32767, 32767, torch.int16): 2,
541560
},
542561
}

0 commit comments

Comments
 (0)