99from abc import abstractmethod
1010from typing import Any , List , Optional , Set , Type
1111
12+ import torch
1213from executorch .backends .arm .constants import DISALLOW_TFA_META_KEY
1314from executorch .backends .arm .tosa .mapping import TosaSpecialDtype
1415from executorch .exir .dialects ._ops import ops as exir_ops
1516from executorch .exir .pass_base import ExportPass , NodeMetadata , ProxyValue
1617from torch .fx import GraphModule
1718from torch .fx .passes .infra .pass_base import PassResult
19+ from torch .utils import _pytree as pytree
1820
1921
2022class ArmPass (ExportPass ):
@@ -79,6 +81,13 @@ def get_name(pass_) -> str:
7981 )
8082
8183 def call_operator (self , op , args , kwargs , meta , updated : Optional [bool ] = False ):
84+ if (
85+ op == exir_ops .edge .aten .bmm .default
86+ and isinstance (meta , NodeMetadata )
87+ and len (meta .data .get ("input_qparams" , {})) > 0
88+ ):
89+ return self ._call_quantized_bmm_without_fake_kernel (op , args , kwargs , meta )
90+
8291 if not updated :
8392 return super ().call_operator (op , args , kwargs , meta )
8493
@@ -91,6 +100,35 @@ def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False)
91100 new_meta ["stack_trace" ] = f"{ old_stack_trace } \n { traceback .format_stack ()[- 2 ]} "
92101 return super ().call_operator (op , args , kwargs , NodeMetadata (new_meta ))
93102
103+ def _call_quantized_bmm_without_fake_kernel (
104+ self ,
105+ op ,
106+ args : tuple [ProxyValue , ...],
107+ kwargs : dict [str , Any ],
108+ meta : NodeMetadata ,
109+ ) -> ProxyValue :
110+ old_val = meta .data ["val" ]
111+ output_qparams = meta .data .get ("output_qparams" , {})
112+ dtype = (
113+ next (iter (output_qparams .values ())).dtype
114+ if len (output_qparams ) > 0
115+ else old_val .dtype
116+ )
117+ res_data = torch .empty_like (old_val , dtype = dtype )
118+
119+ args_proxy , kwargs_proxy = pytree .tree_map_only (
120+ ProxyValue , lambda x : x .proxy , (args , kwargs )
121+ )
122+ res_proxy = self .tracer .create_proxy (
123+ "call_function" ,
124+ op ,
125+ args_proxy ,
126+ kwargs_proxy ,
127+ )
128+ res_proxy .node .meta .update (meta .data )
129+ self .tracer .set_metadata (res_proxy .node , res_data )
130+ return ProxyValue (res_data , res_proxy )
131+
94132 def call_submodule (
95133 self , graph_module : GraphModule , inputs : tuple [Any , ...]
96134 ) -> PassResult :
0 commit comments