Skip to content

Commit 824cbff

Browse files
authored
Prevent _safe_softmax decomposition in traceand rewire replaceSafeSoftmaxWithSoftmax
Differential Revision: D105367634 Pull Request resolved: pytorch#19619
1 parent 7dbd972 commit 824cbff

2 files changed

Lines changed: 4 additions & 1 deletion

File tree

backends/cadence/aot/compiler_funcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def trace(
3535
model.eval()
3636

3737
decomp_table = torch.export.default_decompositions()
38+
ops_to_keep = [*(ops_to_keep or []), torch.ops.aten._safe_softmax.default]
3839
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
3940
remove_decompositions(decomp_table, ops_to_keep)
4041
program = torch.export.export(model, inputs, strict=strict).run_decompositions(

backends/cadence/aot/passes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from executorch.backends.cadence.aot.replace_ops import (
3434
CadenceReplaceOpsInGraph,
3535
ReplaceMulTensorWithMulAndFullOpsPass,
36+
ReplaceSafeSoftmaxWithSoftmax,
3637
)
3738
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
3839
from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass
@@ -131,7 +132,8 @@ def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram:
131132
"""
132133

133134
aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [
134-
ReplaceMulTensorWithMulAndFullOpsPass()
135+
ReplaceSafeSoftmaxWithSoftmax(),
136+
ReplaceMulTensorWithMulAndFullOpsPass(),
135137
]
136138
# TODO(T230417247): Use PassResult which is currently ignored.
137139
PassManager(aten_passes)(expo_program.graph_module)

0 commit comments

Comments
 (0)