11from __future__ import annotations
22from collections .abc import Callable
33from typing import ParamSpec , TypeVar , Any , Optional , Literal
4+ import types
45import inspect
56from dataclasses import dataclass
67import utils
@@ -24,14 +25,14 @@ def isEmptySignature(sig: inspect.Signature) -> bool:
2425 return False
2526 return isEmptyAnnotation (sig .return_annotation )
2627
27- def handleMatchesTyResult (res : MatchesTyResult , tyLoc : Optional [location .Loc ]) -> bool :
28+ def handleMatchesTyResult (res : MatchesTyResult , getTyLoc : Callable [[], Optional [location .Loc ] ]) -> bool :
2829 match res :
2930 case MatchesTyFailure (exc , ty ):
3031 if isDebug ():
3132 debug (f'Exception occurred while calling matchesTy with type { ty } , re-raising' )
3233 raise exc
3334 else :
34- raise errors .WyppTypeError .invalidType (ty , tyLoc )
35+ raise errors .WyppTypeError .invalidType (ty , getTyLoc () )
3536 case b :
3637 return b
3738
@@ -63,9 +64,11 @@ def checkSignature(sig: inspect.Signature, info: location.CallableInfo, cfg: Che
6364 locDecl = info .getParamSourceLocation (name )
6465 raise errors .WyppTypeError .partialAnnotationError (location .CallableName .mk (info ), name , locDecl )
6566 if p .default is not inspect .Parameter .empty :
66- locDecl = info .getParamSourceLocation (name )
67+ locDecl = lambda : info .getParamSourceLocation (name )
6768 if not handleMatchesTyResult (matchesTy (p .default , ty , cfg .ns ), locDecl ):
68- raise errors .WyppTypeError .defaultError (location .CallableName .mk (info ), name , locDecl , ty , p .default )
69+ raise errors .WyppTypeError .defaultError (location .CallableName .mk (info ), name ,
70+ locDecl (), ty , p .default )
71+
6972
7073def mandatoryArgCount (sig : inspect .Signature ) -> int :
7174 required_kinds = {
@@ -80,24 +83,34 @@ def mandatoryArgCount(sig: inspect.Signature) -> int:
8083 return res
8184
8285def checkArgument (p : inspect .Parameter , name : str , idx : Optional [int ], a : Any ,
83- locArg : Optional [location .Loc ], info : location .CallableInfo , cfg : CheckCfg ):
86+ getLocArg : Callable [[], Optional [location .Loc ]],
87+ info : location .CallableInfo , cfg : CheckCfg ):
8488 t = p .annotation
8589 if not isEmptyAnnotation (t ):
90+ locDecl = lambda : info .getParamSourceLocation (name )
8691 if p .kind == inspect .Parameter .VAR_POSITIONAL :
92+ if type (t ) == str :
93+ t = eval (t )
8794 argT = None
8895 # For *args annotated as tuple[X, ...], extract the element type X
8996 origin = getattr (t , '__origin__' , None )
9097 if origin is tuple :
9198 args = getattr (t , '__args__' , None )
92- if args :
99+ if args and len (args ) == 2 and args [1 ] is Ellipsis :
100+ # tuple[X, ...] — homogeneous variadic
93101 argT = args [0 ]
102+ elif args :
103+ # tuple[X, Y, ...] — fixed-length tuple, no single element type to extract
104+ raise errors .WyppTypeError .invalidRestArgType (t , locDecl ())
94105 elif t is tuple :
95106 # bare `tuple` without type parameters, nothing to check
96107 return
97108 else :
98- raise ValueError ( f'Invalid type for rest argument: { t } ' )
109+ raise errors . WyppTypeError . invalidRestArgType ( t , locDecl () )
99110 t = argT
100111 elif p .kind == inspect .Parameter .VAR_KEYWORD :
112+ if type (t ) == str :
113+ t = eval (t )
101114 valT = None
102115 # For **kwargs annotated as dict[str, X], extract the value type X
103116 origin = getattr (t , '__origin__' , None )
@@ -108,31 +121,39 @@ def checkArgument(p: inspect.Parameter, name: str, idx: Optional[int], a: Any,
108121 elif t is dict :
109122 return
110123 else :
111- raise ValueError ( f'Invalid type for keyword argument: { t } ' )
124+ raise errors . WyppTypeError . invalidKwArgType ( t , info . getParamSourceLocation ( p . name ) )
112125 t = valT
113- locDecl = info .getParamSourceLocation (name )
114126 if not handleMatchesTyResult (matchesTy (a , t , cfg .ns ), locDecl ):
115127 cn = location .CallableName .mk (info )
116128 raise errors .WyppTypeError .argumentError (cn ,
117129 name ,
118130 idx ,
119- locDecl ,
131+ locDecl () ,
120132 t ,
121133 a ,
122- locArg )
134+ getLocArg ())
135+
123136def checkArguments (sig : inspect .Signature , args : tuple , kwargs : dict ,
124137 info : location .CallableInfo , cfg : CheckCfg ) -> None :
125- debug (f'Checking arguments when calling { info } ' )
138+ if isDebug ():
139+ debug (f'Checking arguments when calling { info } ' )
126140 paramNames = list (sig .parameters )
127141 mandatory = mandatoryArgCount (sig )
128142 kind = getKind (cfg , paramNames )
129143 offset = 1 if kind == 'method' else 0
130- fi = stacktrace .callerOutsideWypp ()
131- callLoc = None if not fi else location .Loc .fromFrameInfo (fi )
132144 cn = location .CallableName .mk (info )
145+ # stacktrace.callerOutsideWypp() is expensive, only access it lazily
146+ def getCallLoc () -> Optional [location .Loc ]:
147+ fi = stacktrace .callerOutsideWypp ()
148+ return None if not fi else location .Loc .fromFrameInfo (fi )
149+ def getLocArg (idxOrName : int | str ) -> Callable [[], Optional [location .Loc ]]:
150+ def f ():
151+ fi = stacktrace .callerOutsideWypp ()
152+ return None if fi is None else location .locationOfArgument (fi , i )
153+ return f
133154 def raiseArgMismatch ():
134155 raise errors .WyppTypeError .argCountMismatch (cn ,
135- callLoc ,
156+ getCallLoc () ,
136157 len (paramNames ) - offset ,
137158 mandatory - offset ,
138159 len (args ) - offset )
@@ -152,46 +173,46 @@ def raiseArgMismatch():
152173 raiseArgMismatch ()
153174 # Check positional args
154175 for i in range (len (args )):
155- locArg = None if fi is None else location .locationOfArgument (fi , i )
156176 if i < len (positionalNames ):
157177 name = positionalNames [i ]
158178 p = sig .parameters [name ]
159- checkArgument (p , name , i - offset , args [i ], locArg , info , cfg )
179+ checkArgument (p , name , i - offset , args [i ], getLocArg ( i ) , info , cfg )
160180 elif varPositionalParam is not None :
161- checkArgument (varPositionalParam , varPositionalParam .name , i - offset , args [i ], locArg , info , cfg )
181+ checkArgument (varPositionalParam , varPositionalParam .name , i - offset , args [i ], getLocArg ( i ) , info , cfg )
162182 else :
163183 raiseArgMismatch ()
164184 # Check keyword args
165185 for name in kwargs :
166- locArg = None if fi is None else location .locationOfArgument (fi , name )
167186 if name in sig .parameters and sig .parameters [name ].kind not in (
168187 inspect .Parameter .VAR_POSITIONAL , inspect .Parameter .VAR_KEYWORD
169188 ):
170- checkArgument (sig .parameters [name ], name , None , kwargs [name ], locArg , info , cfg )
189+ checkArgument (sig .parameters [name ], name , None , kwargs [name ], getLocArg ( name ) , info , cfg )
171190 elif varKeywordParam is not None :
172- checkArgument (varKeywordParam , name , None , kwargs [name ], locArg , info , cfg )
191+ checkArgument (varKeywordParam , name , None , kwargs [name ], getLocArg ( name ) , info , cfg )
173192 else :
174- raise errors .WyppTypeError .unknownKeywordArgument (cn , callLoc , name )
193+ raise errors .WyppTypeError .unknownKeywordArgument (cn , getCallLoc () , name )
175194
176- def checkReturn (sig : inspect .Signature , returnFrame : Optional [inspect . FrameInfo ],
195+ def checkReturn (sig : inspect .Signature , returnFrameType : Optional [types . FrameType ],
177196 result : Any , info : location .CallableInfo , cfg : CheckCfg ) -> None :
178197 if info .isAsync :
179198 return
180199 t = sig .return_annotation
181200 if isEmptyAnnotation (t ):
182201 t = None
183- debug (f'Checking return value when calling { info } , return type: { t } ' )
184- locDecl = info .getResultTypeLocation ()
202+ if isDebug ():
203+ debug (f'Checking return value when calling { info } , return type: { t } ' )
204+ locDecl = lambda : info .getResultTypeLocation ()
185205 if not handleMatchesTyResult (matchesTy (result , t , cfg .ns ), locDecl ):
186206 fi = stacktrace .callerOutsideWypp ()
187207 if fi is not None :
188208 locRes = location .Loc .fromFrameInfo (fi )
189209 returnLoc = None
190210 extraFrames = []
211+ returnFrame = stacktrace .frameTypeToFrameInfo (returnFrameType )
191212 if returnFrame :
192213 returnLoc = location .Loc .fromFrameInfo (returnFrame )
193214 extraFrames = [returnFrame ]
194- raise errors .WyppTypeError .resultError (location .CallableName .mk (info ), locDecl , t , returnLoc , result ,
215+ raise errors .WyppTypeError .resultError (location .CallableName .mk (info ), locDecl () , t , returnLoc , result ,
195216 locRes , extraFrames )
196217
197218
@@ -232,9 +253,9 @@ def wrapped(*args, **kwargs) -> T:
232253 utils ._call_with_frames_removed (checkArguments , sig , args , kwargs , info , checkCfg )
233254 returnTracker = stacktrace .getReturnTracker ()
234255 result = utils ._call_with_next_frame_removed (f , * args , ** kwargs )
235- retFrame = returnTracker .getReturnFrame (0 )
256+ ft = returnTracker .getReturnFrameType (0 )
236257 utils ._call_with_frames_removed (
237- checkReturn , sig , retFrame , result , info , checkCfg
258+ checkReturn , sig , ft , result , info , checkCfg
238259 )
239260 return result
240261 return wrapped
0 commit comments