|
6 | 6 | "errors" |
7 | 7 | "fmt" |
8 | 8 | "net/http" |
9 | | - "net/url" |
10 | 9 | "strings" |
11 | 10 | "time" |
12 | 11 |
|
@@ -163,34 +162,23 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio |
163 | 162 | if i.bedrockCfg != nil { |
164 | 163 | ctx, cancel := context.WithTimeout(ctx, time.Second*30) |
165 | 164 | defer cancel() |
166 | | - bedrockOpt, err := i.withAWSBedrock(ctx, i.bedrockCfg) |
| 165 | + bedrockOpts, err := i.withAWSBedrockOptions(ctx, i.bedrockCfg) |
167 | 166 | if err != nil { |
168 | 167 | return anthropic.MessageService{}, err |
169 | 168 | } |
170 | | - opts = append(opts, bedrockOpt) |
| 169 | + opts = append(opts, bedrockOpts...) |
171 | 170 | i.augmentRequestForBedrock() |
172 | | - |
173 | | - // If an endpoint override is set (for testing), add a custom HTTP client AFTER the bedrock config |
174 | | - // This overrides any HTTP client set by the bedrock middleware |
175 | | - if i.bedrockCfg.EndpointOverride != "" { |
176 | | - opts = append(opts, option.WithHTTPClient(&http.Client{ |
177 | | - Transport: &redirectTransport{ |
178 | | - base: http.DefaultTransport, |
179 | | - redirectToURL: i.bedrockCfg.EndpointOverride, |
180 | | - }, |
181 | | - })) |
182 | | - } |
183 | 171 | } |
184 | 172 |
|
185 | 173 | return anthropic.NewMessageService(opts...), nil |
186 | 174 | } |
187 | 175 |
|
188 | | -func (i *interceptionBase) withAWSBedrock(ctx context.Context, cfg *aibconfig.AWSBedrock) (option.RequestOption, error) { |
| 176 | +func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) { |
189 | 177 | if cfg == nil { |
190 | 178 | return nil, fmt.Errorf("nil config given") |
191 | 179 | } |
192 | | - if cfg.Region == "" { |
193 | | - return nil, fmt.Errorf("region required") |
| 180 | + if cfg.Region == "" && cfg.BaseURL == "" { |
| 181 | + return nil, fmt.Errorf("region or base url required") |
194 | 182 | } |
195 | 183 | if cfg.AccessKey == "" { |
196 | 184 | return nil, fmt.Errorf("access key required") |
@@ -221,7 +209,15 @@ func (i *interceptionBase) withAWSBedrock(ctx context.Context, cfg *aibconfig.AW |
221 | 209 | return nil, fmt.Errorf("failed to load AWS Bedrock config: %w", err) |
222 | 210 | } |
223 | 211 |
|
224 | | - return bedrock.WithConfig(awsCfg), nil |
| 212 | + var out []option.RequestOption |
| 213 | + out = append(out, bedrock.WithConfig(awsCfg)) |
| 214 | + |
| 215 | + // If a custom base URL is set, override the default endpoint constructed by the bedrock middleware. |
| 216 | + if cfg.BaseURL != "" { |
| 217 | + out = append(out, option.WithBaseURL(cfg.BaseURL)) |
| 218 | + } |
| 219 | + |
| 220 | + return out, nil |
225 | 221 | } |
226 | 222 |
|
227 | 223 | // augmentRequestForBedrock will change the model used for the request since AWS Bedrock doesn't support |
@@ -261,28 +257,6 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err |
261 | 257 | } |
262 | 258 | } |
263 | 259 |
|
264 | | -// redirectTransport is an HTTP RoundTripper that redirects requests to a different endpoint. |
265 | | -// This is useful for testing when we need to redirect AWS Bedrock requests to a mock server. |
266 | | -type redirectTransport struct { |
267 | | - base http.RoundTripper |
268 | | - redirectToURL string |
269 | | -} |
270 | | - |
271 | | -func (t *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
272 | | - // Parse the redirect URL |
273 | | - redirectURL, err := url.Parse(t.redirectToURL) |
274 | | - if err != nil { |
275 | | - return nil, err |
276 | | - } |
277 | | - |
278 | | - // Redirect the request to the mock server |
279 | | - req.URL.Scheme = redirectURL.Scheme |
280 | | - req.URL.Host = redirectURL.Host |
281 | | - req.Host = redirectURL.Host |
282 | | - |
283 | | - return t.base.RoundTrip(req) |
284 | | -} |
285 | | - |
286 | 260 | // accumulateUsage accumulates usage statistics from source into dest. |
287 | 261 | // It handles both [anthropic.Usage] and [anthropic.MessageDeltaUsage] types through [any]. |
288 | 262 | // The function uses reflection to handle the differences between the types: |
|
0 commit comments