Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/intercept/apidump"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/recorder"
Expand All @@ -29,6 +30,10 @@ type interceptionBase struct {
req *ChatCompletionNewParamsWrapper
cfg config.OpenAI

// clientHeaders holds the original client request headers to forward
// to upstream providers.
clientHeaders http.Header

logger slog.Logger
tracer trace.Tracer

Expand All @@ -37,7 +42,18 @@ type interceptionBase struct {
}

func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)}
var opts []option.RequestOption
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

option order matters? I'd assume it should not so old way of initializing opts could stay?


// Forward sanitized client headers to the upstream provider.
// Client headers are added first so that SDK auth appended
// below takes priority on any conflict.
for k, vals := range intercept.SanitizeClientHeaders(i.clientHeaders) {
for _, v := range vals {
opts = append(opts, option.WithHeader(k, v))
}
}

opts = append(opts, option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL))

// Add extra headers if configured.
// Some providers require additional headers that are not added by the SDK.
Expand Down
17 changes: 12 additions & 5 deletions intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,19 @@ type BlockingInterception struct {
interceptionBase
}

func NewBlockingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *BlockingInterception {
func NewBlockingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
cfg config.OpenAI,
clientHeaders http.Header,
tracer trace.Tracer,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
req: req,
cfg: cfg,
tracer: tracer,
id: id,
req: req,
cfg: cfg,
clientHeaders: clientHeaders,
tracer: tracer,
}}
}

Expand Down
17 changes: 12 additions & 5 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,19 @@ type StreamingInterception struct {
interceptionBase
}

func NewStreamingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *StreamingInterception {
func NewStreamingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
cfg config.OpenAI,
clientHeaders http.Header,
tracer trace.Tracer,
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
req: req,
cfg: cfg,
tracer: tracer,
id: id,
req: req,
cfg: cfg,
clientHeaders: clientHeaders,
tracer: tracer,
}}
}

Expand Down
2 changes: 1 addition & 1 deletion intercept/chatcompletions/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
}

tracer := otel.Tracer("test")
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, tracer)
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, nil, tracer)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
Expand Down
58 changes: 58 additions & 0 deletions intercept/client_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package intercept

import "net/http"

// hopByHopHeaders are connection-level headers specific to the connection
// between client and AI Bridge, not meant for the upstream.
// See https://www.rfc-editor.org/rfc/rfc2616#section-13.5.1.
var hopByHopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
"Upgrade",
}

// nonForwardedHeaders are headers that should not be forwarded to the upstream provider:
// - Connection-specific headers managed by AI Bridge or the HTTP transport.
// - Auth headers that are re-injected by the SDK from the provider configuration.
// - User-Agent is set by the SDK to identify AI Bridge as the upstream client.
var nonForwardedHeaders = []string{
"Host",
"Accept-Encoding",
"Content-Length",
"Content-Type",
"Authorization",
"X-Api-Key",
"User-Agent",
}

// SanitizeClientHeaders clones headers and returns a sanitized copy suitable
// for forwarding to an upstream provider.
//
// It removes:
// - Hop-by-hop headers
// - Non-forwarded headers (connection-specific, transport-managed, auth, and SDK-managed headers)
//
// Callers should apply these headers first, so that any subsequently
// added headers take priority in case of conflict.
func SanitizeClientHeaders(headers http.Header) http.Header {
if headers == nil {
return http.Header{}
}

outHeaders := headers.Clone()

for _, h := range hopByHopHeaders {
outHeaders.Del(h)
}

for _, h := range nonForwardedHeaders {
outHeaders.Del(h)
}

return outHeaders
}
109 changes: 109 additions & 0 deletions intercept/client_headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package intercept

import (
"net/http"
"testing"

"github.com/stretchr/testify/require"
)

func TestSanitizeClientHeaders(t *testing.T) {
t.Parallel()

tests := []struct {
name string
input http.Header
expectAbsent []string
expectPresent map[string][]string
expectEmpty bool
}{
{
name: "no headers returns empty header",
input: nil,
expectEmpty: true,
},
{
name: "hop-by-hop headers are removed",
input: http.Header{
"Connection": []string{"keep-alive"},
"Keep-Alive": []string{"timeout=5"},
"Transfer-Encoding": []string{"chunked"},
"Upgrade": []string{"websocket"},
},
expectAbsent: []string{"Connection", "Keep-Alive", "Transfer-Encoding", "Upgrade"},
},
{
name: "bridge headers are removed",
input: http.Header{
"Host": []string{"example.com"},
"Accept-Encoding": []string{"gzip"},
"Content-Length": []string{"42"},
"Content-Type": []string{"application/json"},
"User-Agent": []string{"client/1.0"},
},
expectAbsent: []string{"Host", "Accept-Encoding", "Content-Length", "Content-Type", "User-Agent"},
},
{
name: "auth headers are removed",
input: http.Header{
"Authorization": []string{"Bearer some-token"},
"X-Api-Key": []string{"some-key"},
},
expectAbsent: []string{"Authorization", "X-Api-Key"},
},
{
name: "custom headers are preserved",
input: http.Header{
"X-Custom-Header": []string{"custom-value"},
"X-Request-Id": []string{"req-123"},
},
expectPresent: map[string][]string{
"X-Custom-Header": {"custom-value"},
"X-Request-Id": {"req-123"},
},
},
{
name: "multi-value headers are preserved",
input: http.Header{
"X-Custom-Header": []string{"value-1", "value-2"},
},
expectPresent: map[string][]string{
"X-Custom-Header": {"value-1", "value-2"},
},
},
{
name: "input is not mutated",
input: http.Header{
"Connection": []string{"keep-alive"},
"X-Custom-Header": []string{"custom-value"},
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

// Capture original state to verify no mutation.
originalCopy := tc.input.Clone()

result := SanitizeClientHeaders(tc.input)

if tc.expectEmpty {
require.Empty(t, result)
return
}

for _, h := range tc.expectAbsent {
require.Empty(t, result.Get(h), "expected header %q to be absent", h)
}

for h, vals := range tc.expectPresent {
require.Equal(t, vals, result[h], "expected header %q to be present", h)
}

// Verify input was not mutated.
require.Equal(t, originalCopy, tc.input)
})
}
}
14 changes: 14 additions & 0 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials"
aibconfig "github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/intercept/apidump"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/recorder"
Expand All @@ -40,6 +41,10 @@ type interceptionBase struct {
cfg aibconfig.Anthropic
bedrockCfg *aibconfig.AWSBedrock

// clientHeaders holds the original client request headers to forward
// to upstream providers.
clientHeaders http.Header

tracer trace.Tracer
logger slog.Logger

Expand Down Expand Up @@ -178,6 +183,15 @@ func (i *interceptionBase) isSmallFastModel() bool {
}

func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) {
// Forward sanitized client headers to the upstream provider.
// Client headers are added first so that SDK auth appended
// below takes priority on any conflict.
for k, vals := range intercept.SanitizeClientHeaders(i.clientHeaders) {
for _, v := range vals {
opts = append(opts, option.WithHeader(k, v))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WithHeader uses Header.Set which overrides header value instead of adding it.
I think WithHeaderAdd would help here. Test case needed.

}
}

opts = append(opts, option.WithAPIKey(i.cfg.Key))
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))

Expand Down
23 changes: 16 additions & 7 deletions intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,23 @@ type BlockingInterception struct {
interceptionBase
}

func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception {
func NewBlockingInterceptor(
id uuid.UUID,
req *MessageNewParamsWrapper,
payload []byte,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
clientHeaders http.Header,
tracer trace.Tracer,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
req: req,
payload: payload,
cfg: cfg,
bedrockCfg: bedrockCfg,
tracer: tracer,
id: id,
req: req,
payload: payload,
cfg: cfg,
bedrockCfg: bedrockCfg,
clientHeaders: clientHeaders,
tracer: tracer,
}}
}

Expand Down
23 changes: 16 additions & 7 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,23 @@ type StreamingInterception struct {
interceptionBase
}

func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception {
func NewStreamingInterceptor(
id uuid.UUID,
req *MessageNewParamsWrapper,
payload []byte,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
clientHeaders http.Header,
tracer trace.Tracer,
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
req: req,
payload: payload,
cfg: cfg,
bedrockCfg: bedrockCfg,
tracer: tracer,
id: id,
req: req,
payload: payload,
cfg: cfg,
bedrockCfg: bedrockCfg,
clientHeaders: clientHeaders,
tracer: tracer,
}}
}

Expand Down
Loading