-
Notifications
You must be signed in to change notification settings - Fork 56
feat(generics): add object literal pattern matching in conditional type infer #928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,231 @@ | ||
| #[cfg(test)] | ||
| mod test { | ||
| use crate::VirtualWorkspace; | ||
|
|
||
| #[test] | ||
| fn test_simple_infer_through_generic_func() { | ||
| // First verify that basic infer through generic function works | ||
| let mut ws = VirtualWorkspace::new(); | ||
| ws.def( | ||
| r#" | ||
| ---@alias Identity<T> T extends infer P and P or never | ||
|
|
||
| ---@generic T | ||
| ---@param v T | ||
| ---@return Identity<T> | ||
| function identity(v) end | ||
|
|
||
| Z = identity("hello") | ||
| "#, | ||
| ); | ||
|
|
||
| let z_ty = ws.expr_ty("Z"); | ||
| // Should be "string" if basic infer works through generic functions | ||
| assert_eq!(ws.humanize_type(z_ty), "string"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_object_literal_infer_basic() { | ||
| let mut ws = VirtualWorkspace::new(); | ||
| ws.def( | ||
| r#" | ||
| ---@alias ExtractFoo<T> T extends { foo: infer F } and F or never | ||
|
|
||
| ---@generic T | ||
| ---@param v T | ||
| ---@return ExtractFoo<T> | ||
| function extractFoo(v) end | ||
|
|
||
| ---@type { foo: string, bar: number } | ||
| local myTable | ||
|
|
||
| A = extractFoo(myTable) | ||
| "#, | ||
| ); | ||
|
|
||
| let a_ty = ws.expr_ty("A"); | ||
| assert_eq!(ws.humanize_type(a_ty), "string"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_object_literal_infer_from_class() { | ||
| let mut ws = VirtualWorkspace::new(); | ||
| ws.def( | ||
| r#" | ||
| ---@alias ExtractFoo<T> T extends { foo: infer F } and F or never | ||
|
|
||
| ---@class MyClass | ||
| ---@field foo number | ||
| ---@field bar string | ||
|
|
||
| ---@generic T | ||
| ---@param v T | ||
| ---@return ExtractFoo<T> | ||
| function extractFoo(v) end | ||
|
|
||
| ---@type MyClass | ||
| local myObj | ||
|
|
||
| B = extractFoo(myObj) | ||
| "#, | ||
| ); | ||
|
|
||
| let b_ty = ws.expr_ty("B"); | ||
| assert_eq!(ws.humanize_type(b_ty), "number"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_object_literal_infer_constructor_params_multiple() { | ||
| let mut ws = VirtualWorkspace::new(); | ||
| ws.def( | ||
| r#" | ||
| ---@alias ConstructorParams<T> T extends { constructor: fun(self: any, ...: infer P): any } and P or never | ||
|
|
||
| ---@class Widget | ||
| ---@field constructor fun(self: Widget, name: string, width: number): Widget | ||
|
|
||
| ---@generic T | ||
| ---@param v T | ||
| ---@return ConstructorParams<T> | ||
| function getParams(v) end | ||
|
|
||
| ---@type Widget | ||
| local widget | ||
|
|
||
| C = getParams(widget) | ||
| "#, | ||
| ); | ||
|
|
||
| let c_ty = ws.expr_ty("C"); | ||
| // Should be a tuple of the inferred parameters | ||
| assert_eq!(ws.humanize_type(c_ty), "(string,number)"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_object_literal_infer_constructor_params_single() { | ||
| // Test that single parameter constructors also return a tuple for consistent spreading | ||
| let mut ws = VirtualWorkspace::new(); | ||
| ws.def( | ||
| r#" | ||
| ---@alias ConstructorParams<T> T extends { constructor: fun(self: any, ...: infer P): any } and P or never | ||
|
|
||
| ---@class SimpleWidget | ||
| ---@field constructor fun(self: SimpleWidget, name: string): SimpleWidget | ||
|
|
||
| ---@generic T | ||
| ---@param v T | ||
| ---@return ConstructorParams<T> | ||
| function getParams(v) end | ||
|
|
||
| ---@type SimpleWidget | ||
| local widget | ||
|
|
||
| D = getParams(widget) | ||
| "#, | ||
| ); | ||
|
|
||
| let d_ty = ws.expr_ty("D"); | ||
| // Single parameter should also be a tuple for consistent variadic spreading | ||
| // This ensures `fun(...: ConstructorParams<T>...)` works correctly | ||
| assert_eq!(ws.humanize_type(d_ty), "(string)"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_object_literal_infer_nested() { | ||
| let mut ws = VirtualWorkspace::new(); | ||
| ws.def( | ||
| r#" | ||
| ---@alias ExtractNested<T> T extends { outer: { inner: infer I } } and I or never | ||
|
|
||
| ---@generic T | ||
| ---@param v T | ||
| ---@return ExtractNested<T> | ||
| function extractNested(v) end | ||
|
|
||
| ---@type { outer: { inner: boolean } } | ||
| local nested | ||
|
|
||
| D = extractNested(nested) | ||
| "#, | ||
| ); | ||
|
|
||
| let d_ty = ws.expr_ty("D"); | ||
| assert_eq!(ws.humanize_type(d_ty), "boolean"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_object_literal_infer_no_match() { | ||
| let mut ws = VirtualWorkspace::new(); | ||
| ws.def( | ||
| r#" | ||
| ---@alias ExtractFoo<T> T extends { foo: infer F } and F or never | ||
|
|
||
| ---@generic T | ||
| ---@param v T | ||
| ---@return ExtractFoo<T> | ||
| function extractFoo(v) end | ||
|
|
||
| ---@type { bar: string } | ||
| local noFoo | ||
|
|
||
| E = extractFoo(noFoo) | ||
| "#, | ||
| ); | ||
|
|
||
| let e_ty = ws.expr_ty("E"); | ||
| assert_eq!(ws.humanize_type(e_ty), "never"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_object_literal_infer_function_field() { | ||
| let mut ws = VirtualWorkspace::new(); | ||
| ws.def( | ||
| r#" | ||
| ---@alias ExtractCallback<T> T extends { callback: infer C } and C or never | ||
|
|
||
| ---@generic T | ||
| ---@param v T | ||
| ---@return ExtractCallback<T> | ||
| function extractCallback(v) end | ||
|
|
||
| ---@type { callback: fun(x: number): string } | ||
| local obj | ||
|
|
||
| F = extractCallback(obj) | ||
| "#, | ||
| ); | ||
|
|
||
| let f_ty = ws.expr_ty("F"); | ||
| assert_eq!(ws.humanize_type(f_ty), "fun(x: number) -> string"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_object_literal_infer_true_variadic_params() { | ||
| // Test that true variadic functions (fun(self, ...: T)) preserve variadic behavior | ||
| // This should NOT be wrapped in a tuple - it should stay as the base type | ||
| let mut ws = VirtualWorkspace::new(); | ||
| ws.def( | ||
| r#" | ||
| ---@alias ExtractVariadic<T> T extends { handler: fun(self: any, ...: infer P): any } and P or never | ||
|
|
||
| ---@class VariadicWidget | ||
| ---@field handler fun(self: VariadicWidget, ...: string): VariadicWidget | ||
|
|
||
| ---@generic T | ||
| ---@param v T | ||
| ---@return ExtractVariadic<T> | ||
| function getVariadicType(v) end | ||
|
|
||
| ---@type VariadicWidget | ||
| local widget | ||
|
|
||
| V = getVariadicType(widget) | ||
| "#, | ||
| ); | ||
|
|
||
| let v_ty = ws.expr_ty("V"); | ||
| // True variadic should return the base type (not wrapped in tuple) | ||
| // so that variadic spreading continues to work as expected | ||
| assert_eq!(ws.humanize_type(v_ty), "string"); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ use crate::{ | |
|
|
||
| use super::type_substitutor::{SubstitutorValue, TypeSubstitutor}; | ||
| use crate::TypeVisitTrait; | ||
| use crate::semantic::member::find_members_with_key; | ||
| pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate_func_generic}; | ||
| pub use instantiate_special_generic::get_keyof_members; | ||
| pub use instantiate_special_generic::instantiate_alias_call; | ||
|
|
@@ -660,7 +661,23 @@ fn collect_infer_assignments( | |
| } | ||
| let ty = match rest_types.len() { | ||
| 0 => LuaType::Never, | ||
| 1 => rest_types[0].clone(), | ||
| 1 => { | ||
| // If the source function is truly variadic (has `...` param), | ||
| // return the type as-is for proper variadic spreading. | ||
| // If the source has named params, wrap in a tuple so that | ||
| // spreading unpacks to named params (var0, var1, etc.) | ||
| if source_func.is_variadic() { | ||
| rest_types[0].clone() | ||
| } else { | ||
| LuaType::Tuple( | ||
| LuaTupleType::new( | ||
| rest_types, | ||
| LuaTupleStatus::InferResolve, | ||
| ) | ||
| .into(), | ||
| ) | ||
| } | ||
| } | ||
| _ => LuaType::Tuple( | ||
| LuaTupleType::new(rest_types, LuaTupleStatus::InferResolve) | ||
| .into(), | ||
|
|
@@ -726,6 +743,18 @@ fn collect_infer_assignments( | |
| false | ||
| } | ||
| } | ||
| LuaType::Object(pattern_object) => match source { | ||
| LuaType::Object(source_object) => { | ||
| collect_infer_from_object_to_object(db, source_object, pattern_object, assignments) | ||
| } | ||
| LuaType::Ref(type_id) | LuaType::Def(type_id) => { | ||
| collect_infer_from_class_to_object(db, type_id, pattern_object, assignments) | ||
| } | ||
| LuaType::TableConst(table_id) => { | ||
| collect_infer_from_table_to_object(db, table_id, pattern_object, assignments) | ||
| } | ||
| _ => false, | ||
| }, | ||
| _ => { | ||
| if contains_conditional_infer(pattern) { | ||
| false | ||
|
|
@@ -736,6 +765,84 @@ fn collect_infer_assignments( | |
| } | ||
| } | ||
|
|
||
| /// Match object literal to object pattern, extracting infer types from fields | ||
| fn collect_infer_from_object_to_object( | ||
| db: &DbIndex, | ||
| source_object: &LuaObjectType, | ||
| pattern_object: &LuaObjectType, | ||
| assignments: &mut HashMap<String, LuaType>, | ||
| ) -> bool { | ||
| let source_fields = source_object.get_fields(); | ||
| let pattern_fields = pattern_object.get_fields(); | ||
|
|
||
| for (key, pattern_field_ty) in pattern_fields { | ||
| if let Some(source_field_ty) = source_fields.get(key) { | ||
| if !collect_infer_assignments(db, source_field_ty, pattern_field_ty, assignments) { | ||
| return false; | ||
| } | ||
| } else if contains_conditional_infer(pattern_field_ty) { | ||
| // Pattern field contains infer but source doesn't have the field | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| true | ||
| } | ||
|
|
||
| /// Match class/ref type to object pattern by looking up class members | ||
| fn collect_infer_from_class_to_object( | ||
| db: &DbIndex, | ||
| type_id: &LuaTypeDeclId, | ||
| pattern_object: &LuaObjectType, | ||
| assignments: &mut HashMap<String, LuaType>, | ||
| ) -> bool { | ||
| let pattern_fields = pattern_object.get_fields(); | ||
| let source_type = LuaType::Ref(type_id.clone()); | ||
|
|
||
| for (key, pattern_field_ty) in pattern_fields { | ||
| if let Some(member_infos) = find_members_with_key(db, &source_type, key.clone(), false) { | ||
| if let Some(member_info) = member_infos.first() { | ||
| if !collect_infer_assignments(db, &member_info.typ, pattern_field_ty, assignments) { | ||
| return false; | ||
| } | ||
| } else if contains_conditional_infer(pattern_field_ty) { | ||
| return false; | ||
| } | ||
| } else if contains_conditional_infer(pattern_field_ty) { | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| true | ||
| } | ||
|
|
||
| /// Match table constant to object pattern by looking up table members | ||
| fn collect_infer_from_table_to_object( | ||
| db: &DbIndex, | ||
| table_id: &crate::InFiled<rowan::TextRange>, | ||
| pattern_object: &LuaObjectType, | ||
| assignments: &mut HashMap<String, LuaType>, | ||
| ) -> bool { | ||
| let pattern_fields = pattern_object.get_fields(); | ||
| let source_type = LuaType::TableConst(table_id.clone()); | ||
|
|
||
| for (key, pattern_field_ty) in pattern_fields { | ||
| if let Some(member_infos) = find_members_with_key(db, &source_type, key.clone(), false) { | ||
| if let Some(member_info) = member_infos.first() { | ||
| if !collect_infer_assignments(db, &member_info.typ, pattern_field_ty, assignments) { | ||
| return false; | ||
| } | ||
| } else if contains_conditional_infer(pattern_field_ty) { | ||
| return false; | ||
| } | ||
| } else if contains_conditional_infer(pattern_field_ty) { | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| true | ||
| } | ||
|
Comment on lines
+792
to
+844
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The functions Additionally, the member lookup logic within both functions can be simplified by chaining While I can't suggest adding a new helper function directly, I've refactored both functions to be more concise. You could then easily extract the common logic. /// Match class/ref type to object pattern by looking up class members
fn collect_infer_from_class_to_object(
db: &DbIndex,
type_id: &LuaTypeDeclId,
pattern_object: &LuaObjectType,
assignments: &mut HashMap<String, LuaType>,
) -> bool {
let pattern_fields = pattern_object.get_fields();
let source_type = LuaType::Ref(type_id.clone());
for (key, pattern_field_ty) in pattern_fields {
if let Some(member_info) = find_members_with_key(db, &source_type, key.clone(), false)
.as_deref()
.and_then(|infos| infos.first())
{
if !collect_infer_assignments(db, &member_info.typ, pattern_field_ty, assignments) {
return false;
}
} else if contains_conditional_infer(pattern_field_ty) {
return false;
}
}
true
}
/// Match table constant to object pattern by looking up table members
fn collect_infer_from_table_to_object(
db: &DbIndex,
table_id: &crate::InFiled<rowan::TextRange>,
pattern_object: &LuaObjectType,
assignments: &mut HashMap<String, LuaType>,
) -> bool {
let pattern_fields = pattern_object.get_fields();
let source_type = LuaType::TableConst(table_id.clone());
for (key, pattern_field_ty) in pattern_fields {
if let Some(member_info) = find_members_with_key(db, &source_type, key.clone(), false)
.as_deref()
.and_then(|infos| infos.first())
{
if !collect_infer_assignments(db, &member_info.typ, pattern_field_ty, assignments) {
return false;
}
} else if contains_conditional_infer(pattern_field_ty) {
return false;
}
}
true
} |
||
|
|
||
| fn strict_type_match(db: &DbIndex, source: &LuaType, pattern: &LuaType) -> bool { | ||
| if source == pattern { | ||
| return true; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new test suite is quite comprehensive. However, the PR description mentions a known limitation with inline table literals (
TableConst). It would be beneficial to add a test case that specifically covers this scenario, for example,extractFoo({ foo = "hello" }). This would document the current behavior and provide a baseline for future improvements. Here is a suggestion for such a test: