Skip to content

Commit be9caeb

Browse files
committed
fix(generic): resolve higher-order callable returns by args
Flatten callable unions and intersections into callable candidates, instantiate them from the remaining argument types, and select the matching return via overload resolution instead of unioning member returns directly. Add focused pcall regressions for callable union and intersection values. Callable union handling still mirrors overload-style resolution and is not semantically complete yet; leave that for follow-up work.
1 parent 3e0bb16 commit be9caeb

3 files changed

Lines changed: 155 additions & 31 deletions

File tree

crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,42 @@ mod test {
141141
assert_eq!(ws.expr_ty("status"), ws.ty("boolean|string"));
142142
assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string"));
143143
}
144+
145+
#[test]
146+
fn test_return_overload_infers_callable_union_member() {
147+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
148+
149+
ws.def(
150+
r#"
151+
---@alias FnA fun(x: integer): integer
152+
---@alias FnB fun(x: string): integer
153+
154+
---@type FnA | FnB
155+
local run
156+
157+
_, a = pcall(run, 1)
158+
"#,
159+
);
160+
161+
assert_eq!(ws.expr_ty("a"), ws.ty("integer|string"));
162+
}
163+
164+
#[test]
165+
fn test_return_overload_selects_callable_intersection_member() {
166+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
167+
168+
ws.def(
169+
r#"
170+
---@alias FnA fun(x: integer): integer
171+
---@alias FnB fun(x: string): boolean
172+
173+
---@type FnA & FnB
174+
local run
175+
176+
_, a = pcall(run, 1)
177+
"#,
178+
);
179+
180+
assert_eq!(ws.expr_ty("a"), ws.ty("integer|string"));
181+
}
144182
}

crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::{
2222
},
2323
infer::InferFailReason,
2424
infer_expr,
25+
overload_resolve::resolve_signature_by_args,
2526
},
2627
};
2728
use crate::{
@@ -156,60 +157,145 @@ pub fn infer_callable_return_from_remaining_args(
156157
return Ok(None);
157158
}
158159

159-
let Some(callable) = as_doc_function_type(context.db, callable_type)? else {
160+
let mut overloads = Vec::new();
161+
collect_callable_overloads(context.db, callable_type, &mut overloads)?;
162+
if overloads.is_empty() {
160163
return Ok(None);
164+
}
165+
166+
let db = context.db;
167+
168+
// Fall back to the union of all candidate returns when args cannot narrow the callable.
169+
let fallback_return = || {
170+
LuaType::from_vec(
171+
overloads
172+
.iter()
173+
.map(|callable| {
174+
let mut callable_tpls = HashSet::new();
175+
callable.visit_type(&mut |ty| {
176+
if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty
177+
{
178+
callable_tpls.insert(generic_tpl.get_tpl_id());
179+
}
180+
});
181+
182+
let mut callable_substitutor = TypeSubstitutor::new();
183+
callable_substitutor.add_need_infer_tpls(callable_tpls);
184+
infer_return_from_callable(db, callable, &callable_substitutor)
185+
})
186+
.collect(),
187+
)
161188
};
162189

190+
let call_arg_types = match infer_expr_list_types(db, context.cache, arg_exprs, None, infer_expr)
191+
{
192+
Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::<Vec<_>>(),
193+
Err(_) => return Ok(Some(fallback_return())),
194+
};
195+
if call_arg_types.is_empty() {
196+
return Ok(None);
197+
}
198+
199+
let instantiated_overloads = overloads
200+
.iter()
201+
.map(|callable| instantiate_callable_from_arg_types(context, callable, &call_arg_types))
202+
.collect::<Vec<_>>();
203+
204+
Ok(Some(
205+
resolve_signature_by_args(db, &instantiated_overloads, &call_arg_types, false, None)
206+
.map(|callable| callable.get_ret().clone())
207+
.unwrap_or_else(|_| fallback_return()),
208+
))
209+
}
210+
211+
fn collect_callable_overloads(
212+
db: &DbIndex,
213+
callable_type: &LuaType,
214+
overloads: &mut Vec<Arc<LuaFunctionType>>,
215+
) -> Result<(), InferFailReason> {
216+
// TODO: Distinguish callable union vs intersection semantics here instead of flattening both
217+
// into one overload-candidate pool. Keep in sync with `infer_union` / `infer_intersection`.
218+
match callable_type {
219+
LuaType::DocFunction(doc_func) => overloads.push(doc_func.clone()),
220+
LuaType::Signature(sig_id) => {
221+
let signature = db
222+
.get_signature_index()
223+
.get(sig_id)
224+
.ok_or(InferFailReason::None)?;
225+
overloads.extend(signature.overloads.iter().cloned());
226+
overloads.push(signature.to_doc_func_type());
227+
}
228+
LuaType::Ref(type_id) | LuaType::Def(type_id) => {
229+
if let Some(origin_type) = db
230+
.get_type_index()
231+
.get_type_decl(type_id)
232+
.ok_or(InferFailReason::None)?
233+
.get_alias_origin(db, None)
234+
{
235+
collect_callable_overloads(db, &origin_type, overloads)?;
236+
}
237+
}
238+
LuaType::Generic(generic) => {
239+
let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec());
240+
if let Some(origin_type) = db
241+
.get_type_index()
242+
.get_type_decl(&generic.get_base_type_id())
243+
.ok_or(InferFailReason::None)?
244+
.get_alias_origin(db, Some(&substitutor))
245+
{
246+
collect_callable_overloads(db, &origin_type, overloads)?;
247+
}
248+
}
249+
LuaType::Union(union) => {
250+
for member in union.into_vec() {
251+
collect_callable_overloads(db, &member, overloads)?;
252+
}
253+
}
254+
LuaType::Intersection(intersection) => {
255+
for member in intersection.get_types() {
256+
collect_callable_overloads(db, member, overloads)?;
257+
}
258+
}
259+
_ => {}
260+
}
261+
262+
Ok(())
263+
}
264+
265+
fn instantiate_callable_from_arg_types(
266+
context: &mut TplContext,
267+
callable: &Arc<LuaFunctionType>,
268+
call_arg_types: &[LuaType],
269+
) -> Arc<LuaFunctionType> {
163270
let mut callable_tpls = HashSet::new();
164271
callable.visit_type(&mut |ty| {
165272
if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty {
166273
callable_tpls.insert(generic_tpl.get_tpl_id());
167274
}
168275
});
169276
if callable_tpls.is_empty() {
170-
return Ok(Some(callable.get_ret().clone()));
171-
}
172-
173-
let mut callable_substitutor = TypeSubstitutor::new();
174-
callable_substitutor.add_need_infer_tpls(callable_tpls);
175-
let fallback_return = infer_return_from_callable(context.db, &callable, &callable_substitutor);
176-
177-
let call_arg_types =
178-
match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) {
179-
Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::<Vec<_>>(),
180-
Err(_) => return Ok(Some(fallback_return)),
181-
};
182-
if call_arg_types.is_empty() {
183-
return Ok(None);
277+
return callable.clone();
184278
}
185279

186280
let callable_param_types = callable
187281
.get_params()
188282
.iter()
189283
.map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown))
190284
.collect::<Vec<_>>();
191-
285+
let mut callable_substitutor = TypeSubstitutor::new();
286+
callable_substitutor.add_need_infer_tpls(callable_tpls);
192287
let mut callable_context = TplContext {
193288
db: context.db,
194289
cache: context.cache,
195290
substitutor: &mut callable_substitutor,
196291
call_expr: context.call_expr.clone(),
197292
};
198-
if tpl_pattern_match_args(
199-
&mut callable_context,
200-
&callable_param_types,
201-
&call_arg_types,
202-
)
203-
.is_err()
204-
{
205-
return Ok(Some(fallback_return));
206-
}
293+
let _ = tpl_pattern_match_args(&mut callable_context, &callable_param_types, call_arg_types);
207294

208-
Ok(Some(infer_return_from_callable(
209-
context.db,
210-
&callable,
211-
&callable_substitutor,
212-
)))
295+
match instantiate_doc_function(context.db, callable, &callable_substitutor) {
296+
LuaType::DocFunction(func) => func,
297+
_ => callable.clone(),
298+
}
213299
}
214300

215301
fn infer_generic_types_from_call(

crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use super::{
1616
infer::{InferCallFuncResult, InferFailReason},
1717
};
1818

19-
use resolve_signature_by_args::resolve_signature_by_args;
19+
pub(crate) use resolve_signature_by_args::resolve_signature_by_args;
2020

2121
pub fn resolve_signature(
2222
db: &DbIndex,

0 commit comments

Comments
 (0)