Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
#[cfg(test)]
mod test {
use crate::VirtualWorkspace;

#[test]
fn test_simple_infer_through_generic_func() {
// First verify that basic infer through generic function works
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias Identity<T> T extends infer P and P or never

---@generic T
---@param v T
---@return Identity<T>
function identity(v) end

Z = identity("hello")
"#,
);

let z_ty = ws.expr_ty("Z");
// Should be "string" if basic infer works through generic functions
assert_eq!(ws.humanize_type(z_ty), "string");
}

#[test]
fn test_object_literal_infer_basic() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias ExtractFoo<T> T extends { foo: infer F } and F or never

---@generic T
---@param v T
---@return ExtractFoo<T>
function extractFoo(v) end

---@type { foo: string, bar: number }
local myTable

A = extractFoo(myTable)
"#,
);

let a_ty = ws.expr_ty("A");
assert_eq!(ws.humanize_type(a_ty), "string");
}

#[test]
fn test_object_literal_infer_from_class() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias ExtractFoo<T> T extends { foo: infer F } and F or never

---@class MyClass
---@field foo number
---@field bar string

---@generic T
---@param v T
---@return ExtractFoo<T>
function extractFoo(v) end

---@type MyClass
local myObj

B = extractFoo(myObj)
"#,
);

let b_ty = ws.expr_ty("B");
assert_eq!(ws.humanize_type(b_ty), "number");
}

#[test]
fn test_object_literal_infer_constructor_params_multiple() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias ConstructorParams<T> T extends { constructor: fun(self: any, ...: infer P): any } and P or never

---@class Widget
---@field constructor fun(self: Widget, name: string, width: number): Widget

---@generic T
---@param v T
---@return ConstructorParams<T>
function getParams(v) end

---@type Widget
local widget

C = getParams(widget)
"#,
);

let c_ty = ws.expr_ty("C");
// Should be a tuple of the inferred parameters
assert_eq!(ws.humanize_type(c_ty), "(string,number)");
}

#[test]
fn test_object_literal_infer_constructor_params_single() {
// Test that single parameter constructors also return a tuple for consistent spreading
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias ConstructorParams<T> T extends { constructor: fun(self: any, ...: infer P): any } and P or never

---@class SimpleWidget
---@field constructor fun(self: SimpleWidget, name: string): SimpleWidget

---@generic T
---@param v T
---@return ConstructorParams<T>
function getParams(v) end

---@type SimpleWidget
local widget

D = getParams(widget)
"#,
);

let d_ty = ws.expr_ty("D");
// Single parameter should also be a tuple for consistent variadic spreading
// This ensures `fun(...: ConstructorParams<T>...)` works correctly
assert_eq!(ws.humanize_type(d_ty), "(string)");
}

#[test]
fn test_object_literal_infer_nested() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias ExtractNested<T> T extends { outer: { inner: infer I } } and I or never

---@generic T
---@param v T
---@return ExtractNested<T>
function extractNested(v) end

---@type { outer: { inner: boolean } }
local nested

D = extractNested(nested)
"#,
);

let d_ty = ws.expr_ty("D");
assert_eq!(ws.humanize_type(d_ty), "boolean");
}

#[test]
fn test_object_literal_infer_no_match() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias ExtractFoo<T> T extends { foo: infer F } and F or never

---@generic T
---@param v T
---@return ExtractFoo<T>
function extractFoo(v) end

---@type { bar: string }
local noFoo

E = extractFoo(noFoo)
"#,
);

let e_ty = ws.expr_ty("E");
assert_eq!(ws.humanize_type(e_ty), "never");
}

#[test]
fn test_object_literal_infer_function_field() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias ExtractCallback<T> T extends { callback: infer C } and C or never

---@generic T
---@param v T
---@return ExtractCallback<T>
function extractCallback(v) end

---@type { callback: fun(x: number): string }
local obj

F = extractCallback(obj)
"#,
);

let f_ty = ws.expr_ty("F");
assert_eq!(ws.humanize_type(f_ty), "fun(x: number) -> string");
}

#[test]
fn test_object_literal_infer_true_variadic_params() {
// Test that true variadic functions (fun(self, ...: T)) preserve variadic behavior
// This should NOT be wrapped in a tuple - it should stay as the base type
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias ExtractVariadic<T> T extends { handler: fun(self: any, ...: infer P): any } and P or never

---@class VariadicWidget
---@field handler fun(self: VariadicWidget, ...: string): VariadicWidget

---@generic T
---@param v T
---@return ExtractVariadic<T>
function getVariadicType(v) end

---@type VariadicWidget
local widget

V = getVariadicType(widget)
"#,
);

let v_ty = ws.expr_ty("V");
// True variadic should return the base type (not wrapped in tuple)
// so that variadic spreading continues to work as expected
assert_eq!(ws.humanize_type(v_ty), "string");
}
}
Comment on lines +1 to +231

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new test suite is quite comprehensive. However, the PR description mentions a known limitation with inline table literals (TableConst). It would be beneficial to add a test case that specifically covers this scenario, for example, extractFoo({ foo = "hello" }). This would document the current behavior and provide a baseline for future improvements. Here is a suggestion for such a test:

    #[test]
    fn test_object_literal_infer_from_inline_table() {
        let mut ws = VirtualWorkspace::new();
        ws.def(
            r#"
            ---@alias ExtractFoo<T> T extends { foo: infer F } and F or never

            ---@generic T
            ---@param v T
            ---@return ExtractFoo<T>
            function extractFoo(v) end

            G = extractFoo({ foo = "hello" })
            "#,
        );

        let g_ty = ws.expr_ty("G");
        assert_eq!(ws.humanize_type(g_ty), "string");
    }

1 change: 1 addition & 0 deletions crates/emmylua_code_analysis/src/compilation/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod diagnostic_disable_test;
mod export_test;
mod flow;
mod for_range_var_infer_test;
mod generic_infer_test;
mod generic_test;
mod infer_str_tpl_test;
mod inherit_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{

use super::type_substitutor::{SubstitutorValue, TypeSubstitutor};
use crate::TypeVisitTrait;
use crate::semantic::member::find_members_with_key;
pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate_func_generic};
pub use instantiate_special_generic::get_keyof_members;
pub use instantiate_special_generic::instantiate_alias_call;
Expand Down Expand Up @@ -660,7 +661,23 @@ fn collect_infer_assignments(
}
let ty = match rest_types.len() {
0 => LuaType::Never,
1 => rest_types[0].clone(),
1 => {
// If the source function is truly variadic (has `...` param),
// return the type as-is for proper variadic spreading.
// If the source has named params, wrap in a tuple so that
// spreading unpacks to named params (var0, var1, etc.)
if source_func.is_variadic() {
rest_types[0].clone()
} else {
LuaType::Tuple(
LuaTupleType::new(
rest_types,
LuaTupleStatus::InferResolve,
)
.into(),
)
}
}
_ => LuaType::Tuple(
LuaTupleType::new(rest_types, LuaTupleStatus::InferResolve)
.into(),
Expand Down Expand Up @@ -726,6 +743,18 @@ fn collect_infer_assignments(
false
}
}
LuaType::Object(pattern_object) => match source {
LuaType::Object(source_object) => {
collect_infer_from_object_to_object(db, source_object, pattern_object, assignments)
}
LuaType::Ref(type_id) | LuaType::Def(type_id) => {
collect_infer_from_class_to_object(db, type_id, pattern_object, assignments)
}
LuaType::TableConst(table_id) => {
collect_infer_from_table_to_object(db, table_id, pattern_object, assignments)
}
_ => false,
},
_ => {
if contains_conditional_infer(pattern) {
false
Expand All @@ -736,6 +765,84 @@ fn collect_infer_assignments(
}
}

/// Match object literal to object pattern, extracting infer types from fields
fn collect_infer_from_object_to_object(
db: &DbIndex,
source_object: &LuaObjectType,
pattern_object: &LuaObjectType,
assignments: &mut HashMap<String, LuaType>,
) -> bool {
let source_fields = source_object.get_fields();
let pattern_fields = pattern_object.get_fields();

for (key, pattern_field_ty) in pattern_fields {
if let Some(source_field_ty) = source_fields.get(key) {
if !collect_infer_assignments(db, source_field_ty, pattern_field_ty, assignments) {
return false;
}
} else if contains_conditional_infer(pattern_field_ty) {
// Pattern field contains infer but source doesn't have the field
return false;
}
}

true
}

/// Match class/ref type to object pattern by looking up class members
fn collect_infer_from_class_to_object(
db: &DbIndex,
type_id: &LuaTypeDeclId,
pattern_object: &LuaObjectType,
assignments: &mut HashMap<String, LuaType>,
) -> bool {
let pattern_fields = pattern_object.get_fields();
let source_type = LuaType::Ref(type_id.clone());

for (key, pattern_field_ty) in pattern_fields {
if let Some(member_infos) = find_members_with_key(db, &source_type, key.clone(), false) {
if let Some(member_info) = member_infos.first() {
if !collect_infer_assignments(db, &member_info.typ, pattern_field_ty, assignments) {
return false;
}
} else if contains_conditional_infer(pattern_field_ty) {
return false;
}
} else if contains_conditional_infer(pattern_field_ty) {
return false;
}
}

true
}

/// Match table constant to object pattern by looking up table members
fn collect_infer_from_table_to_object(
db: &DbIndex,
table_id: &crate::InFiled<rowan::TextRange>,
pattern_object: &LuaObjectType,
assignments: &mut HashMap<String, LuaType>,
) -> bool {
let pattern_fields = pattern_object.get_fields();
let source_type = LuaType::TableConst(table_id.clone());

for (key, pattern_field_ty) in pattern_fields {
if let Some(member_infos) = find_members_with_key(db, &source_type, key.clone(), false) {
if let Some(member_info) = member_infos.first() {
if !collect_infer_assignments(db, &member_info.typ, pattern_field_ty, assignments) {
return false;
}
} else if contains_conditional_infer(pattern_field_ty) {
return false;
}
} else if contains_conditional_infer(pattern_field_ty) {
return false;
}
}

true
}
Comment on lines +792 to +844

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The functions collect_infer_from_class_to_object and collect_infer_from_table_to_object are almost identical, leading to code duplication. Their logic can be extracted into a single helper function that takes source_type as a parameter. This would improve maintainability.

Additionally, the member lookup logic within both functions can be simplified by chaining .as_deref() and .and_then() to avoid nested if let statements and a redundant check for contains_conditional_infer.

While I can't suggest adding a new helper function directly, I've refactored both functions to be more concise. You could then easily extract the common logic.

/// Match class/ref type to object pattern by looking up class members
fn collect_infer_from_class_to_object(
    db: &DbIndex,
    type_id: &LuaTypeDeclId,
    pattern_object: &LuaObjectType,
    assignments: &mut HashMap<String, LuaType>,
) -> bool {
    let pattern_fields = pattern_object.get_fields();
    let source_type = LuaType::Ref(type_id.clone());

    for (key, pattern_field_ty) in pattern_fields {
        if let Some(member_info) = find_members_with_key(db, &source_type, key.clone(), false)
            .as_deref()
            .and_then(|infos| infos.first())
        {
            if !collect_infer_assignments(db, &member_info.typ, pattern_field_ty, assignments) {
                return false;
            }
        } else if contains_conditional_infer(pattern_field_ty) {
            return false;
        }
    }

    true
}

/// Match table constant to object pattern by looking up table members
fn collect_infer_from_table_to_object(
    db: &DbIndex,
    table_id: &crate::InFiled<rowan::TextRange>,
    pattern_object: &LuaObjectType,
    assignments: &mut HashMap<String, LuaType>,
) -> bool {
    let pattern_fields = pattern_object.get_fields();
    let source_type = LuaType::TableConst(table_id.clone());

    for (key, pattern_field_ty) in pattern_fields {
        if let Some(member_info) = find_members_with_key(db, &source_type, key.clone(), false)
            .as_deref()
            .and_then(|infos| infos.first())
        {
            if !collect_infer_assignments(db, &member_info.typ, pattern_field_ty, assignments) {
                return false;
            }
        } else if contains_conditional_infer(pattern_field_ty) {
            return false;
        }
    }

    true
}


fn strict_type_match(db: &DbIndex, source: &LuaType, pattern: &LuaType) -> bool {
if source == pattern {
return true;
Expand Down