Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pkg/vmcp/session/default_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ type defaultMultiSession struct {
prompts []vmcp.Prompt
backendSessions map[string]string

queue AdmissionQueue
queue AdmissionQueue
keepalive *KeepaliveManager // may be nil in tests that bypass the factory
}

// Tools returns a snapshot copy of the tools available in this session.
Expand Down Expand Up @@ -180,6 +181,12 @@ func (s *defaultMultiSession) GetPrompt(
func (s *defaultMultiSession) Close() error {
s.queue.CloseAndDrain()

// Stop keepalive goroutines before closing backend connections so they
// don't attempt probes on already-closed sessions.
if s.keepalive != nil {
s.keepalive.Stop()
}

var errs []error
for id, conn := range s.connections {
if err := conn.Close(); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion pkg/vmcp/session/default_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ func (m *mockConnectedBackend) GetPrompt(ctx context.Context, name string, argum
return &vmcp.PromptGetResult{Messages: "hello"}, nil
}

func (m *mockConnectedBackend) SessionID() string { return m.sessID }
func (*mockConnectedBackend) Ping(_ context.Context) error { return nil }
func (m *mockConnectedBackend) SessionID() string { return m.sessID }
func (m *mockConnectedBackend) Close() error {
m.closeCalled.Store(true)
return m.closeErr
Expand Down
54 changes: 53 additions & 1 deletion pkg/vmcp/session/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/google/uuid"
"go.opentelemetry.io/otel/metric"

"github.com/stacklok/toolhive/pkg/auth"
transportsession "github.com/stacklok/toolhive/pkg/transport/session"
Expand Down Expand Up @@ -112,6 +113,8 @@ type defaultMultiSessionFactory struct {
maxConcurrency int
backendInitTimeout time.Duration
hmacSecret []byte // Server-managed secret for HMAC-SHA256 token hashing
keepaliveCfg KeepaliveConfig
keepaliveMetrics *keepaliveMetrics
}

// MultiSessionFactoryOption configures a defaultMultiSessionFactory.
Expand All @@ -137,6 +140,43 @@ func WithBackendInitTimeout(d time.Duration) MultiSessionFactoryOption {
}
}

// WithKeepaliveConfig sets the server-level keepalive configuration applied to
// every session created by this factory. Per-backend overrides live in
// BackendTarget.KeepaliveMethod. Keepalive is disabled when Method is
// KeepaliveMethodNone on the target.
func WithKeepaliveConfig(cfg KeepaliveConfig) MultiSessionFactoryOption {
return func(f *defaultMultiSessionFactory) {
f.keepaliveCfg = cfg
}
}

// WithMeterProvider sets the OTel MeterProvider used to emit keepalive metrics.
// If not provided (or nil), a no-op provider is used so keepalive still works
// without metrics infrastructure.
func WithMeterProvider(mp metric.MeterProvider) MultiSessionFactoryOption {
return func(f *defaultMultiSessionFactory) {
if mp == nil {
return
}
km, err := newKeepaliveMetrics(mp)
if err != nil {
slog.Warn("failed to initialise keepalive metrics; metrics will be no-ops", "error", err)
return
}
f.keepaliveMetrics = km
}
}

// ensureKeepaliveMetrics returns the factory's keepalive metrics, initialising
// a no-op set if none were configured via WithMeterProvider.
func (f *defaultMultiSessionFactory) ensureKeepaliveMetrics() *keepaliveMetrics {
if f.keepaliveMetrics != nil {
return f.keepaliveMetrics
}
km, _ := newKeepaliveMetrics(nil) // nil → noop provider, error is impossible
return km
}

// WithHMACSecret sets the server-managed secret used for HMAC-SHA256 token hashing.
// The secret should be 32+ bytes and loaded from secure configuration (e.g., environment
// variable, secret management system).
Expand Down Expand Up @@ -404,7 +444,18 @@ func (f *defaultMultiSessionFactory) makeSession(

populateBackendMetadata(transportSess, results)

// Create the base session
// Build the target map needed by the keepalive manager (workloadID → target).
targetsByID := make(map[string]*vmcp.BackendTarget, len(results))
for i := range results {
targetsByID[results[i].target.WorkloadID] = results[i].target
}

// Create and start the keepalive manager. This factory is only constructed
// when SessionManagementV2 is enabled, so keepalive is gated by that flag.
km := newKeepaliveManager(connections, targetsByID, f.keepaliveCfg, f.ensureKeepaliveMetrics())
km.Start(context.Background())

// Create the base session without token binding
baseSession := &defaultMultiSession{
Session: transportSess,
connections: connections,
Expand All @@ -414,6 +465,7 @@ func (f *defaultMultiSessionFactory) makeSession(
prompts: allPrompts,
backendSessions: backendSessions,
queue: newAdmissionQueue(),
keepalive: km,
}

// Apply hijack prevention: computes token binding, stores metadata, and wraps
Expand Down
4 changes: 4 additions & 0 deletions pkg/vmcp/session/internal/backend/mcp_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ type mcpSession struct {
// SessionID returns the backend-assigned session ID.
func (c *mcpSession) SessionID() string { return c.backendSessionID }

// Ping sends an MCP protocol ping to the backend. It is side-effect-free and
// used by the keepalive mechanism to verify the connection is still alive.
func (c *mcpSession) Ping(ctx context.Context) error { return c.client.Ping(ctx) }

// Close closes the underlying MCP client transport.
func (c *mcpSession) Close() error { return c.client.Close() }

Expand Down
5 changes: 5 additions & 0 deletions pkg/vmcp/session/internal/backend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ type Session interface {
arguments map[string]any,
) (*vmcp.PromptGetResult, error)

// Ping sends a protocol-level ping to the backend and returns an error if
// the backend is unreachable or does not respond. It is side-effect-free
// and is used exclusively by the keepalive mechanism.
Ping(ctx context.Context) error

// Close releases all resources held by this session. Implementations must
// be idempotent: calling Close multiple times returns nil.
Close() error
Expand Down
Loading
Loading