Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/fixtures"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/provider"
Expand Down Expand Up @@ -1617,6 +1618,162 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
}
}

func TestActorHeaders(t *testing.T) {
t.Parallel()

actorUsername := "bob"

cases := []struct {
name string
createRequest createRequestFunc
createProviderFn func(url, key string, sendHeaders bool) aibridge.Provider
fixture []byte
streaming bool
}{
{
name: "openai/v1/chat/completions",
createRequest: createOpenAIChatCompletionsReq,
createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider {
cfg := openaiCfg(url, key)
cfg.SendActorHeaders = sendHeaders
return provider.NewOpenAI(cfg)
},
fixture: fixtures.OaiChatSimple,
streaming: true,
},
{
name: "openai/v1/chat/completions",
createRequest: createOpenAIChatCompletionsReq,
createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider {
cfg := openaiCfg(url, key)
cfg.SendActorHeaders = sendHeaders
return provider.NewOpenAI(cfg)
},
fixture: fixtures.OaiChatSimple,
streaming: false,
},
{
name: "openai/v1/responses",
createRequest: createOpenAIResponsesReq,
createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider {
cfg := openaiCfg(url, key)
cfg.SendActorHeaders = sendHeaders
return provider.NewOpenAI(cfg)
},
fixture: fixtures.OaiResponsesStreamingSimple,
streaming: true,
},
{
name: "openai/v1/responses",
createRequest: createOpenAIResponsesReq,
createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider {
cfg := openaiCfg(url, key)
cfg.SendActorHeaders = sendHeaders
return provider.NewOpenAI(cfg)
},
fixture: fixtures.OaiResponsesBlockingSimple,
streaming: false,
},
{
name: "anthropic/v1/messages",
createRequest: createAnthropicMessagesReq,
createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider {
cfg := anthropicCfg(url, key)
cfg.SendActorHeaders = sendHeaders
return provider.NewAnthropic(cfg, nil)
},
fixture: fixtures.AntSimple,
streaming: true,
},
{
name: "anthropic/v1/messages",
createRequest: createAnthropicMessagesReq,
createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider {
cfg := anthropicCfg(url, key)
cfg.SendActorHeaders = sendHeaders
return provider.NewAnthropic(cfg, nil)
},
fixture: fixtures.AntSimple,
streaming: false,
},
}

for _, tc := range cases {
for _, send := range []bool{true, false} {
t.Run(fmt.Sprintf("%s/streaming=%v/send-headers=%v", tc.name, tc.streaming, send), func(t *testing.T) {
t.Parallel()

logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)

arc := txtar.Parse(tc.fixture)
files := filesMap(arc)
reqBody := files[fixtureRequest]

// Add the stream param to the request.
newBody, err := setJSON(reqBody, "stream", tc.streaming)
require.NoError(t, err)
reqBody = newBody

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

// Track headers received by the upstream server.
var receivedHeaders http.Header
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusTeapot)
}))
srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
srv.Start()
t.Cleanup(srv.Close)

rec := &testutil.MockRecorder{}
provider := tc.createProviderFn(srv.URL, apiKey, send)

b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, rec, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
require.NoError(t, err, "failed to create handler")

mockSrv := httptest.NewUnstartedServer(b)
t.Cleanup(mockSrv.Close)

metadataKey := "Username"
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
// Attach an actor to the request context.
return aibcontext.AsActor(ctx, userID, recorder.Metadata{
metadataKey: actorUsername,
})
}
mockSrv.Start()

req := tc.createRequest(t, mockSrv.URL, reqBody)
client := &http.Client{}
resp, err := client.Do(req)
require.NoError(t, err)
require.NotEmpty(t, receivedHeaders)
defer resp.Body.Close()

// Verify that the actor headers were only received if intended.
found := make(map[string][]string)
for k, v := range receivedHeaders {
k = strings.ToLower(k)
if intercept.IsActorHeader(k) {
found[k] = v
}
}

if send {
require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{userID})
require.Equal(t, found[strings.ToLower(intercept.ActorMetadataHeader(metadataKey))], []string{actorUsername})
} else {
require.Empty(t, found)
}
})
}
}
}

func calculateTotalInputTokens(in []*recorder.TokenUsageRecord) int64 {
var total int64
for _, el := range in {
Expand Down
50 changes: 26 additions & 24 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,32 @@ const (
ProviderOpenAI = "openai"
)

type Anthropic struct {
BaseURL string
Key string
APIDumpDir string
CircuitBreaker *CircuitBreaker
SendActorHeaders bool
}

type AWSBedrock struct {
Region string
AccessKey, AccessKeySecret string
Model, SmallFastModel string
// If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint
// (https://bedrock-runtime.{region}.amazonaws.com).
// This is useful for routing requests through a proxy or for testing.
BaseURL string
}

type OpenAI struct {
BaseURL string
Key string
APIDumpDir string
CircuitBreaker *CircuitBreaker
SendActorHeaders bool
}

// CircuitBreaker holds configuration for circuit breakers.
type CircuitBreaker struct {
// MaxRequests is the maximum number of requests allowed in half-open state.
Expand Down Expand Up @@ -34,27 +60,3 @@ func DefaultCircuitBreaker() CircuitBreaker {
MaxRequests: 3,
}
}

type Anthropic struct {
BaseURL string
Key string
APIDumpDir string
CircuitBreaker *CircuitBreaker
}

type AWSBedrock struct {
Region string
AccessKey, AccessKeySecret string
Model, SmallFastModel string
// If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint
// (https://bedrock-runtime.{region}.amazonaws.com).
// This is useful for routing requests through a proxy or for testing.
BaseURL string
}

type OpenAI struct {
BaseURL string
Key string
APIDumpDir string
CircuitBreaker *CircuitBreaker
}
79 changes: 79 additions & 0 deletions intercept/actor_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package intercept

import (
"fmt"
"strings"

ant_option "github.com/anthropics/anthropic-sdk-go/option"
"github.com/coder/aibridge/context"
oai_option "github.com/openai/openai-go/v3/option"
)

const (
prefix = "X-AI-Bridge-Actor"
)

func ActorIDHeader() string {
return fmt.Sprintf("%s-ID", prefix)
}

func ActorMetadataHeader(name string) string {
return fmt.Sprintf("%s-Metadata-%s", prefix, name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shouldn't we check if name != ""?
Otherwise, we could end with headers like X-AI-Bridge-Actor-Metadata-

}

func IsActorHeader(name string) bool {
return strings.HasPrefix(strings.ToLower(name), strings.ToLower(prefix))
}

// ActorHeadersAsOpenAIOpts produces a slice of headers using OpenAI's RequestOption type.
func ActorHeadersAsOpenAIOpts(actor *context.Actor) []oai_option.RequestOption {
var opts []oai_option.RequestOption

headers := headersFromActor(actor)
if len(headers) == 0 {
return nil
}

for k, v := range headers {
// [k] will be canonicalized, see [http.Header]'s [Add] method.
opts = append(opts, oai_option.WithHeaderAdd(k, v))
}

return opts
}

// ActorHeadersAsAnthropicOpts produces a slice of headers using Anthropic's RequestOption type.
func ActorHeadersAsAnthropicOpts(actor *context.Actor) []ant_option.RequestOption {
var opts []ant_option.RequestOption

headers := headersFromActor(actor)
if len(headers) == 0 {
return nil
}

for k, v := range headers {
// [k] will be canonicalized, see [http.Header]'s [Add] method.
opts = append(opts, ant_option.WithHeaderAdd(k, v))
}

return opts
}

// headersFromActor produces a map of headers from a given [context.Actor].
func headersFromActor(actor *context.Actor) map[string]string {
if actor == nil {
return nil
}

headers := make(map[string]string, len(actor.Metadata)+1)

// Add actor ID.
headers[ActorIDHeader()] = actor.ID
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can actor.ID be ""? Probably not, but be worth the check


// Add headers for provided metadata.
for k, v := range actor.Metadata {
headers[ActorMetadataHeader(k)] = fmt.Sprintf("%v", v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should probably be sanitizing the headers/vallues to make sure they don't have whitespaces or invalid characters

}

return headers
}
55 changes: 55 additions & 0 deletions intercept/actor_headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package intercept

import (
"testing"

"github.com/coder/aibridge/context"
"github.com/coder/aibridge/recorder"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

func TestNilActor(t *testing.T) {
t.Parallel()

require.Nil(t, ActorHeadersAsOpenAIOpts(nil))
require.Nil(t, ActorHeadersAsAnthropicOpts(nil))
}

func TestBasic(t *testing.T) {
t.Parallel()

actorID := uuid.NewString()
actor := &context.Actor{
ID: actorID,
}

// We can't peek inside since these opts require an internal type to apply onto.
// All we can do is check the length.
// See TestActorHeaders for an integration test.
oaiOpts := ActorHeadersAsOpenAIOpts(actor)
require.Len(t, oaiOpts, 1)
antOpts := ActorHeadersAsAnthropicOpts(actor)
require.Len(t, antOpts, 1)
}

func TestBasicAndMetadata(t *testing.T) {
t.Parallel()

actorID := uuid.NewString()
actor := &context.Actor{
ID: actorID,
Metadata: recorder.Metadata{
"This": "That",
"And": "The other",
},
}

// We can't peek inside since these opts require an internal type to apply onto.
// All we can do is check the length.
// See TestActorHeaders for an integration test.
oaiOpts := ActorHeadersAsOpenAIOpts(actor)
require.Len(t, oaiOpts, 1+len(actor.Metadata))
antOpts := ActorHeadersAsAnthropicOpts(actor)
require.Len(t, antOpts, 1+len(actor.Metadata))
}
6 changes: 6 additions & 0 deletions intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"time"

"github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/intercept/eventstream"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/recorder"
Expand Down Expand Up @@ -73,8 +75,12 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req

for {
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)

var opts []option.RequestOption
opts = append(opts, option.WithRequestTimeout(time.Second*600))
if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: should we maybe check the cfg first, as in most cases this will be false?

opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...)
}

completion, err = i.newChatCompletion(ctx, svc, opts)
if err != nil {
Expand Down
Loading