Skip to content

Commit ce8810c

Browse files
committed
Add DoS protection and idle timeout to vMCP session manager
Implements resource-exhaustion protections for the session-scoped backend lifecycle (SessionManagementV2), resolving issue #3874. Global session limit: new session requests (no Mcp-Session-Id header) receive HTTP 503 with a Retry-After header when the server-wide cap is reached. Default: 100 sessions. Per-client session limit: CreateSession enforces a maximum number of concurrent sessions per auth.Identity.Subject. Anonymous clients are exempt. The counter is rolled back on all failure paths. Default: 10 sessions per identity. Idle session timeout: a background reaper goroutine terminates sessions that have had no CallTool activity for longer than the configured threshold. The idle clock resets on every tool call and is initialised when a session is fully established. The reaper is wired into shutdownFuncs so it stops cleanly on server shutdown. Default: 5 min. All three limits are configurable via Config fields; zero values select the defaults. The Limits struct is passed to sessionmanager.New() and all existing call sites are updated. Closes: #3874
1 parent 85c5f3e commit ce8810c

4 files changed

Lines changed: 806 additions & 8 deletions

File tree

pkg/vmcp/server/server.go

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"log/slog"
1717
"net"
1818
"net/http"
19+
"strconv"
1920
"strings"
2021
"sync"
2122
"time"
@@ -67,6 +68,13 @@ const (
6768
// defaultSessionTTL is the default session time-to-live duration.
6869
// Sessions that are inactive for this duration will be automatically cleaned up.
6970
defaultSessionTTL = 30 * time.Minute
71+
72+
// defaultIdleCheckInterval is how often the idle reaper scans for inactive sessions.
73+
defaultIdleCheckInterval = time.Minute
74+
75+
// defaultRetryAfterSeconds is the Retry-After value returned with HTTP 503
76+
// when the global session limit is reached.
77+
defaultRetryAfterSeconds = 30
7078
)
7179

7280
//go:generate mockgen -destination=mocks/mock_watcher.go -package=mocks -source=server.go Watcher
@@ -160,6 +168,21 @@ type Config struct {
160168
// SessionFactory creates MultiSessions for Phase 2 session management.
161169
// Required when SessionManagementV2 is true; ignored otherwise.
162170
SessionFactory vmcpsession.MultiSessionFactory
171+
172+
// MaxSessions is the global concurrent session limit when SessionManagementV2 is enabled.
173+
// Requests that would exceed this limit receive HTTP 503 with a Retry-After header.
174+
// 0 uses the default (100). Requires SessionManagementV2 = true.
175+
MaxSessions int
176+
177+
// MaxSessionsPerClient is the per-identity session limit when SessionManagementV2 is enabled.
178+
// Keyed by auth.Identity.Subject; anonymous clients are not limited.
179+
// 0 uses the default (10). Requires SessionManagementV2 = true.
180+
MaxSessionsPerClient int
181+
182+
// IdleSessionTimeout is the duration after which inactive sessions are proactively
183+
// expired when SessionManagementV2 is enabled. Must be ≤ SessionTTL.
184+
// 0 uses the default (5 minutes). Requires SessionManagementV2 = true.
185+
IdleSessionTimeout time.Duration
163186
}
164187

165188
// Server is the Virtual MCP Server that aggregates multiple backends.
@@ -277,6 +300,24 @@ func New(
277300
if cfg.SessionTTL == 0 {
278301
cfg.SessionTTL = defaultSessionTTL
279302
}
303+
if cfg.MaxSessions == 0 {
304+
cfg.MaxSessions = sessionmanager.DefaultMaxSessions
305+
}
306+
if cfg.MaxSessionsPerClient == 0 {
307+
cfg.MaxSessionsPerClient = sessionmanager.DefaultMaxSessionsPerClient
308+
}
309+
if cfg.IdleSessionTimeout == 0 {
310+
cfg.IdleSessionTimeout = sessionmanager.DefaultIdleSessionTimeout
311+
}
312+
// IdleSessionTimeout must not exceed SessionTTL: if it did, the transport
313+
// TTL reaper could evict sessions before the idle reaper fires, leaving
314+
// per-client counters and idle-tracking maps stale.
315+
if cfg.IdleSessionTimeout > cfg.SessionTTL {
316+
slog.Warn("IdleSessionTimeout exceeds SessionTTL; clamping to SessionTTL",
317+
"idle_session_timeout", cfg.IdleSessionTimeout,
318+
"session_ttl", cfg.SessionTTL)
319+
cfg.IdleSessionTimeout = cfg.SessionTTL
320+
}
280321

281322
// Create hooks for SDK integration
282323
hooks := &server.Hooks{}
@@ -400,7 +441,12 @@ func New(
400441
if cfg.SessionFactory == nil {
401442
return nil, fmt.Errorf("SessionManagementV2 is enabled but no SessionFactory was provided")
402443
}
403-
vmcpSessMgr = sessionmanager.New(sessionManager, cfg.SessionFactory, backendRegistry)
444+
limits := sessionmanager.Limits{
445+
MaxSessions: cfg.MaxSessions,
446+
MaxSessionsPerClient: cfg.MaxSessionsPerClient,
447+
IdleSessionTimeout: cfg.IdleSessionTimeout,
448+
}
449+
vmcpSessMgr = sessionmanager.New(sessionManager, cfg.SessionFactory, backendRegistry, limits)
404450
slog.Info("session-scoped backend lifecycle enabled")
405451

406452
// Warn about incompatible optimizer configuration and disable it
@@ -557,6 +603,13 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) {
557603
slog.Info("audit middleware enabled for MCP endpoints")
558604
}
559605

606+
// Apply session limit middleware when V2 session management is active.
607+
// Runs before auth so over-limit requests are rejected early without auth overhead.
608+
if s.vmcpSessionMgr != nil && s.config.MaxSessions > 0 {
609+
mcpHandler = s.sessionLimitMiddleware(mcpHandler)
610+
slog.Info("session limit middleware enabled", "max_sessions", s.config.MaxSessions)
611+
}
612+
560613
// Apply authentication middleware if configured (runs first in chain)
561614
if s.config.AuthMiddleware != nil {
562615
mcpHandler = s.config.AuthMiddleware(mcpHandler)
@@ -575,6 +628,37 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) {
575628
return mux, nil
576629
}
577630

631+
// sessionLimitMiddleware is a best-effort fast-fail for new session requests
632+
// (no Mcp-Session-Id header): it returns HTTP 503 + Retry-After before the
633+
// request reaches the SDK when the global session cap appears to be reached.
634+
// Existing sessions (with a valid Mcp-Session-Id) are never affected.
635+
//
636+
// This check is intentionally optimistic (non-atomic): it avoids the overhead
637+
// of routing and SDK processing for clearly-over-limit requests, but it does
638+
// not guarantee strict enforcement under concurrent load. Strict enforcement
639+
// is provided atomically by sessionmanager.Manager.Generate(), which uses an
640+
// increment-first reservation to prevent races between concurrent initialize
641+
// requests.
642+
func (s *Server) sessionLimitMiddleware(next http.Handler) http.Handler {
643+
// Resolve the concrete manager once so we can call ActiveSessionCount().
644+
mgr, _ := s.vmcpSessionMgr.(*sessionmanager.Manager)
645+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
646+
if r.Header.Get("Mcp-Session-Id") == "" && mgr != nil {
647+
if mgr.ActiveSessionCount() >= s.config.MaxSessions {
648+
w.Header().Set("Retry-After", strconv.Itoa(defaultRetryAfterSeconds))
649+
w.Header().Set("Content-Type", "application/json")
650+
w.WriteHeader(http.StatusServiceUnavailable)
651+
_, _ = w.Write([]byte(
652+
`{"error":{"code":-32000,"message":"Maximum concurrent sessions exceeded. ` +
653+
`Please try again later or contact administrator."}}`,
654+
))
655+
return
656+
}
657+
}
658+
next.ServeHTTP(w, r)
659+
})
660+
}
661+
578662
// Start starts the Virtual MCP Server and begins serving requests.
579663
//
580664
//nolint:gocyclo // Complexity from health monitoring and startup orchestration is acceptable
@@ -667,6 +751,19 @@ func (s *Server) Start(ctx context.Context) error {
667751
}
668752
}
669753

754+
// Start idle session reaper if V2 session management is active with an idle timeout.
755+
if mgr, ok := s.vmcpSessionMgr.(*sessionmanager.Manager); ok && s.config.IdleSessionTimeout > 0 {
756+
idleCtx, idleCancel := context.WithCancel(ctx)
757+
mgr.StartIdleReaper(idleCtx, defaultIdleCheckInterval)
758+
slog.Info("idle session reaper started",
759+
"idle_timeout", s.config.IdleSessionTimeout,
760+
"check_interval", defaultIdleCheckInterval)
761+
s.shutdownFuncs = append(s.shutdownFuncs, func(context.Context) error {
762+
idleCancel()
763+
return nil
764+
})
765+
}
766+
670767
// Start status reporter if configured
671768
if s.statusReporter != nil {
672769
shutdown, err := s.statusReporter.Start(ctx)

pkg/vmcp/server/session_management_v2_integration_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,55 @@ func buildV2Server(
218218
return ts
219219
}
220220

221+
// buildV2ServerWithLimits is like buildV2Server but accepts an explicit MaxSessions cap.
222+
func buildV2ServerWithLimits(
223+
t *testing.T,
224+
factory vmcpsession.MultiSessionFactory,
225+
maxSessions int,
226+
) *httptest.Server {
227+
t.Helper()
228+
229+
ctrl := gomock.NewController(t)
230+
t.Cleanup(ctrl.Finish)
231+
232+
mockBackendClient := mocks.NewMockBackendClient(ctrl)
233+
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
234+
mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl)
235+
236+
emptyAggCaps := &aggregator.AggregatedCapabilities{}
237+
mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes()
238+
mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(emptyAggCaps, nil).AnyTimes()
239+
mockDiscoveryMgr.EXPECT().Stop().AnyTimes()
240+
241+
rt := router.NewDefaultRouter()
242+
243+
srv, err := server.New(
244+
context.Background(),
245+
&server.Config{
246+
Host: "127.0.0.1",
247+
Port: 0,
248+
SessionTTL: 5 * time.Minute,
249+
SessionManagementV2: true,
250+
SessionFactory: factory,
251+
MaxSessions: maxSessions,
252+
},
253+
rt,
254+
mockBackendClient,
255+
mockDiscoveryMgr,
256+
mockBackendRegistry,
257+
nil,
258+
)
259+
require.NoError(t, err)
260+
261+
handler, err := srv.Handler(context.Background())
262+
require.NoError(t, err)
263+
264+
ts := httptest.NewServer(handler)
265+
t.Cleanup(ts.Close)
266+
267+
return ts
268+
}
269+
221270
// postMCP sends a JSON-RPC POST to /mcp and returns the response.
222271
func postMCP(t *testing.T, baseURL string, body map[string]any, sessionID string) *http.Response {
223272
t.Helper()
@@ -474,3 +523,72 @@ func TestIntegration_SessionManagementV2_OldPathUnused(t *testing.T) {
474523
"MakeSessionWithID should NOT be called when SessionManagementV2 is false",
475524
)
476525
}
526+
527+
// TestIntegration_SessionManagementV2_SessionLimitMiddleware verifies that the
528+
// global session cap (MaxSessions) is enforced end-to-end: once the cap is
529+
// reached every new initialize request gets HTTP 503 with a Retry-After header
530+
// and a JSON error body, while existing sessions are unaffected.
531+
func TestIntegration_SessionManagementV2_SessionLimitMiddleware(t *testing.T) {
532+
t.Parallel()
533+
534+
const maxSessions = 2
535+
536+
factory := newV2FakeFactory([]vmcp.Tool{{Name: "noop"}})
537+
ts := buildV2ServerWithLimits(t, factory, maxSessions)
538+
539+
initReq := map[string]any{
540+
"jsonrpc": "2.0",
541+
"id": 1,
542+
"method": "initialize",
543+
"params": map[string]any{
544+
"protocolVersion": "2025-06-18",
545+
"capabilities": map[string]any{},
546+
"clientInfo": map[string]any{"name": "test", "version": "1.0"},
547+
},
548+
}
549+
550+
// Fill the pool to exactly MaxSessions.
551+
sessionIDs := make([]string, maxSessions)
552+
for i := range maxSessions {
553+
resp := postMCP(t, ts.URL, initReq, "")
554+
defer resp.Body.Close() //nolint:gocritic // deferred inside loop is intentional for test cleanup
555+
require.Equal(t, http.StatusOK, resp.StatusCode, "session %d should succeed", i+1)
556+
id := resp.Header.Get("Mcp-Session-Id")
557+
require.NotEmpty(t, id, "session %d should return a session ID", i+1)
558+
sessionIDs[i] = id
559+
}
560+
561+
// The next initialize request must be rejected with 503.
562+
overResp := postMCP(t, ts.URL, initReq, "")
563+
defer overResp.Body.Close()
564+
565+
assert.Equal(t, http.StatusServiceUnavailable, overResp.StatusCode,
566+
"initialize beyond MaxSessions must return 503")
567+
assert.NotEmpty(t, overResp.Header.Get("Retry-After"),
568+
"503 response must include Retry-After header")
569+
assert.Equal(t, "application/json", overResp.Header.Get("Content-Type"))
570+
571+
var errBody struct {
572+
Error struct {
573+
Code int `json:"code"`
574+
Message string `json:"message"`
575+
} `json:"error"`
576+
}
577+
require.NoError(t, json.NewDecoder(overResp.Body).Decode(&errBody))
578+
assert.Equal(t, -32000, errBody.Error.Code)
579+
assert.NotEmpty(t, errBody.Error.Message)
580+
581+
// Existing sessions must still be valid (DELETE returns 200, not 404/503).
582+
for _, id := range sessionIDs {
583+
req, err := http.NewRequestWithContext(
584+
context.Background(), http.MethodDelete, ts.URL+"/mcp", http.NoBody,
585+
)
586+
require.NoError(t, err)
587+
req.Header.Set("Mcp-Session-Id", id)
588+
delResp, err := http.DefaultClient.Do(req)
589+
require.NoError(t, err)
590+
delResp.Body.Close()
591+
assert.Equal(t, http.StatusOK, delResp.StatusCode,
592+
"existing session %s should still be terminable after cap is hit", id)
593+
}
594+
}

0 commit comments

Comments
 (0)