Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 80 additions & 4 deletions pyrefly/lib/alt/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,40 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
)
}

/// Validate that a quantified ParamSpec forwarding pattern has the expected
/// `*P.args` / `**P.kwargs` as the last positional and keyword arguments.
/// Called when `var_to_rparams` returns `Err(q)` (the Var resolved to a
/// still-quantified ParamSpec `q`).
fn check_paramspec_forwarding(
&self,
q: &Quantified,
args: &[CallArg],
keywords: &[CallKeyword],
arguments_range: TextRange,
arg_errors: &ErrorCollector,
call_errors: &ErrorCollector,
context: Option<&dyn Fn() -> ErrorContext>,
) {
let args_ok = args
.last()
.is_some_and(|x| self.is_param_spec_args(x, q, arg_errors));
let kwargs_ok = keywords
.last()
.is_some_and(|x| self.is_param_spec_kwargs(x, q, arg_errors));
if !args_ok || !kwargs_ok {
self.error(
call_errors,
arguments_range,
ErrorInfo::new(ErrorKind::InvalidParamSpec, context),
format!(
"Expected *-unpacked {}.args and **-unpacked {}.kwargs",
q.name(),
q.name()
),
);
}
}

// See comment on `callable_infer` about `arg_errors` and `call_errors`.
fn callable_infer_params(
&self,
Expand Down Expand Up @@ -566,7 +600,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let mut variadic_name: Option<&Name> = None;
let mut variadic_collected: Vec<Type> = Vec::new();

let var_to_rparams = |var| {
// Resolve a deferred ParamSpec Var into additional parameters.
// Returns `Err(q)` when the Var resolved to a quantified ParamSpec `q`
// (forwarding case), meaning the caller should validate that the
// remaining args are `*P.args` / `**P.kwargs` and stop matching.
let var_to_rparams = |var| -> Result<Vec<&Param>, Quantified> {
let ps = match self.solver().force_var(var) {
Type::ParamSpecValue(ps) => ps,
Type::Any(_) | Type::Ellipsis => ParamList::everything(),
Expand All @@ -575,6 +613,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let ps = ParamList::everything();
ps.prepend_types(&prefix).into_owned()
}
// The ParamSpec Var resolved to another quantified ParamSpec (e.g.,
// one generic helper forwarding `*args: P.args, **kwargs: P.kwargs`
// to another). There are no concrete parameters to contribute;
// the caller must validate the forwarding pattern.
Type::Quantified(q) if q.is_param_spec() => return Err(*q),
t => {
error(
call_errors,
Expand All @@ -585,7 +628,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
ParamList::everything()
}
};
param_list_owner.push(ps).items().iter().rev().collect()
Ok(param_list_owner.push(ps).items().iter().rev().collect())
};
for arg in self_arg.iter().chain(args.iter()) {
let mut arg_pre = arg.pre_eval(self, arg_errors);
Expand All @@ -596,7 +639,24 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
// We've run out of parameters but haven't finished matching arguments. If we
// have a ParamSpec Var, it may contribute more parameters; force it and tack
// the result onto the parameter list.
rparams = var_to_rparams(var);
match var_to_rparams(var) {
Ok(new_rparams) => rparams = new_rparams,
Err(q) => {
// Quantified ParamSpec forwarding: validate that the
// remaining args/kwargs are the expected `*P.args` /
// `**P.kwargs` pair and stop matching.
self.check_paramspec_forwarding(
&q,
args,
keywords,
arguments_range,
arg_errors,
call_errors,
context,
);
return;
}
}
paramspec = None;
continue;
} else {
Expand Down Expand Up @@ -809,7 +869,23 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Some(p) => p,
None if let Some(var) = paramspec => {
// We've reached the end of our regular parameter list. Now check if we have more parameters from a ParamSpec.
rparams = var_to_rparams(var);
match var_to_rparams(var) {
Ok(new_rparams) => rparams = new_rparams,
Err(q) => {
// Quantified ParamSpec forwarding: validate both
// *P.args and **P.kwargs in the original arguments.
self.check_paramspec_forwarding(
&q,
args,
keywords,
arguments_range,
arg_errors,
call_errors,
context,
);
return;
}
}
paramspec = None;
continue;
}
Expand Down
93 changes: 93 additions & 0 deletions pyrefly/lib/test/paramspec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,99 @@ def wrap(f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"#,
);

testcase!(
test_paramspec_forwarding_between_generic_helpers,
r#"
from typing import Callable, ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")

def run_and_get_code(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: ...

def run_and_get_kernels(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
return run_and_get_code(fn, *args, **kwargs)
"#,
);

testcase!(
test_paramspec_forwarding_bad_args,
r#"
from typing import Callable, ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")

def inner(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: ...

def outer(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
return inner(fn, 1, 2) # E: Expected *-unpacked P.args and **-unpacked P.kwargs
"#,
);

testcase!(
test_paramspec_forwarding_with_concatenate,
r#"
from typing import Callable, Concatenate, ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")

def inner(fn: Callable[Concatenate[int, P], R], x: int, *args: P.args, **kwargs: P.kwargs) -> R: ...

def outer(fn: Callable[Concatenate[int, P], R], x: int, *args: P.args, **kwargs: P.kwargs) -> R:
return inner(fn, x, *args, **kwargs)
"#,
);

testcase!(
test_paramspec_forwarding_extra_concrete_arg,
r#"
from typing import Callable, ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")

def inner(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: ...

def outer(fn: Callable[P, R], extra: int, *args: P.args, **kwargs: P.kwargs) -> R:
return inner(fn, *args, **kwargs)
"#,
);

testcase!(
test_paramspec_forwarding_chained,
r#"
from typing import Callable, ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")

def level1(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: ...

def level2(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
return level1(fn, *args, **kwargs)

def level3(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
return level2(fn, *args, **kwargs)
"#,
);

testcase!(
test_paramspec_forwarding_kwargs_only,
r#"
from typing import Callable, Concatenate, ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")

def inner(fn: Callable[Concatenate[int, P], R], x: int, *args: P.args, **kwargs: P.kwargs) -> R: ...

def outer(fn: Callable[Concatenate[int, P], R], x: int, *args: P.args, **kwargs: P.kwargs) -> R:
return inner(fn, x, **kwargs) # E: Expected *-unpacked P.args and **-unpacked P.kwargs
"#,
);

testcase!(
test_param_spec_ellipsis,
r#"
Expand Down
Loading