1- # Copyright 2025 Arm Limited and/or its affiliates.
1+ # Copyright 2025-2026 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
7979 _fp_profile_ops | _int_profile_ops
8080)
8181
82+ _preserve_in_tfa = {
83+ torch .ops .aten .remainder .Scalar ,
84+ exir_ops .edge .aten .remainder .Scalar ,
85+ }
86+
8287
8388class ReplaceScalarWithTensorByProfilePass (ArmPass , ReplaceScalarWithTensorArgPass ):
8489 """Profile-aware scalar-to-tensor replacement pass for binary ops."""
@@ -94,6 +99,9 @@ def __init__(self, tfa_pass=False, *args, **kwargs):
9499 super ().__init__ (tfa_pass , _all_ops , * args , ** kwargs )
95100
96101 def call_operator (self , op , args , kwargs , meta ):
102+ if self .is_tfa_pass and op in _preserve_in_tfa :
103+ return ExportPass .call_operator (self , op , args , kwargs , meta )
104+
97105 tosa_spec = get_context_spec ()
98106
99107 included_ops = {}
@@ -108,7 +116,7 @@ def call_operator(self, op, args, kwargs, meta):
108116 if op in TableOps .included_ops ():
109117 # Do not handle quantized table ops; forward unchanged.
110118 input_qparams = meta .data .get ("input_qparams" , {})
111- output_qparams = meta .data .get ("input_qparams " , {})
119+ output_qparams = meta .data .get ("output_qparams " , {})
112120 if len (input_qparams ) > 0 and len (output_qparams ) > 0 :
113121 # Do not handle; forward unchanged.
114122 return ExportPass .call_operator (self , op , args , kwargs , meta )
0 commit comments