Skip to content

Commit bf7321e

Browse files
committed
lockdown mode: remove RepoAccessCache singleton and isolate viewer state per instance
1 parent f5f9c72 commit bf7321e

5 files changed

Lines changed: 231 additions & 107 deletions

File tree

internal/ghmcp/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv
103103
if cfg.RepoAccessTTL != nil {
104104
opts = append(opts, lockdown.WithTTL(*cfg.RepoAccessTTL))
105105
}
106-
repoAccessCache = lockdown.GetInstance(gqlClient, restClient, opts...)
106+
repoAccessCache = lockdown.NewRepoAccessCache(gqlClient, restClient, opts...)
107107
}
108108

109109
return &githubClients{

pkg/github/dependencies.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ func (d *RequestDeps) GetRepoAccessCache(ctx context.Context) (*lockdown.RepoAcc
399399
}
400400

401401
// Create repo access cache
402-
instance := lockdown.GetInstance(gqlClient, restClient, d.RepoAccessOpts...)
402+
instance := lockdown.NewRepoAccessCache(gqlClient, restClient, d.RepoAccessOpts...)
403403
return instance, nil
404404
}
405405

pkg/github/issues_test.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,15 @@ func (rt *repoAccessMockTransport) RoundTrip(req *http.Request) (*http.Response,
7070
value = repoAccessValue{isPrivate: false}
7171
}
7272

73-
responseBody, err := json.Marshal(map[string]any{
74-
"data": map[string]any{
75-
"viewer": map[string]any{
76-
"login": "test-viewer",
77-
},
78-
"repository": map[string]any{
79-
"isPrivate": value.isPrivate,
80-
},
81-
},
82-
})
73+
data := map[string]any{}
74+
if strings.Contains(payload.Query, "viewer") {
75+
data["viewer"] = map[string]any{"login": "test-viewer"}
76+
}
77+
if strings.Contains(payload.Query, "repository") {
78+
data["repository"] = map[string]any{"isPrivate": value.isPrivate}
79+
}
80+
81+
responseBody, err := json.Marshal(map[string]any{"data": data})
8382
if err != nil {
8483
return nil, err
8584
}

pkg/lockdown/lockdown.go

Lines changed: 86 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"log/slog"
7+
"maps"
78
"strings"
89
"sync"
910
"time"
@@ -15,39 +16,36 @@ import (
1516

1617
// RepoAccessCache caches repository metadata related to lockdown checks so that
1718
// multiple tools can reuse the same access information safely across goroutines.
19+
// In HTTP mode each request must construct its own instance so viewer-scoped
20+
// lookups run under the requesting user's credentials.
1821
type RepoAccessCache struct {
1922
client *githubv4.Client
2023
restClient *github.Client
21-
mu sync.Mutex
2224
cache *cache2go.CacheTable
2325
ttl time.Duration
2426
logger *slog.Logger
2527
trustedBotLogins map[string]struct{}
28+
29+
viewerMu sync.Mutex
30+
viewerLogin string
2631
}
2732

2833
type repoAccessCacheEntry struct {
29-
isPrivate bool
30-
knownUsers map[string]bool // normalized login -> has push access
31-
viewerLogin string
34+
isPrivate bool
35+
knownUsers map[string]bool // normalized login -> has push access
3236
}
3337

3438
// RepoAccessInfo captures repository metadata needed for lockdown decisions.
3539
type RepoAccessInfo struct {
3640
IsPrivate bool
3741
HasPushAccess bool
38-
ViewerLogin string
3942
}
4043

4144
const (
4245
defaultRepoAccessTTL = 20 * time.Minute
4346
defaultRepoAccessCacheKey = "repo-access-cache"
4447
)
4548

46-
var (
47-
instance *RepoAccessCache
48-
instanceMu sync.Mutex
49-
)
50-
5149
// RepoAccessOption configures RepoAccessCache at construction time.
5250
type RepoAccessOption func(*RepoAccessCache)
5351

@@ -66,8 +64,8 @@ func WithLogger(logger *slog.Logger) RepoAccessOption {
6664
}
6765
}
6866

69-
// WithCacheName overrides the cache table name used for storing entries. This option is intended for tests
70-
// that need isolated cache instances.
67+
// WithCacheName overrides the cache table name used for storing entries.
68+
// Use this to isolate cache entries between tenants or in tests.
7169
func WithCacheName(name string) RepoAccessOption {
7270
return func(c *RepoAccessCache) {
7371
if name != "" {
@@ -76,25 +74,8 @@ func WithCacheName(name string) RepoAccessOption {
7674
}
7775
}
7876

79-
// GetInstance returns the singleton instance of RepoAccessCache.
80-
// It initializes the instance on first call with the provided client and options.
81-
// Subsequent calls ignore the client and options parameters and return the existing instance.
82-
// This is the preferred way to access the cache in production code.
83-
func GetInstance(client *githubv4.Client, restClient *github.Client, opts ...RepoAccessOption) *RepoAccessCache {
84-
instanceMu.Lock()
85-
defer instanceMu.Unlock()
86-
if instance == nil {
87-
instance = newRepoAccessCache(client, restClient, opts...)
88-
}
89-
return instance
90-
}
91-
92-
// NewRepoAccessCache creates a standalone cache instance, used for tests.
77+
// NewRepoAccessCache creates a RepoAccessCache bound to the supplied clients.
9378
func NewRepoAccessCache(client *githubv4.Client, restClient *github.Client, opts ...RepoAccessOption) *RepoAccessCache {
94-
return newRepoAccessCache(client, restClient, opts...)
95-
}
96-
97-
func newRepoAccessCache(client *githubv4.Client, restClient *github.Client, opts ...RepoAccessOption) *RepoAccessCache {
9879
c := &RepoAccessCache{
9980
client: client,
10081
restClient: restClient,
@@ -113,13 +94,6 @@ func newRepoAccessCache(client *githubv4.Client, restClient *github.Client, opts
11394
return c
11495
}
11596

116-
// SetLogger updates the logger used for cache diagnostics.
117-
func (c *RepoAccessCache) SetLogger(logger *slog.Logger) {
118-
c.mu.Lock()
119-
c.logger = logger
120-
c.mu.Unlock()
121-
}
122-
12397
// CacheStats summarizes cache activity counters.
12498
type CacheStats struct {
12599
Hits int64
@@ -150,10 +124,55 @@ func (c *RepoAccessCache) IsSafeContent(ctx context.Context, username, owner, re
150124
c.logDebug(ctx, fmt.Sprintf("evaluated repo access for user %s to %s/%s for content filtering, result: hasPushAccess=%t, isPrivate=%t",
151125
username, owner, repo, repoInfo.HasPushAccess, repoInfo.IsPrivate))
152126

153-
if repoInfo.IsPrivate || repoInfo.ViewerLogin == strings.ToLower(username) {
127+
if repoInfo.IsPrivate {
128+
return true, nil
129+
}
130+
if repoInfo.HasPushAccess {
154131
return true, nil
155132
}
156-
return repoInfo.HasPushAccess, nil
133+
134+
viewerLogin, err := c.viewerLoginFor(ctx)
135+
if err != nil {
136+
return false, err
137+
}
138+
return viewerLogin == strings.ToLower(username), nil
139+
}
140+
141+
func (c *RepoAccessCache) viewerLoginFor(ctx context.Context) (string, error) {
142+
c.viewerMu.Lock()
143+
defer c.viewerMu.Unlock()
144+
if c.viewerLogin != "" {
145+
return c.viewerLogin, nil
146+
}
147+
if c.client == nil {
148+
return "", fmt.Errorf("nil GraphQL client")
149+
}
150+
var query struct {
151+
Viewer struct {
152+
Login githubv4.String
153+
}
154+
}
155+
if err := c.client.Query(ctx, &query, nil); err != nil {
156+
return "", fmt.Errorf("failed to query viewer login: %w", err)
157+
}
158+
login := strings.ToLower(string(query.Viewer.Login))
159+
if login == "" {
160+
return "", fmt.Errorf("viewer login returned empty")
161+
}
162+
c.viewerLogin = login
163+
return c.viewerLogin, nil
164+
}
165+
166+
// setViewerLogin seeds the cached viewer login from a piggy-backed query response.
167+
func (c *RepoAccessCache) setViewerLogin(login string) {
168+
if login == "" {
169+
return
170+
}
171+
c.viewerMu.Lock()
172+
defer c.viewerMu.Unlock()
173+
if c.viewerLogin == "" {
174+
c.viewerLogin = strings.ToLower(login)
175+
}
157176
}
158177

159178
func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) {
@@ -163,19 +182,16 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner
163182

164183
key := cacheKey(owner, repo)
165184
userKey := strings.ToLower(username)
166-
c.mu.Lock()
167-
defer c.mu.Unlock()
168185

169-
// Try to get entry from cache - this will keep the item alive if it exists
170-
cacheItem, err := c.cache.Value(key)
171-
if err == nil {
186+
// Entries are immutable once added: the cache table is shared across instances,
187+
// so we publish a fresh entry with a cloned knownUsers map on every miss.
188+
if cacheItem, err := c.cache.Value(key); err == nil {
172189
entry := cacheItem.Data().(*repoAccessCacheEntry)
173190
if cachedHasPush, known := entry.knownUsers[userKey]; known {
174191
c.logDebug(ctx, fmt.Sprintf("repo access cache hit for user %s to %s/%s", username, owner, repo))
175192
return RepoAccessInfo{
176193
IsPrivate: entry.isPrivate,
177194
HasPushAccess: cachedHasPush,
178-
ViewerLogin: entry.viewerLogin,
179195
}, nil
180196
}
181197

@@ -186,41 +202,48 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner
186202
return RepoAccessInfo{}, pushErr
187203
}
188204

189-
entry.knownUsers[userKey] = hasPush
190-
c.cache.Add(key, c.ttl, entry)
205+
users := make(map[string]bool, len(entry.knownUsers)+1)
206+
maps.Copy(users, entry.knownUsers)
207+
users[userKey] = hasPush
208+
c.cache.Add(key, c.ttl, &repoAccessCacheEntry{
209+
isPrivate: entry.isPrivate,
210+
knownUsers: users,
211+
})
191212

192213
return RepoAccessInfo{
193214
IsPrivate: entry.isPrivate,
194215
HasPushAccess: hasPush,
195-
ViewerLogin: entry.viewerLogin,
196216
}, nil
197217
}
198218

199219
c.logDebug(ctx, fmt.Sprintf("repo access cache miss for user %s to %s/%s", username, owner, repo))
200220

201-
info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo)
221+
isPrivate, viewerLogin, queryErr := c.queryRepoAccessInfo(ctx, owner, repo)
202222
if queryErr != nil {
203223
return RepoAccessInfo{}, queryErr
204224
}
225+
c.setViewerLogin(viewerLogin)
205226

206-
// Create new entry
207-
entry := &repoAccessCacheEntry{
208-
knownUsers: map[string]bool{userKey: info.HasPushAccess},
209-
isPrivate: info.IsPrivate,
210-
viewerLogin: info.ViewerLogin,
227+
hasPush, pushErr := c.checkPushAccess(ctx, username, owner, repo)
228+
if pushErr != nil {
229+
return RepoAccessInfo{}, pushErr
211230
}
212-
c.cache.Add(key, c.ttl, entry)
231+
232+
c.cache.Add(key, c.ttl, &repoAccessCacheEntry{
233+
knownUsers: map[string]bool{userKey: hasPush},
234+
isPrivate: isPrivate,
235+
})
213236

214237
return RepoAccessInfo{
215-
IsPrivate: entry.isPrivate,
216-
HasPushAccess: entry.knownUsers[userKey],
217-
ViewerLogin: entry.viewerLogin,
238+
IsPrivate: isPrivate,
239+
HasPushAccess: hasPush,
218240
}, nil
219241
}
220242

221-
func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) {
243+
// queryRepoAccessInfo fetches repository visibility and the viewer login in a single GraphQL round-trip.
244+
func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, owner, repo string) (bool, string, error) {
222245
if c.client == nil {
223-
return RepoAccessInfo{}, fmt.Errorf("nil GraphQL client")
246+
return false, "", fmt.Errorf("nil GraphQL client")
224247
}
225248

226249
var query struct {
@@ -238,22 +261,12 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own
238261
}
239262

240263
if err := c.client.Query(ctx, &query, variables); err != nil {
241-
return RepoAccessInfo{}, fmt.Errorf("failed to query repository metadata: %w", err)
242-
}
243-
244-
hasPush, err := c.checkPushAccess(ctx, username, owner, repo)
245-
if err != nil {
246-
return RepoAccessInfo{}, err
264+
return false, "", fmt.Errorf("failed to query repository metadata: %w", err)
247265
}
248266

249-
c.logDebug(ctx, fmt.Sprintf("queried repo access info for user %s to %s/%s: isPrivate=%t, hasPushAccess=%t, viewerLogin=%s",
250-
username, owner, repo, bool(query.Repository.IsPrivate), hasPush, query.Viewer.Login))
267+
c.logDebug(ctx, fmt.Sprintf("queried repo access info for %s/%s: isPrivate=%t", owner, repo, bool(query.Repository.IsPrivate)))
251268

252-
return RepoAccessInfo{
253-
IsPrivate: bool(query.Repository.IsPrivate),
254-
HasPushAccess: hasPush,
255-
ViewerLogin: string(query.Viewer.Login),
256-
}, nil
269+
return bool(query.Repository.IsPrivate), string(query.Viewer.Login), nil
257270
}
258271

259272
// checkPushAccess checks if the user has push access to the repository via the REST permission endpoint.

0 commit comments

Comments
 (0)