Skip to content

Commit 5522ab9

Browse files
committed
test: verify upstream errors are relayed to client in streaming chatcompletions
1 parent a127009 commit 5522ab9

1 file changed

Lines changed: 110 additions & 0 deletions

File tree

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package chatcompletions
2+
3+
import (
4+
"io"
5+
"net/http"
6+
"net/http/httptest"
7+
"strconv"
8+
"testing"
9+
10+
"cdr.dev/slog/v3"
11+
"cdr.dev/slog/v3/sloggers/sloghuman"
12+
"github.com/coder/aibridge/config"
13+
"github.com/coder/aibridge/internal/testutil"
14+
"github.com/google/uuid"
15+
"github.com/openai/openai-go/v3"
16+
"github.com/stretchr/testify/assert"
17+
"github.com/stretchr/testify/require"
18+
"go.opentelemetry.io/otel"
19+
)
20+
21+
// Test that when the upstream provider returns an error before streaming starts,
22+
// the error status code and body are correctly relayed to the client.
23+
func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
24+
t.Parallel()
25+
26+
tests := []struct {
27+
name string
28+
statusCode int
29+
responseBody string
30+
expectedErrStr string
31+
expectedBody string
32+
}{
33+
{
34+
name: "bad request error",
35+
statusCode: http.StatusBadRequest,
36+
responseBody: `{"error":{"message":"Invalid request","type":"invalid_request_error","code":"invalid_request"}}`,
37+
expectedErrStr: strconv.Itoa(http.StatusBadRequest),
38+
expectedBody: "invalid_request",
39+
},
40+
{
41+
name: "rate limit error",
42+
statusCode: http.StatusTooManyRequests,
43+
responseBody: `{"error":{"message":"Rate limit exceeded","type":"rate_limit_error","code":"rate_limit_exceeded"}}`,
44+
expectedErrStr: strconv.Itoa(http.StatusTooManyRequests),
45+
expectedBody: "rate_limit",
46+
},
47+
{
48+
name: "internal server error",
49+
statusCode: http.StatusInternalServerError,
50+
responseBody: `{"error":{"message":"Internal server error","type":"server_error","code":"internal_error"}}`,
51+
expectedErrStr: strconv.Itoa(http.StatusInternalServerError),
52+
expectedBody: "server_error",
53+
},
54+
}
55+
56+
for _, tc := range tests {
57+
t.Run(tc.name, func(t *testing.T) {
58+
t.Parallel()
59+
60+
// Setup a mock server that returns an error immediately (before any streaming)
61+
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
62+
w.Header().Set("Content-Type", "application/json")
63+
w.Header().Set("x-should-retry", "false")
64+
w.WriteHeader(tc.statusCode)
65+
_, _ = w.Write([]byte(tc.responseBody))
66+
}))
67+
t.Cleanup(mockServer.Close)
68+
69+
// Create interceptor with mock server URL
70+
cfg := config.OpenAI{
71+
BaseURL: mockServer.URL,
72+
Key: "test-key",
73+
}
74+
75+
req := &ChatCompletionNewParamsWrapper{
76+
ChatCompletionNewParams: openai.ChatCompletionNewParams{
77+
Model: "gpt-4",
78+
Messages: []openai.ChatCompletionMessageParamUnion{
79+
openai.UserMessage("hello"),
80+
},
81+
},
82+
Stream: true,
83+
}
84+
85+
tracer := otel.Tracer("test")
86+
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, tracer)
87+
88+
logger := slog.Make(sloghuman.Sink(io.Discard))
89+
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
90+
91+
// Create test request
92+
w := httptest.NewRecorder()
93+
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)
94+
95+
// Process the request
96+
err := interceptor.ProcessRequest(w, httpReq)
97+
98+
// Verify error was returned
99+
require.Error(t, err)
100+
assert.Contains(t, err.Error(), tc.expectedErrStr)
101+
102+
// Verify status code was written to response
103+
assert.Equal(t, tc.statusCode, w.Code, "expected status code to be relayed to client")
104+
105+
// Verify error body contains expected error info
106+
body := w.Body.String()
107+
assert.Contains(t, body, tc.expectedBody, "expected error type in response body")
108+
})
109+
}
110+
}

0 commit comments

Comments
 (0)