Skip to content

Commit e164b50

Browse files
authored
feat: allow bedrock base url definition (#131)
1 parent 9ff99a2 commit e164b50

5 files changed

Lines changed: 66 additions & 64 deletions

File tree

bridge_integration_test.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
186186
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
187187
t.Cleanup(cancel)
188188

189-
// Invalid bedrock config - missing region
189+
// Invalid bedrock config - missing region & base url
190190
bedrockCfg := &config.AWSBedrock{
191191
Region: "",
192192
AccessKey: "test-key",
@@ -218,7 +218,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
218218
body, err := io.ReadAll(resp.Body)
219219
require.NoError(t, err)
220220
require.Contains(t, string(body), "create anthropic client")
221-
require.Contains(t, string(body), "region required")
221+
require.Contains(t, string(body), "region or base url required")
222222
})
223223

224224
t.Run("/v1/messages", func(t *testing.T) {
@@ -281,15 +281,14 @@ func TestAWSBedrockIntegration(t *testing.T) {
281281
srv.Start()
282282
t.Cleanup(srv.Close)
283283

284-
// Configure Bedrock with test credentials and model names.
285-
// The EndpointOverride will make requests go to the mock server instead of real AWS endpoints.
284+
// We define region here to validate that with Region & BaseURL defined, the latter takes precedence.
286285
bedrockCfg := &config.AWSBedrock{
287-
Region: "us-west-2",
288-
AccessKey: "test-access-key",
289-
AccessKeySecret: "test-secret-key",
290-
Model: "danthropic", // This model should override the request's given one.
291-
SmallFastModel: "danthropic-mini", // Unused but needed for validation.
292-
EndpointOverride: srv.URL,
286+
Region: "us-west-2",
287+
AccessKey: "test-access-key",
288+
AccessKeySecret: "test-secret-key",
289+
Model: "danthropic", // This model should override the request's given one.
290+
SmallFastModel: "danthropic-mini", // Unused but needed for validation.
291+
BaseURL: srv.URL, // Use the mock server.
293292
}
294293

295294
recorderClient := &testutil.MockRecorder{}

config/config.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ type AWSBedrock struct {
4646
Region string
4747
AccessKey, AccessKeySecret string
4848
Model, SmallFastModel string
49-
// EndpointOverride allows overriding the Bedrock endpoint URL for testing.
50-
// If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint.
51-
EndpointOverride string
49+
// If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint
50+
// (https://bedrock-runtime.{region}.amazonaws.com).
51+
// This is useful for routing requests through a proxy or for testing.
52+
BaseURL string
5253
}
5354

5455
type OpenAI struct {

intercept/messages/base.go

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"errors"
77
"fmt"
88
"net/http"
9-
"net/url"
109
"strings"
1110
"time"
1211

@@ -163,34 +162,23 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio
163162
if i.bedrockCfg != nil {
164163
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
165164
defer cancel()
166-
bedrockOpt, err := i.withAWSBedrock(ctx, i.bedrockCfg)
165+
bedrockOpts, err := i.withAWSBedrockOptions(ctx, i.bedrockCfg)
167166
if err != nil {
168167
return anthropic.MessageService{}, err
169168
}
170-
opts = append(opts, bedrockOpt)
169+
opts = append(opts, bedrockOpts...)
171170
i.augmentRequestForBedrock()
172-
173-
// If an endpoint override is set (for testing), add a custom HTTP client AFTER the bedrock config
174-
// This overrides any HTTP client set by the bedrock middleware
175-
if i.bedrockCfg.EndpointOverride != "" {
176-
opts = append(opts, option.WithHTTPClient(&http.Client{
177-
Transport: &redirectTransport{
178-
base: http.DefaultTransport,
179-
redirectToURL: i.bedrockCfg.EndpointOverride,
180-
},
181-
}))
182-
}
183171
}
184172

185173
return anthropic.NewMessageService(opts...), nil
186174
}
187175

188-
func (i *interceptionBase) withAWSBedrock(ctx context.Context, cfg *aibconfig.AWSBedrock) (option.RequestOption, error) {
176+
func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
189177
if cfg == nil {
190178
return nil, fmt.Errorf("nil config given")
191179
}
192-
if cfg.Region == "" {
193-
return nil, fmt.Errorf("region required")
180+
if cfg.Region == "" && cfg.BaseURL == "" {
181+
return nil, fmt.Errorf("region or base url required")
194182
}
195183
if cfg.AccessKey == "" {
196184
return nil, fmt.Errorf("access key required")
@@ -221,7 +209,15 @@ func (i *interceptionBase) withAWSBedrock(ctx context.Context, cfg *aibconfig.AW
221209
return nil, fmt.Errorf("failed to load AWS Bedrock config: %w", err)
222210
}
223211

224-
return bedrock.WithConfig(awsCfg), nil
212+
var out []option.RequestOption
213+
out = append(out, bedrock.WithConfig(awsCfg))
214+
215+
// If a custom base URL is set, override the default endpoint constructed by the bedrock middleware.
216+
if cfg.BaseURL != "" {
217+
out = append(out, option.WithBaseURL(cfg.BaseURL))
218+
}
219+
220+
return out, nil
225221
}
226222

227223
// augmentRequestForBedrock will change the model used for the request since AWS Bedrock doesn't support
@@ -261,28 +257,6 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err
261257
}
262258
}
263259

264-
// redirectTransport is an HTTP RoundTripper that redirects requests to a different endpoint.
265-
// This is useful for testing when we need to redirect AWS Bedrock requests to a mock server.
266-
type redirectTransport struct {
267-
base http.RoundTripper
268-
redirectToURL string
269-
}
270-
271-
func (t *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) {
272-
// Parse the redirect URL
273-
redirectURL, err := url.Parse(t.redirectToURL)
274-
if err != nil {
275-
return nil, err
276-
}
277-
278-
// Redirect the request to the mock server
279-
req.URL.Scheme = redirectURL.Scheme
280-
req.URL.Host = redirectURL.Host
281-
req.Host = redirectURL.Host
282-
283-
return t.base.RoundTrip(req)
284-
}
285-
286260
// accumulateUsage accumulates usage statistics from source into dest.
287261
// It handles both [anthropic.Usage] and [anthropic.MessageDeltaUsage] types through [any].
288262
// The function uses reflection to handle the differences between the types:

intercept/messages/base_test.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ func TestAWSBedrockValidation(t *testing.T) {
2121
expectError bool
2222
errorMsg string
2323
}{
24+
// Valid cases.
2425
{
25-
name: "valid",
26+
name: "valid with region",
2627
cfg: &config.AWSBedrock{
2728
Region: "us-east-1",
2829
AccessKey: "test-key",
@@ -32,7 +33,33 @@ func TestAWSBedrockValidation(t *testing.T) {
3233
},
3334
},
3435
{
35-
name: "missing region",
36+
name: "valid with base url",
37+
cfg: &config.AWSBedrock{
38+
BaseURL: "http://bedrock.internal",
39+
AccessKey: "test-key",
40+
AccessKeySecret: "test-secret",
41+
Model: "test-model",
42+
SmallFastModel: "test-small-model",
43+
},
44+
},
45+
{
46+
// There unfortunately isn't a way for us to determine precedence in a unit test,
47+
// since the produced options take a `requestconfig.RequestConfig` input value
48+
// which is internal to the anthropic SDK.
49+
//
50+
// See TestAWSBedrockIntegration which validates this.
51+
name: "valid with base url & region",
52+
cfg: &config.AWSBedrock{
53+
Region: "us-east-1",
54+
AccessKey: "test-key",
55+
AccessKeySecret: "test-secret",
56+
Model: "test-model",
57+
SmallFastModel: "test-small-model",
58+
},
59+
},
60+
// Invalid cases.
61+
{
62+
name: "missing region & base url",
3663
cfg: &config.AWSBedrock{
3764
Region: "",
3865
AccessKey: "test-key",
@@ -41,7 +68,7 @@ func TestAWSBedrockValidation(t *testing.T) {
4168
SmallFastModel: "test-small-model",
4269
},
4370
expectError: true,
44-
errorMsg: "region required",
71+
errorMsg: "region or base url required",
4572
},
4673
{
4774
name: "missing access key",
@@ -95,7 +122,7 @@ func TestAWSBedrockValidation(t *testing.T) {
95122
name: "all fields empty",
96123
cfg: &config.AWSBedrock{},
97124
expectError: true,
98-
errorMsg: "region required",
125+
errorMsg: "region or base url required",
99126
},
100127
{
101128
name: "nil config",
@@ -108,12 +135,13 @@ func TestAWSBedrockValidation(t *testing.T) {
108135
for _, tt := range tests {
109136
t.Run(tt.name, func(t *testing.T) {
110137
base := &interceptionBase{}
111-
_, err := base.withAWSBedrock(context.Background(), tt.cfg)
138+
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)
112139

113140
if tt.expectError {
114141
require.Error(t, err)
115142
require.Contains(t, err.Error(), tt.errorMsg)
116143
} else {
144+
require.NotEmpty(t, opts)
117145
require.NoError(t, err)
118146
}
119147
})

trace_integration_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -863,11 +863,11 @@ func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []e
863863

864864
func testBedrockCfg(url string) *config.AWSBedrock {
865865
return &config.AWSBedrock{
866-
Region: "us-west-2",
867-
AccessKey: "test-access-key",
868-
AccessKeySecret: "test-secret-key",
869-
Model: "beddel", // This model should override the request's given one.
870-
SmallFastModel: "modrock", // Unused but needed for validation.
871-
EndpointOverride: url,
866+
Region: "us-west-2",
867+
AccessKey: "test-access-key",
868+
AccessKeySecret: "test-secret-key",
869+
Model: "beddel", // This model should override the request's given one.
870+
SmallFastModel: "modrock", // Unused but needed for validation.
871+
BaseURL: url,
872872
}
873873
}

0 commit comments

Comments
 (0)