Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from .decompose_int16_activation_conv_pass import ( # noqa
DecomposeConvWithInt16ActivationPass,
)
from .decompose_int32_clamp_pass import DecomposeInt32ClampPass # noqa
from .decompose_int_pow_pass import DecomposeIntPowPass # noqa
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
Expand All @@ -75,6 +74,9 @@
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
from .decompose_sum_pass import DecomposeSumPass # noqa
from .decompose_tosa_unsupported_clamp_pass import ( # noqa
DecomposeTOSAUnsupportedClampPass,
)
from .decompose_var_pass import DecomposeVarPass # noqa
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
DecomposeGluPass,
DecomposeGroupedConvPass,
DecomposeGroupNormPass,
DecomposeInt32ClampPass,
DecomposeIntPowPass,
DecomposeLayerNormPass,
DecomposeLeakyReLUPass,
Expand All @@ -78,6 +77,7 @@
DecomposeSoftmaxUnstablePass,
DecomposeSqrtPass,
DecomposeSumPass,
DecomposeTOSAUnsupportedClampPass,
DecomposeVarPass,
DecorateFp32toInt32CastingPass,
FoldAndAnnotateQParamsPass,
Expand Down Expand Up @@ -220,7 +220,7 @@ def _tosa_pipeline(
[
FuseQuantizedActivationPass(),
ConvertToClampPass(),
DecomposeInt32ClampPass(),
DecomposeTOSAUnsupportedClampPass(),
DecomposeGroupNormPass(),
DecomposeLayerNormPass(),
DecomposeVarPass(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
from executorch.exir.pass_base import ExportPass


class DecomposeInt32ClampPass(ArmPass):
"""Rewrite int32 clamp into min/max chain since TOSA lacks int32 clamp support."""
class DecomposeTOSAUnsupportedClampPass(ArmPass):
"""Rewrite TOSA unsupported clamp into min/max chain since TOSA lacks int32 clamp support
and only supports scalar min/max values."""

_passes_required_after: Set[Type[ExportPass]] = set()
_supported_ops = {
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.clamp.Tensor,
torch.ops.aten.clamp.default,
torch.ops.aten.clamp.Tensor,
}

def _ensure_tensor(
Expand All @@ -40,31 +43,53 @@ def _ensure_tensor(

def call_operator(self, op, args, kwargs, meta):
val = meta["val"]
if op not in self._supported_ops or val.dtype != torch.int32:

is_scalar_clamp = op in {
exir_ops.edge.aten.clamp.default,
torch.ops.aten.clamp.default,
}
is_tensor_clamp = op in {
exir_ops.edge.aten.clamp.Tensor,
torch.ops.aten.clamp.Tensor,
}

if op not in self._supported_ops:
return super().call_operator(op, args, kwargs, meta)

# Only rewrite scalar clamp for int32
if is_scalar_clamp and val.dtype != torch.int32:
return super().call_operator(op, args, kwargs, meta)

input_tensor = args[0]
min_arg = args[1] if len(args) > 1 else None
max_arg = args[2] if len(args) > 2 else None
dtype = val.dtype
rank = len(val.shape)
min_arg = args[1] if len(args) > 1 else None
max_arg = args[2] if len(args) > 2 else None

min_arg = self._ensure_tensor(min_arg, input_tensor, dtype, rank, meta)
max_arg = self._ensure_tensor(max_arg, input_tensor, dtype, rank, meta)
if is_scalar_clamp:
# Scalar min/max -> make them tensors for min/max ops
min_arg = self._ensure_tensor(min_arg, input_tensor, dtype, rank, meta)
max_arg = self._ensure_tensor(max_arg, input_tensor, dtype, rank, meta)
else:
# Tensor variant: arguments are already tensors; nothing extra to do
if not is_tensor_clamp:
raise RuntimeError(
f"DecomposeTOSAUnsupportedClampPass: unexpected op {op} in tensor clamp branch"
)

current = input_tensor
if max_arg is not None:
if min_arg is not None:
current = super().call_operator(
exir_ops.edge.aten.minimum.default,
(current, max_arg),
exir_ops.edge.aten.maximum.default,
(current, min_arg),
{},
meta,
updated=True,
)
if min_arg is not None:
if max_arg is not None:
current = super().call_operator(
exir_ops.edge.aten.maximum.default,
(current, min_arg),
exir_ops.edge.aten.minimum.default,
(current, max_arg),
{},
meta,
updated=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.clamp.Tensor,
exir_ops.edge.aten.cumsum.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.permute_copy.default,
Expand Down Expand Up @@ -138,6 +139,7 @@
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.clamp.Tensor,
exir_ops.edge.aten.cos.default,
exir_ops.edge.aten.cumsum.default,
exir_ops.edge.aten.bmm.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/scripts/parse_test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"pixel_shuffle.default",
"pixel_unshuffle.default",
"while_loop.default",
"clamp.Tensor",
]
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS

Expand Down
Loading
Loading