@@ -22,6 +22,7 @@ use crate::{
2222 } ,
2323 infer:: InferFailReason ,
2424 infer_expr,
25+ overload_resolve:: resolve_signature_by_args,
2526 } ,
2627} ;
2728use crate :: {
@@ -156,60 +157,145 @@ pub fn infer_callable_return_from_remaining_args(
156157 return Ok ( None ) ;
157158 }
158159
159- let Some ( callable) = as_doc_function_type ( context. db , callable_type) ? else {
160+ let mut overloads = Vec :: new ( ) ;
161+ collect_callable_overloads ( context. db , callable_type, & mut overloads) ?;
162+ if overloads. is_empty ( ) {
160163 return Ok ( None ) ;
164+ }
165+
166+ let db = context. db ;
167+
168+ // Fall back to the union of all candidate returns when args cannot narrow the callable.
169+ let fallback_return = || {
170+ LuaType :: from_vec (
171+ overloads
172+ . iter ( )
173+ . map ( |callable| {
174+ let mut callable_tpls = HashSet :: new ( ) ;
175+ callable. visit_type ( & mut |ty| {
176+ if let LuaType :: TplRef ( generic_tpl) | LuaType :: ConstTplRef ( generic_tpl) = ty
177+ {
178+ callable_tpls. insert ( generic_tpl. get_tpl_id ( ) ) ;
179+ }
180+ } ) ;
181+
182+ let mut callable_substitutor = TypeSubstitutor :: new ( ) ;
183+ callable_substitutor. add_need_infer_tpls ( callable_tpls) ;
184+ infer_return_from_callable ( db, callable, & callable_substitutor)
185+ } )
186+ . collect ( ) ,
187+ )
161188 } ;
162189
190+ let call_arg_types = match infer_expr_list_types ( db, context. cache , arg_exprs, None , infer_expr)
191+ {
192+ Ok ( types) => types. into_iter ( ) . map ( |( ty, _) | ty) . collect :: < Vec < _ > > ( ) ,
193+ Err ( _) => return Ok ( Some ( fallback_return ( ) ) ) ,
194+ } ;
195+ if call_arg_types. is_empty ( ) {
196+ return Ok ( None ) ;
197+ }
198+
199+ let instantiated_overloads = overloads
200+ . iter ( )
201+ . map ( |callable| instantiate_callable_from_arg_types ( context, callable, & call_arg_types) )
202+ . collect :: < Vec < _ > > ( ) ;
203+
204+ Ok ( Some (
205+ resolve_signature_by_args ( db, & instantiated_overloads, & call_arg_types, false , None )
206+ . map ( |callable| callable. get_ret ( ) . clone ( ) )
207+ . unwrap_or_else ( |_| fallback_return ( ) ) ,
208+ ) )
209+ }
210+
211+ fn collect_callable_overloads (
212+ db : & DbIndex ,
213+ callable_type : & LuaType ,
214+ overloads : & mut Vec < Arc < LuaFunctionType > > ,
215+ ) -> Result < ( ) , InferFailReason > {
216+ // TODO: Distinguish callable union vs intersection semantics here instead of flattening both
217+ // into one overload-candidate pool. Keep in sync with `infer_union` / `infer_intersection`.
218+ match callable_type {
219+ LuaType :: DocFunction ( doc_func) => overloads. push ( doc_func. clone ( ) ) ,
220+ LuaType :: Signature ( sig_id) => {
221+ let signature = db
222+ . get_signature_index ( )
223+ . get ( sig_id)
224+ . ok_or ( InferFailReason :: None ) ?;
225+ overloads. extend ( signature. overloads . iter ( ) . cloned ( ) ) ;
226+ overloads. push ( signature. to_doc_func_type ( ) ) ;
227+ }
228+ LuaType :: Ref ( type_id) | LuaType :: Def ( type_id) => {
229+ if let Some ( origin_type) = db
230+ . get_type_index ( )
231+ . get_type_decl ( type_id)
232+ . ok_or ( InferFailReason :: None ) ?
233+ . get_alias_origin ( db, None )
234+ {
235+ collect_callable_overloads ( db, & origin_type, overloads) ?;
236+ }
237+ }
238+ LuaType :: Generic ( generic) => {
239+ let substitutor = TypeSubstitutor :: from_type_array ( generic. get_params ( ) . to_vec ( ) ) ;
240+ if let Some ( origin_type) = db
241+ . get_type_index ( )
242+ . get_type_decl ( & generic. get_base_type_id ( ) )
243+ . ok_or ( InferFailReason :: None ) ?
244+ . get_alias_origin ( db, Some ( & substitutor) )
245+ {
246+ collect_callable_overloads ( db, & origin_type, overloads) ?;
247+ }
248+ }
249+ LuaType :: Union ( union) => {
250+ for member in union. into_vec ( ) {
251+ collect_callable_overloads ( db, & member, overloads) ?;
252+ }
253+ }
254+ LuaType :: Intersection ( intersection) => {
255+ for member in intersection. get_types ( ) {
256+ collect_callable_overloads ( db, member, overloads) ?;
257+ }
258+ }
259+ _ => { }
260+ }
261+
262+ Ok ( ( ) )
263+ }
264+
265+ fn instantiate_callable_from_arg_types (
266+ context : & mut TplContext ,
267+ callable : & Arc < LuaFunctionType > ,
268+ call_arg_types : & [ LuaType ] ,
269+ ) -> Arc < LuaFunctionType > {
163270 let mut callable_tpls = HashSet :: new ( ) ;
164271 callable. visit_type ( & mut |ty| {
165272 if let LuaType :: TplRef ( generic_tpl) | LuaType :: ConstTplRef ( generic_tpl) = ty {
166273 callable_tpls. insert ( generic_tpl. get_tpl_id ( ) ) ;
167274 }
168275 } ) ;
169276 if callable_tpls. is_empty ( ) {
170- return Ok ( Some ( callable. get_ret ( ) . clone ( ) ) ) ;
171- }
172-
173- let mut callable_substitutor = TypeSubstitutor :: new ( ) ;
174- callable_substitutor. add_need_infer_tpls ( callable_tpls) ;
175- let fallback_return = infer_return_from_callable ( context. db , & callable, & callable_substitutor) ;
176-
177- let call_arg_types =
178- match infer_expr_list_types ( context. db , context. cache , arg_exprs, None , infer_expr) {
179- Ok ( types) => types. into_iter ( ) . map ( |( ty, _) | ty) . collect :: < Vec < _ > > ( ) ,
180- Err ( _) => return Ok ( Some ( fallback_return) ) ,
181- } ;
182- if call_arg_types. is_empty ( ) {
183- return Ok ( None ) ;
277+ return callable. clone ( ) ;
184278 }
185279
186280 let callable_param_types = callable
187281 . get_params ( )
188282 . iter ( )
189283 . map ( |( _, ty) | ty. clone ( ) . unwrap_or ( LuaType :: Unknown ) )
190284 . collect :: < Vec < _ > > ( ) ;
191-
285+ let mut callable_substitutor = TypeSubstitutor :: new ( ) ;
286+ callable_substitutor. add_need_infer_tpls ( callable_tpls) ;
192287 let mut callable_context = TplContext {
193288 db : context. db ,
194289 cache : context. cache ,
195290 substitutor : & mut callable_substitutor,
196291 call_expr : context. call_expr . clone ( ) ,
197292 } ;
198- if tpl_pattern_match_args (
199- & mut callable_context,
200- & callable_param_types,
201- & call_arg_types,
202- )
203- . is_err ( )
204- {
205- return Ok ( Some ( fallback_return) ) ;
206- }
293+ let _ = tpl_pattern_match_args ( & mut callable_context, & callable_param_types, call_arg_types) ;
207294
208- Ok ( Some ( infer_return_from_callable (
209- context. db ,
210- & callable,
211- & callable_substitutor,
212- ) ) )
295+ match instantiate_doc_function ( context. db , callable, & callable_substitutor) {
296+ LuaType :: DocFunction ( func) => func,
297+ _ => callable. clone ( ) ,
298+ }
213299}
214300
215301fn infer_generic_types_from_call (
0 commit comments