Skip to content

Commit f2047b8

Browse files
committed
feat(cli): add configurable workers retries and fail-fast mode
1 parent 2979efb commit f2047b8

3 files changed

Lines changed: 257 additions & 10 deletions

File tree

internal/cli/run.go

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
const (
2929
defaultChunkSize = 2000
3030
defaultTranslateWorkers = 4
31+
defaultMaxRetries = 5
3132
defaultOutDir = "out"
3233
summaryFileName = "_summary.json"
3334
stateFileName = ".transblog.state.json"
@@ -43,6 +44,9 @@ type options struct {
4344
Model string
4445
OutPath string
4546
View bool
47+
Workers int
48+
MaxRetries int
49+
FailFast bool
4650
Glossary string
4751
Timeout time.Duration
4852
ChunkSize int
@@ -173,7 +177,7 @@ func Run(args []string, stdout io.Writer, stderr io.Writer) error {
173177
return err
174178
}
175179

176-
openAIClient := openai.NewClient(apiKey, baseURL, httpClient)
180+
openAIClient := openai.NewClient(apiKey, baseURL, httpClient, opts.MaxRetries)
177181
runCtx, stopSignal := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
178182
defer stopSignal()
179183
runStart := time.Now()
@@ -200,6 +204,10 @@ func Run(args []string, stdout io.Writer, stderr io.Writer) error {
200204
summary.FailureCount++
201205
summary.Results = append(summary.Results, item)
202206
_, _ = fmt.Fprintf(stderr, "Failed [%s]: %s (%s)\n", errorType, compactURL(sourceURL), errorMessage)
207+
if opts.FailFast {
208+
_, _ = fmt.Fprintln(stderr, "Fail-fast enabled: stop after first failure.")
209+
break
210+
}
203211
continue
204212
}
205213

@@ -245,6 +253,9 @@ func parseFlags(args []string, stderr io.Writer) (options, error) {
245253
fs.StringVar(&opts.Model, "model", "gpt-5.2", "OpenAI model name")
246254
fs.StringVar(&opts.OutPath, "out", "", "Output path: file for single URL, directory for multiple URLs (default: ./out/)")
247255
fs.BoolVar(&opts.View, "view", false, "Generate HTML split-view with Markdown rendering and synchronized scrolling")
256+
fs.IntVar(&opts.Workers, "workers", defaultTranslateWorkers, "Translation worker count")
257+
fs.IntVar(&opts.MaxRetries, "max-retries", defaultMaxRetries, "Maximum retries for OpenAI requests")
258+
fs.BoolVar(&opts.FailFast, "fail-fast", false, "Stop at first URL failure (default: continue for partial success)")
248259
fs.StringVar(&opts.Glossary, "glossary", "", "Path to glossary JSON map, e.g. {\"term\":\"translation\"}")
249260
fs.DurationVar(&opts.Timeout, "timeout", 90*time.Second, "HTTP timeout, e.g. 120s")
250261
fs.IntVar(&opts.ChunkSize, "chunk-size", defaultChunkSize, "Target chunk size in characters")
@@ -268,6 +279,12 @@ func parseFlags(args []string, stderr io.Writer) (options, error) {
268279
if opts.Timeout <= 0 {
269280
return options{}, errors.New("--timeout must be positive")
270281
}
282+
if opts.Workers <= 0 {
283+
return options{}, errors.New("--workers must be greater than 0")
284+
}
285+
if opts.MaxRetries < 0 {
286+
return options{}, errors.New("--max-retries must be 0 or greater")
287+
}
271288
if opts.ChunkSize <= 0 {
272289
opts.ChunkSize = defaultChunkSize
273290
}
@@ -561,6 +578,8 @@ func processURL(
561578
glossaryMap,
562579
tasks,
563580
progress,
581+
opts.Workers,
582+
opts.FailFast,
564583
func(task translationTask, translated string) error {
565584
if err := stateStore.saveChunk(
566585
sourceURL,
@@ -754,6 +773,8 @@ func translateAllChunks(
754773
glossaryMap map[string]string,
755774
tasks []translationTask,
756775
progress io.Writer,
776+
workers int,
777+
failFast bool,
757778
onChunkTranslated func(task translationTask, translated string) error,
758779
) ([]pair, error) {
759780
if len(tasks) == 0 {
@@ -763,7 +784,7 @@ func translateAllChunks(
763784
progress = io.Discard
764785
}
765786

766-
workerCount := defaultTranslateWorkers
787+
workerCount := workers
767788
if workerCount > len(tasks) {
768789
workerCount = len(tasks)
769790
}
@@ -795,8 +816,11 @@ func translateAllChunks(
795816
case results <- translationResult{task: task, err: fmt.Errorf("translate chunk %d for %s: %w", task.chunkNumber, task.sourceURL, err)}:
796817
case <-runCtx.Done():
797818
}
798-
cancel()
799-
return
819+
if failFast {
820+
cancel()
821+
return
822+
}
823+
continue
800824
}
801825

802826
translated = glossary.Apply(translated, glossaryMap)

internal/cli/run_test.go

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"strings"
1414
"sync/atomic"
1515
"testing"
16+
"time"
1617
)
1718

1819
func TestRunMultiURLWritesPerURLOutputsAndSummary(t *testing.T) {
@@ -345,6 +346,223 @@ func TestRunSummaryIncludesOutputErrorType(t *testing.T) {
345346
}
346347
}
347348

349+
func TestRunFailFastStopsAfterFirstFailure(t *testing.T) {
350+
contentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
351+
switch r.URL.Path {
352+
case "/bad":
353+
http.Error(w, "broken", http.StatusInternalServerError)
354+
case "/ok":
355+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
356+
_, _ = w.Write([]byte(sampleArticle("OK", "good content")))
357+
default:
358+
http.NotFound(w, r)
359+
}
360+
}))
361+
t.Cleanup(contentServer.Close)
362+
363+
openAIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
364+
if r.URL.Path != "/v1/responses" {
365+
http.NotFound(w, r)
366+
return
367+
}
368+
w.Header().Set("Content-Type", "application/json")
369+
_, _ = io.WriteString(w, `{"output_text":"译文内容"}`)
370+
}))
371+
t.Cleanup(openAIServer.Close)
372+
373+
tmpDir := useTempWorkingDir(t)
374+
t.Setenv("OPENAI_API_KEY", "test-key")
375+
t.Setenv("OPENAI_BASE_URL", openAIServer.URL)
376+
377+
badURL := contentServer.URL + "/bad"
378+
okURL := contentServer.URL + "/ok"
379+
380+
var stdout bytes.Buffer
381+
var stderr bytes.Buffer
382+
err := Run([]string{"--fail-fast", badURL, okURL}, &stdout, &stderr)
383+
if err == nil {
384+
t.Fatalf("Run() error = nil, want fail-fast error")
385+
}
386+
387+
summaryPath := filepath.Join(tmpDir, "out", "_summary.json")
388+
rawSummary, err := os.ReadFile(summaryPath)
389+
if err != nil {
390+
t.Fatalf("read summary file: %v", err)
391+
}
392+
393+
var summary taskSummary
394+
if err := json.Unmarshal(rawSummary, &summary); err != nil {
395+
t.Fatalf("unmarshal summary JSON: %v", err)
396+
}
397+
398+
if len(summary.Results) != 1 {
399+
t.Fatalf("summary result len=%d, want 1 due to fail-fast stop", len(summary.Results))
400+
}
401+
if summary.Results[0].SourceURL != badURL {
402+
t.Fatalf("summary first source_url=%q, want %q", summary.Results[0].SourceURL, badURL)
403+
}
404+
if !strings.Contains(stderr.String(), "Fail-fast enabled: stop after first failure.") {
405+
t.Fatalf("stderr missing fail-fast message: %s", stderr.String())
406+
}
407+
if !strings.Contains(stdout.String(), "Done: 0 succeeded, 1 failed") {
408+
t.Fatalf("stdout missing final summary: %s", stdout.String())
409+
}
410+
}
411+
412+
func TestRunMaxRetriesFlagControlsOpenAIRetry(t *testing.T) {
413+
contentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
414+
if r.URL.Path != "/retry" {
415+
http.NotFound(w, r)
416+
return
417+
}
418+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
419+
_, _ = w.Write([]byte(sampleArticle("Retry", "retry content")))
420+
}))
421+
t.Cleanup(contentServer.Close)
422+
423+
var callCount int32
424+
openAIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
425+
if r.URL.Path != "/v1/responses" {
426+
http.NotFound(w, r)
427+
return
428+
}
429+
430+
n := atomic.AddInt32(&callCount, 1)
431+
if n == 1 {
432+
w.WriteHeader(http.StatusInternalServerError)
433+
_, _ = io.WriteString(w, `{"error":{"message":"temporary upstream failure"}}`)
434+
return
435+
}
436+
437+
w.Header().Set("Content-Type", "application/json")
438+
_, _ = io.WriteString(w, `{"output_text":"译文内容"}`)
439+
}))
440+
t.Cleanup(openAIServer.Close)
441+
442+
t.Setenv("OPENAI_API_KEY", "test-key")
443+
t.Setenv("OPENAI_BASE_URL", openAIServer.URL)
444+
sourceURL := contentServer.URL + "/retry"
445+
446+
dirNoRetry := t.TempDir()
447+
runInWorkingDir(t, dirNoRetry, func() string {
448+
var stdout bytes.Buffer
449+
var stderr bytes.Buffer
450+
atomic.StoreInt32(&callCount, 0)
451+
452+
err := Run([]string{"--chunk-size", "10000", "--max-retries", "0", sourceURL}, &stdout, &stderr)
453+
if err == nil {
454+
t.Fatalf("Run() error = nil, want failure when retries are disabled")
455+
}
456+
if got := atomic.LoadInt32(&callCount); got != 1 {
457+
t.Fatalf("OpenAI call count=%d, want 1 with --max-retries=0", got)
458+
}
459+
return ""
460+
})
461+
462+
dirWithRetry := t.TempDir()
463+
runInWorkingDir(t, dirWithRetry, func() string {
464+
var stdout bytes.Buffer
465+
var stderr bytes.Buffer
466+
atomic.StoreInt32(&callCount, 0)
467+
468+
if err := Run([]string{"--chunk-size", "10000", "--max-retries", "1", sourceURL}, &stdout, &stderr); err != nil {
469+
t.Fatalf("Run() error = %v; stderr=%s", err, stderr.String())
470+
}
471+
if got := atomic.LoadInt32(&callCount); got != 2 {
472+
t.Fatalf("OpenAI call count=%d, want 2 with --max-retries=1", got)
473+
}
474+
return ""
475+
})
476+
}
477+
478+
func TestRunWorkersFlagChangesConcurrency(t *testing.T) {
479+
contentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
480+
if r.URL.Path != "/parallel" {
481+
http.NotFound(w, r)
482+
return
483+
}
484+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
485+
_, _ = w.Write([]byte(sampleLongArticle("Parallel")))
486+
}))
487+
t.Cleanup(contentServer.Close)
488+
489+
var inFlight int32
490+
var maxInFlight int32
491+
openAIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
492+
if r.URL.Path != "/v1/responses" {
493+
http.NotFound(w, r)
494+
return
495+
}
496+
497+
current := atomic.AddInt32(&inFlight, 1)
498+
for {
499+
prev := atomic.LoadInt32(&maxInFlight)
500+
if current <= prev {
501+
break
502+
}
503+
if atomic.CompareAndSwapInt32(&maxInFlight, prev, current) {
504+
break
505+
}
506+
}
507+
time.Sleep(25 * time.Millisecond)
508+
atomic.AddInt32(&inFlight, -1)
509+
510+
w.Header().Set("Content-Type", "application/json")
511+
_, _ = io.WriteString(w, `{"output_text":"译文内容"}`)
512+
}))
513+
t.Cleanup(openAIServer.Close)
514+
515+
t.Setenv("OPENAI_API_KEY", "test-key")
516+
t.Setenv("OPENAI_BASE_URL", openAIServer.URL)
517+
sourceURL := contentServer.URL + "/parallel"
518+
519+
singleWorkerDir := t.TempDir()
520+
runInWorkingDir(t, singleWorkerDir, func() string {
521+
var stdout bytes.Buffer
522+
var stderr bytes.Buffer
523+
atomic.StoreInt32(&inFlight, 0)
524+
atomic.StoreInt32(&maxInFlight, 0)
525+
526+
if err := Run([]string{"--chunk-size", "80", "--workers", "1", sourceURL}, &stdout, &stderr); err != nil {
527+
t.Fatalf("Run() with --workers=1 error = %v; stderr=%s", err, stderr.String())
528+
}
529+
if got := atomic.LoadInt32(&maxInFlight); got != 1 {
530+
t.Fatalf("max in-flight=%d, want 1 when workers=1", got)
531+
}
532+
return ""
533+
})
534+
535+
multiWorkerDir := t.TempDir()
536+
runInWorkingDir(t, multiWorkerDir, func() string {
537+
var stdout bytes.Buffer
538+
var stderr bytes.Buffer
539+
atomic.StoreInt32(&inFlight, 0)
540+
atomic.StoreInt32(&maxInFlight, 0)
541+
542+
if err := Run([]string{"--chunk-size", "80", "--workers", "4", sourceURL}, &stdout, &stderr); err != nil {
543+
t.Fatalf("Run() with --workers=4 error = %v; stderr=%s", err, stderr.String())
544+
}
545+
if got := atomic.LoadInt32(&maxInFlight); got <= 1 {
546+
t.Fatalf("max in-flight=%d, want >1 when workers=4", got)
547+
}
548+
return ""
549+
})
550+
}
551+
552+
func TestParseFlagsRejectsInvalidWorkers(t *testing.T) {
553+
_, err := parseFlags([]string{"--workers", "0", "https://example.com"}, io.Discard)
554+
if err == nil || !strings.Contains(err.Error(), "--workers must be greater than 0") {
555+
t.Fatalf("parseFlags error=%v, want workers validation error", err)
556+
}
557+
}
558+
559+
func TestParseFlagsRejectsInvalidMaxRetries(t *testing.T) {
560+
_, err := parseFlags([]string{"--max-retries", "-1", "https://example.com"}, io.Discard)
561+
if err == nil || !strings.Contains(err.Error(), "--max-retries must be 0 or greater") {
562+
t.Fatalf("parseFlags error=%v, want max-retries validation error", err)
563+
}
564+
}
565+
348566
func TestRunResumeReusesSavedChunksAndMatchesSingleRun(t *testing.T) {
349567
contentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
350568
if r.URL.Path != "/long" {

internal/openai/responses.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,35 @@ import (
1717
)
1818

1919
const (
20-
defaultBaseURL = "https://api.openai.com/v1"
21-
maxErrBody = 2048
22-
maxRetries = 5
20+
defaultBaseURL = "https://api.openai.com/v1"
21+
maxErrBody = 2048
22+
defaultMaxRetries = 5
2323
)
2424

2525
type Client struct {
2626
apiKey string
2727
endpoint string
2828
httpClient *http.Client
29+
maxRetries int
2930
}
3031

31-
func NewClient(apiKey string, baseURL string, httpClient *http.Client) *Client {
32+
func NewClient(apiKey string, baseURL string, httpClient *http.Client, maxRetries int) *Client {
3233
if strings.TrimSpace(baseURL) == "" {
3334
baseURL = defaultBaseURL
3435
}
3536
baseURL = strings.TrimSuffix(strings.TrimSpace(baseURL), "/")
3637
if strings.HasSuffix(baseURL, "/v1") {
3738
baseURL = strings.TrimSuffix(baseURL, "/v1")
3839
}
40+
if maxRetries < 0 {
41+
maxRetries = defaultMaxRetries
42+
}
3943

4044
return &Client{
4145
apiKey: apiKey,
4246
endpoint: baseURL + "/v1/responses",
4347
httpClient: httpClient,
48+
maxRetries: maxRetries,
4449
}
4550
}
4651

@@ -88,14 +93,14 @@ func (c *Client) TranslateMarkdownChunk(ctx context.Context, model string, mdChu
8893
}
8994

9095
var lastErr error
91-
for attempt := 0; attempt <= maxRetries; attempt++ {
96+
for attempt := 0; attempt <= c.maxRetries; attempt++ {
9297
translated, retry, err := c.callResponses(ctx, body)
9398
if err == nil {
9499
return translated, nil
95100
}
96101

97102
lastErr = err
98-
if !retry || attempt == maxRetries {
103+
if !retry || attempt == c.maxRetries {
99104
break
100105
}
101106

0 commit comments

Comments
 (0)