Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions internal/assistant/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,16 @@ func waitForRetry(ctx context.Context, delay time.Duration) error {

// ShouldRetryModelError reports whether a model/provider error is transient.
func ShouldRetryModelError(err error) bool {
if err == nil || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
if err == nil || errors.Is(err, context.Canceled) {
return false
}
message := strings.ToLower(err.Error())
if nonRetryableProviderMessage(message) {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return retryableDeadlineExceeded(message)
}
if code, ok := providerErrorCode(err); ok {
if nonRetryableProviderCode(code) {
return false
Expand All @@ -106,11 +113,19 @@ func ShouldRetryModelError(err error) bool {
if errors.As(err, &netErr) {
return true
}
message := strings.ToLower(err.Error())
if nonRetryableProviderMessage(message) {
return retryableProviderMessage(message)
}

func retryableDeadlineExceeded(message string) bool {
// Match provider/client timeout details, not wrapper call-site labels such as
// "request provider response", so caller-owned deadlines remain non-retryable.
providerTimeout := strings.Contains(message, "client.timeout exceeded") ||
strings.Contains(message, "awaiting headers")
if !providerTimeout {
return false
}
return retryableProviderMessage(message)

return !nonRetryableProviderMessage(message)
}

func providerErrorCode(err error) (string, bool) {
Expand Down
63 changes: 63 additions & 0 deletions internal/assistant/retry_timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package assistant_test

import (
"context"
"testing"

"github.com/samber/oops"
"github.com/stretchr/testify/assert"

"github.com/omarluq/librecode/internal/assistant"
)

func TestShouldRetryModelErrorHandlesDeadlineExceeded(t *testing.T) {
t.Parallel()

tests := []struct {
err error
name string
want bool
}{
{
name: "provider client timeout is transient",
err: oops.In("assistant").
Code("responses_http").
Wrapf(
context.DeadlineExceeded,
`request provider response: Post "https://chatgpt.com/backend-api/codex/responses": `+
`context deadline exceeded (Client.Timeout exceeded while awaiting headers)`,
),
want: true,
},
{
name: "auth timeout is not retried",
err: oops.In("assistant").
Code("responses_http").
Wrapf(
context.DeadlineExceeded,
"authentication request: Client.Timeout exceeded while awaiting headers",
),
want: false,
},
{
name: "wrapped caller deadline is not retried",
err: oops.In("assistant").
Code("responses_http").
Wrapf(context.DeadlineExceeded, "request provider response"),
want: false,
},
{
name: "caller deadline is not retried",
err: context.DeadlineExceeded,
want: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()

assert.Equal(t, test.want, assistant.ShouldRetryModelError(test.err))
})
}
}
10 changes: 8 additions & 2 deletions internal/assistant/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package assistant

import (
"context"
"errors"
"fmt"
"log/slog"
"path/filepath"
Expand Down Expand Up @@ -298,7 +297,14 @@ func (runtime *Runtime) respondWithPartialProgress(
)
if err != nil {
persistErr := runtime.appendPartialPromptFailure(ctx, sessionID, userEntryID, progress, err)
return nil, false, errors.Join(err, persistErr)
if persistErr != nil {
return nil, false, oops.
In("assistant").
Code("persist_failed_prompt").
Wrapf(persistErr, "persist failed prompt progress")
}

return nil, false, err
}

return bundle, cached, nil
Expand Down
1 change: 1 addition & 0 deletions internal/assistant/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ func TestRuntime_PromptPersistsPartialProgressOnProviderFailure(t *testing.T) {
_, err := runtime.Prompt(context.Background(), request)

require.Error(t, err)
assert.EqualError(t, err, "provider returned an empty response")
require.NotEmpty(t, request.SessionID)
messages, err := repository.Messages(context.Background(), request.SessionID)
require.NoError(t, err)
Expand Down
26 changes: 26 additions & 0 deletions internal/terminal/async_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package terminal

import (
"context"
"strings"
"time"

"github.com/gdamore/tcell/v3"
Expand Down Expand Up @@ -364,6 +365,7 @@ func (app *App) applyPromptError(message string, promptID uint64) {
app.setStatus("response canceled; conversation reverted")
return
}
streamingBlocks := append([]chatMessage(nil), app.streamingBlocks...)
app.working = false
app.streamingText = ""
app.streamingThinkingText = ""
Expand All @@ -375,9 +377,33 @@ func (app *App) applyPromptError(message string, promptID uint64) {
return
}
app.activePrompt = nil
app.applyFailedPromptStreamedBlocks(streamingBlocks)
app.addMessage(database.RoleCustom, message)
}

func (app *App) applyFailedPromptStreamedBlocks(streamingBlocks []chatMessage) {
for _, block := range streamingBlocks {
if block.Content == "" {
continue
}
switch block.Role {
case database.RoleAssistant,
database.RoleToolResult,
database.RoleBashExecution,
database.RoleCustom:
app.addMessage(block.Role, block.Content)
case database.RoleThinking:
if strings.TrimSpace(block.Content) != "" {
app.addMessage(block.Role, block.Content)
}
case database.RoleUser,
database.RoleBranchSummary,
database.RoleCompactionSummary:
continue
}
}
}

func (app *App) consumeCanceledPrompt(promptID uint64) bool {
if _, ok := app.canceledPrompts[promptID]; !ok {
return false
Expand Down
25 changes: 25 additions & 0 deletions internal/terminal/render_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/gdamore/tcell/v3"
cellcolor "github.com/gdamore/tcell/v3/color"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/omarluq/librecode/internal/assistant"
"github.com/omarluq/librecode/internal/database"
Expand Down Expand Up @@ -72,6 +73,30 @@ func TestPromptThinkingDeltaUsesSeparateStreamingBuffer(t *testing.T) {
}
}

func TestApplyPromptErrorKeepsStreamedProgressVisible(t *testing.T) {
t.Parallel()

app := newRenderTestApp(t)
app.working = true
app.activePrompt = newTestActivePrompt(nil)
app.handlePromptStreamEvent(context.Background(), newTestAsyncEvent(asyncEventPromptDelta, "partial"))
toolEvent := newTestAsyncEvent(asyncEventPromptToolResult, "")
toolEvent.ToolEvent = newTestToolEvent("read", "file content")
app.handlePromptStreamEvent(context.Background(), toolEvent)

app.applyPromptError("provider returned an empty response", app.activePrompt.ID)

assert.False(t, app.working)
require.Len(t, app.messages, 3)
assert.Equal(t, database.RoleAssistant, app.messages[0].Role)
assert.Equal(t, "partial", app.messages[0].Content)
assert.Equal(t, database.RoleToolResult, app.messages[1].Role)
assert.Contains(t, app.messages[1].Content, "tool: read")
assert.Equal(t, database.RoleCustom, app.messages[2].Role)
assert.Equal(t, "provider returned an empty response", app.messages[2].Content)
assert.Empty(t, app.streamingBlocks)
}

func TestFooterLinesIgnoreTransientStatus(t *testing.T) {
t.Parallel()

Expand Down
Loading