File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -35,6 +35,7 @@ def trace(
3535 model .eval ()
3636
3737 decomp_table = torch .export .default_decompositions ()
38+ ops_to_keep = [* (ops_to_keep or []), torch .ops .aten ._safe_softmax .default ]
3839 # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
3940 remove_decompositions (decomp_table , ops_to_keep )
4041 program = torch .export .export (model , inputs , strict = strict ).run_decompositions (
Original file line number Diff line number Diff line change 3333from executorch .backends .cadence .aot .replace_ops import (
3434 CadenceReplaceOpsInGraph ,
3535 ReplaceMulTensorWithMulAndFullOpsPass ,
36+ ReplaceSafeSoftmaxWithSoftmax ,
3637)
3738from executorch .backends .cadence .aot .simplify_ops import CadenceSimplifyOpsInGraph
3839from executorch .backends .cadence .aot .type_dispatch import CompileTimeTypeDispatchPass
@@ -131,7 +132,8 @@ def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram:
131132 """
132133
133134 aten_passes : List [Callable [[torch .fx .GraphModule ], Optional [PassResult ]]] = [
134- ReplaceMulTensorWithMulAndFullOpsPass ()
135+ ReplaceSafeSoftmaxWithSoftmax (),
136+ ReplaceMulTensorWithMulAndFullOpsPass (),
135137 ]
136138 # TODO(T230417247): Use PassResult which is currently ignored.
137139 PassManager (aten_passes )(expo_program .graph_module )
You can’t perform that action at this time.
0 commit comments