Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions internal/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,10 @@ func AddCustomEndpoint(cfg *config.Config, u *ui.UI, endpoint, modelName, apiKey
return nil
}

// probeBackoffSleep is the sleep used between ValidateCustomEndpoint inference-probe
// retries. Overridable in tests to keep them fast.
var probeBackoffSleep = time.Sleep

// ValidateCustomEndpoint validates that a custom OpenAI-compatible endpoint works.
// It runs a 2-step validation: reachability check, then inference probe.
// The inference probe is the definitive test — some servers (e.g., mlx-lm) don't
Expand Down Expand Up @@ -917,9 +921,34 @@ func ValidateCustomEndpoint(endpoint, modelName, apiKey string) error {
probeReq.Header.Set("Authorization", authHeader)
}

probeResp, err := client.Do(probeReq)
if err != nil {
return fmt.Errorf("inference probe failed — cannot reach %s: %w", completionsURL, err)
// Retry on transient network errors (DNS flake, TCP reset, route loss).
// Only client.Do errors are retried — non-200 HTTP responses are real
// upstream signals (4xx = config bug, 5xx = upstream broken) and fail fast.
const probeMaxAttempts = 3
probeBackoffs := []time.Duration{
250 * time.Millisecond,
1 * time.Second,
4 * time.Second,
}

var probeResp *http.Response
var probeErr error
for attempt := 0; attempt < probeMaxAttempts; attempt++ {
// Bodies are single-use; re-attach the payload for each attempt.
attemptReq := probeReq.Clone(probeReq.Context())
attemptReq.Body = io.NopCloser(bytes.NewReader(probePayload))

probeResp, probeErr = client.Do(attemptReq)
if probeErr == nil {
break
}
if attempt < probeMaxAttempts-1 {
probeBackoffSleep(probeBackoffs[attempt])
}
}
if probeErr != nil {
return fmt.Errorf("inference probe failed after %d attempts — cannot reach %s: %w",
probeMaxAttempts, completionsURL, probeErr)
}
defer probeResp.Body.Close()

Expand Down
146 changes: 146 additions & 0 deletions internal/model/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"
)

func TestBuildModelEntries(t *testing.T) {
Expand Down Expand Up @@ -550,6 +552,150 @@ func TestValidateCustomEndpoint(t *testing.T) {
})
}

// withNoSleep replaces probeBackoffSleep with a no-op while the test runs so
// retry tests don't spend real seconds on backoff.
func withNoSleep(t *testing.T) {
t.Helper()
orig := probeBackoffSleep
probeBackoffSleep = func(time.Duration) {}
t.Cleanup(func() { probeBackoffSleep = orig })
}

// abortAfterNHandler aborts the connection on the first n POST hits to
// /chat/completions (forcing client.Do to return a Go-level network error),
// then serves a normal 200 with valid choices. Reachability hits
// (GET /models, /health, /) are always served normally.
func abortAfterNHandler(n int) (http.Handler, *int32) {
var posts int32
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/chat/completions") {
cur := atomic.AddInt32(&posts, 1)
if int(cur) <= n {
panic(http.ErrAbortHandler)
}
fmt.Fprint(w, `{"choices":[{"message":{"content":"pong"}}]}`)
return
}
// Reachability paths.
if strings.HasSuffix(r.URL.Path, "/models") {
fmt.Fprint(w, `{"data":[{"id":"test-model"}]}`)
return
}
w.WriteHeader(http.StatusOK)
})
return h, &posts
}

func TestValidateCustomEndpoint_RetriesOnNetworkError(t *testing.T) {
withNoSleep(t)

cases := []struct {
name string
errorsBeforeOK int
wantErr bool
wantPostAttempts int32
}{
{name: "success first try", errorsBeforeOK: 0, wantErr: false, wantPostAttempts: 1},
{name: "one transient error, then ok", errorsBeforeOK: 1, wantErr: false, wantPostAttempts: 2},
{name: "two transient errors, then ok (at limit)", errorsBeforeOK: 2, wantErr: false, wantPostAttempts: 3},
{name: "three transient errors — budget exhausted", errorsBeforeOK: 3, wantErr: true, wantPostAttempts: 3},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
handler, posts := abortAfterNHandler(tc.errorsBeforeOK)
srv := httptest.NewServer(handler)
defer srv.Close()

// Silence ErrAbortHandler panics — they're expected.
srv.Config.ErrorLog = nil

err := ValidateCustomEndpoint(srv.URL+"/v1", "test-model", "")
if tc.wantErr && err == nil {
t.Fatalf("expected error, got nil")
}
if !tc.wantErr && err != nil {
t.Fatalf("unexpected error: %v", err)
}
if tc.wantErr && !strings.Contains(err.Error(), "inference probe failed after 3 attempts") {
t.Errorf("error message should reference attempt count, got: %v", err)
}
if got := atomic.LoadInt32(posts); got != tc.wantPostAttempts {
t.Errorf("POST attempts: got %d, want %d", got, tc.wantPostAttempts)
}
})
}
}

func TestValidateCustomEndpoint_NoRetryOnNon2xx(t *testing.T) {
withNoSleep(t)

statuses := []int{http.StatusUnauthorized, http.StatusNotFound, http.StatusServiceUnavailable}
for _, code := range statuses {
t.Run(fmt.Sprintf("HTTP %d", code), func(t *testing.T) {
var posts int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/chat/completions") {
atomic.AddInt32(&posts, 1)
w.WriteHeader(code)
fmt.Fprint(w, `{"error":"nope"}`)
return
}
if strings.HasSuffix(r.URL.Path, "/models") {
fmt.Fprint(w, `{"data":[{"id":"test-model"}]}`)
return
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

err := ValidateCustomEndpoint(srv.URL+"/v1", "test-model", "")
if err == nil {
t.Fatalf("expected error for HTTP %d", code)
}
if !strings.Contains(err.Error(), fmt.Sprintf("returned %d", code)) {
t.Errorf("error should reference returned status %d, got: %v", code, err)
}
if got := atomic.LoadInt32(&posts); got != 1 {
t.Errorf("non-2xx must not retry: got %d POSTs, want 1", got)
}
})
}
}

func TestValidateCustomEndpoint_NoRetryOnInvalidResponseBody(t *testing.T) {
withNoSleep(t)

var posts int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/chat/completions") {
atomic.AddInt32(&posts, 1)
fmt.Fprint(w, `not json {{{`)
return
}
if strings.HasSuffix(r.URL.Path, "/models") {
fmt.Fprint(w, `{"data":[{"id":"test-model"}]}`)
return
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

err := ValidateCustomEndpoint(srv.URL+"/v1", "test-model", "")
if err == nil {
t.Fatal("expected JSON decode error")
}
if !strings.Contains(err.Error(), "invalid response") {
t.Errorf("error should mention 'invalid response', got: %v", err)
}
if got := atomic.LoadInt32(&posts); got != 1 {
t.Errorf("malformed body must not retry: got %d POSTs, want 1", got)
}
}

func TestFormatBytes(t *testing.T) {
tests := []struct {
input int64
Expand Down
Loading