Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pkg/transport/session/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import (
"github.com/stacklok/toolhive/pkg/logger"
)

// contextKey is a private type for context keys to avoid collisions.
type contextKey string

// SessionIDContextKey is the context key for storing session IDs.
const SessionIDContextKey = contextKey("session-id")

// Session interface defines the contract for all session types
type Session interface {
ID() string
Expand Down
45 changes: 43 additions & 2 deletions pkg/transport/session/storage_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,26 @@ func (s *LocalStorage) Load(_ context.Context, id string) (Session, error) {
}

// Delete removes a session from local storage.
// If the session implements Close() error, it will be called before deletion.
func (s *LocalStorage) Delete(_ context.Context, id string) error {
if id == "" {
return fmt.Errorf("cannot delete session with empty ID")
}

// Try to close the session if it supports cleanup
if val, ok := s.sessions.Load(id); ok {
if session, ok := val.(Session); ok {
// Check if session implements Close() method
if closer, ok := session.(interface{ Close() error }); ok {
if err := closer.Close(); err != nil {
// Log error but continue with deletion
// We don't want to prevent deletion due to cleanup errors
_ = err // TODO: Add logger once available in this package
}
}
}
}

s.sessions.Delete(id)
return nil
}
Expand Down Expand Up @@ -87,24 +102,50 @@ func (s *LocalStorage) DeleteExpired(ctx context.Context, before time.Time) erro
return true
})

// Second pass: delete expired sessions
// Second pass: close and delete expired sessions
for _, id := range toDelete {
// Try to close the session if it supports cleanup
if val, ok := s.sessions.Load(id); ok {
if session, ok := val.(Session); ok {
// Check if session implements Close() method
if closer, ok := session.(interface{ Close() error }); ok {
if err := closer.Close(); err != nil {
// Log error but continue with deletion
// We don't want cleanup errors to prevent session expiration
_ = err // TODO: Add logger once available in this package
}
}
}
}
s.sessions.Delete(id)
}

return nil
}

// Close clears all sessions from local storage.
// Calls Close() on sessions that implement it before removing them.
func (s *LocalStorage) Close() error {
// Collect keys first to avoid modifying map during iteration
var toDelete []any
s.sessions.Range(func(key, _ any) bool {
toDelete = append(toDelete, key)
return true
})
// Clear all sessions
// Close and clear all sessions
for _, key := range toDelete {
// Try to close the session if it supports cleanup
if val, ok := s.sessions.Load(key); ok {
if session, ok := val.(Session); ok {
// Check if session implements Close() method
if closer, ok := session.(interface{ Close() error }); ok {
if err := closer.Close(); err != nil {
// Log error but continue with cleanup
_ = err // TODO: Add logger once available in this package
}
}
}
}
s.sessions.Delete(key)
}
return nil
Expand Down
108 changes: 72 additions & 36 deletions pkg/vmcp/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,35 +526,18 @@ func (h *httpBackendClient) ListCapabilities(ctx context.Context, target *vmcp.B
return capabilities, nil
}

// CallTool invokes a tool on the backend MCP server.
// Returns the complete tool result including _meta field.
// callToolWithClient calls a tool using an already-initialized MCP client.
// This helper is used by both CallTool (ephemeral client) and pooled client (reused client).
//
//nolint:gocyclo // this function is complex because it handles tool calls with various content types and error handling.
func (h *httpBackendClient) CallTool(
func (*httpBackendClient) callToolWithClient(
ctx context.Context,
c *client.Client,
target *vmcp.BackendTarget,
toolName string,
arguments map[string]any,
meta map[string]any,
) (*vmcp.ToolCallResult, error) {
logger.Debugf("Calling tool %s on backend %s", toolName, target.WorkloadName)

// Create a client for this backend
c, err := h.clientFactory(ctx, target)
if err != nil {
return nil, wrapBackendError(err, target.WorkloadID, "create client")
}
defer func() {
if err := c.Close(); err != nil {
logger.Debugf("Failed to close client: %v", err)
}
}()

// Initialize the client
if _, err := initializeClient(ctx, c); err != nil {
return nil, wrapBackendError(err, target.WorkloadID, "initialize client")
}

// Call the tool using the original capability name from the backend's perspective.
// When conflict resolution renames tools (e.g., "fetch" → "fetch_fetch"),
// we must use the original backend name when forwarding requests.
Expand Down Expand Up @@ -637,12 +620,16 @@ func (h *httpBackendClient) CallTool(
}, nil
}

// ReadResource retrieves a resource from the backend MCP server.
// Returns the complete resource result including _meta field.
func (h *httpBackendClient) ReadResource(
ctx context.Context, target *vmcp.BackendTarget, uri string,
) (*vmcp.ResourceReadResult, error) {
logger.Debugf("Reading resource %s from backend %s", uri, target.WorkloadName)
// CallTool invokes a tool on the backend MCP server.
// Returns the complete tool result including _meta field.
func (h *httpBackendClient) CallTool(
ctx context.Context,
target *vmcp.BackendTarget,
toolName string,
arguments map[string]any,
meta map[string]any,
) (*vmcp.ToolCallResult, error) {
logger.Debugf("Calling tool %s on backend %s", toolName, target.WorkloadName)

// Create a client for this backend
c, err := h.clientFactory(ctx, target)
Expand All @@ -660,6 +647,17 @@ func (h *httpBackendClient) ReadResource(
return nil, wrapBackendError(err, target.WorkloadID, "initialize client")
}

return h.callToolWithClient(ctx, c, target, toolName, arguments, meta)
}

// readResourceWithClient reads a resource using an already-initialized MCP client.
// This helper is used by both ReadResource (ephemeral client) and pooled client (reused client).
func (*httpBackendClient) readResourceWithClient(
ctx context.Context,
c *client.Client,
target *vmcp.BackendTarget,
uri string,
) (*vmcp.ResourceReadResult, error) {
// Read the resource using the original URI from the backend's perspective.
// When conflict resolution renames resources, we must use the original backend URI.
backendURI := target.GetBackendCapabilityName(uri)
Expand Down Expand Up @@ -716,15 +714,12 @@ func (h *httpBackendClient) ReadResource(
}, nil
}

// GetPrompt retrieves a prompt from the backend MCP server.
// Returns the complete prompt result including _meta field.
func (h *httpBackendClient) GetPrompt(
ctx context.Context,
target *vmcp.BackendTarget,
name string,
arguments map[string]any,
) (*vmcp.PromptGetResult, error) {
logger.Debugf("Getting prompt %s from backend %s", name, target.WorkloadName)
// ReadResource retrieves a resource from the backend MCP server.
// Returns the complete resource result including _meta field.
func (h *httpBackendClient) ReadResource(
ctx context.Context, target *vmcp.BackendTarget, uri string,
) (*vmcp.ResourceReadResult, error) {
logger.Debugf("Reading resource %s from backend %s", uri, target.WorkloadName)

// Create a client for this backend
c, err := h.clientFactory(ctx, target)
Expand All @@ -742,6 +737,18 @@ func (h *httpBackendClient) GetPrompt(
return nil, wrapBackendError(err, target.WorkloadID, "initialize client")
}

return h.readResourceWithClient(ctx, c, target, uri)
}

// getPromptWithClient retrieves a prompt using an already-initialized MCP client.
// This helper is used by both GetPrompt (ephemeral client) and pooled client (reused client).
func (*httpBackendClient) getPromptWithClient(
ctx context.Context,
c *client.Client,
target *vmcp.BackendTarget,
name string,
arguments map[string]any,
) (*vmcp.PromptGetResult, error) {
// Get the prompt using the original prompt name from the backend's perspective.
// When conflict resolution renames prompts, we must use the original backend name.
backendPromptName := target.GetBackendCapabilityName(name)
Expand Down Expand Up @@ -788,3 +795,32 @@ func (h *httpBackendClient) GetPrompt(
Meta: meta,
}, nil
}

// GetPrompt retrieves a prompt from the backend MCP server.
// Returns the complete prompt result including _meta field.
func (h *httpBackendClient) GetPrompt(
ctx context.Context,
target *vmcp.BackendTarget,
name string,
arguments map[string]any,
) (*vmcp.PromptGetResult, error) {
logger.Debugf("Getting prompt %s from backend %s", name, target.WorkloadName)

// Create a client for this backend
c, err := h.clientFactory(ctx, target)
if err != nil {
return nil, wrapBackendError(err, target.WorkloadID, "create client")
}
defer func() {
if err := c.Close(); err != nil {
logger.Debugf("Failed to close client: %v", err)
}
}()

// Initialize the client
if _, err := initializeClient(ctx, c); err != nil {
return nil, wrapBackendError(err, target.WorkloadID, "initialize client")
}

return h.getPromptWithClient(ctx, c, target, name, arguments)
}
100 changes: 100 additions & 0 deletions pkg/vmcp/client/pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package client

import (
"context"
"fmt"
"sync"

"github.com/mark3labs/mcp-go/client"
)

// BackendClientPool manages a pool of initialized MCP clients for a single session.
// Each client is keyed by backend ID and reused across multiple tool calls.
//
// Thread-safe for concurrent access.
type BackendClientPool struct {
clients map[string]*client.Client
mu sync.RWMutex
}

// NewBackendClientPool creates a new empty client pool.
func NewBackendClientPool() *BackendClientPool {
return &BackendClientPool{
clients: make(map[string]*client.Client),
}
}

// GetOrCreate returns a cached client for the given backend ID, or creates one using the factory.
// The factory is only called if no healthy client exists for the backend.
//
// Thread-safe: Multiple goroutines can call this concurrently.
func (p *BackendClientPool) GetOrCreate(
ctx context.Context,
backendID string,
factory func(context.Context) (*client.Client, error),
) (*client.Client, error) {
// Fast path: Check if client already exists (read lock)
p.mu.RLock()
if c, ok := p.clients[backendID]; ok {
p.mu.RUnlock()
return c, nil
}
p.mu.RUnlock()

// Slow path: Create new client (write lock)
p.mu.Lock()
defer p.mu.Unlock()

// Double-check: Another goroutine might have created it while we waited
if c, ok := p.clients[backendID]; ok {
return c, nil
}

// Create and initialize new client
c, err := factory(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create client for backend %s: %w", backendID, err)
}

// Store in pool
p.clients[backendID] = c
return c, nil
}

// MarkUnhealthy removes a client from the pool, forcing recreation on next access.
// Call this when a client experiences connection errors (EOF, reset, etc.).
func (p *BackendClientPool) MarkUnhealthy(backendID string) {
p.mu.Lock()
defer p.mu.Unlock()

if c, ok := p.clients[backendID]; ok {
// Best-effort close (ignore errors, check for nil)
if c != nil {
_ = c.Close()
}
delete(p.clients, backendID)
}
}

// Close shuts down all clients in the pool.
// Called automatically when the session expires.
func (p *BackendClientPool) Close() error {
p.mu.Lock()
defer p.mu.Unlock()

var firstErr error
for backendID, c := range p.clients {
if c != nil {
if err := c.Close(); err != nil && firstErr == nil {
firstErr = fmt.Errorf("failed to close client for backend %s: %w", backendID, err)
}
}
}

// Clear the map
p.clients = make(map[string]*client.Client)
return firstErr
}
Loading
Loading