Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 73 additions & 27 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,27 @@
import logging
import operator
from dataclasses import dataclass, replace
from typing import Callable, cast, List, Optional, Sequence
from typing import Any, Callable, cast, Iterable, List, NamedTuple, Optional, Sequence

import torch
import torch.fx
from executorch.backends.arm.common.debug import get_node_debug_info
from executorch.backends.arm.common.type import ensure_type
from executorch.backends.arm.quantizer import QuantizationConfig
from torch._subclasses import FakeTensor

from torch._subclasses import FakeTensor
from torch.fx import Node
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
MovingAveragePerChannelMinMaxObserver,
PartialWrapper,
)

from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
FixedQParamsQuantizationSpec,
QuantizationSpec,
QuantizationSpecBase,
SharedQuantizationSpec,
Expand Down Expand Up @@ -78,6 +80,11 @@ def __init__(self):
self.quant_output: Optional[_QuantProperty] = None


class _QParams(NamedTuple):
scale: float
zero_point: int


def _as_list(x):
"""Return ``x`` wrapped as a list if needed.

Expand Down Expand Up @@ -391,14 +398,16 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):


def _match_pattern(
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
node: Node,
pattern: Sequence[Iterable[object]],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> bool:
"""Check whether a node chain matches a pattern.

Verify a chain of ancestors -> node -> descendants matches the provided
``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
to pass the filter. Each pattern element is a list of disjunctive node
targets.
to pass the filter. Each pattern element is an iterable of disjunctive
node targets.

"""
if len(pattern) < 1:
Expand Down Expand Up @@ -432,16 +441,39 @@ def _match_pattern(
return left_condition and right_condition


_conv_ops = [
_conv_ops = {
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.conv3d.default,
torch.ops.aten.conv3d.padding,
]
}

_one_to_one = [
# For these ops, we use fixed qspecs, meaning that quantization params for
# these are statically defined. This is to prevent issues with out-of-range
# values when using dynamic quantization.
#
# Dict of operator to a dict of num_bits to qparams for that operator.
_fixed_input_qspec_ops: dict[Any, dict[int, _QParams]] = {
# acos has a valid range of [-1, 1]
torch.ops.aten.acos.default: {
8: _QParams((1.0 - (-1.0)) / (1 << 8), 0),
16: _QParams((1.0 - (-1.0)) / (1 << 16), 0),
},
# asin has a valid range of [-1, 1]
torch.ops.aten.asin.default: {
8: _QParams((1.0 - (-1.0)) / (1 << 8), 0),
16: _QParams((1.0 - (-1.0)) / (1 << 16), 0),
},
# atanh has a valid range of (-1, 1) (excluding -1 and 1).
torch.ops.aten.atanh.default: {
8: _QParams((0.999 - (-0.999)) / (1 << 8), 0),
16: _QParams((0.99999 - (-0.99999)) / (1 << 16), 0),
Comment on lines +458 to +472
},
}

_one_to_one = {
torch.ops.aten.abs.default,
torch.ops.aten.ceil.default,
torch.ops.aten.erf.default,
Expand Down Expand Up @@ -472,16 +504,13 @@ def _match_pattern(
torch.ops.aten.log1p.default,
torch.ops.aten.acosh.default,
torch.ops.aten.sign.default,
torch.ops.aten.asin.default,
torch.ops.aten.atanh.default,
torch.ops.aten.asinh.default,
torch.ops.aten.cosh.default,
torch.ops.aten.acos.default,
torch.ops.aten.cumsum.default,
torch.ops.aten.tan.default,
]
}

_one_to_one_shared_input_qspec = [
_one_to_one_shared_input_qspec = {
torch.ops.aten.squeeze.default,
torch.ops.aten.squeeze_copy.default,
torch.ops.aten.squeeze_copy.dim,
Expand Down Expand Up @@ -539,9 +568,9 @@ def _match_pattern(
# dequant -> neg -> requant chain.
torch.ops.aten.neg.default,
torch.ops.aten.detach_copy.default,
]
}

_one_to_one_shared_input_or_input_act_qspec = [
_one_to_one_shared_input_or_input_act_qspec = {
torch.ops.aten.alias.default,
torch.ops.aten.clone.default,
torch.ops.aten.hardtanh.default,
Expand All @@ -562,7 +591,7 @@ def _match_pattern(
torch.ops.aten.alias_copy.default,
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.pixel_unshuffle.default,
]
}


def get_quant_properties( # noqa: C901
Expand Down Expand Up @@ -615,13 +644,13 @@ def any_or_hardtanh_min_zero(n: Node):
node,
[
_conv_ops,
[torch.ops.aten.batch_norm.default],
[
{torch.ops.aten.batch_norm.default},
{
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
],
},
],
filter_fn=any_or_hardtanh_min_zero,
):
Expand All @@ -644,7 +673,7 @@ def any_or_hardtanh_min_zero(n: Node):
node,
[
_conv_ops,
[torch.ops.aten.batch_norm.default],
{torch.ops.aten.batch_norm.default},
],
):
if node.target in _conv_ops:
Expand All @@ -654,23 +683,21 @@ def any_or_hardtanh_min_zero(n: Node):
_QuantProperty(1, conv_weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
elif node.target in [
torch.ops.aten.batch_norm.default,
]:
elif node.target in {torch.ops.aten.batch_norm.default}:
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif not is_symmetric and _match_pattern(
node,
[
[
{
*_conv_ops,
torch.ops.aten.linear.default,
],
[
},
{
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
],
},
],
any_or_hardtanh_min_zero,
):
Expand Down Expand Up @@ -784,6 +811,25 @@ def any_or_hardtanh_min_zero(n: Node):
elif node.target in _one_to_one:
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in _fixed_input_qspec_ops:
num_bits = torch.iinfo(input_act_qspec.dtype).bits
qparams = _fixed_input_qspec_ops[node.target][num_bits]

quant_properties.quant_inputs = [
_QuantProperty(
0,
FixedQParamsQuantizationSpec(
dtype=input_act_qspec.dtype,
scale=qparams.scale,
zero_point=qparams.zero_point,
quant_min=input_act_qspec.quant_min,
quant_max=input_act_qspec.quant_max,
qscheme=input_act_qspec.qscheme,
is_dynamic=input_act_qspec.is_dynamic,
),
)
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in _one_to_one_shared_input_qspec:
input_node = ensure_type(Node, node.args[0])
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/ops/test_acos.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def test_acos_tosa_INT(test_data: Tuple):
(test_data(),),
aten_op=aten_op,
exir_op=exir_op,
frobenius_threshold=0.5, # MLETORCH-1709
)
pipeline.run()

Expand Down
2 changes: 0 additions & 2 deletions backends/arm/test/ops/test_asin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def test_asin_tosa_INT(test_data: Tuple):
(test_data(),),
aten_op=[],
exir_op=[],
frobenius_threshold=0.6, # MLETORCH-1709
cosine_threshold=0.8, # MLETORCH-1709
)
pipeline.run()

Expand Down
9 changes: 5 additions & 4 deletions backends/arm/test/ops/test_atanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
test_data_suite = {
"zeros": torch.zeros(1, 10, 10, 10),
"zeros_alt_shape": torch.zeros(1, 10, 3, 5),
"ones": torch.ones(10, 10, 10),
"rand": torch.rand(10, 10) - 0.5,
"rand_alt_shape": torch.rand(1, 10, 3, 5) - 0.5,
"ramp": torch.arange(-1, 1, 0.2),
"near_bounds": torch.tensor([-0.999999, -0.999, -0.9, 0.9, 0.999, 0.999999]),
"near_bounds": torch.tensor([-0.99, -0.9, 0.9, 0.99]),
"on_bounds": torch.tensor([-1.0, 1.0]),
}

Expand Down Expand Up @@ -58,9 +57,11 @@ def test_atanh_tosa_INT(test_data: Tuple):
(test_data,),
aten_op=aten_op,
exir_op=exir_op,
frobenius_threshold=None, # MLETORCH-1709
cosine_threshold=0.7,
)
if torch.any(test_data >= 1) or torch.any(test_data <= -1):
# The quantized model will saturate to max/min values while the
# original model will return inf/-inf, so comparison wont be valid here.
pipeline.pop_stage("run_method_and_compare_outputs.original_model")
pipeline.run()


Expand Down
32 changes: 16 additions & 16 deletions backends/arm/test/tester/analyze_output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,22 +337,6 @@ def dump_error_output(
logger.error(f"{atol=}, {rtol=}, {qtol=}")


if __name__ == "__main__":
"""This is expected to produce the example output of print_diff."""
torch.manual_seed(0)
a = torch.rand(3, 3, 2, 2) * 0.01
b = a.clone().detach()
logger.info(b)

# Errors in all channels in element (1,1)
a[1, :, 1, 1] = 0
# Errors in (0,0) and (1,1) in channel 1
a[2, 1, 1, 1] = 0
a[2, 1, 0, 0] = 0

print_error_diffs(a, b)


def compare_rel_frobenius_and_cosine_similarity(
reference_output: torch.Tensor,
test_output: torch.Tensor,
Expand Down Expand Up @@ -452,3 +436,19 @@ def compare_rel_frobenius_and_cosine_similarity(
f"Tensor-wise comparison failed: Cosine similarity {cosine_similarity} is below threshold {cosine_threshold}."
f" (Relative frobenius error: {relative_frobenius_error}, threshold {frobenius_threshold})."
)


if __name__ == "__main__":
"""This is expected to produce the example output of print_diff."""
torch.manual_seed(0)
a = torch.rand(3, 3, 2, 2) * 0.01
b = a.clone().detach()
logger.info(b)

# Errors in all channels in element (1,1)
a[1, :, 1, 1] = 0
# Errors in (0,0) and (1,1) in channel 1
a[2, 1, 1, 1] = 0
a[2, 1, 0, 0] = 0

print_error_diffs(a, b)
Loading