Skip to content
54 changes: 7 additions & 47 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,6 @@ const (
recordingTimeout = time.Second * 5
)

const (
// Possible values for the "client" field in interception records.
// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
ClientClaude = "Claude Code"
ClientCodex = "Codex"
ClientCursor = "Cursor"
ClientCopilotVSC = "GitHub Copilot (VS Code)"
ClientCopilotCLI = "GitHub Copilot (CLI)"
ClientKilo = "Kilo Code"
ClientMux = "Mux"
ClientRoo = "Roo Code"
ClientZed = "Zed"
ClientUnknown = "Unknown"
)

// RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs;
// specifically, OpenAI's & Anthropic's at present.
// RequestBridge intercepts requests to - and responses from - these upstream services to provide
Expand Down Expand Up @@ -167,6 +152,11 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
ctx, span := tracer.Start(r.Context(), "Intercept")
defer span.End()

// We execute this before CreateInterceptor since the interceptors
// read the request body and don't reset them.
client := guessClient(r)
sessionID := guessSessionID(client, r)

interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer)
if err != nil {
span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err))
Expand Down Expand Up @@ -203,13 +193,14 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
interceptor.Setup(logger, asyncRecorder, mcpProxy)

if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{
Client: guessClient(r),
ID: interceptor.ID().String(),
InitiatorID: actor.ID,
Metadata: actor.Metadata,
Model: interceptor.Model(),
Provider: p.Name(),
UserAgent: r.UserAgent(),
Client: string(client),
ClientSessionID: sessionID,
CorrelatingToolCallID: interceptor.CorrelatingToolCallID(),
}); err != nil {
span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err))
Expand Down Expand Up @@ -338,34 +329,3 @@ func mergeContexts(base, other context.Context) context.Context {
}()
return ctx
}

// guessClient attempts to guess the client application from the request headers.
// Not all clients set proper user agent headers, so this is a best-effort approach.
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
func guessClient(r *http.Request) string {
userAgent := strings.ToLower(r.UserAgent())
originator := r.Header.Get("originator")

// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
switch {
case strings.HasPrefix(userAgent, "mux/"):
return ClientMux
case strings.HasPrefix(userAgent, "claude"):
return ClientClaude
case strings.HasPrefix(userAgent, "codex"):
return ClientCodex
case strings.HasPrefix(userAgent, "zed/"):
return ClientZed
case strings.HasPrefix(userAgent, "githubcopilotchat/"):
return ClientCopilotVSC
case strings.HasPrefix(userAgent, "copilot/"):
return ClientCopilotCLI
case strings.HasPrefix(userAgent, "kilo-code/") || originator == "kilo-code":
return ClientKilo
case strings.HasPrefix(userAgent, "roo-code/") || originator == "roo-code":
return ClientRoo
case r.Header.Get("x-cursor-client-version") != "":
return ClientCursor
}
return ClientUnknown
}
127 changes: 124 additions & 3 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ func TestSimple(t *testing.T) {
createRequest func(*testing.T, string, []byte) *http.Request
expectedMsgID string
userAgent string
expectedClient string
expectedClient aibridge.Client
}{
{
name: config.ProviderAnthropic,
Expand All @@ -561,7 +561,7 @@ func TestSimple(t *testing.T) {
createRequest: createAnthropicMessagesReq,
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
userAgent: "claude-cli/2.0.67 (external, cli)",
expectedClient: aibridge.ClientClaude,
expectedClient: aibridge.ClientClaudeCode,
},
{
name: config.ProviderOpenAI,
Expand Down Expand Up @@ -671,7 +671,7 @@ func TestSimple(t *testing.T) {
interceptions := recorderClient.RecordedInterceptions()
require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions)
assert.Equal(t, tc.userAgent, interceptions[0].UserAgent)
assert.Equal(t, tc.expectedClient, interceptions[0].Client)
assert.Equal(t, string(tc.expectedClient), interceptions[0].Client)

recorderClient.VerifyAllInterceptionsEnded(t)
})
Expand All @@ -680,6 +680,127 @@ func TestSimple(t *testing.T) {
}
}

func TestSessionIDTracking(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
fixture []byte
expectedClient aibridge.Client
sessionID string
configureFunc func(*testing.T, string, aibridge.Recorder) (*aibridge.RequestBridge, error)
createRequest func(t *testing.T, baseURL string, body []byte) *http.Request
}{
// Session in header.
{
name: "mux",
fixture: fixtures.AntSimple,
expectedClient: aibridge.ClientMux,
sessionID: "mux-workspace-321",
configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
},
createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request {
t.Helper()
req := createAnthropicMessagesReq(t, baseURL, body)
req.Header.Set("User-Agent", "mux/1.0.0")
req.Header.Set("X-Mux-Workspace-Id", "mux-workspace-321")
return req
},
},
// Session in body.
{
name: "claude_code",
fixture: fixtures.AntSimple,
expectedClient: aibridge.ClientClaudeCode,
sessionID: "f47ac10b-58cc-4372-a567-0e02b2c3d479",
configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
},
createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request {
t.Helper()
// Claude Code embeds the session ID in metadata.user_id within the body.
body, err := sjson.SetBytes(body, "metadata.user_id",
"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479")
require.NoError(t, err)
req := createAnthropicMessagesReq(t, baseURL, body)
req.Header.Set("User-Agent", "claude-cli/2.0.67 (external, cli)")
return req
},
},
// No session.
{
name: "zed",
fixture: fixtures.AntSimple,
expectedClient: aibridge.ClientZed,
configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
},
createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request {
t.Helper()
req := createAnthropicMessagesReq(t, baseURL, body)
req.Header.Set("User-Agent", "Zed/0.219.4+stable.119.abc123 (macos; aarch64)")
return req
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

fix := fixtures.Parse(t, tc.fixture)
upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix))

recorderClient := &testutil.MockRecorder{}

b, err := tc.configureFunc(t, upstream.URL, recorderClient)
require.NoError(t, err)
mockSrv := httptest.NewUnstartedServer(b)
t.Cleanup(mockSrv.Close)
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibcontext.AsActor(ctx, userID, nil)
}
mockSrv.Start()

req := tc.createRequest(t, mockSrv.URL, fix.Request())
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()

// Drain the body to let the stream complete.
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)

interceptions := recorderClient.RecordedInterceptions()
require.Len(t, interceptions, 1, "expected exactly one interception")
assert.Equal(t, string(tc.expectedClient), interceptions[0].Client)

if tc.sessionID == "" {
assert.Nil(t, interceptions[0].ClientSessionID, "expected nil session ID for %s", tc.name)
} else {
require.NotNil(t, interceptions[0].ClientSessionID, "expected non-nil session ID for %s", tc.name)
assert.Equal(t, tc.sessionID, *interceptions[0].ClientSessionID)
}

recorderClient.VerifyAllInterceptionsEnded(t)
})
}
}

func TestFallthrough(t *testing.T) {
t.Parallel()

Expand Down
100 changes: 0 additions & 100 deletions bridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,103 +104,3 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
})
}
}

func TestGuessClient(t *testing.T) {
t.Parallel()

tests := []struct {
name string
userAgent string
headers map[string]string
wantClient string
}{
{
name: "mux",
userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22",
wantClient: ClientMux,
},
{
name: "claude_code",
userAgent: "claude-cli/2.0.67 (external, cli)",
wantClient: ClientClaude,
},
{
name: "codex_cli",
userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef",
wantClient: ClientCodex,
},
{
name: "zed",
userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)",
wantClient: ClientZed,
},
{
name: "github_copilot_vsc",
userAgent: "GitHubCopilotChat/0.37.2026011603",
wantClient: ClientCopilotVSC,
},
{
name: "github_copilot_cli",
userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)",
wantClient: ClientCopilotCLI,
},
{
name: "kilo_code_user_agent",
userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1",
wantClient: ClientKilo,
},
{
name: "kilo_code_originator",
headers: map[string]string{"Originator": "kilo-code"},
wantClient: ClientKilo,
},
{
name: "roo_code_user_agent",
userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1",
wantClient: ClientRoo,
},
{
name: "roo_code_originator",
headers: map[string]string{"Originator": "roo-code"},
wantClient: ClientRoo,
},
{
name: "cursor_x_cursor_client_version",
userAgent: "connect-es/1.6.1",
headers: map[string]string{"X-Cursor-client-version": "0.50.0"},
wantClient: ClientCursor,
},
{
name: "cursor_x_cursor_some_other_header",
headers: map[string]string{"x-cursor-client-version": "abc123"},
wantClient: ClientCursor,
},
{
name: "unknown_client",
userAgent: "ccclaude-cli/calude-with-wrong-prefix",
wantClient: ClientUnknown,
},
{
name: "empty_user_agent",
userAgent: "",
wantClient: ClientUnknown,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

req, err := http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)

req.Header.Set("User-Agent", tt.userAgent)
for key, value := range tt.headers {
req.Header.Set(key, value)
}

got := guessClient(req)
require.Equal(t, tt.wantClient, got)
})
}
}
54 changes: 54 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package aibridge

import (
"net/http"
"strings"
)

type Client string

const (
// Possible values for the "client" field in interception records.
// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
ClientClaudeCode Client = "Claude Code"
ClientCodex Client = "Codex"
ClientZed Client = "Zed"
ClientCopilotVSC Client = "GitHub Copilot (VS Code)"
ClientCopilotCLI Client = "GitHub Copilot (CLI)"
ClientKilo Client = "Kilo Code"
ClientMux Client = "Mux"
ClientRoo Client = "Roo Code"
ClientCursor Client = "Cursor"
ClientUnknown Client = "Unknown"
)

// guessClient attempts to guess the client application from the request headers.
// Not all clients set proper user agent headers, so this is a best-effort approach.
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
func guessClient(r *http.Request) Client {
userAgent := strings.ToLower(r.UserAgent())
originator := r.Header.Get("originator")

// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
switch {
case strings.HasPrefix(userAgent, "mux/"):
return ClientMux
case strings.HasPrefix(userAgent, "claude"):
return ClientClaudeCode
case strings.HasPrefix(userAgent, "codex"):
return ClientCodex
case strings.HasPrefix(userAgent, "zed/"):
return ClientZed
case strings.HasPrefix(userAgent, "githubcopilotchat/"):
return ClientCopilotVSC
case strings.HasPrefix(userAgent, "copilot/"):
return ClientCopilotCLI
case strings.HasPrefix(userAgent, "kilo-code/") || originator == "kilo-code":
return ClientKilo
case strings.HasPrefix(userAgent, "roo-code/") || originator == "roo-code":
return ClientRoo
case r.Header.Get("x-cursor-client-version") != "":
return ClientCursor
}
return ClientUnknown
}
Loading