Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions intercept/chatcompletions/streaming_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package chatcompletions

import (
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"

"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/sloghuman"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/internal/testutil"
"github.com/google/uuid"
"github.com/openai/openai-go/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
)

// Test that when the upstream provider returns an error before streaming starts,
// the error status code and body are correctly relayed to the client.
func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
t.Parallel()

tests := []struct {
name string
statusCode int
responseBody string
expectedErrStr string
expectedBody string
}{
{
name: "bad request error",
statusCode: http.StatusBadRequest,
responseBody: `{"error":{"message":"Invalid request","type":"invalid_request_error","code":"invalid_request"}}`,
expectedErrStr: strconv.Itoa(http.StatusBadRequest),
expectedBody: "invalid_request",
},
{
name: "rate limit error",
statusCode: http.StatusTooManyRequests,
responseBody: `{"error":{"message":"Rate limit exceeded","type":"rate_limit_error","code":"rate_limit_exceeded"}}`,
expectedErrStr: strconv.Itoa(http.StatusTooManyRequests),
expectedBody: "rate_limit",
},
{
name: "internal server error",
statusCode: http.StatusInternalServerError,
responseBody: `{"error":{"message":"Internal server error","type":"server_error","code":"internal_error"}}`,
expectedErrStr: strconv.Itoa(http.StatusInternalServerError),
expectedBody: "server_error",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

// Setup a mock server that returns an error immediately (before any streaming)
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("x-should-retry", "false")
w.WriteHeader(tc.statusCode)
_, _ = w.Write([]byte(tc.responseBody))
}))
t.Cleanup(mockServer.Close)

// Create interceptor with mock server URL
cfg := config.OpenAI{
BaseURL: mockServer.URL,
Key: "test-key",
}

req := &ChatCompletionNewParamsWrapper{
ChatCompletionNewParams: openai.ChatCompletionNewParams{
Model: "gpt-4",
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("hello"),
},
},
Stream: true,
}

tracer := otel.Tracer("test")
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, tracer)

logger := slog.Make(sloghuman.Sink(io.Discard))
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)

// Create test request
w := httptest.NewRecorder()
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)

// Process the request
err := interceptor.ProcessRequest(w, httpReq)

// Verify error was returned
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedErrStr)

// Verify status code was written to response
assert.Equal(t, tc.statusCode, w.Code, "expected status code to be relayed to client")

// Verify error body contains expected error info
body := w.Body.String()
assert.Contains(t, body, tc.expectedBody, "expected error type in response body")
})
}
}