Skip to content

Commit 53ee6ef

Browse files
committed
make Import() function more module-aware with v%d support
Now paths like `foo/v2` will work automatically. For a more complicated use case, call `ImportAs` and assign a proper name manually.
1 parent 7513f50 commit 53ee6ef

File tree

6 files changed

+87
-4
lines changed

6 files changed

+87
-4
lines changed

analyzer/testdata/src/imports/f1.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,24 @@ type distraction struct{}
2525
func (distraction) Read(p []byte) (int, error) {
2626
return 0, nil
2727
}
28+
29+
type (
30+
v1Impl struct{}
31+
v2Impl struct{}
32+
v3Impl struct{}
33+
)
34+
35+
func (i *v1Impl) Do() {}
36+
37+
func (i *v2Impl) Do(x any) {}
38+
39+
func (i *v3Impl) Do(x int) {}
40+
41+
func json2() {
42+
var i1 *v1Impl
43+
var i2 *v2Impl
44+
var i3 *v3Impl
45+
_ = i1 // want `\Qv1 implemented`
46+
_ = i2 // want `\Qv2 implemented`
47+
_ = i3 // want `\Qv3 implemented`
48+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package mylib
2+
3+
type Contract interface {
4+
Do()
5+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package mylib
2+
3+
type Contract interface {
4+
Do(x any)
5+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package mylib
2+
3+
type Contract interface {
4+
Do(x int)
5+
}

analyzer/testdata/src/imports/rules.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,24 @@ func testCryptoRand(m dsl.Matcher) {
1515
m.Import(`crypto/rand`)
1616
m.Match(`rand.Read($*_)`).Report(`crypto/rand`)
1717
}
18+
19+
func testImportV1(m dsl.Matcher) {
20+
m.Import(`github.com/quasilyte/go-ruleguard/analyzer/testdata/src/imports/mylib`)
21+
m.Match(`_ = $x`).
22+
Where(m["x"].Type.Implements(`mylib.Contract`)).
23+
Report(`v1 implemented`)
24+
}
25+
26+
func testImportV2(m dsl.Matcher) {
27+
m.Import(`github.com/quasilyte/go-ruleguard/analyzer/testdata/src/imports/mylib/v2`)
28+
m.Match(`_ = $x`).
29+
Where(m["x"].Type.Implements(`mylib.Contract`)).
30+
Report(`v2 implemented`)
31+
}
32+
33+
func testImportV3(m dsl.Matcher) {
34+
m.Import(`github.com/quasilyte/go-ruleguard/analyzer/testdata/src/imports/mylib/v3`)
35+
m.Match(`_ = $x`).
36+
Where(m["x"].Type.Implements(`mylib.Contract`)).
37+
Report(`v3 implemented`)
38+
}

ruleguard/irconv/irconv.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"go/token"
88
"go/types"
99
"path"
10+
"regexp"
1011
"strconv"
1112
"strings"
1213

@@ -41,10 +42,11 @@ func ConvertFile(ctx *Context, f *ast.File) (result *ir.File, err error) {
4142
}()
4243

4344
conv := &converter{
44-
types: ctx.Types,
45-
pkg: ctx.Pkg,
46-
fset: ctx.Fset,
47-
src: ctx.Src,
45+
types: ctx.Types,
46+
pkg: ctx.Pkg,
47+
fset: ctx.Fset,
48+
src: ctx.Src,
49+
versionPathRe: regexp.MustCompile(`^v[0-9]+$`),
4850
}
4951
result = conv.ConvertFile(f)
5052
return result, nil
@@ -66,6 +68,8 @@ type converter struct {
6668
fset *token.FileSet
6769
src []byte
6870

71+
versionPathRe *regexp.Regexp
72+
6973
group *ir.RuleGroup
7074
groupFuncs []localMacroFunc
7175

@@ -224,6 +228,11 @@ func (conv *converter) convertRuleGroup(decl *ast.FuncDecl) *ir.RuleGroup {
224228
panic(conv.errorf(call, "Import() should be used before any rules definitions"))
225229
}
226230
conv.doMatcherImport(call)
231+
case "ImportAs":
232+
if seenRules {
233+
panic(conv.errorf(call, "ImportAs() should be used before any rules definitions"))
234+
}
235+
conv.doMatcherImportAs(call)
227236
default:
228237
seenRules = true
229238
conv.convertRuleExpr(call)
@@ -375,7 +384,24 @@ func (conv *converter) localDefine(assign *ast.AssignStmt) {
375384

376385
func (conv *converter) doMatcherImport(call *ast.CallExpr) {
377386
pkgPath := conv.parseStringArg(call.Args[0])
387+
388+
// Try to be at least somewhat module-aware.
389+
// If the last path part is "/v%d", we might want to take
390+
// the previous path part as a package name.
378391
pkgName := path.Base(pkgPath)
392+
if conv.versionPathRe.MatchString(pkgName) {
393+
pkgName = path.Base(path.Dir(pkgPath))
394+
}
395+
396+
conv.group.Imports = append(conv.group.Imports, ir.PackageImport{
397+
Path: pkgPath,
398+
Name: pkgName,
399+
})
400+
}
401+
402+
func (conv *converter) doMatcherImportAs(call *ast.CallExpr) {
403+
pkgPath := conv.parseStringArg(call.Args[0])
404+
pkgName := conv.parseStringArg(call.Args[1])
379405
conv.group.Imports = append(conv.group.Imports, ir.PackageImport{
380406
Path: pkgPath,
381407
Name: pkgName,

0 commit comments

Comments
 (0)