99from abc import abstractmethod
1010from typing import Any , List , Optional , Set , Type
1111
12- import torch
1312from executorch .backends .arm .constants import DISALLOW_TFA_META_KEY
1413from executorch .backends .arm .tosa .mapping import TosaSpecialDtype
1514from executorch .exir .dialects ._ops import ops as exir_ops
1615from executorch .exir .pass_base import ExportPass , NodeMetadata , ProxyValue
1716from torch .fx import GraphModule
1817from torch .fx .passes .infra .pass_base import PassResult
19- from torch .utils import _pytree as pytree
2018
2119
2220class ArmPass (ExportPass ):
@@ -81,13 +79,6 @@ def get_name(pass_) -> str:
8179 )
8280
8381 def call_operator (self , op , args , kwargs , meta , updated : Optional [bool ] = False ):
84- if (
85- op == exir_ops .edge .aten .bmm .default
86- and isinstance (meta , NodeMetadata )
87- and len (meta .data .get ("input_qparams" , {})) > 0
88- ):
89- return self ._call_quantized_bmm_without_fake_kernel (op , args , kwargs , meta )
90-
9182 if not updated :
9283 return super ().call_operator (op , args , kwargs , meta )
9384
@@ -100,35 +91,6 @@ def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False)
10091 new_meta ["stack_trace" ] = f"{ old_stack_trace } \n { traceback .format_stack ()[- 2 ]} "
10192 return super ().call_operator (op , args , kwargs , NodeMetadata (new_meta ))
10293
103- def _call_quantized_bmm_without_fake_kernel (
104- self ,
105- op ,
106- args : tuple [ProxyValue , ...],
107- kwargs : dict [str , Any ],
108- meta : NodeMetadata ,
109- ) -> ProxyValue :
110- old_val = meta .data ["val" ]
111- output_qparams = meta .data .get ("output_qparams" , {})
112- dtype = (
113- next (iter (output_qparams .values ())).dtype
114- if len (output_qparams ) > 0
115- else old_val .dtype
116- )
117- res_data = torch .empty_like (old_val , dtype = dtype )
118-
119- args_proxy , kwargs_proxy = pytree .tree_map_only (
120- ProxyValue , lambda x : x .proxy , (args , kwargs )
121- )
122- res_proxy = self .tracer .create_proxy (
123- "call_function" ,
124- op ,
125- args_proxy ,
126- kwargs_proxy ,
127- )
128- res_proxy .node .meta .update (meta .data )
129- self .tracer .set_metadata (res_proxy .node , res_data )
130- return ProxyValue (res_data , res_proxy )
131-
13294 def call_submodule (
13395 self , graph_module : GraphModule , inputs : tuple [Any , ...]
13496 ) -> PassResult :
0 commit comments