Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,59 @@ mod test {
assert_eq!(ws.humanize_type(success_result), "boolean");
assert_eq!(ws.humanize_type(failure_result), "string");
}

#[test]
fn test_pcall_with_overloaded_module_function() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

// When a module function has both @field annotation and actual implementation,
// pcall should correctly infer the return type from the overloaded callable.
ws.def(
r#"
---@class ShlexModule
---@field split fun(s: string): string[]
local Shlex = {}

---@param s string
---@return string[]
function Shlex.split(s)
return {}
end

ok, args = pcall(Shlex.split, "hello world")
"#,
);

let ok_ty = ws.expr_ty("ok");
let args_ty = ws.expr_ty("args");
assert_eq!(ok_ty, ws.ty("true|false"));
assert_eq!(args_ty, ws.ty("string[]|string"));
}

#[test]
fn test_pcall_with_overloaded_module_function_scalar() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

// Same scenario but with a scalar return type (integer).
ws.def(
r#"
---@class ModScalar
---@field compute fun(s: string): integer
local M = {}

---@param s string
---@return integer
function M.compute(s)
return 1
end

ok, result = pcall(M.compute, "hello")
"#,
);

let ok_ty = ws.expr_ty("ok");
let result_ty = ws.expr_ty("result");
assert_eq!(ok_ty, ws.ty("true|false"));
assert_eq!(result_ty, ws.ty("integer|string"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
},
};
use crate::{
LuaMemberOwner, LuaSemanticDeclId, SemanticDeclLevel, infer_node_semantic_decl,
LuaMemberOwner, LuaSemanticDeclId, LuaUnionType, SemanticDeclLevel, infer_node_semantic_decl,
tpl_pattern_match_args,
};

Expand Down Expand Up @@ -115,6 +115,30 @@ pub fn as_doc_function_type(
.ok_or(InferFailReason::None)?
.to_doc_func_type(),
),
LuaType::Union(union) => {
match union.as_ref() {
LuaUnionType::Basic(basic) => {
for member in basic.iter() {
if let Some(func) = as_doc_function_type(db, &member)? {
return Ok(Some(func));
}
}
}
LuaUnionType::Nullable(ty) => {
if let Some(func) = as_doc_function_type(db, ty)? {
return Ok(Some(func));
}
}
LuaUnionType::Multi(types) => {
for member in types {
if let Some(func) = as_doc_function_type(db, member)? {
return Ok(Some(func));
}
}
}
}
None
}
_ => None,
})
}
Expand Down
Loading