Skip to content

Commit 2fa7311

Browse files
authored
Merge pull request #3330 from ian-pascoe:fix/generic-class-inheritance
fix: comprehensive generic type resolution improvements
2 parents 65a3fa0 + 5eaa187 commit 2fa7311

File tree

9 files changed

+871
-20
lines changed

9 files changed

+871
-20
lines changed

changelog.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
## Unreleased
44
<!-- Add all new changes here. They will be moved under a version at release -->
55
* `FIX` convert all keys to string before using
6+
* `FIX` Generic class inheritance with type arguments now works correctly (e.g., `class Bar: Foo<integer>`) [#1929](https://github.com/LuaLS/lua-language-server/issues/1929)
7+
* `FIX` Method return types on generic classes now resolve correctly (e.g., `Box<string>:getValue()` returns `string`) [#1863](https://github.com/LuaLS/lua-language-server/issues/1863)
8+
* `FIX` Self-referential generic classes no longer cause infinite expansion in hover display [#1853](https://github.com/LuaLS/lua-language-server/issues/1853)
9+
* `FIX` Generic type parameters now work in `@overload` annotations [#723](https://github.com/LuaLS/lua-language-server/issues/723)
10+
* `NEW` Support `fun<T>` syntax for inline generic function types in `@field` and `@type` annotations [#1170](https://github.com/LuaLS/lua-language-server/issues/1170)
11+
* `FIX` Methods with `@generic T` and `@param self T` now correctly resolve return type to the receiver's concrete type (e.g., `List<number>:identity()` returns `List<number>`) [#1000](https://github.com/LuaLS/lua-language-server/issues/1000)
612

713
## 3.16.4
814
`2025-12-25`

script/core/diagnostics/param-type-mismatch.lua

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ local vm = require 'vm'
55
local await = require 'await'
66

77
---@param defNode vm.node
8-
local function expandGenerics(defNode)
8+
---@param classGenericMap table<string, vm.node>?
9+
local function expandGenerics(defNode, classGenericMap)
910
---@type parser.object[]
1011
local generics = {}
1112
for dn in defNode:eachObject() do
@@ -20,27 +21,78 @@ local function expandGenerics(defNode)
2021
end
2122

2223
for _, generic in ipairs(generics) do
23-
local limits = generic.generic and generic.generic.extends
24-
if limits then
25-
defNode:merge(vm.compileNode(limits))
24+
-- First check if this generic is a class generic that can be resolved
25+
local genericName = generic[1]
26+
if classGenericMap and genericName and classGenericMap[genericName] then
27+
defNode:merge(classGenericMap[genericName])
2628
else
27-
local unknownType = vm.declareGlobal('type', 'unknown')
28-
defNode:merge(unknownType)
29+
-- Fall back to constraint or unknown
30+
local limits = generic.generic and generic.generic.extends
31+
if limits then
32+
defNode:merge(vm.compileNode(limits))
33+
else
34+
local unknownType = vm.declareGlobal('type', 'unknown')
35+
defNode:merge(unknownType)
36+
end
37+
end
38+
end
39+
end
40+
41+
---@param uri uri
42+
---@param source parser.object
43+
---@return table<string, vm.node>?
44+
local function getReceiverGenericMap(uri, source)
45+
local callNode = source.node
46+
if not callNode then
47+
return nil
48+
end
49+
-- Only resolve generics for method calls (obj:method()), not static calls (Class.method())
50+
if callNode.type ~= 'getmethod' then
51+
return nil
52+
end
53+
local receiver = callNode.node
54+
if not receiver then
55+
return nil
56+
end
57+
local receiverNode = vm.compileNode(receiver)
58+
for rn in receiverNode:eachObject() do
59+
if rn.type == 'doc.type.sign' and rn.signs and rn.node and rn.node[1] then
60+
local classGlobal = vm.getGlobal('type', rn.node[1])
61+
if classGlobal then
62+
return vm.getClassGenericMap(uri, classGlobal, rn.signs)
63+
end
2964
end
3065
end
66+
return nil
3167
end
3268

3369
---@param funcNode vm.node
3470
---@param i integer
71+
---@param classGenericMap table<string, vm.node>?
3572
---@return vm.node?
36-
local function getDefNode(funcNode, i)
73+
local function getDefNode(funcNode, i, classGenericMap)
3774
local defNode = vm.createNode()
3875
for src in funcNode:eachObject() do
3976
if src.type == 'function'
4077
or src.type == 'doc.type.function' then
4178
local param = src.args and src.args[i]
4279
if param then
43-
defNode:merge(vm.compileNode(param))
80+
local paramNode = vm.compileNode(param)
81+
-- Check for global type references that match class generic params
82+
if classGenericMap then
83+
local newNode = vm.createNode()
84+
for pn in paramNode:eachObject() do
85+
if pn.type == 'global' and pn.cate == 'type' and classGenericMap[pn.name] then
86+
-- Replace the global type reference with the resolved type
87+
newNode:merge(classGenericMap[pn.name])
88+
else
89+
newNode:merge(pn)
90+
end
91+
end
92+
defNode:merge(newNode)
93+
else
94+
defNode:merge(paramNode)
95+
end
4496
if param[1] == '...' then
4597
defNode:addOptional()
4698
end
@@ -51,7 +103,7 @@ local function getDefNode(funcNode, i)
51103
return nil
52104
end
53105

54-
expandGenerics(defNode)
106+
expandGenerics(defNode, classGenericMap)
55107

56108
return defNode
57109
end
@@ -87,12 +139,14 @@ return function (uri, callback)
87139
end
88140
await.delay()
89141
local funcNode = vm.compileNode(source.node)
142+
-- Get the class generic map for method calls on generic class instances
143+
local classGenericMap = getReceiverGenericMap(uri, source)
90144
for i, arg in ipairs(source.args) do
91145
local refNode = vm.compileNode(arg)
92146
if not refNode then
93147
goto CONTINUE
94148
end
95-
local defNode = getDefNode(funcNode, i)
149+
local defNode = getDefNode(funcNode, i, classGenericMap)
96150
if not defNode then
97151
goto CONTINUE
98152
end

script/core/diagnostics/undefined-doc-name.lua

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,100 @@ local guide = require 'parser.guide'
33
local lang = require 'language'
44
local vm = require 'vm'
55

6+
--- Check if name is a generic parameter from a class context
7+
---@param source parser.object The doc.type.name source
8+
---@param name string The type name to check
9+
---@param uri uri The file URI
10+
---@return boolean
11+
local function isClassGenericParam(source, name, uri)
12+
-- Find containing doc node
13+
local doc = guide.getParentTypes(source, {
14+
['doc.return'] = true,
15+
['doc.param'] = true,
16+
['doc.type'] = true,
17+
['doc.field'] = true,
18+
['doc.overload'] = true,
19+
['doc.vararg'] = true,
20+
})
21+
if not doc then
22+
return false
23+
end
24+
25+
-- Walk up to find a doc node with bindGroup (intermediate doc.type nodes don't have it)
26+
while doc and not doc.bindGroup do
27+
doc = doc.parent
28+
end
29+
if not doc then
30+
return false
31+
end
32+
33+
-- Check bindGroup for class/alias with matching generic sign
34+
local bindGroup = doc.bindGroup
35+
if bindGroup then
36+
for _, other in ipairs(bindGroup) do
37+
if (other.type == 'doc.class' or other.type == 'doc.alias') and other.signs then
38+
for _, sign in ipairs(other.signs) do
39+
if sign[1] == name then
40+
return true
41+
end
42+
end
43+
end
44+
end
45+
end
46+
47+
-- Check direct class reference (for doc.field, doc.overload, doc.operator)
48+
if doc.class and doc.class.signs then
49+
for _, sign in ipairs(doc.class.signs) do
50+
if sign[1] == name then
51+
return true
52+
end
53+
end
54+
end
55+
56+
-- Check if bound to a method on a generic class
57+
-- Find the function from any doc in the bindGroup
58+
local func = nil
59+
if bindGroup then
60+
for _, other in ipairs(bindGroup) do
61+
local bindSource = other.bindSource
62+
if bindSource then
63+
if bindSource.type == 'function' then
64+
-- doc.return binds directly to function
65+
func = bindSource
66+
break
67+
else
68+
-- doc.param binds to local param, find containing function
69+
func = guide.getParentFunction(bindSource)
70+
if func then
71+
break
72+
end
73+
end
74+
end
75+
end
76+
end
77+
78+
-- If we found a function, check if it's a method on a generic class
79+
if func and func.parent then
80+
local parent = func.parent
81+
if parent.type == 'setmethod' or parent.type == 'setfield' or parent.type == 'setindex' then
82+
local classGlobal = vm.getDefinedClass(uri, parent.node)
83+
if classGlobal then
84+
for _, set in ipairs(classGlobal:getSets(uri)) do
85+
if set.type == 'doc.class' and set.signs then
86+
for _, sign in ipairs(set.signs) do
87+
if sign[1] == name then
88+
return true
89+
end
90+
end
91+
end
92+
end
93+
end
94+
end
95+
end
96+
97+
return false
98+
end
99+
6100
return function (uri, callback)
7101
local state = files.getState(uri)
8102
if not state then
@@ -25,6 +119,9 @@ return function (uri, callback)
25119
if name == '...' or name == '_' or name == 'self' then
26120
return
27121
end
122+
if isClassGenericParam(source, name, uri) then
123+
return
124+
end
28125
if #vm.getDocSets(uri, name) > 0 then
29126
return
30127
end

script/parser/guide.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ local childMap = {
177177
['doc.generic.object'] = {'generic', 'extends', 'comment'},
178178
['doc.vararg'] = {'vararg', 'comment'},
179179
['doc.type.array'] = {'node'},
180-
['doc.type.function'] = {'#args', '#returns', 'comment'},
180+
['doc.type.function'] = {'#args', '#returns', '#signs', 'comment'},
181181
['doc.type.table'] = {'#fields', 'comment'},
182182
['doc.type.literal'] = {'node'},
183183
['doc.type.arg'] = {'name', 'extends'},

script/parser/luadoc.lua

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,8 @@ local function parseTypeUnitFunction(parent)
523523
args = {},
524524
returns = {},
525525
}
526+
-- Parse optional generic params: fun<T, V>(...)
527+
typeUnit.signs = parseSigns(typeUnit)
526528
if not nextSymbolOrError('(') then
527529
return nil
528530
end
@@ -617,6 +619,51 @@ local function parseTypeUnitFunction(parent)
617619
end
618620
end
619621
typeUnit.finish = getFinish()
622+
-- Bind local generics from fun<T, V> to type names within this function
623+
if typeUnit.signs then
624+
local generics = {}
625+
for _, sign in ipairs(typeUnit.signs) do
626+
generics[sign[1]] = sign
627+
end
628+
local function bindTypeNames(obj)
629+
if not obj then return end
630+
if obj.type == 'doc.type.name' and generics[obj[1]] then
631+
obj.type = 'doc.generic.name'
632+
obj.generic = generics[obj[1]]
633+
elseif obj.type == 'doc.type' and obj.types then
634+
for _, t in ipairs(obj.types) do
635+
bindTypeNames(t)
636+
end
637+
elseif obj.type == 'doc.type.array' then
638+
bindTypeNames(obj.node)
639+
elseif obj.type == 'doc.type.table' and obj.fields then
640+
for _, field in ipairs(obj.fields) do
641+
bindTypeNames(field.name)
642+
bindTypeNames(field.extends)
643+
end
644+
elseif obj.type == 'doc.type.sign' then
645+
bindTypeNames(obj.node)
646+
if obj.signs then
647+
for _, s in ipairs(obj.signs) do
648+
bindTypeNames(s)
649+
end
650+
end
651+
elseif obj.type == 'doc.type.function' then
652+
for _, arg in ipairs(obj.args) do
653+
bindTypeNames(arg.extends)
654+
end
655+
for _, ret in ipairs(obj.returns) do
656+
bindTypeNames(ret)
657+
end
658+
end
659+
end
660+
for _, arg in ipairs(typeUnit.args) do
661+
bindTypeNames(arg.extends)
662+
end
663+
for _, ret in ipairs(typeUnit.returns) do
664+
bindTypeNames(ret)
665+
end
666+
end
620667
return typeUnit
621668
end
622669

@@ -1030,6 +1077,12 @@ local docSwitch = util.switch()
10301077
}
10311078
return result
10321079
end
1080+
if extend.type == 'doc.extends.name' then
1081+
local signResult = parseTypeUnitSign(result, extend)
1082+
if signResult then
1083+
extend = signResult
1084+
end
1085+
end
10331086
result.extends[#result.extends+1] = extend
10341087
result.finish = getFinish()
10351088
if not checkToken('symbol', ',', 1) then
@@ -1850,7 +1903,9 @@ local function bindGeneric(binded)
18501903
or doc.type == 'doc.return'
18511904
or doc.type == 'doc.type'
18521905
or doc.type == 'doc.class'
1853-
or doc.type == 'doc.alias' then
1906+
or doc.type == 'doc.alias'
1907+
or doc.type == 'doc.field'
1908+
or doc.type == 'doc.overload' then
18541909
guide.eachSourceType(doc, 'doc.type.name', function (src)
18551910
local name = src[1]
18561911
if generics[name] then

0 commit comments

Comments
 (0)