Skip to content

Commit cfec2c9

Browse files
committed
feat(llm): add openai chat completions
1 parent 84d13fc commit cfec2c9

4 files changed

Lines changed: 326 additions & 7 deletions

File tree

.env.example

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# API mode: "responses" (default) or "chat" for Chat Completions API
2+
# CLI flag: --api-mode=chat
3+
OPENAI_API_MODE=responses
4+
5+
# Default API key (used as fallback for both modes)
16
OPENAI_API_KEY=sk-your-key-here
2-
# Optional: custom OpenAI-compatible gateway
7+
# Default base URL (used as fallback for both modes)
38
OPENAI_BASE_URL=https://api.openai.com
9+
10+
# Responses API (default mode)
11+
# Override with mode-specific keys/URLs for cleaner switching
12+
# OPENAI_RESPONSES_API_KEY=
13+
# OPENAI_RESPONSES_BASE_URL=
14+
15+
# Chat Completions API (use --api-mode=chat to enable)
16+
# OPENAI_CHAT_API_KEY=
17+
# OPENAI_CHAT_BASE_URL=

internal/cli/run.go

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ type options struct {
7171
Refresh bool
7272
Timeout time.Duration
7373
ChunkSize int
74+
APIMode string
7475
ShowHelp bool
7576
SourceURLs []string
7677
setFlags map[string]bool
@@ -287,8 +288,40 @@ func Run(args []string, stdout io.Writer, stderr io.Writer) error {
287288
}
288289
applyEnvDefaults(&opts, envValues)
289290

290-
apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY"))
291291
opts.SourceURLs = normalizeSourceURLs(opts.SourceURLs)
292+
293+
// Determine API mode first (CLI flag > env var > default)
294+
apiMode := opts.APIMode
295+
if apiMode == "" {
296+
apiMode = strings.TrimSpace(os.Getenv("OPENAI_API_MODE"))
297+
}
298+
if apiMode == "" {
299+
apiMode = "responses"
300+
}
301+
302+
// Read API key and base URL based on mode
303+
var apiKey, baseURL string
304+
switch apiMode {
305+
case "chat":
306+
apiKey = strings.TrimSpace(os.Getenv("OPENAI_CHAT_API_KEY"))
307+
baseURL = strings.TrimSpace(os.Getenv("OPENAI_CHAT_BASE_URL"))
308+
if apiKey == "" {
309+
apiKey = strings.TrimSpace(os.Getenv("OPENAI_API_KEY"))
310+
}
311+
if baseURL == "" {
312+
baseURL = strings.TrimSpace(os.Getenv("OPENAI_BASE_URL"))
313+
}
314+
default:
315+
apiKey = strings.TrimSpace(os.Getenv("OPENAI_RESPONSES_API_KEY"))
316+
baseURL = strings.TrimSpace(os.Getenv("OPENAI_RESPONSES_BASE_URL"))
317+
if apiKey == "" {
318+
apiKey = strings.TrimSpace(os.Getenv("OPENAI_API_KEY"))
319+
}
320+
if baseURL == "" {
321+
baseURL = strings.TrimSpace(os.Getenv("OPENAI_BASE_URL"))
322+
}
323+
}
324+
292325
if apiKey == "" || len(opts.SourceURLs) == 0 {
293326
opts, apiKey, err = runOnboarding(opts, envPath, envValues, apiKey, stdout, stderr)
294327
if err != nil {
@@ -302,7 +335,6 @@ func Run(args []string, stdout io.Writer, stderr io.Writer) error {
302335
return errors.New("at least one URL is required")
303336
}
304337

305-
baseURL := strings.TrimSpace(os.Getenv("OPENAI_BASE_URL"))
306338
httpClient := &http.Client{Timeout: opts.Timeout}
307339

308340
glossaryMap, err := glossary.Load(opts.Glossary)
@@ -341,7 +373,13 @@ func Run(args []string, stdout io.Writer, stderr io.Writer) error {
341373
return err
342374
}
343375

344-
openAIClient := openai.NewClient(apiKey, baseURL, httpClient, opts.MaxRetries)
376+
var openAIClient openai.Translator
377+
switch apiMode {
378+
case "chat":
379+
openAIClient = openai.NewChatClient(apiKey, baseURL, httpClient, opts.MaxRetries)
380+
default:
381+
openAIClient = openai.NewClient(apiKey, baseURL, httpClient, opts.MaxRetries)
382+
}
345383
runCtx, stopSignal := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
346384
defer stopSignal()
347385
runStart := time.Now()
@@ -515,6 +553,7 @@ func parseFlags(args []string, stderr io.Writer) (options, error) {
515553
fs.StringVar(&opts.Glossary, "glossary", "", "Path to glossary JSON map, e.g. {\"term\":\"translation\"}")
516554
fs.DurationVar(&opts.Timeout, "timeout", 90*time.Second, "HTTP timeout, e.g. 120s")
517555
fs.IntVar(&opts.ChunkSize, "chunk-size", defaultChunkSize, "Target chunk size in characters")
556+
fs.StringVar(&opts.APIMode, "api-mode", "", "API mode: responses or chat (default: responses)")
518557

519558
fs.Usage = func() {
520559
fmt.Fprintln(stderr, "Usage: transblog [flags] <url> [url...]")
@@ -1593,7 +1632,7 @@ func pairsToCachedPairs(items []pair) []cachedPair {
15931632
func processURL(
15941633
ctx context.Context,
15951634
httpClient *http.Client,
1596-
openAIClient *openai.Client,
1635+
openAIClient openai.Translator,
15971636
opts options,
15981637
glossaryMap map[string]string,
15991638
prices priceConfig,
@@ -2193,7 +2232,7 @@ type translationResult struct {
21932232

21942233
func translateAllChunks(
21952234
ctx context.Context,
2196-
client *openai.Client,
2235+
client openai.Translator,
21972236
model string,
21982237
glossaryMap map[string]string,
21992238
tasks []translationTask,
@@ -2360,7 +2399,7 @@ func translateAllChunks(
23602399

23612400
func translateChunkWithQualityGuard(
23622401
ctx context.Context,
2363-
client *openai.Client,
2402+
client openai.Translator,
23642403
model string,
23652404
sourceChunk string,
23662405
glossaryMap map[string]string,

internal/openai/chat.go

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"math/rand"
11+
"net/http"
12+
"strconv"
13+
"strings"
14+
"time"
15+
)
16+
17+
type ChatClient struct {
18+
apiKey string
19+
endpoint string
20+
httpClient *http.Client
21+
maxRetries int
22+
}
23+
24+
func NewChatClient(apiKey string, baseURL string, httpClient *http.Client, maxRetries int) *ChatClient {
25+
if strings.TrimSpace(baseURL) == "" {
26+
baseURL = defaultBaseURL
27+
}
28+
baseURL = strings.TrimSuffix(strings.TrimSpace(baseURL), "/")
29+
if strings.HasSuffix(baseURL, "/v1") {
30+
baseURL = strings.TrimSuffix(baseURL, "/v1")
31+
}
32+
if maxRetries < 0 {
33+
maxRetries = defaultMaxRetries
34+
}
35+
36+
return &ChatClient{
37+
apiKey: apiKey,
38+
endpoint: baseURL + "/v1/chat/completions",
39+
httpClient: httpClient,
40+
maxRetries: maxRetries,
41+
}
42+
}
43+
44+
func (c *ChatClient) TranslateMarkdownChunk(ctx context.Context, model string, mdChunk string, glossaryMap map[string]string) (string, error) {
45+
translated, _, err := c.TranslateMarkdownChunkWithUsage(ctx, model, mdChunk, glossaryMap)
46+
return translated, err
47+
}
48+
49+
func (c *ChatClient) TranslateMarkdownChunkWithUsage(
50+
ctx context.Context,
51+
model string,
52+
mdChunk string,
53+
glossaryMap map[string]string,
54+
) (string, Usage, error) {
55+
return c.translateMarkdownChunk(ctx, model, mdChunk, glossaryMap, false, "")
56+
}
57+
58+
func (c *ChatClient) TranslateMarkdownChunkStrict(
59+
ctx context.Context,
60+
model string,
61+
mdChunk string,
62+
glossaryMap map[string]string,
63+
failureReason string,
64+
) (string, error) {
65+
translated, _, err := c.TranslateMarkdownChunkStrictWithUsage(ctx, model, mdChunk, glossaryMap, failureReason)
66+
return translated, err
67+
}
68+
69+
func (c *ChatClient) TranslateMarkdownChunkStrictWithUsage(
70+
ctx context.Context,
71+
model string,
72+
mdChunk string,
73+
glossaryMap map[string]string,
74+
failureReason string,
75+
) (string, Usage, error) {
76+
return c.translateMarkdownChunk(ctx, model, mdChunk, glossaryMap, true, failureReason)
77+
}
78+
79+
func (c *ChatClient) translateMarkdownChunk(
80+
ctx context.Context,
81+
model string,
82+
mdChunk string,
83+
glossaryMap map[string]string,
84+
strict bool,
85+
failureReason string,
86+
) (string, Usage, error) {
87+
systemPrompt := buildSystemPrompt(strict)
88+
userPrompt := buildUserPrompt(mdChunk, glossaryMap, strict, failureReason)
89+
90+
// Build messages array for Chat Completions API
91+
payload := map[string]any{
92+
"model": model,
93+
"messages": []map[string]any{
94+
{
95+
"role": "system",
96+
"content": systemPrompt,
97+
},
98+
{
99+
"role": "user",
100+
"content": userPrompt,
101+
},
102+
},
103+
}
104+
105+
body, err := json.Marshal(payload)
106+
if err != nil {
107+
return "", Usage{}, fmt.Errorf("marshal OpenAI request: %w", err)
108+
}
109+
110+
var lastErr error
111+
for attempt := 0; attempt <= c.maxRetries; attempt++ {
112+
translated, usage, retry, err := c.callChatCompletions(ctx, body)
113+
if err == nil {
114+
return translated, usage, nil
115+
}
116+
117+
lastErr = err
118+
if !retry || attempt == c.maxRetries {
119+
break
120+
}
121+
122+
delay := backoffDelay(attempt)
123+
select {
124+
case <-time.After(delay):
125+
case <-ctx.Done():
126+
return "", Usage{}, ctx.Err()
127+
}
128+
}
129+
130+
if lastErr == nil {
131+
lastErr = errors.New("unknown translation error")
132+
}
133+
return "", Usage{}, lastErr
134+
}
135+
136+
func (c *ChatClient) callChatCompletions(ctx context.Context, body []byte) (translated string, usage Usage, retry bool, err error) {
137+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, bytes.NewReader(body))
138+
if err != nil {
139+
return "", Usage{}, false, fmt.Errorf("build OpenAI request: %w", err)
140+
}
141+
req.Header.Set("Authorization", "Bearer "+c.apiKey)
142+
req.Header.Set("Content-Type", "application/json")
143+
144+
resp, err := c.httpClient.Do(req)
145+
if err != nil {
146+
return "", Usage{}, true, fmt.Errorf("request OpenAI Chat Completions API: %w", err)
147+
}
148+
defer resp.Body.Close()
149+
150+
respBody, err := io.ReadAll(resp.Body)
151+
if err != nil {
152+
return "", Usage{}, true, fmt.Errorf("read OpenAI response body: %w", err)
153+
}
154+
155+
if resp.StatusCode < 200 || resp.StatusCode > 299 {
156+
message := parseAPIError(respBody)
157+
err := fmt.Errorf("OpenAI Chat Completions API status %d: %s", resp.StatusCode, message)
158+
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500 {
159+
if retryAfter := parseRetryAfter(resp.Header.Get("Retry-After")); retryAfter > 0 {
160+
select {
161+
case <-time.After(retryAfter):
162+
case <-ctx.Done():
163+
return "", Usage{}, false, ctx.Err()
164+
}
165+
}
166+
return "", Usage{}, true, err
167+
}
168+
return "", Usage{}, false, err
169+
}
170+
171+
output, err := extractChatOutputText(respBody)
172+
if err != nil {
173+
return "", Usage{}, false, err
174+
}
175+
usage = extractChatUsage(respBody)
176+
return output, usage, false, nil
177+
}
178+
179+
func extractChatOutputText(body []byte) (string, error) {
180+
var parsed struct {
181+
Choices []struct {
182+
Message struct {
183+
Content string `json:"content"`
184+
} `json:"message"`
185+
} `json:"choices"`
186+
}
187+
188+
if err := json.Unmarshal(body, &parsed); err != nil {
189+
return "", fmt.Errorf("parse OpenAI Chat response JSON: %w", err)
190+
}
191+
192+
if len(parsed.Choices) == 0 {
193+
return "", fmt.Errorf("OpenAI Chat response has no choices")
194+
}
195+
196+
content := strings.TrimSpace(parsed.Choices[0].Message.Content)
197+
if content == "" {
198+
return "", fmt.Errorf("OpenAI Chat response missing message content")
199+
}
200+
201+
return content, nil
202+
}
203+
204+
func extractChatUsage(body []byte) Usage {
205+
var parsed struct {
206+
Usage struct {
207+
PromptTokens int64 `json:"prompt_tokens"`
208+
CompletionTokens int64 `json:"completion_tokens"`
209+
TotalTokens int64 `json:"total_tokens"`
210+
} `json:"usage"`
211+
}
212+
213+
if err := json.Unmarshal(body, &parsed); err != nil {
214+
return Usage{}
215+
}
216+
217+
if parsed.Usage.PromptTokens == 0 && parsed.Usage.CompletionTokens == 0 && parsed.Usage.TotalTokens == 0 {
218+
return Usage{}
219+
}
220+
221+
return Usage{
222+
InputTokens: parsed.Usage.PromptTokens,
223+
OutputTokens: parsed.Usage.CompletionTokens,
224+
TotalTokens: parsed.Usage.TotalTokens,
225+
Available: true,
226+
}
227+
}
228+
229+
func parseRetryAfterChat(value string) time.Duration {
230+
value = strings.TrimSpace(value)
231+
if value == "" {
232+
return 0
233+
}
234+
235+
if seconds, err := strconv.Atoi(value); err == nil && seconds > 0 {
236+
return time.Duration(seconds) * time.Second
237+
}
238+
239+
if ts, err := http.ParseTime(value); err == nil {
240+
delta := time.Until(ts)
241+
if delta > 0 {
242+
return delta
243+
}
244+
}
245+
246+
return 0
247+
}
248+
249+
func backoffDelayChat(attempt int) time.Duration {
250+
base := time.Second
251+
delay := base * time.Duration(1<<attempt)
252+
jitter := time.Duration(rand.Intn(250)) * time.Millisecond
253+
max := 30 * time.Second
254+
if delay+jitter > max {
255+
return max
256+
}
257+
return delay + jitter
258+
}

0 commit comments

Comments
 (0)