Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions backends/arm/test/ops/test_as_strided_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _make_case(
delegated_cases = {
"reshape_2d": lambda: _make_case((4, 6), (3, 8)),
"flatten": lambda: _make_case((2, 3, 4), (6, 4)),
"expand_rank": lambda: _make_case((2, 3, 4), (2, 3, 4)),
"expand_rank": lambda: _make_case((2, 3, 4), (1, 2, 3, 4)),
}

unsupported_cases = {
Expand All @@ -67,11 +67,12 @@ def _make_case(
contiguous_strides((4, 4)),
4,
),
"noop": lambda: _make_case((2, 3, 4), (2, 3, 4)), # Single noop is not delegated
}


@common.parametrize("test_data", delegated_cases)
def test_as_strided_copy_tosa_FP(test_data):
def test_as_strided_tosa_FP(test_data):
tensor, size, stride = test_data()
module = AsStridedCopyModule(size, stride)
pipeline = TosaPipelineFP[input_t](
Expand All @@ -83,7 +84,7 @@ def test_as_strided_copy_tosa_FP(test_data):


@common.parametrize("test_data", delegated_cases)
def test_as_strided_copy_tosa_INT(test_data):
def test_as_strided_tosa_INT(test_data):
tensor, size, stride = test_data()
module = AsStridedCopyModule(size, stride)
pipeline = TosaPipelineINT[input_t](
Expand All @@ -96,7 +97,7 @@ def test_as_strided_copy_tosa_INT(test_data):

@common.parametrize("test_data", delegated_cases)
@common.SkipIfNoModelConverter
def test_as_strided_copy_vgf_no_quant(test_data):
def test_as_strided_vgf_no_quant(test_data):
tensor, size, stride = test_data()
module = AsStridedCopyModule(size, stride)
pipeline = VgfPipeline[input_t](
Expand All @@ -111,7 +112,7 @@ def test_as_strided_copy_vgf_no_quant(test_data):

@common.parametrize("test_data", delegated_cases)
@common.SkipIfNoModelConverter
def test_as_strided_copy_vgf_quant(test_data):
def test_as_strided_vgf_quant(test_data):
tensor, size, stride = test_data()
module = AsStridedCopyModule(size, stride)
pipeline = VgfPipeline[input_t](
Expand All @@ -124,7 +125,7 @@ def test_as_strided_copy_vgf_quant(test_data):


@common.parametrize("test_data", unsupported_cases)
def test_as_strided_copy_not_delegated(test_data):
def test_as_strided_no_target_not_delegated(test_data):
tensor, size, stride, *rest = test_data()
storage_offset = rest[0] if rest else 0
module = AsStridedCopyModule(size, stride, storage_offset=storage_offset)
Expand Down
14 changes: 14 additions & 0 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ def _is_noop_detach_copy(node: torch.fx.Node) -> bool:
return node.target == exir_ops.edge.aten.detach_copy.default


def _is_noop_as_strided_copy(node: torch.fx.Node) -> bool:
if node.target != exir_ops.edge.aten.as_strided_copy.default:
return False
else:
input_tensor = get_first_fake_tensor(ensure_type(torch.fx.Node, node.args[0]))
output_tensor = get_first_fake_tensor(node)
return (
input_tensor.shape == output_tensor.shape
and input_tensor.stride() == output_tensor.stride()
and input_tensor.storage_offset() == output_tensor.storage_offset()
)


def _is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool:
if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
return False
Expand Down Expand Up @@ -263,6 +276,7 @@ def _tag_module( # noqa
or _is_noop_detach_copy(node)
or _is_noop_to_dim_order_copy(node)
or _is_view_copy(node)
or _is_noop_as_strided_copy(node)
or node.target in Q_OPS
or node.target in DQ_OPS
for node in partition.nodes
Expand Down
178 changes: 90 additions & 88 deletions backends/transforms/fuse_view_copy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,101 +14,103 @@
from executorch.exir.pass_base import ExportPass, PassResult


UNARY_ELEMENTWISE_OPS = [
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.alias_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.round.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.silu.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.sign.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.log.default,
]


def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]:
"""
Find chains of view_copy nodes and unary elementwise ops and set all
view_copy nodes to have the final shape. The views will then be removed
by the remove_noop_view_copy call.

Only merges view_copy nodes that are not used by any other nodes.
"""
ops = exir_ops.edge
view_op = ops.aten.view_copy.default
modified = False
for node in graph.nodes:
if node.op == "call_function" and node.target == view_op:
# Find a chain of unary elementwise ops and save all view_copy nodes
end_node = node
view_ops = [node]
while (
end_node.op == "call_function"
and end_node.target in UNARY_ELEMENTWISE_OPS
and len(end_node.users) == 1
and list(end_node.users)[0].target in UNARY_ELEMENTWISE_OPS
):
end_node = list(end_node.users)[0]
if end_node.target == view_op:
view_ops.append(end_node)

# Set all view_copy nodes to have the final shape
if len(view_ops) > 1:
final_shape = view_ops[-1].args[1]
for node in view_ops:
new_args = (node.args[0], final_shape)
node.args = new_args
modified = True

graph.eliminate_dead_code()
return graph, modified


def remove_noop_view_copy(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]:
"""
Remove view_copy nodes that are no-ops.
"""
ops = exir_ops.edge
view_op = ops.aten.view_copy.default
modified = False
for node in graph.nodes:
if node.op == "call_function" and node.target == view_op:
input_shape = list(node.args[0].meta["val"].shape)
target_shape = node.args[1]
if input_shape == target_shape:
node.replace_all_uses_with(node.args[0])
modified = True
graph.eliminate_dead_code()
return graph, modified


class FuseViewCopyTransform(ExportPass):
_passes_required_after: Set[Type[ExportPass]] = set()

VIEW_OP = exir_ops.edge.aten.view_copy.default

UNARY_ELEMENTWISE_OPS = [
exir_ops.edge.aten.alias_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.round.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.silu.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.sign.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.log.default,
]

def merge_view_copy_chains(
self, graph: torch.fx.Graph
) -> tuple[torch.fx.Graph, bool]:
"""
Find chains of view_copy nodes and unary elementwise ops and set all
view_copy nodes to have the final shape. The views will then be removed
by the remove_noop_view_copy call.

Only merges view_copy nodes that are not used by any other nodes.
"""
view_op = self.VIEW_OP
modified = False
ops = self.UNARY_ELEMENTWISE_OPS + [view_op]
for node in graph.nodes:
if node.op == "call_function" and node.target == view_op:
# Find a chain of unary elementwise ops and save all view_copy nodes
end_node = node
view_ops = [node]
while (
end_node.op == "call_function"
and end_node.target in ops
and len(end_node.users) == 1
and list(end_node.users)[0].target in ops
):
end_node = list(end_node.users)[0]
if end_node.target == view_op:
view_ops.append(end_node)

# Set all view_copy nodes to have the final shape
if len(view_ops) > 1:
final_shape = view_ops[-1].args[1]
for node in view_ops:
new_args = (node.args[0], final_shape)
node.args = new_args
modified = True
if modified:
graph.eliminate_dead_code()
return graph, modified

def remove_noop_view_copy(
self, graph: torch.fx.Graph
) -> tuple[torch.fx.Graph, bool]:
"""
Remove view_copy nodes that are no-ops.
"""
view_op = self.VIEW_OP
modified = False
for node in graph.nodes:
if node.op == "call_function" and node.target == view_op:
input_shape = list(node.args[0].meta["val"].shape)
target_shape = list(node.meta["val"].shape)
if input_shape == target_shape:
node.replace_all_uses_with(node.args[0])
modified = True
if modified:
graph.eliminate_dead_code()
return graph, modified

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.graph, modified = merge_view_copy_chains(graph_module.graph)
graph_module.graph, modified = self.merge_view_copy_chains(graph_module.graph)
if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

graph_module.graph, modified = remove_noop_view_copy(graph_module.graph)
graph_module.graph, modified = self.remove_noop_view_copy(graph_module.graph)
if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
Expand Down
Loading