66
77import torch
88from executorch .backends .arm ._passes .arm_pass_utils import get_first_fake_tensor
9- from executorch .backends .arm .test .common import parametrize
9+ from executorch .backends .arm .test .common import parametrize , xfail_type
1010from executorch .backends .cortex_m .test .tester import (
1111 CortexMTester ,
1212 McuTestCase ,
@@ -141,8 +141,8 @@ def forward(self, x, y):
141141class SharedQspecInputForkXConstant (torch .nn .Module ):
142142 """Shared qspec cluster with an input fork with left input as global constant."""
143143
144- ops_before_transforms = {}
145- ops_after_transforms = {}
144+ ops_before_transforms : dict [ str , int ] = {}
145+ ops_after_transforms : dict [ str , int ] = {}
146146 constant = torch .tensor (5.0 )
147147
148148 def forward (self , x ):
@@ -152,8 +152,8 @@ def forward(self, x):
152152class SharedQspecInputForkYConstant (torch .nn .Module ):
153153 """Shared qspec cluster with an input fork with left input as local constant."""
154154
155- ops_before_transforms = {}
156- ops_after_transforms = {}
155+ ops_before_transforms : dict [ str , int ] = {}
156+ ops_after_transforms : dict [ str , int ] = {}
157157
158158 def forward (self , x ):
159159 return torch .maximum (x , torch .tensor (5.0 ))
@@ -259,8 +259,8 @@ def forward(self, x):
259259
260260
261261class SharedQspecSurroundedQuantizedOpConstant (torch .nn .Module ):
262- ops_before_transforms = {}
263- ops_after_transforms = {}
262+ ops_before_transforms : dict [ str , int ] = {}
263+ ops_after_transforms : dict [ str , int ] = {}
264264
265265 def forward (self , x ):
266266 x1 = torch .clone (x )
@@ -270,16 +270,16 @@ def forward(self, x):
270270
271271
272272class SharedQspecSub (torch .nn .Module ):
273- ops_before_transforms = {}
274- ops_after_transforms = {}
273+ ops_before_transforms : dict [ str , int ] = {}
274+ ops_after_transforms : dict [ str , int ] = {}
275275
276276 def forward (self , x , y ):
277277 return torch .clone (x - y )
278278
279279
280280class SharedQspecCompetingQspecs (torch .nn .Module ):
281- ops_before_transforms = {}
282- ops_after_transforms = {}
281+ ops_before_transforms : dict [ str , int ] = {}
282+ ops_after_transforms : dict [ str , int ] = {}
283283
284284 def __init__ (self ):
285285 super ().__init__ ()
@@ -299,8 +299,8 @@ def forward(self, x):
299299
300300
301301class SharedQspecNoQspecs (torch .nn .Module ):
302- ops_before_transforms = {}
303- ops_after_transforms = {}
302+ ops_before_transforms : dict [ str , int ] = {}
303+ ops_after_transforms : dict [ str , int ] = {}
304304
305305 def forward (self , x ):
306306 z = torch .clone (x - x )
@@ -358,7 +358,7 @@ def forward(self, x):
358358 ),
359359}
360360
361- xfails = {
361+ xfails : dict [ str , xfail_type ] = {
362362 "surrounded_quantized_op_constant" : "Numerical error since the add is forced to have non-correct qparams." ,
363363}
364364
0 commit comments