Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions pkg/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,13 @@ type MessageStreamChoice struct {

// MessageStreamResponse represents a streaming response from the model
type MessageStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []MessageStreamChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []MessageStreamChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
RateLimit *RateLimit `json:"rate_limit,omitempty"`
}

type Usage struct {
Expand All @@ -130,6 +131,13 @@ type Usage struct {
ReasoningTokens int64 `json:"reasoning_tokens,omitempty"`
}

type RateLimit struct {
Limit int64 `json:"limit,omitempty"`
Remaining int64 `json:"remaining,omitempty"`
Reset int64 `json:"reset,omitempty"`
RetryAfter int64 `json:"retry_after,omitempty"`
}

// MessageStream interface represents a stream of chat completions
type MessageStream interface {
// Recv gets the next completion chunk
Expand Down
32 changes: 27 additions & 5 deletions pkg/model/provider/anthropic/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"strconv"
"strings"

"github.com/anthropics/anthropic-sdk-go"
Expand All @@ -23,14 +24,16 @@ type streamAdapter struct {
toolCall bool
toolID string
// For single retry on context length error
retryFn func() *streamAdapter
retried bool
retryFn func() *streamAdapter
retried bool
getResponseTrailer func() http.Header
}

func newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion], trackUsage bool) *streamAdapter {
func (c *Client) newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion], trackUsage bool) *streamAdapter {
return &streamAdapter{
stream: stream,
trackUsage: trackUsage,
stream: stream,
trackUsage: trackUsage,
getResponseTrailer: c.getResponseTrailer,
}
}

Expand Down Expand Up @@ -163,11 +166,30 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
} else {
response.Choices[0].FinishReason = chat.FinishReasonStop
}

// MessageStopEvent is the last event. Let's drain the response to get the trailing headers.
trailers := a.getResponseTrailer()
if trailers.Get("X-RateLimit-Limit") != "" {
response.RateLimit = &chat.RateLimit{
Limit: parseHeaderInt64(trailers.Get("X-RateLimit-Limit")),
Remaining: parseHeaderInt64(trailers.Get("X-RateLimit-Remaining")),
Reset: parseHeaderInt64(trailers.Get("X-RateLimit-Reset")),
RetryAfter: parseHeaderInt64(trailers.Get("Retry-After")),
}
}
}

return response, nil
}

func parseHeaderInt64(headerValue string) int64 {
value, err := strconv.ParseInt(headerValue, 10, 64)
if err != nil {
return 0
}
return value
}

// Close closes the stream
func (a *streamAdapter) Close() {
if a.stream != nil {
Expand Down
31 changes: 19 additions & 12 deletions pkg/model/provider/anthropic/beta_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"io"
"log/slog"
"net/http"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
Expand All @@ -14,18 +15,22 @@ import (

// betaStreamAdapter adapts the Anthropic Beta stream to our interface
type betaStreamAdapter struct {
stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion]
toolCall bool
toolID string
stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion]
trackUsage bool
toolCall bool
toolID string
// For single retry on context length error
retryFn func() *betaStreamAdapter
retried bool
retryFn func() *betaStreamAdapter
retried bool
getResponseTrailer func() http.Header
}

// newBetaStreamAdapter creates a new Beta stream adapter
func newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion]) *betaStreamAdapter {
func (c *Client) newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion], trackUsage bool) *betaStreamAdapter {
return &betaStreamAdapter{
stream: stream,
stream: stream,
trackUsage: trackUsage,
getResponseTrailer: c.getResponseTrailer,
}
}

Expand Down Expand Up @@ -111,11 +116,13 @@ func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
return response, fmt.Errorf("unknown delta type: %T", deltaVariant)
}
case anthropic.BetaRawMessageDeltaEvent:
response.Usage = &chat.Usage{
InputTokens: eventVariant.Usage.InputTokens,
OutputTokens: eventVariant.Usage.OutputTokens,
CachedInputTokens: eventVariant.Usage.CacheReadInputTokens,
CacheWriteTokens: eventVariant.Usage.CacheCreationInputTokens,
if a.trackUsage {
response.Usage = &chat.Usage{
InputTokens: eventVariant.Usage.InputTokens,
OutputTokens: eventVariant.Usage.OutputTokens,
CachedInputTokens: eventVariant.Usage.CacheReadInputTokens,
CacheWriteTokens: eventVariant.Usage.CacheCreationInputTokens,
}
}
case anthropic.BetaRawMessageStopEvent:
if a.toolCall {
Expand Down
5 changes: 3 additions & 2 deletions pkg/model/provider/anthropic/beta_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ func (c *Client) createBetaStream(
"message_count", len(params.Messages))

stream := client.Beta.Messages.NewStreaming(ctx, params)
ad := newBetaStreamAdapter(stream)
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
ad := c.newBetaStreamAdapter(stream, trackUsage)

// Set up single retry for context length errors
ad.retryFn = func() *betaStreamAdapter {
Expand All @@ -120,7 +121,7 @@ func (c *Client) createBetaStream(
slog.Warn("Retrying with clamped max_tokens after context length error", "original", maxTokens, "clamped", newMaxTokens, "used", used)
retryParams := params
retryParams.MaxTokens = newMaxTokens
return newBetaStreamAdapter(client.Beta.Messages.NewStreaming(ctx, retryParams))
return c.newBetaStreamAdapter(client.Beta.Messages.NewStreaming(ctx, retryParams), trackUsage)
}

slog.Debug("Anthropic Beta API chat completion stream created successfully", "model", c.ModelConfig.Model)
Expand Down
44 changes: 30 additions & 14 deletions pkg/model/provider/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"

Expand All @@ -27,7 +29,20 @@ import (
// It holds the anthropic client and model config
type Client struct {
base.Config
clientFn func(context.Context) (anthropic.Client, error)
clientFn func(context.Context) (anthropic.Client, error)
lastHTTPResponse *http.Response
}

func (c *Client) getResponseTrailer() http.Header {
if c.lastHTTPResponse == nil {
return nil
}

if c.lastHTTPResponse.Body != nil {
_, _ = io.Copy(io.Discard, c.lastHTTPResponse.Body)
}

return c.lastHTTPResponse.Trailer
}

// adjustMaxTokensForThinking checks if max_tokens needs adjustment for thinking_budget.
Expand Down Expand Up @@ -116,7 +131,14 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
}
}

var clientFn func(context.Context) (anthropic.Client, error)
anthropicClient := &Client{
Config: base.Config{
ModelConfig: *cfg,
ModelOptions: globalOptions,
Env: env,
},
}

if gateway := globalOptions.Gateway(); gateway == "" {
authToken, _ := env.Get(ctx, "ANTHROPIC_API_KEY")
if authToken == "" {
Expand All @@ -132,7 +154,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
requestOptions = append(requestOptions, option.WithBaseURL(cfg.BaseURL))
}
client := anthropic.NewClient(requestOptions...)
clientFn = func(context.Context) (anthropic.Client, error) {
anthropicClient.clientFn = func(context.Context) (anthropic.Client, error) {
return client, nil
}
} else {
Expand All @@ -143,7 +165,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
}

// When using a Gateway, tokens are short-lived.
clientFn = func(ctx context.Context) (anthropic.Client, error) {
anthropicClient.clientFn = func(ctx context.Context) (anthropic.Client, error) {
// Query a fresh auth token each time the client is used
authToken, _ := env.Get(ctx, environment.DockerDesktopTokenEnv)
if authToken == "" {
Expand All @@ -168,6 +190,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
}

client := anthropic.NewClient(
option.WithResponseInto(&anthropicClient.lastHTTPResponse),
option.WithAuthToken(authToken),
option.WithAPIKey(authToken),
option.WithBaseURL(baseURL),
Expand All @@ -180,14 +203,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro

slog.Debug("Anthropic client created successfully", "model", cfg.Model)

return &Client{
Config: base.Config{
ModelConfig: *cfg,
ModelOptions: globalOptions,
Env: env,
},
clientFn: clientFn,
}, nil
return anthropicClient, nil
}

// CreateChatCompletionStream creates a streaming chat completion request
Expand Down Expand Up @@ -304,7 +320,7 @@ func (c *Client) CreateChatCompletionStream(

stream := client.Messages.NewStreaming(ctx, params, betaHeader)
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
ad := newStreamAdapter(stream, trackUsage)
ad := c.newStreamAdapter(stream, trackUsage)

// Set up single retry for context length errors
ad.retryFn = func() *streamAdapter {
Expand All @@ -321,7 +337,7 @@ func (c *Client) CreateChatCompletionStream(
slog.Warn("Retrying with clamped max_tokens after context length error", "original max_tokens", maxTokens, "clamped max_tokens", newMaxTokens, "used tokens", used)
retryParams := params
retryParams.MaxTokens = newMaxTokens
return newStreamAdapter(client.Messages.NewStreaming(ctx, retryParams, betaHeader), trackUsage)
return c.newStreamAdapter(client.Messages.NewStreaming(ctx, retryParams, betaHeader), trackUsage)
}

slog.Debug("Anthropic chat completion stream created successfully", "model", c.ModelConfig.Model)
Expand Down
1 change: 1 addition & 0 deletions pkg/runtime/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ type Usage struct {
// It embeds chat.Usage and adds Cost and Model fields.
type MessageUsage struct {
chat.Usage
chat.RateLimit
Cost float64
Model string
}
Expand Down
14 changes: 12 additions & 2 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ type streamResult struct {
Stopped bool
ActualModel string // The actual model used (may differ from configured model with routing)
Usage *chat.Usage // Token usage for this stream
RateLimit *chat.RateLimit
}

type Opt func(*LocalRuntime)
Expand Down Expand Up @@ -946,6 +947,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
Model: messageModel,
}
}
if res.RateLimit != nil {
msgUsage.RateLimit = *res.RateLimit
}

sess.AddMessage(session.NewAgentMessage(a, &assistantMessage))
r.saveSession(ctx, sess)
Expand Down Expand Up @@ -1119,8 +1123,9 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
var actualModel string
var actualModelEventEmitted bool
var messageUsage *chat.Usage
modelID := getAgentModelID(a)
var messageRateLimit *chat.RateLimit

modelID := getAgentModelID(a)
toolCallIndex := make(map[string]int) // toolCallID -> index in toolCalls slice
emittedPartial := make(map[string]bool) // toolCallID -> whether we've emitted a partial event
toolDefMap := make(map[string]tools.Tool, len(agentTools))
Expand All @@ -1138,7 +1143,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
}

if response.Usage != nil {
// Capture the usage for this specific message
messageUsage = response.Usage

if m != nil && m.Cost != nil {
Expand All @@ -1159,6 +1163,10 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
telemetry.RecordTokenUsage(ctx, modelName, sess.InputTokens, sess.OutputTokens, sess.Cost)
}

if response.RateLimit != nil {
messageRateLimit = response.RateLimit
}

if len(response.Choices) == 0 {
continue
}
Expand Down Expand Up @@ -1195,6 +1203,7 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
Stopped: true,
ActualModel: actualModel,
Usage: messageUsage,
RateLimit: messageRateLimit,
}, nil
}

Expand Down Expand Up @@ -1267,6 +1276,7 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
Stopped: stoppedDueToNoOutput,
ActualModel: actualModel,
Usage: messageUsage,
RateLimit: messageRateLimit,
}, nil
}

Expand Down