11from __future__ import annotations
22from collections .abc import Callable
33from typing import ParamSpec , TypeVar , Any , Optional , Literal
4+ import types
45import inspect
56from dataclasses import dataclass
67import utils
@@ -68,12 +69,8 @@ def checkSignature(sig: inspect.Signature, info: location.CallableInfo, cfg: Che
6869 raise errors .WyppTypeError .defaultError (location .CallableName .mk (info ), name ,
6970 locDecl (), ty , p .default )
7071
71- _argCountCache = {}
7272
7373def mandatoryArgCount (sig : inspect .Signature ) -> int :
74- x = _argCountCache .get (sig )
75- if x is not None :
76- return x
7774 required_kinds = {
7875 inspect .Parameter .POSITIONAL_ONLY ,
7976 inspect .Parameter .POSITIONAL_OR_KEYWORD ,
@@ -83,29 +80,37 @@ def mandatoryArgCount(sig: inspect.Signature) -> int:
8380 for p in sig .parameters .values ():
8481 if p .kind in required_kinds and p .default is inspect ._empty :
8582 res = res + 1
86- _argCountCache [sig ] = res
8783 return res
8884
8985def checkArgument (p : inspect .Parameter , name : str , idx : Optional [int ], a : Any ,
9086 getLocArg : Callable [[], Optional [location .Loc ]],
9187 info : location .CallableInfo , cfg : CheckCfg ):
9288 t = p .annotation
9389 if not isEmptyAnnotation (t ):
90+ locDecl = lambda : info .getParamSourceLocation (name )
9491 if p .kind == inspect .Parameter .VAR_POSITIONAL :
92+ if type (t ) == str :
93+ t = eval (t )
9594 argT = None
9695 # For *args annotated as tuple[X, ...], extract the element type X
9796 origin = getattr (t , '__origin__' , None )
9897 if origin is tuple :
9998 args = getattr (t , '__args__' , None )
100- if args :
99+ if args and len (args ) == 2 and args [1 ] is Ellipsis :
100+ # tuple[X, ...] — homogeneous variadic
101101 argT = args [0 ]
102+ elif args :
103+ # tuple[X, Y, ...] — fixed-length tuple, no single element type to extract
104+ raise errors .WyppTypeError .invalidRestArgType (t , locDecl ())
102105 elif t is tuple :
103106 # bare `tuple` without type parameters, nothing to check
104107 return
105108 else :
106- raise ValueError ( f'Invalid type for rest argument: { t } ' )
109+ raise errors . WyppTypeError . invalidRestArgType ( t , locDecl () )
107110 t = argT
108111 elif p .kind == inspect .Parameter .VAR_KEYWORD :
112+ if type (t ) == str :
113+ t = eval (t )
109114 valT = None
110115 # For **kwargs annotated as dict[str, X], extract the value type X
111116 origin = getattr (t , '__origin__' , None )
@@ -116,9 +121,8 @@ def checkArgument(p: inspect.Parameter, name: str, idx: Optional[int], a: Any,
116121 elif t is dict :
117122 return
118123 else :
119- raise ValueError ( f'Invalid type for keyword argument: { t } ' )
124+ raise errors . WyppTypeError . invalidKwArgType ( t , info . getParamSourceLocation ( p . name ) )
120125 t = valT
121- locDecl = lambda : info .getParamSourceLocation (name )
122126 if not handleMatchesTyResult (matchesTy (a , t , cfg .ns ), locDecl ):
123127 cn = location .CallableName .mk (info )
124128 raise errors .WyppTypeError .argumentError (cn ,
@@ -188,7 +192,7 @@ def raiseArgMismatch():
188192 else :
189193 raise errors .WyppTypeError .unknownKeywordArgument (cn , getCallLoc (), name )
190194
191- def checkReturn (sig : inspect .Signature , getReturnFrame : Callable [[], Optional [inspect . FrameInfo ] ],
195+ def checkReturn (sig : inspect .Signature , returnFrameType : Optional [types . FrameType ],
192196 result : Any , info : location .CallableInfo , cfg : CheckCfg ) -> None :
193197 if info .isAsync :
194198 return
@@ -204,7 +208,7 @@ def checkReturn(sig: inspect.Signature, getReturnFrame: Callable[[], Optional[in
204208 locRes = location .Loc .fromFrameInfo (fi )
205209 returnLoc = None
206210 extraFrames = []
207- returnFrame = getReturnFrame ( )
211+ returnFrame = stacktrace . frameTypeToFrameInfo ( returnFrameType )
208212 if returnFrame :
209213 returnLoc = location .Loc .fromFrameInfo (returnFrame )
210214 extraFrames = [returnFrame ]
@@ -249,9 +253,9 @@ def wrapped(*args, **kwargs) -> T:
249253 utils ._call_with_frames_removed (checkArguments , sig , args , kwargs , info , checkCfg )
250254 returnTracker = stacktrace .getReturnTracker ()
251255 result = utils ._call_with_next_frame_removed (f , * args , ** kwargs )
252- getRetFrame = lambda : returnTracker .getReturnFrame (0 )
256+ ft = returnTracker .getReturnFrameType (0 )
253257 utils ._call_with_frames_removed (
254- checkReturn , sig , getRetFrame , result , info , checkCfg
258+ checkReturn , sig , ft , result , info , checkCfg
255259 )
256260 return result
257261 return wrapped
0 commit comments