Skip to content

Commit 55affde

Browse files
committed
fix typechecking of rest and kw args
1 parent 0e617c9 commit 55affde

31 files changed

Lines changed: 257 additions & 8 deletions

python/code/wypp/typecheck.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,33 @@ def checkArgument(p: inspect.Parameter, name: str, idx: Optional[int], a: Any,
8383
locArg: Optional[location.Loc], info: location.CallableInfo, cfg: CheckCfg):
8484
t = p.annotation
8585
if not isEmptyAnnotation(t):
86+
if p.kind == inspect.Parameter.VAR_POSITIONAL:
87+
argT = None
88+
# For *args annotated as tuple[X, ...], extract the element type X
89+
origin = getattr(t, '__origin__', None)
90+
if origin is tuple:
91+
args = getattr(t, '__args__', None)
92+
if args:
93+
argT = args[0]
94+
elif t is tuple:
95+
# bare `tuple` without type parameters, nothing to check
96+
return
97+
else:
98+
raise ValueError(f'Invalid type for rest argument: {t}')
99+
t = argT
100+
elif p.kind == inspect.Parameter.VAR_KEYWORD:
101+
valT = None
102+
# For **kwargs annotated as dict[str, X], extract the value type X
103+
origin = getattr(t, '__origin__', None)
104+
if origin is dict:
105+
type_args = getattr(t, '__args__', None)
106+
if type_args and len(type_args) >= 2:
107+
valT = type_args[1]
108+
elif t is dict:
109+
return
110+
else:
111+
raise ValueError(f'Invalid type for keyword argument: {t}')
112+
t = valT
86113
locDecl = info.getParamSourceLocation(name)
87114
if not handleMatchesTyResult(matchesTy(a, t, cfg.ns), locDecl):
88115
cn = location.CallableName.mk(info)
@@ -109,20 +136,42 @@ def raiseArgMismatch():
109136
len(paramNames) - offset,
110137
mandatory - offset,
111138
len(args) - offset)
139+
# Classify parameters by kind
140+
varPositionalParam: Optional[inspect.Parameter] = None
141+
varKeywordParam: Optional[inspect.Parameter] = None
142+
positionalNames: list[str] = []
143+
for pName in paramNames:
144+
p = sig.parameters[pName]
145+
if p.kind == inspect.Parameter.VAR_POSITIONAL:
146+
varPositionalParam = p
147+
elif p.kind == inspect.Parameter.VAR_KEYWORD:
148+
varKeywordParam = p
149+
elif p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
150+
positionalNames.append(pName)
112151
if len(args) + len(kwargs) < mandatory:
113152
raiseArgMismatch()
153+
# Check positional args
114154
for i in range(len(args)):
115-
if i >= len(paramNames):
116-
raiseArgMismatch()
117-
name = paramNames[i]
118-
p = sig.parameters[name]
119155
locArg = None if fi is None else location.locationOfArgument(fi, i)
120-
checkArgument(p, name, i - offset, args[i], locArg, info, cfg)
156+
if i < len(positionalNames):
157+
name = positionalNames[i]
158+
p = sig.parameters[name]
159+
checkArgument(p, name, i - offset, args[i], locArg, info, cfg)
160+
elif varPositionalParam is not None:
161+
checkArgument(varPositionalParam, varPositionalParam.name, i - offset, args[i], locArg, info, cfg)
162+
else:
163+
raiseArgMismatch()
164+
# Check keyword args
121165
for name in kwargs:
122-
if name not in sig.parameters:
123-
raise errors.WyppTypeError.unknownKeywordArgument(cn, callLoc, name)
124166
locArg = None if fi is None else location.locationOfArgument(fi, name)
125-
checkArgument(sig.parameters[name], name, None, kwargs[name], locArg, info, cfg)
167+
if name in sig.parameters and sig.parameters[name].kind not in (
168+
inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD
169+
):
170+
checkArgument(sig.parameters[name], name, None, kwargs[name], locArg, info, cfg)
171+
elif varKeywordParam is not None:
172+
checkArgument(varKeywordParam, name, None, kwargs[name], locArg, info, cfg)
173+
else:
174+
raise errors.WyppTypeError.unknownKeywordArgument(cn, callLoc, name)
126175

127176
def checkReturn(sig: inspect.Signature, returnFrame: Optional[inspect.FrameInfo],
128177
result: Any, info: location.CallableInfo, cfg: CheckCfg) -> None:
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
Traceback (most recent call last):
2+
File "file-test-data/extras/args.py", line 9, in <module>
3+
f(1, 2, '3', 4)
4+
5+
WyppTypeError: '3'
6+
7+
Der Aufruf der Funktion `f` erwartet einen Wert vom Typ `int` als drittes Argument.
8+
Aber der übergebene Wert hat den Typ `str`.
9+
10+
## Datei file-test-data/extras/args.py
11+
## Fehlerhafter Aufruf in Zeile 9:
12+
13+
f(1, 2, '3', 4)
14+
15+
## Typ deklariert in Zeile 3:
16+
17+
def f(x: int, *rest: tuple[int,...]):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
x=1, rest=()
2+
x=1, rest=(2,)
3+
x=1, rest=(2, 3)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from wypp import *
2+
3+
def f(x: int, *rest: tuple[int,...]):
4+
print(f'x={x}, rest={rest}')
5+
6+
f(1)
7+
f(1, 2)
8+
f(1, 2, 3)
9+
f(1, 2, '3', 4)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
Traceback (most recent call last):
2+
File "file-test-data/extras/args2.py", line 9, in <module>
3+
f(1, *[2, '3', 4])
4+
5+
WyppTypeError: '3'
6+
7+
Der Aufruf der Funktion `f` erwartet einen Wert vom Typ `int` als drittes Argument.
8+
Aber der übergebene Wert hat den Typ `str`.
9+
10+
## Datei file-test-data/extras/args2.py
11+
## Fehlerhafter Aufruf in Zeile 9:
12+
13+
f(1, *[2, '3', 4])
14+
15+
## Typ deklariert in Zeile 3:
16+
17+
def f(x: int, *rest: tuple[int,...]):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
x=1, rest=()
2+
x=1, rest=(2,)
3+
x=1, rest=(2, 3)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from wypp import *
2+
3+
def f(x: int, *rest: tuple[int,...]):
4+
print(f'x={x}, rest={rest}')
5+
6+
f(1)
7+
f(1, 2)
8+
f(1, 2, 3)
9+
f(1, *[2, '3', 4])

python/file-test-data/extras/args2_ok.err

Whitespace-only changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
x=1, kw={}
2+
x=1, kw={'y': 2}
3+
x=1, kw={'y': 2, 'z': 3}
4+
x=1, kw={'y': 2, 'z': 3}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from wypp import *
2+
3+
def f(x: int, **kw: dict[str, int]):
4+
print(f'x={x}, kw={kw}')
5+
6+
f(1)
7+
f(1, y=2)
8+
f(1, y=2, z=3)
9+
f(1, **{'y': 2, 'z': 3})

0 commit comments

Comments
 (0)