Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions backends/cadence/aot/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,6 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
)


# Return a chain of nodes with target in op_targets
def get_cascaded_ops(
nodes: List[torch.fx.Node],
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
op_targets: Iterable[Union[Callable[..., Any], str]],
) -> Sequence[torch.fx.Node]:
"""
'nodes' contains a chain of ops with target in 'op_targets'. Extend that chain
by one if nodes[-1] has a single user with its op target in 'op_targets'.
"""
cur = nodes[-1]
users = list(cur.users.keys())
# Assert that (a) there is only one user of cur, and (b) that user is
# one of the op in op_targets.
if len(users) == 1 and users[0].target in op_targets:
nodes.append(users[0])
# Recursively find the chain starting at the user
return get_cascaded_ops(nodes, op_targets)

return nodes


def get_transposed_dims(
node: torch.fx.Node, dims: Optional[List[int]] = None
) -> List[int]:
Expand Down
72 changes: 30 additions & 42 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch.fx
from executorch.backends.cadence.aot.compiler_utils import (
broadcastable,
get_cascaded_ops,
get_permuted_dims,
get_scale,
get_tensor_from_attr,
Expand Down Expand Up @@ -581,7 +580,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
"""
Fuse a cascaded chain of transpose and permute ops
Fuse a chain of transpose and permute ops into a single permute or a no-op.
"""

transpose_or_permute_target = {
Expand All @@ -594,55 +593,44 @@ def targets(self) -> list[EdgeOpOverload]:
return list(self.transpose_or_permute_target)

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Get the cascaded chain of transpose/permute ops starting at node
cascaded_transpose_or_permute_ops = get_cascaded_ops(
[node], self.transpose_or_permute_target
)
# The chain must have more than 1 node
if len(cascaded_transpose_or_permute_ops) == 1:
if "val" not in node.meta:
return False
rank = len(node.meta["val"].shape)

# Walk up the graph collecting consecutive permute/transpose ops.
chain = [node]
input_node = node.args[0]
while (
isinstance(input_node, torch.fx.Node)
and input_node.target in self.transpose_or_permute_target
):
chain.append(input_node)
input_node = input_node.args[0]

# Get shape from node metadata
val = node.meta.get("val")
if val is None:
if len(chain) < 2:
return False
out_shape = val.shape
out_dims = len(out_shape)

# This is the trivial dimension order
dims = list(range(out_dims))
# Compute the effect of the chain on dims
for tp in cascaded_transpose_or_permute_ops:
dims = (
get_transposed_dims(tp, dims)
if tp.target == exir_ops.edge.aten.transpose_copy.int
else get_permuted_dims(tp, dims)
)

graph = node.graph
# Compute combined effect of permutes (chain is populated in reverse order).
dims = list(range(rank))
for op in reversed(chain):
if op.target == exir_ops.edge.aten.transpose_copy.int:
dims = get_transposed_dims(op, dims)
else:
assert op.target == exir_ops.edge.aten.permute_copy.default
dims = get_permuted_dims(op, dims)

# In case the permute chain cancelled each other, the final dims will
# be the same as the initial order. In that case, the chain was nop.
# Otherwise create a new permute op that encompasses the effect of the
# chain.
if dims == list(range(out_dims)):
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
cast(torch.fx.Node, node.args[0])
)
# If combined effect is identity replace the node with input.
if dims == list(range(rank)):
node.replace_all_uses_with(cast(torch.fx.Node, input_node))
else:
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
new_permute = graph.call_function(
with node.graph.inserting_before(node):
new_permute = node.graph.call_function(
exir_ops.edge.aten.permute_copy.default,
args=(node.args[0], dims),
args=(input_node, dims),
)
new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)

# Now erase the chain (except the first node which will be handled by the interface)
for tp in reversed(cascaded_transpose_or_permute_ops[1:]):
graph.erase_node(tp)
new_permute.meta = node.meta
node.replace_all_uses_with(new_permute)

# Return True to indicate the first node in the chain should be removed
return True


Expand Down
45 changes: 45 additions & 0 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,51 @@ def test_permute_transpose_fusion(self) -> None:
graph_copy, converted_graph, (x_input,), "FuseCascadedTransposeOrPermuteOps"
)

def test_cascaded_permutes_multiple_users(self) -> None:
# Test case where intermediate permute has multiple users.
# x
# |
# permute1
# / \
# permute2 permute3
# | |
# out0 out1
builder = GraphBuilder()
x_input = torch.randn(1, 3, 8, 8, dtype=torch.float32)
x = builder.placeholder("x", x_input)
permute1 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(x, [0, 2, 3, 1]),
)
permute2 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(permute1, [0, 3, 1, 2]),
)
permute3 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(permute1, [0, 1, 3, 2]),
)
builder.output([permute2, permute3])
original_graph = builder.get_graph_module()
graph_copy = copy.deepcopy(original_graph)

p = FuseCascadedTransposeOrPermuteOps()
result = p.call(original_graph)
self.assertTrue(result.modified)
converted_graph = result.graph_module
converted_graph.graph.eliminate_dead_code()

# permute2 becomes a no-op, permute3 is fused with permute1.
self.assertEqual(
count_node(converted_graph, exir_ops.edge.aten.permute_copy.default), 1
)
validate_numerics(
graph_copy,
converted_graph,
(x_input,),
"FuseCascadedTransposeOrPermuteOps_multiple_users",
)

def test_view_fusion(self) -> None:
builder = GraphBuilder()
x_input = torch.randn(8, 5, 3, dtype=torch.float32)
Expand Down
Loading