Skip to content

Commit 7824373

Browse files
authored
Qualcomm AI Engine Direct - Backend awareness quantizer (pytorch#17665)
1 parent 2cd88e7 commit 7824373

53 files changed

Lines changed: 3235 additions & 1707 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/scripts/test_wheel_package_qnn.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ import argparse
1818
1919
import torch
2020
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
21+
from executorch.backends.qualcomm.serialization.qc_schema import (
22+
QnnExecuTorchBackendType,
23+
)
2124
from executorch.backends.qualcomm.utils.utils import (
2225
generate_htp_compiler_spec,
2326
generate_qnn_executorch_compiler_spec,
@@ -50,7 +53,7 @@ def main() -> None:
5053
example_inputs = model.get_example_inputs()
5154
5255
if args.quantization:
53-
quantizer = QnnQuantizer()
56+
quantizer = QnnQuantizer(backend=QnnExecuTorchBackendType.kHtpBackend, soc_model=get_soc_to_chipset_map()[args.soc])
5457
m = torch.export.export(model.eval(), example_inputs, strict=True).module()
5558
if args.quantization == "qat":
5659
m = prepare_qat_pt2e(m, quantizer)

backends/qualcomm/_passes/annotate_adaptive_avg_pool1d.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7-
from executorch.backends.qualcomm.builders.node_visitor import q_ops
7+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
88
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
9+
from executorch.exir.dialects._ops import ops as exir_ops
910
from executorch.exir.pass_base import ExportPass, PassResult
1011
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1112

@@ -25,17 +26,27 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
2526

2627
def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
2728
partitions = get_source_partitions(
28-
graph_module.graph, [torch.ops.aten.adaptive_avg_pool1d.default]
29+
graph_module.graph,
30+
[torch.ops.aten.adaptive_avg_pool1d.default, torch.adaptive_avg_pool1d],
2931
)
3032
for src_partitions in partitions.values():
3133
for src_partition in src_partitions:
34+
input_node = src_partition.input_nodes[0]
35+
if input_node.target in dq_ops:
36+
quant_attrs = get_quant_attrs(self.edge_program, input_node)
37+
for n in src_partition.nodes:
38+
if n.target == exir_ops.edge.aten.unsqueeze_copy.default:
39+
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
40+
3241
output = src_partition.output_nodes[0]
3342
if (list(output.users)[0].target) in q_ops:
3443
quant_attrs = get_quant_attrs(
3544
self.edge_program, list(output.users)[0]
3645
)
3746
for n in src_partition.nodes:
38-
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
47+
# For adaptive_avg_pool2d and squeeze
48+
if n.target != exir_ops.edge.aten.unsqueeze_copy.default:
49+
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
3950

4051
def call(self, graph_module: torch.fx.GraphModule):
4152
self._annotate_adaptive_avg_pool1d(graph_module)

backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,10 @@ std::string GetQnnSdkBuildId(std::string library_path) {
163163
if (err != QNN_SUCCESS || id == nullptr) {
164164
throw std::runtime_error("Failed to get QNN backend build ID");
165165
}
166+
// Copy id to avoid dangling pointer.
167+
std::string build_id(id);
166168
qnn_loaded_backend.Unload();
167-
return std::string(id);
169+
return build_id;
168170
}
169171

170172
py::array_t<char> StripProtocol(const py::bytes& preprocessed_binary) {

backends/qualcomm/builders/qnn_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ class OpElementWiseFloor:
167167
op_name: str = "ElementWiseFloor"
168168

169169

170+
@dataclass(init=False, frozen=True)
171+
class OpElementWiseFloorDiv:
172+
op_name: str = "ElementWiseFloorDiv"
173+
174+
170175
@dataclass(init=False, frozen=True)
171176
class OpElementWiseGreater:
172177
op_name: str = "ElementWiseGreater"

backends/qualcomm/quantizer/README.md

Lines changed: 169 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,61 @@ In order to conduct PTQ for floating point precision graph, observers are requir
4343
Qualcomm backend will consume the generated encodings and lower operators with fixed precision. This tutorial will guide you through the details of inserting observer and some useful utilities.
4444
4545
### Register Annotation via Operator Type
46-
Let's start with hooking callback for designated operator target:
46+
Let's start with hooking callback for designated operator target in `annotators/{backend}_rules.py`:
4747
```python
48-
def register_annotator(ops: List[OpOverload]):
49-
def decorator(annotator: Callable):
50-
for op in ops:
51-
OP_ANNOTATOR[op] = annotator
48+
def register_annotator(aten_ops: List[OpOverload], qnn_op: Optional[str]):
49+
def _wrap(op_def: GeneralOpDef):
50+
for aten_op in aten_ops:
51+
annotate_fn = op_def.annotate
52+
validate_fn = op_def.validate
53+
rule = OpQuantRule(
54+
aten_op=aten_op,
55+
qnn_op=qnn_op,
56+
annotate_fn=annotate_fn,
57+
validate_fn=validate_fn,
58+
)
59+
_RULES[rule.aten_op] = rule
60+
return rule
5261
53-
return decorator
62+
return _wrap
5463
```
55-
The `register_annotator` decorator provides a convenient way to attach your own annotation logic, which requires list of operator type as its input argument.<br/> For example, the torch activation functions have `copy`, `in-place` implementation with small difference appears in naming (an extra `_` postfix), which will map to the same [Core ATen](https://pytorch.org/docs/stable/torch.compiler_ir.html) operators after `to_edge`:
64+
The `register_annotator` decorator provides a convenient way to attach your own annotation and validation logic, which requires list of operator type as its input argument and a QNN operation name<br/> For example, the torch activation functions have `copy`, `in-place` implementation with small difference appears in naming (an extra `_` postfix), which will map to the same [Core ATen](https://pytorch.org/docs/stable/torch.compiler_ir.html) operators after `to_edge`:
5665
```python
57-
@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
66+
@register_annotator(
67+
[torch.ops.aten.relu.default, torch.ops.aten.relu_.default],
68+
QnnConstants.OpRelu.op_name,
69+
)
70+
```
71+
Where `torch.ops.aten.relu.default` / `torch.ops.aten.relu_.default` map to `copy` / `in-place` version and both will be converted into `torch.ops.aten.relu.default` ultimately.<br/>
72+
The `qnn_op` is used to specify quantization constraints for validation with the `BackendOpInfo` library. If an operator doesn’t directly correspond to a QNN operator, you can set its value to `None`, which will skip validation for that operator.
73+
```python
74+
@register_annotator([operator.getitem], qnn_op=None)
75+
```
76+
The `operator.getitem` function acts as a skip operator in the QNN backend and does not correspond to any QNN operator. Therefore, we assign `qnn_op=None`.<br/><br>
77+
78+
Create a base class `GeneralOpDef` that establishes the standard annotation and validation function behaviors.
79+
```python
80+
class GeneralOpDef:
81+
@staticmethod
82+
def annotate(node: Node, quantization_config: QuantizationConfig):
83+
annotate_single_in_single_out(node, quantization_config)
84+
85+
@staticmethod
86+
def validate(
87+
node: Node, constraints_list: List[NormalizedConstraints], soc_info: SocInfo
88+
) -> bool:
89+
valid = True
90+
# If there's no quantization annotation, we can't validate against constraints.
91+
if not _is_annotated([node]):
92+
return valid
93+
valid &= validate_against_backend_constraints(node, constraints_list)
94+
return valid
5895
```
59-
Where `torch.ops.aten.relu.default` / `torch.ops.aten.relu_.default` map to `copy` / `in-place` version and both will be converted into `torch.ops.aten.relu.default` ultimately.<br/><br>
6096

61-
The function signature is defined as follow with two arguments:
97+
The `annotate` function signature is defined as follow with two arguments:
6298
```python
63-
def annotate_xxx(node: Node, quantization_config: QuantizationConfig) -> None:
99+
@staticmethod
100+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
64101
```
65102
- __node__: graph node required to be observed
66103
- __quantization_config__: data structure describing quantization configurations for IO activation / weight / bias
@@ -112,75 +149,161 @@ Now, we can start to fill in the function body:
112149
```python
113150
@register_annotator(
114151
[
115-
torch.ops.aten.conv2d.default,
116152
torch.ops.aten.conv1d.default,
117-
torch.ops.aten.conv_transpose2d.input,
153+
torch.ops.aten.conv2d.default,
154+
torch.ops.aten.conv2d.padding,
155+
torch.ops.aten.convolution.default,
118156
]
119157
)
120-
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
158+
class Conv2d(GeneralOpDef):
121159
```
122160
There are multiple targets expected to meet our annotation criteria, it's encouraged to do so for code reuse.
123-
161+
- Define a annotation function interface
162+
```python
163+
@staticmethod
164+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
165+
```
124166
- Define map of input quantization spec
125167
```python
126-
if _is_annotated([node]):
127-
return
168+
if _is_annotated([node]):
169+
return
170+
171+
# block quantization
172+
if quantization_config.block_size is not None:
173+
quantization_config.weight.observer_or_fake_quant_ctr.p.keywords.update(
174+
{QCOM_BLOCK_SIZE: quantization_config.block_size}
175+
)
128176

129-
input_qspec_map = {}
177+
input_qspec_map = {}
130178

131-
# annotate input activation
132-
input_act = node.args[0]
133-
input_spec = quantization_config.input_activation
134-
input_qspec_map[input_act] = input_spec
179+
# annotate input activation
180+
input_act = node.args[0]
181+
input_spec = quantization_config.input_activation
182+
input_qspec_map[input_act] = input_spec
135183

136-
# annotate kernel
137-
kernel = node.args[1]
138-
input_qspec_map[kernel] = quantization_config.weight
184+
# annotate kernel
185+
kernel = node.args[1]
186+
input_qspec_map[kernel] = quantization_config.weight
139187

140-
# annotate bias
141-
if len(node.args) > 2:
142-
bias = node.args[2]
143-
input_qspec_map[bias] = quantization_config.bias(node)
188+
# annotate bias
189+
if len(node.args) > 2:
190+
bias = node.args[2]
191+
input_qspec_map[bias] = quantization_config.bias(node)
144192
```
145193
We first check if current graph node has been annotated. If not, an `input_qspec_map` dictionary required by PyTorch framework will be declared for providing mapping between graph nodes and their configurations.<br/>
146194
The parameters' order could be found [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Convolution.cpp) mentioned in [ATen Operator Definitions](#pytorch). Since bias node is optional, the implementation will invoke `_derived_bias_quant_spec` to calculate the per-channel bias encoding only if it exists.
147195

148196
- Update node's meta with framework compatible data structure
149197
```python
150-
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
151-
input_qspec_map=input_qspec_map,
152-
output_qspec=quantization_config.output_activation,
153-
_annotated=True,
154-
)
198+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
199+
input_qspec_map=input_qspec_map,
200+
output_qspec=quantization_config.output_activation,
201+
_annotated=True,
202+
)
155203
```
156204
After done processing `input_qspec_map`, it's required to have it in node's meta with special tag (`Q_ANNOTATION_KEY`) for `convert_pt2e` to properly insert observers.
157205

206+
- Define a validation function interface
207+
```python
208+
@staticmethod
209+
def validate(
210+
node: Node, constraints_list: List[NormalizedConstraints], soc_info: SocInfo
211+
) -> bool:
212+
```
213+
- Check if current node is annotated
214+
```python
215+
valid = True
216+
if not _is_annotated([node]):
217+
return valid
218+
```
219+
- Check if current node supports LPBQ
220+
```python
221+
weight_node = node.args[1]
222+
weight_qspec = node.meta[Q_ANNOTATION_KEY].input_qspec_map.get(
223+
weight_node, None
224+
)
225+
if (
226+
weight_qspec
227+
and weight_qspec.observer_or_fake_quant_ctr.p.keywords.get(
228+
QCOM_BLOCK_SIZE, None
229+
)
230+
is not None
231+
):
232+
valid &= validate_lpbq_support(soc_info)
233+
if not valid:
234+
logging.warning(
235+
f"LPBQ (16a4w block-wise quantization) requires V69 or newer for {node.name}"
236+
)
237+
```
238+
- Check if current node supports 16a16w quantization
239+
```python
240+
act_node = node.args[0]
241+
act_qspec = node.meta[Q_ANNOTATION_KEY].input_qspec_map.get(act_node, None)
242+
if (
243+
act_qspec
244+
and act_qspec.dtype == torch.int32
245+
and weight_qspec
246+
and weight_qspec.dtype == torch.int32
247+
):
248+
valid &= validate_16a16w_support(soc_info)
249+
if not valid:
250+
logging.warning(
251+
f"16-bit activations + 16-bit weights requires V73 or newer for {node.name}"
252+
)
253+
```
254+
- Validate the current node against the backend constraints obtained from `BackendOpInfo` based on the `qnn_op`.
255+
```python
256+
valid &= validate_against_backend_constraints(node, constraints_list)
257+
return valid
258+
```
259+
- Validate against the backend constraints by doing the following:
260+
- Make sure that `SharedQuantizationSpec` is applied for `is_math_invariant` operator, such as view operations.
261+
- Check the `scale` and `zero_point` values for specific operations. For example, sigmoid op requires `scale = 1 / (q_max - q_min + 1)` and `zero_point = 0`.
262+
- Ensure that the `qscheme` satisfies symmetric constraints.
263+
- Verify that the input and output `dtype` are supported.
158264
### Common Annotators
159265
For operators without extra parameters to be observed, there are pre-defined annotation method for convenience:
160266
- Single in single out operators, e.g.:
161267
```python
162-
@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
163-
def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
164-
annotate_single_in_single_out(node, quantization_config)
268+
@register_annotator(
269+
[torch.ops.aten.relu.default, torch.ops.aten.relu_.default],
270+
QnnConstants.OpRelu.op_name,
271+
)
272+
class Relu(GeneralOpDef):
273+
pass
165274
```
166275

167276
- Binary in single out operators, e.g.:
168277
```python
169-
@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor])
170-
def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
171-
annotate_binary(node, quantization_config)
278+
@register_annotator(
279+
[torch.ops.aten.add, torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor],
280+
QnnConstants.OpElementWiseAdd.op_name,
281+
)
282+
class Add(GeneralOpDef):
283+
@staticmethod
284+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
285+
annotate_binary(node, quantization_config)
172286
```
173287

174288
- Shared encodings between input / output, e.g.:<br/>
175289
```python
176290
# For operators without arithmetical function, IOs are expected to own the same encodings.
177-
@register_annotator([torch.ops.aten.transpose.int])
178-
def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None:
179-
annotate_in_out_obs_sharing_op(node, quantization_config)
180-
if not _is_annotated([node]):
181-
annotate_single_in_single_out(node, quantization_config)
291+
@register_annotator(
292+
[
293+
torch.ops.aten.permute.default,
294+
torch.ops.aten.swapaxes.default,
295+
torch.ops.aten.transpose.int,
296+
],
297+
QnnConstants.OpTranspose.op_name,
298+
)
299+
class Permute(GeneralOpDef):
300+
@staticmethod
301+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
302+
annotate_in_out_obs_sharing_op(node, quantization_config)
303+
if not _is_annotated([node]):
304+
annotate_single_in_share_out(node, quantization_config)
182305
```
183-
This annotator only works for single-in-single-out scenario with node's input that has already been annotated. If not, we still need to invoke `annotate_single_in_single_out` again (this path should be less likely).
306+
This annotator only works for single-in-single-out scenario with node's input that has already been annotated. If not, we still need to invoke `annotate_single_in_share_out` again (this path should be less likely).
184307

185308
## Issues
186309
Please refer to the [issue section](../README.md#issues) for more information.

0 commit comments

Comments
 (0)