Skip to content
193 changes: 121 additions & 72 deletions pkg/transport/proxy/httpsse/http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ type HTTPSSEProxy struct {
// Session manager for SSE clients
sessionManager *session.Manager

// liveSSESessions tracks active SSE connections local to this instance.
// Keys are clientID strings; values are *session.SSESession.
// This is separate from sessionManager so that distributed storage backends
// (e.g. Redis) can be used for session metadata without breaking SSE fan-out,
// which must iterate live in-memory connections regardless of storage backend.
liveSSESessions sync.Map

// Pending messages for SSE clients
pendingMessages []*ssecommon.PendingSSEMessage
pendingMutex sync.Mutex
Expand All @@ -84,9 +91,42 @@ type HTTPSSEProxy struct {
// Health checker
healthChecker *healthcheck.HealthChecker

// Track closed clients to prevent double-close
closedClients map[string]bool
closedClientsMutex sync.Mutex
// stopOnce ensures Stop is idempotent even when called concurrently.
stopOnce sync.Once
}

// Option configures an HTTPSSEProxy.
type Option func(*HTTPSSEProxy)

// WithSessionStorage injects a custom storage backend into the session manager.
// When not provided, the proxy uses in-memory LocalStorage (single-replica default).
// Provide a Redis-backed storage for multi-replica deployments so all replicas
Comment thread
yrobla marked this conversation as resolved.
// share the same session store.
//
// Architectural note: HTTPSSEProxy is used by StdioTransport for stdio-backed MCP
// servers. SSE fan-out (ForwardResponseToClients) and POST handling are both local
// to the instance holding the live SSE connection, so Redis storage enables
// cross-replica session metadata sharing but does NOT solve cross-replica message
// delivery — a POST accepted on replica B won't reach a client whose SSE connection
// is on replica A. Callers must ensure an external load balancer provides session
// affinity (sticky sessions) when using distributed storage with this proxy.
//
Comment thread
yrobla marked this conversation as resolved.
// Prefer Streamable HTTP (ProxyModeStreamableHTTP), also supported on StdioTransport,
// which does not have this affinity constraint.
//
// Note: SSE fan-out and graceful disconnect use a separate in-memory liveSSESessions
// registry, not the session manager, so any Storage implementation is safe to inject here.
func WithSessionStorage(storage session.Storage) Option {
return func(p *HTTPSSEProxy) {
if storage == nil {
return
}
if p.sessionManager != nil {
_ = p.sessionManager.Stop()
}
sseFactory := func(id string) session.Session { return session.NewSSESession(id) }
p.sessionManager = session.NewManagerWithStorage(session.DefaultSessionTTL, sseFactory, storage)
}
Comment thread
yrobla marked this conversation as resolved.
Comment thread
yrobla marked this conversation as resolved.
Comment thread
yrobla marked this conversation as resolved.
}
Comment thread
yrobla marked this conversation as resolved.

// NewHTTPSSEProxy creates a new HTTP SSE proxy for transports.
Expand All @@ -95,7 +135,8 @@ func NewHTTPSSEProxy(
port int,
trustProxyHeaders bool,
prometheusHandler http.Handler,
middlewares ...types.NamedMiddleware,
middlewares []types.NamedMiddleware,
opts ...Option,
) *HTTPSSEProxy {
Comment thread
yrobla marked this conversation as resolved.
// Create a factory for SSE sessions
sseFactory := func(id string) session.Session {
Expand All @@ -112,7 +153,10 @@ func NewHTTPSSEProxy(
sessionManager: session.NewManager(session.DefaultSessionTTL, sseFactory),
pendingMessages: []*ssecommon.PendingSSEMessage{},
prometheusHandler: prometheusHandler,
closedClients: make(map[string]bool),
}

for _, opt := range opts {
opt(proxy)
}

// Create MCP pinger and health checker
Expand Down Expand Up @@ -207,32 +251,44 @@ func (p *HTTPSSEProxy) Start(_ context.Context) error {
return nil
}

// Stop stops the HTTP SSE proxy.
// Stop stops the HTTP SSE proxy. It is safe to call Stop more than once or
// concurrently; only the first call performs the shutdown sequence.
func (p *HTTPSSEProxy) Stop(ctx context.Context) error {
// Signal shutdown
close(p.shutdownCh)

// Stop the session manager cleanup routine
if p.sessionManager != nil {
if err := p.sessionManager.Stop(); err != nil {
slog.Error("failed to stop session manager", "error", err)
}
}

// Disconnect all active sessions
p.sessionManager.Range(func(_, value interface{}) bool {
if sess, ok := value.(*session.SSESession); ok {
sess.Disconnect()
var stopErr error
p.stopOnce.Do(func() {
// Signal shutdown to SSE handlers waiting on shutdownCh.
close(p.shutdownCh)

// Disconnect all active SSE connections.
p.liveSSESessions.Range(func(_, value interface{}) bool {
if sess, ok := value.(*session.SSESession); ok {
sess.Disconnect()
}
return true
})

// Stop the session manager last: terminates the cleanup goroutine and
// closes any underlying storage connections (e.g. Redis client).
// Deferred so it always runs even if server.Shutdown returns an error.
defer func() {
if p.sessionManager != nil {
if err := p.sessionManager.Stop(); err != nil {
slog.Error("failed to stop session manager", "error", err)
}
}
}()

// Drain active HTTP connections before tearing down storage. This ensures
// that removeClient calls (triggered by SSE handler cancellation) can still
// reach sessionManager.Delete without hitting a closed storage backend.
if p.server != nil {
if err := p.server.Shutdown(ctx); err != nil {
stopErr = err
return
}
}
return true
})

// Stop the HTTP server
if p.server != nil {
return p.server.Shutdown(ctx)
}

return nil
return stopErr
}

// IsRunning checks if the proxy is running.
Expand Down Expand Up @@ -273,9 +329,9 @@ func (p *HTTPSSEProxy) ForwardResponseToClients(_ context.Context, msg jsonrpc2.
// Create an SSE message
sseMsg := ssecommon.NewSSEMessage("message", string(data))

// Check if there are any connected clients by checking session count
// Check if there are any connected clients
hasClients := false
p.sessionManager.Range(func(_, _ interface{}) bool {
p.liveSSESessions.Range(func(_, _ interface{}) bool {
hasClients = true
return false // Stop iteration after finding first session
})
Expand Down Expand Up @@ -320,6 +376,7 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
p.liveSSESessions.Store(clientID, sseSession)

// Process any pending messages for this client
p.processPendingMessages(clientID, messageCh)
Expand Down Expand Up @@ -398,13 +455,23 @@ func (p *HTTPSSEProxy) handlePostRequest(w http.ResponseWriter, r *http.Request)
return
}

// Check if the session exists
// Check if the session exists in the distributed store.
_, exists := p.sessionManager.Get(sessionID)
if !exists {
http.Error(w, "Could not find session", http.StatusNotFound)
return
}

// Verify the live SSE connection for this session is held by this instance.
// With a distributed storage backend (e.g. Redis), sessionManager.Get succeeds
// on any replica, but fan-out only reaches clients connected locally. Rejecting
// here with 503 surfaces the affinity failure explicitly instead of silently
// dropping the response after forwarding to the backend.
if _, local := p.liveSSESessions.Load(sessionID); !local {
http.Error(w, "SSE connection not held by this instance", http.StatusServiceUnavailable)
return
}

// Read the request body
body, err := io.ReadAll(r.Body)
if err != nil {
Expand Down Expand Up @@ -439,29 +506,26 @@ func (p *HTTPSSEProxy) sendSSEEvent(msg *ssecommon.SSEMessage) error {
// Convert the message to an SSE-formatted string
sseString := msg.ToSSEString()

// Iterate through all sessions and send to SSE sessions
p.sessionManager.Range(func(key, value interface{}) bool {
// Iterate through all live SSE connections and deliver the event
p.liveSSESessions.Range(func(key, value interface{}) bool {
clientID, ok := key.(string)
if !ok {
return true // Continue iteration
}

sess, ok := value.(session.Session)
sseSession, ok := value.(*session.SSESession)
if !ok {
return true // Continue iteration
}

// Check if this is an SSE session
if sseSession, ok := sess.(*session.SSESession); ok {
// Try to send the message
if err := sseSession.SendMessage(sseString); err != nil {
// Log the error but continue sending to other clients
switch {
case errors.Is(err, session.ErrSessionDisconnected):
slog.Debug("client is disconnected, skipping message", "client_id", clientID)
case errors.Is(err, session.ErrMessageChannelFull):
slog.Debug("client channel full, skipping message", "client_id", clientID)
}
// Try to send the message
if err := sseSession.SendMessage(sseString); err != nil {
// Log the error but continue sending to other clients
switch {
case errors.Is(err, session.ErrSessionDisconnected):
slog.Debug("client is disconnected, skipping message", "client_id", clientID)
case errors.Is(err, session.ErrMessageChannelFull):
slog.Debug("client channel full, skipping message", "client_id", clientID)
}
}

Expand All @@ -471,40 +535,25 @@ func (p *HTTPSSEProxy) sendSSEEvent(msg *ssecommon.SSEMessage) error {
return nil
}

// removeClient safely removes a client and closes its channel
// removeClient removes a client and closes its channel.
// liveSSESessions.LoadAndDelete is atomic, so concurrent calls for the same
// clientID are safe: only one will find the entry and call Disconnect.
func (p *HTTPSSEProxy) removeClient(clientID string) {
// Check if already closed
p.closedClientsMutex.Lock()
if p.closedClients[clientID] {
p.closedClientsMutex.Unlock()
return
}
p.closedClients[clientID] = true
p.closedClientsMutex.Unlock()

// Get the session from the manager
sess, exists := p.sessionManager.Get(clientID)
if !exists {
return
}

// If it's an SSE session, disconnect it
if sseSession, ok := sess.(*session.SSESession); ok {
sseSession.Disconnect()
// Disconnect the live session directly from liveSSESessions. With a
// distributed storage backend (e.g. Redis), sessionManager.Get returns a
// freshly-deserialized SSESession with a different MessageCh than the
// actual live connection, so calling Disconnect() on it would close the
// wrong channel and leave the real connection undrained.
if val, ok := p.liveSSESessions.LoadAndDelete(clientID); ok {
if sseSession, ok := val.(*session.SSESession); ok {
sseSession.Disconnect()
}
}

// Remove the session from the manager
if err := p.sessionManager.Delete(clientID); err != nil {
slog.Debug("failed to delete session", "client_id", clientID, "error", err)
}

// Clean up closed clients map periodically (prevent memory leak)
p.closedClientsMutex.Lock()
if len(p.closedClients) > 1000 {
// Reset the map when it gets too large
p.closedClients = make(map[string]bool)
}
p.closedClientsMutex.Unlock()
}

// processPendingMessages processes any pending messages for a new client.
Expand Down
56 changes: 21 additions & 35 deletions pkg/transport/proxy/httpsse/http_proxy_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestIntegrationSSEProxyStressTest(t *testing.T) {
t.Parallel()

// Create proxy with a random port
proxy := NewHTTPSSEProxy("localhost", 0, false, nil)
proxy := NewHTTPSSEProxy("localhost", 0, false, nil, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand Down Expand Up @@ -172,7 +172,7 @@ func TestIntegrationConcurrentClientsWithLongRunning(t *testing.T) {
t.Parallel()

// Create and start proxy
proxy := NewHTTPSSEProxy("localhost", 0, false, nil)
proxy := NewHTTPSSEProxy("localhost", 0, false, nil, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand Down Expand Up @@ -314,10 +314,11 @@ func TestIntegrationConcurrentClientsWithLongRunning(t *testing.T) {
}
}

// TestIntegrationMemoryLeakPrevention tests that the closedClients map doesn't grow unbounded
func TestIntegrationMemoryLeakPrevention(t *testing.T) {
// TestIntegrationLiveSessionsCleanup verifies that liveSSESessions entries are
// removed after clients disconnect, so the map does not grow unbounded.
func TestIntegrationLiveSessionsCleanup(t *testing.T) {
t.Parallel()
proxy := NewHTTPSSEProxy("localhost", 0, false, nil)
proxy := NewHTTPSSEProxy("localhost", 0, false, nil, nil)
ctx := context.Background()

err := proxy.Start(ctx)
Expand All @@ -330,43 +331,28 @@ func TestIntegrationMemoryLeakPrevention(t *testing.T) {

proxyURL := fmt.Sprintf("http://%s", proxy.server.Addr)

// Create and remove many clients to trigger cleanup
for i := 0; i < 1500; i++ {
// Quick connect and disconnect
// Connect and immediately disconnect several clients.
for i := 0; i < 20; i++ {
resp, err := http.Get(proxyURL + "/sse")
if err != nil {
continue
}

// Extract session ID and immediately close
sessionID, _ := extractSessionID(resp.Body)
resp.Body.Close()

// The disconnection should trigger removeClient
if sessionID != "" {
// Give time for the disconnect to be processed
time.Sleep(1 * time.Millisecond)
}

// Check closedClients size periodically
if i%100 == 0 {
proxy.closedClientsMutex.Lock()
size := len(proxy.closedClients)
proxy.closedClientsMutex.Unlock()

// Should have been reset when it hit 1000
assert.Less(t, size, 1100, "closedClients map should be cleaned up")
t.Logf("After %d clients, closedClients size: %d", i, size)
}
}

// Final check
proxy.closedClientsMutex.Lock()
finalSize := len(proxy.closedClients)
proxy.closedClientsMutex.Unlock()

assert.Less(t, finalSize, 1000, "Final closedClients size should be less than 1000")
t.Logf("Final closedClients size: %d", finalSize)
// Poll until liveSSESessions drains or the deadline is reached.
// Disconnect propagation is asynchronous (the server goroutine must observe
// the client disconnect and call removeClient), so a fixed sleep is fragile
// on loaded CI runners.
require.Eventually(t, func() bool {
var liveCount int
proxy.liveSSESessions.Range(func(_, _ interface{}) bool {
liveCount++
return true
})
return liveCount == 0
}, 5*time.Second, 10*time.Millisecond,
"liveSSESessions should be empty after all clients disconnect")
}

// Helper function to extract session ID from SSE response
Expand Down
Loading
Loading