Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions crates/pyrefly_config/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,16 @@ pub struct ConfigOverrideArgs {
num_args = 0..=1
)]
strict_callable_subtyping: Option<bool>,
/// Whether to use spec-compliant overload evaluation semantics.
/// When false (the default), Pyrefly attempts to resolve ambiguous calls precisely.
/// When true, overload evaluation follows the typing spec exactly, falling back to `Any` more frequently.
#[arg(
long,
default_missing_value = "true",
require_equals = true,
num_args = 0..=1
)]
spec_compliant_overloads: Option<bool>,
}

impl ConfigOverrideArgs {
Expand Down Expand Up @@ -398,6 +408,9 @@ impl ConfigOverrideArgs {
if let Some(x) = &self.strict_callable_subtyping {
config.root.strict_callable_subtyping = Some(*x);
}
if let Some(x) = &self.spec_compliant_overloads {
config.root.spec_compliant_overloads = Some(*x);
}
let apply_error_settings = |error_config: &mut ErrorDisplayConfig| {
for error_kind in &self.error {
error_config.set_error_severity(*error_kind, Severity::Error);
Expand Down
10 changes: 10 additions & 0 deletions crates/pyrefly_config/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ pub struct ConfigBase {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub strict_callable_subtyping: Option<bool>,

/// Whether to use spec-compliant overload evaluation semantics.
/// When false (the default), Pyrefly attempts to resolve ambiguous calls precisely.
/// When true, overload evaluation follows the typing spec exactly, falling back to `Any` more frequently.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub spec_compliant_overloads: Option<bool>,

/// Any unknown config items
#[serde(default, flatten)]
pub(crate) extras: ExtraConfigs,
Expand Down Expand Up @@ -265,4 +271,8 @@ impl ConfigBase {
pub fn get_strict_callable_subtyping(base: &Self) -> Option<bool> {
base.strict_callable_subtyping
}

pub fn get_spec_compliant_overloads(base: &Self) -> Option<bool> {
base.spec_compliant_overloads
}
}
17 changes: 17 additions & 0 deletions crates/pyrefly_config/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,14 @@ impl ConfigFile {
self.root.strict_callable_subtyping.unwrap())
}

pub fn spec_compliant_overloads(&self, path: &Path) -> bool {
self.get_from_sub_configs(ConfigBase::get_spec_compliant_overloads, path)
.unwrap_or_else(||
// we can use unwrap here, because the value in the root config must
// be set in `ConfigFile::configure()`.
self.root.spec_compliant_overloads.unwrap())
}

pub fn enabled_ignores(&self, path: &Path) -> &SmallSet<Tool> {
self.get_from_sub_configs(ConfigBase::get_enabled_ignores, path)
.unwrap_or_else(||
Expand Down Expand Up @@ -1171,6 +1179,10 @@ impl ConfigFile {
self.root.strict_callable_subtyping = Some(false);
}

if self.root.spec_compliant_overloads.is_none() {
self.root.spec_compliant_overloads = Some(false);
}

let tools_from_permissive_ignores = match self.root.permissive_ignores {
Some(true) => Some(Tool::all()),
Some(false) => Some(Tool::default_enabled()),
Expand Down Expand Up @@ -1543,6 +1555,7 @@ mod tests {
enabled_ignores: None,
recursion_depth_limit: None,
recursion_overflow_handler: None,
spec_compliant_overloads: None,
},
source_db: Default::default(),
sub_configs: vec![SubConfig {
Expand All @@ -1567,6 +1580,7 @@ mod tests {
enabled_ignores: None,
recursion_depth_limit: None,
recursion_overflow_handler: None,
spec_compliant_overloads: None,
}
}],
typeshed_path: None,
Expand Down Expand Up @@ -1979,6 +1993,7 @@ output-format = "omit-errors"
enabled_ignores: None,
recursion_depth_limit: None,
recursion_overflow_handler: None,
spec_compliant_overloads: None,
},
sub_configs: vec![
SubConfig {
Expand Down Expand Up @@ -2292,6 +2307,7 @@ output-format = "omit-errors"
enabled_ignores: None,
recursion_depth_limit: None,
recursion_overflow_handler: None,
spec_compliant_overloads: None,
},
sub_configs: vec![],
..Default::default()
Expand Down Expand Up @@ -2330,6 +2346,7 @@ output-format = "omit-errors"
enabled_ignores: None,
recursion_depth_limit: None,
recursion_overflow_handler: None,
spec_compliant_overloads: None,
},
sub_configs: vec![],
..Default::default()
Expand Down
115 changes: 87 additions & 28 deletions pyrefly/lib/alt/overload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,41 +629,100 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let _ = matched_overloads.split_off(split_point);
}
}
// Step 5, part 2: are all remaining return types equivalent to one another?
// If not, the call is ambiguous.
let mut matched_overloads = matched_overloads.into_iter();
let first_overload = matched_overloads.next().unwrap();
if matched_overloads.any(|o| !self.is_consistent(&first_overload.res, &o.res)) {
return (
let selected_overload = if self.solver().spec_compliant_overloads {
self.disambiguate_overloads_spec_compliant(&matched_overloads)
} else {
self.disambiguate_overloads(&matched_overloads)
};
if let Some(idx) = selected_overload {
let overload = matched_overloads
.into_iter()
.nth(idx)
.expect("Could not find selected overload");
// Now that we've selected an overload, use the hint to contextually type the arguments.
let contextual_overload = self.call_overload(
&overload.func,
metadata,
self_obj,
args,
keywords,
arguments_range,
&self.error_collector(),
hint,
ctor_targs,
);
(
if contextual_overload.call_errors.is_empty() {
contextual_overload
} else {
overload
},
true,
)
} else {
// Ambiguous call, return Any. Arbitrarily use the first overload as the matched one.
let first_overload = matched_overloads
.into_iter()
.next()
.expect("Expected at least one overload");
(
CalledOverload {
res: self.heap.mk_any_implicit(),
..first_overload
},
true,
);
)
}
}
}

fn disambiguate_overloads_spec_compliant(
&self,
matched_overloads: &[CalledOverload],
) -> Option<usize> {
// Step 5, part 2: are all remaining return types equivalent to one another?
// If not, the call is ambiguous.
let mut matched_overloads = matched_overloads.iter();
let first_overload = matched_overloads
.next()
.expect("Expected at least one overload");
if matched_overloads.any(|o| !self.is_equivalent(&first_overload.res, &o.res)) {
return None;
}
// Step 6: if there are still multiple matches, pick the first one.
Some(0)
}

fn disambiguate_overloads(&self, matched_overloads: &[CalledOverload]) -> Option<usize> {
// When a call to an overloaded function may match multiple overloads, the spec says to
// return Any when the return types are not all equivalent.
// However, neither mypy nor pyright fully follows this part of the spec, and many
// third-party libraries have come to rely on mypy and pyright's behavior. So we do the
// following for ecosystem compatibility:
//
// Step 6 (non-spec-compliant): does there exist a return type such that all
// materializations of every other return type are assignable to it? If so, use this
// return type. Else, return Any.
//
// We check materializations rather than assignability so that we end up with the most
// "general" return type. E.g., if the candidates are `A[None]` and `A[Any]`, we want
// to select `A[Any]`.
//
// First, find a candidate return type.
let mut candidate = 0;
for (i, o) in matched_overloads.iter().enumerate().skip(1) {
if !self.is_subset_eq(&o.res.materialize(), &matched_overloads[candidate].res) {
candidate = i;
}
}
// We've already checked every return type after the candidate.
// Check every return type before the candidate.
for o in matched_overloads.iter().take(candidate) {
if !self.is_subset_eq(&o.res.materialize(), &matched_overloads[candidate].res) {
return None;
}
// Step 6: if there are still multiple matches, pick the first one.
// Now that we've selected an overload, use the hint to contextually type the arguments.
let contextual_overload = self.call_overload(
&first_overload.func,
metadata,
self_obj,
args,
keywords,
arguments_range,
&self.error_collector(),
hint,
ctor_targs,
);
(
if contextual_overload.call_errors.is_empty() {
contextual_overload
} else {
first_overload
},
true,
)
}
Some(candidate)
}

fn call_overload(
Expand Down
3 changes: 3 additions & 0 deletions pyrefly/lib/solver/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ pub struct Solver {
pub heap: TypeHeap,
pub tensor_shapes: bool,
pub strict_callable_subtyping: bool,
pub spec_compliant_overloads: bool,
}

impl Display for Solver {
Expand All @@ -364,6 +365,7 @@ impl Solver {
infer_with_first_use: bool,
tensor_shapes: bool,
strict_callable_subtyping: bool,
spec_compliant_overloads: bool,
) -> Self {
Self {
variables: Default::default(),
Expand All @@ -372,6 +374,7 @@ impl Solver {
heap: TypeHeap::new(),
tensor_shapes,
strict_callable_subtyping,
spec_compliant_overloads,
}
}

Expand Down
4 changes: 4 additions & 0 deletions pyrefly/lib/state/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,8 @@ impl<'a> Transaction<'a> {
tensor_shapes: config.tensor_shapes(module_data.handle.path().as_path()),
strict_callable_subtyping: config
.strict_callable_subtyping(module_data.handle.path().as_path()),
spec_compliant_overloads: config
.spec_compliant_overloads(module_data.handle.path().as_path()),
recursion_limit_config: config.recursion_limit_config(),
pysa_context,
};
Expand Down Expand Up @@ -2058,6 +2060,8 @@ impl<'a> Transaction<'a> {
tensor_shapes: config.tensor_shapes(m.handle.path().as_path()),
strict_callable_subtyping: config
.strict_callable_subtyping(m.handle.path().as_path()),
spec_compliant_overloads: config
.spec_compliant_overloads(m.handle.path().as_path()),
recursion_limit_config: config.recursion_limit_config(),
pysa_context: None,
};
Expand Down
2 changes: 2 additions & 0 deletions pyrefly/lib/state/steps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub struct Context<'a, Lookup> {
pub infer_with_first_use: bool,
pub tensor_shapes: bool,
pub strict_callable_subtyping: bool,
pub spec_compliant_overloads: bool,
pub recursion_limit_config: Option<RecursionLimitConfig>,
/// Pysa context for building PysaSolutions during the Solutions step.
pub pysa_context: Option<PysaContext<'a>>,
Expand Down Expand Up @@ -414,6 +415,7 @@ impl Step {
ctx.infer_with_first_use,
ctx.tensor_shapes,
ctx.strict_callable_subtyping,
ctx.spec_compliant_overloads,
);
let enable_index = ctx.require.keep_index();
let enable_trace = ctx.require.keep_answers_trace() || ctx.pysa_context.is_some();
Expand Down
14 changes: 14 additions & 0 deletions pyrefly/lib/test/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,17 @@ def f(k: str | None):
assert_type(v, int | None)
"#,
);

testcase!(
test_dict_get_return,
r#"
from typing import Any
def f(outcomes: list[Any]) -> dict[str, int]:
ret = {noun: int(count) for (count, noun) in outcomes}
to_plural = {
"warning": "warnings",
"error": "errors",
}
return {to_plural.get(k, k): v for k, v in ret.items()}
"#,
);
34 changes: 30 additions & 4 deletions pyrefly/lib/test/overload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,7 @@ def ndim(shape: tuple[int, ...]) -> int:
return len(shape)

def demo_gradual(s: tuple[Any, ...]):
assert_type(ndim(s), Any)
assert_type(ndim(s), int)

def demo_one(s: tuple[int]):
assert_type(ndim(s), Literal[1])
Expand Down Expand Up @@ -1604,8 +1604,34 @@ def f(x: None):
);

testcase!(
bug = "op(A[Any], A[Any]) should be treated as ambiguous",
test_ambiguous,
test_resolve_ambiguous_precise,
r#"
from typing import Any, overload, assert_type

class A[T]: # covariant
def get(self) -> T: ...

@overload
def op(l: A[None], r: A[None]) -> A[None]: ...
@overload
def op(l: A[None], r: A[Any]) -> A[None]: ...
@overload
def op(l: A[Any], r: A[None]) -> A[None]: ...
@overload
def op(l: A[Any], r: A[Any]) -> A[Any]: ...
def op(l, r) -> A[None | Any]: ...

def test(x: A[None], y: A[Any]) -> None:
assert_type(op(x, x), A[None])
assert_type(op(x, y), A[None])
assert_type(op(y, x), A[None])
assert_type(op(y, y), A[Any])
"#,
);

testcase!(
test_resolve_ambiguous_spec_compliant,
TestEnv::new().enable_spec_compliant_overloads(),
r#"
from typing import Any, overload, assert_type

Expand All @@ -1626,7 +1652,7 @@ def test(x: A[None], y: A[Any]) -> None:
assert_type(op(x, x), A[None])
assert_type(op(x, y), A[None])
assert_type(op(y, x), A[None])
assert_type(op(y, y), Any) # E: assert_type(A[None], Any)
assert_type(op(y, y), Any)
"#,
);

Expand Down
Loading
Loading