Skip to content

Commit 4ad6012

Browse files
authored
Arm backend: Add fold_quantize option to tester (pytorch#18005)
* Enables tests to be run with fold_quantize = False when required * This is used when you do not want convert_pt2e to fold constants * Temporary solution to enable testing of StaticCache in INT8 Change-Id: Ib25ea3949fc5f539c1a0a15565c3cbfe5099b9a1 cc @digantdesai @SS-JIA @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell --------- Signed-off-by: Tom Allsop <tom.allsop@arm.com>
1 parent 122fdef commit 4ad6012

3 files changed

Lines changed: 50 additions & 10 deletions

File tree

backends/arm/quantizer/arm_quantizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ def quantize_with_submodules(
665665
model: GraphModule,
666666
calibration_samples: list[tuple],
667667
is_qat: bool = False,
668+
fold_quantize: bool = True,
668669
):
669670
"""Quantizes a GraphModule in a way such that conditional submodules are
670671
handled properly.
@@ -680,6 +681,8 @@ def quantize_with_submodules(
680681
model with submodules, at least one sample per code path is
681682
needed.
682683
is_qat (bool): Whether to do quantization aware training or not.
684+
fold_quantize (bool): Enables or disables constant folding when quantization
685+
is completed.
683686
684687
Returns:
685688
GraphModule: The quantized model.
@@ -694,8 +697,11 @@ def quantize_with_submodules(
694697
prepared(*inp)
695698

696699
for name, submodule, _ in self._get_submodules_not_handled_by_torchao(prepared):
697-
prepared.set_submodule(name, convert_pt2e(submodule), strict=True)
698-
converted = convert_pt2e(prepared)
700+
prepared.set_submodule(
701+
name, convert_pt2e(submodule, fold_quantize=fold_quantize), strict=True
702+
)
703+
converted = convert_pt2e(prepared, fold_quantize=fold_quantize)
704+
699705
return converted
700706

701707

backends/arm/test/tester/quantize.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Optional, Tuple
6+
from typing import Any, Optional, Sequence, Tuple
77

88
import torch
99
from executorch.backends.arm.quantizer import TOSAQuantizer
@@ -14,9 +14,29 @@
1414
)
1515

1616
from torch.export import export
17+
from torchao.quantization.pt2e.quantizer import Quantizer
1718

1819

1920
class ArmQuantize(Quantize):
21+
def __init__(
22+
self,
23+
quantizer: Optional[Quantizer] = None,
24+
quantization_config: Optional[Any] = None,
25+
calibrate: bool = True,
26+
calibration_samples: Optional[Sequence[Any]] = None,
27+
is_qat: Optional[bool] = False,
28+
set_global: bool = True,
29+
fold_quantize: bool = True,
30+
):
31+
super().__init__(
32+
quantizer,
33+
quantization_config,
34+
calibrate,
35+
calibration_samples,
36+
is_qat,
37+
set_global,
38+
)
39+
self.fold_quantize = fold_quantize
2040

2141
def run(
2242
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
@@ -31,11 +51,11 @@ def run(
3151

3252
if self.calibration_samples is not None:
3353
converted = self.quantizer.quantize_with_submodules(
34-
captured_graph, self.calibration_samples, bool(self.is_qat) # type: ignore
54+
captured_graph, self.calibration_samples, bool(self.is_qat), self.fold_quantize # type: ignore
3555
)
3656
else:
3757
converted = self.quantizer.quantize_with_submodules(
38-
captured_graph, [inputs], bool(self.is_qat)
58+
captured_graph, [inputs], bool(self.is_qat), self.fold_quantize
3959
)
4060

4161
DuplicateDynamicQuantChainPass()(converted)

backends/arm/test/tester/test_pipeline.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def __init__(
425425
tosa_version: Optional[str] = "1.0",
426426
tosa_extensions: Optional[List[str]] = None,
427427
epsilon: float = 2**-16,
428+
fold_quantize: bool = True,
428429
):
429430
if tosa_extensions is None:
430431
tosa_extensions = []
@@ -450,7 +451,9 @@ def __init__(
450451
)
451452
if symmetric_io_quantization:
452453
quantizer.set_io(quantization_config)
453-
quant_stage = Quantize(quantizer, quantization_config)
454+
quant_stage = Quantize(
455+
quantizer, quantization_config, fold_quantize=fold_quantize
456+
)
454457

455458
super().__init__(
456459
module,
@@ -622,6 +625,7 @@ def __init__(
622625
rtol: float = 1e-03,
623626
qtol: int = 1,
624627
epsilon: float = 2**-12,
628+
fold_quantize: bool = True,
625629
):
626630
super().__init__(
627631
module,
@@ -644,7 +648,9 @@ def __init__(
644648
)
645649
if symmetric_io_quantization:
646650
quantizer.set_io(quantization_config)
647-
quant_stage = Quantize(quantizer, quantization_config)
651+
quant_stage = Quantize(
652+
quantizer, quantization_config, fold_quantize=fold_quantize
653+
)
648654

649655
self.add_stage(self.tester.quantize, quant_stage, pos=0)
650656

@@ -720,6 +726,7 @@ def __init__(
720726
rtol: float = 1e-03,
721727
qtol: int = 1,
722728
epsilon: float = 2**-12,
729+
fold_quantize: bool = True,
723730
):
724731
compile_spec = common.get_u55_compile_spec(
725732
custom_path=custom_path,
@@ -740,6 +747,7 @@ def __init__(
740747
rtol=rtol,
741748
qtol=qtol,
742749
epsilon=epsilon,
750+
fold_quantize=fold_quantize,
743751
)
744752

745753

@@ -777,6 +785,7 @@ def __init__(
777785
rtol: float = 1e-03,
778786
qtol: int = 1,
779787
epsilon: float = 2**-12,
788+
fold_quantize: bool = True,
780789
):
781790
compile_spec = common.get_u85_compile_spec(
782791
custom_path=custom_path,
@@ -797,6 +806,7 @@ def __init__(
797806
rtol=rtol,
798807
qtol=qtol,
799808
epsilon=epsilon,
809+
fold_quantize=fold_quantize,
800810
)
801811

802812

@@ -982,6 +992,7 @@ def __init__(
982992
input_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None,
983993
output_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None,
984994
custom_path: Optional[str] = None,
995+
fold_quantize: bool = True,
985996
):
986997
tosa_spec = quantizer.tosa_spec
987998
compile_spec = common.get_tosa_compile_spec(tosa_spec, custom_path=custom_path)
@@ -994,7 +1005,7 @@ def __init__(
9941005
use_to_edge_transform_and_lower=True,
9951006
)
9961007
# TODO sort out typing
997-
quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config) # type: ignore[arg-type]
1008+
quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config, fold_quantize=fold_quantize) # type: ignore[arg-type]
9981009
self.add_stage(self.tester.quantize, quant_stage, pos=0)
9991010

10001011
# Delete most of the pipeline
@@ -1126,6 +1137,7 @@ def __init__(
11261137
tosa_version: str | None = None,
11271138
tosa_extensions: Optional[List[str]] = None,
11281139
tosa_spec: TosaSpecification | str | None = None,
1140+
fold_quantize: bool = True,
11291141
):
11301142
if tosa_spec is None:
11311143
if tosa_version is None:
@@ -1169,7 +1181,9 @@ def __init__(
11691181
)
11701182
if symmetric_io_quantization:
11711183
quantizer.set_io(quantization_config)
1172-
quant_stage = Quantize(quantizer, quantization_config)
1184+
quant_stage = Quantize(
1185+
quantizer, quantization_config, fold_quantize=fold_quantize
1186+
)
11731187

11741188
self.add_stage(self.tester.quantize, quant_stage, pos=0)
11751189

0 commit comments

Comments
 (0)