Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,8 @@ github.com/samber/oops v1.21.0 h1:18atcO4oEigNFuGXqr3NZWZ6P0XOSEXyBSAMXdQRxTc=
github.com/samber/oops v1.21.0/go.mod h1:Hsm/sKPxtCfPh0w/cE3xVoRfSiE1joDRiStPAsmG9bo=
github.com/samber/ro v0.3.0 h1:fxrdIL9yuA6JiGQuPE/xSXovowdG6+WcMBCCEL00wVk=
github.com/samber/ro v0.3.0/go.mod h1:eInj5R1BbXfGoT1ef0HIO5Qie0wlPkkyL0koOaEmfNM=
github.com/samber/ro/plugins/fsnotify v0.0.0-20260516194255-a8c153943435 h1:xy3P6prHbHBCBe/pn5Qjd61k0dBnZmCN64ob9g5W27c=
github.com/samber/ro/plugins/fsnotify v0.0.0-20260516194255-a8c153943435/go.mod h1:39S5OD7epUKdNVSeg4tpy916q1SjxgIlMUN0twGI6sM=
github.com/samber/ro/plugins/fsnotify v0.0.0-20260522200928-fb86b5fc464e h1:eOxaWpgXm/ygKqAz2Avvei898lBD5YOe9twefAFYEZE=
github.com/samber/ro/plugins/fsnotify v0.0.0-20260522200928-fb86b5fc464e/go.mod h1:QtQE2XaB1KYIWnUotePRq36qZi7goPjtR/pm7gy2XWU=
github.com/samber/ro/plugins/signal v0.0.0-20260516194255-a8c153943435 h1:tySZEg7gguzOtqdOxOoe7EU4G34aPOb77d3MU/wJsG8=
github.com/samber/ro/plugins/signal v0.0.0-20260516194255-a8c153943435/go.mod h1:/zJel4/bMNX08spdRGRZvUWz1BE+iqXSZij3hDps3/A=
github.com/samber/ro/plugins/signal v0.0.0-20260522200928-fb86b5fc464e h1:ehkGCX5UBJoGHHN3KMbAjt/bJLBe5asb0gY0KibWSmU=
github.com/samber/ro/plugins/signal v0.0.0-20260522200928-fb86b5fc464e/go.mod h1:qZvrMezFzFNkX7V/Qh01B79tVW/Nw8TIGLjsrSdccy0=
github.com/samber/slog-common v0.21.0 h1:Wo2hTly1Br5RjYqX/BTWJJeDnTE85oWk/7vqlpZuAUc=
Expand Down
1 change: 1 addition & 0 deletions internal/assistant/anthropic_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ func testCompletionRequestAuth(args ...string) *CompletionRequest {
CWD: "",
Auth: testRequestAuth(apiKey),
Messages: nil,
Usage: model.EmptyTokenUsage(),
Model: model.Model{
ThinkingLevelMap: nil,
Headers: nil,
Expand Down
2 changes: 2 additions & 0 deletions internal/assistant/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const (
jsonAssistantRole = "assistant"
jsonToolRole = "tool"
jsonCommandKey = "command"
jsonBreakdownKey = "breakdown"
jsonReadToolName = "read"
jsonBashToolName = "bash"
jsonEditToolName = "edit"
Expand Down Expand Up @@ -92,6 +93,7 @@ type CompletionRequest struct {
CWD string `json:"cwd"`
Auth model.RequestAuth `json:"auth"`
Messages []database.MessageEntity `json:"messages"`
Usage model.TokenUsage `json:"usage"`
Model model.Model `json:"model"`
}

Expand Down
20 changes: 13 additions & 7 deletions internal/assistant/context_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,19 @@ func (runtime *Runtime) modelContextBase(
}

func initialContextBuildResult(base *modelContextBase, selectedModel *model.Model) *contextBuildResult {
breakdown := contextBreakdown(base.SystemTokens, base.SkillTokens, base.HistoryTokens, nil)

return &contextBuildResult{
Contributions: []contextContribution{},
Messages: base.Messages,
Breakdown: contextBreakdown(base.SystemTokens, base.SkillTokens, base.HistoryTokens, nil),
Breakdown: breakdown,
SystemPrompt: base.SystemPrompt,
Usage: estimateContextBuildUsage(
base.SystemPrompt,
base.Messages,
nil,
selectedModel,
breakdown,
),
}
}
Expand All @@ -170,32 +173,35 @@ func recalculateContextBuildResult(
base *modelContextBase,
selectedModel *model.Model,
) {
result.Usage = estimateContextBuildUsage(
base.SystemPrompt,
base.Messages,
result.Contributions,
selectedModel,
)
result.Breakdown = contextBreakdown(
base.SystemTokens,
base.SkillTokens,
base.HistoryTokens,
result.Contributions,
)
result.Usage = estimateContextBuildUsage(
result.SystemPrompt,
base.Messages,
result.Contributions,
selectedModel,
result.Breakdown,
)
}

func estimateContextBuildUsage(
systemPrompt string,
messages []database.MessageEntity,
contributions []contextContribution,
selectedModel *model.Model,
breakdown map[string]int,
) model.TokenUsage {
inputTokens := estimateInputTokens(systemPrompt, messages)
for index := range contributions {
inputTokens += contributions[index].Tokens
}

return model.TokenUsage{
Breakdown: cloneIntMapForUsage(breakdown),
ContextWindow: selectedModel.ContextWindow,
ContextTokens: inputTokens,
InputTokens: inputTokens,
Expand Down
2 changes: 2 additions & 0 deletions internal/assistant/context_build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ end)
assert.Contains(t, client.request.SystemPrompt, "<extension_context>")
assert.Contains(t, client.request.SystemPrompt, "project-note")
assert.Contains(t, client.request.SystemPrompt, "Always mention extension context")
require.NotNil(t, client.request.Usage.Breakdown)
assert.Greater(t, client.request.Usage.Breakdown["extensions"], 0)
}

func TestRuntime_ContextBuildRejectsOversizedExtensionContributions(t *testing.T) {
Expand Down
13 changes: 13 additions & 0 deletions internal/assistant/context_usage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package assistant

func cloneIntMapForUsage(values map[string]int) map[string]int {
if len(values) == 0 {
return nil
}
cloned := make(map[string]int, len(values))
for key, value := range values {
cloned[key] = value
}

return cloned
}
3 changes: 2 additions & 1 deletion internal/assistant/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func contextBuildLifecyclePayload(
lifecycleCWDKey: cwd,
jsonSessionIDKey: sessionID,
"message_count": len(base.Messages),
"breakdown": cloneIntMap(result.Breakdown),
jsonBreakdownKey: cloneIntMap(result.Breakdown),
"contributions": []any{},
"max_contribution_tokens": contextContributionMaxTokens,
"system_tokens": base.SystemTokens,
Expand Down Expand Up @@ -311,6 +311,7 @@ func entryLifecyclePayload(entry *database.EntryEntity) map[string]any {

func tokenUsageLifecyclePayload(usage model.TokenUsage) map[string]any {
return map[string]any{
jsonBreakdownKey: cloneIntMap(usage.Breakdown),
jsonContextTokensKey: usage.ContextTokens,
jsonContextWindowKey: usage.ContextWindow,
jsonInputTokensKey: usage.InputTokens,
Expand Down
3 changes: 3 additions & 0 deletions internal/assistant/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ func (runtime *Runtime) modelResponse(
sessionID,
contextResult.SystemPrompt,
cwd,
contextResult.Usage,
registry,
onEvent,
)
Expand All @@ -812,6 +813,7 @@ func (runtime *Runtime) modelCompletionRequest(
sessionID string,
systemPrompt string,
cwd string,
usage model.TokenUsage,
registry *tool.Registry,
onEvent func(StreamEvent),
) *CompletionRequest {
Expand All @@ -826,6 +828,7 @@ func (runtime *Runtime) modelCompletionRequest(
CWD: cwd,
Auth: auth,
Messages: messages,
Usage: usage,
Model: *selectedModel,
}
}
Expand Down
8 changes: 7 additions & 1 deletion internal/assistant/runtime_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ func TestRuntime_ProviderLifecyclePublishesReactiveEvents(t *testing.T) {
Text: testCompletionText,
Thinking: nil,
ToolEvents: nil,
Usage: model.TokenUsage{InputTokens: 9, OutputTokens: 3, ContextTokens: 9, ContextWindow: 100},
Usage: model.TokenUsage{
Breakdown: nil,
ContextWindow: 100,
ContextTokens: 9,
InputTokens: 9,
OutputTokens: 3,
},
},
err: nil,
})
Expand Down
8 changes: 7 additions & 1 deletion internal/assistant/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,13 @@ func (testCompletionClient) Complete(
Text: "test assistant response for " + request.Messages[len(request.Messages)-1].Content,
Thinking: nil,
ToolEvents: nil,
Usage: model.TokenUsage{InputTokens: 12, OutputTokens: 4, ContextTokens: 12, ContextWindow: 1000},
Usage: model.TokenUsage{
Breakdown: nil,
ContextWindow: 1000,
ContextTokens: 12,
InputTokens: 12,
OutputTokens: 4,
},
}, nil
}

Expand Down
12 changes: 11 additions & 1 deletion internal/assistant/usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func mergeUsage(estimated, reported model.TokenUsage) model.TokenUsage {
if reported.OutputTokens > 0 {
usage.OutputTokens = reported.OutputTokens
}
if len(usage.Breakdown) == 0 && len(reported.Breakdown) > 0 {
usage.Breakdown = cloneIntMapForUsage(reported.Breakdown)
}

return usage
}

Expand All @@ -57,7 +61,13 @@ func usageFromObject(value any) model.TokenUsage {
input := usageInputTokens(object)
output := intFromAny(firstPresent(object, jsonOutputTokensKey, "completion_tokens"))

return model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: input, OutputTokens: output}
return model.TokenUsage{
Breakdown: nil,
ContextWindow: 0,
ContextTokens: 0,
InputTokens: input,
OutputTokens: output,
}
}

func usageInputTokens(object map[string]any) int {
Expand Down
1 change: 1 addition & 0 deletions internal/assistant/usage_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func (runtime *Runtime) emitUsage(ctx context.Context, onEvent func(StreamEvent)
Text: "",
})
payload := map[string]any{
jsonBreakdownKey: cloneIntMap(usage.Breakdown),
jsonContextWindowKey: usage.ContextWindow,
jsonContextTokensKey: usage.ContextTokens,
jsonInputTokensKey: usage.InputTokens,
Expand Down
50 changes: 41 additions & 9 deletions internal/assistant/usage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,32 @@ func TestUsageFromObjectParsesProviderShapes(t *testing.T) {
"input_tokens": float64(123),
jsonOutputTokensKey: float64(45),
},
expected: model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 123, OutputTokens: 45},
expected: model.TokenUsage{
Breakdown: nil, ContextWindow: 0, ContextTokens: 0,
InputTokens: 123, OutputTokens: 45,
},
},
{
name: "chat completions",
usage: map[string]any{
"prompt_tokens": json.Number("77"),
"completion_tokens": json.Number("9"),
},
expected: model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 77, OutputTokens: 9},
expected: model.TokenUsage{
Breakdown: nil, ContextWindow: 0, ContextTokens: 0,
InputTokens: 77, OutputTokens: 9,
},
},
{
name: "total tokens does not become input tokens",
usage: map[string]any{
"total_tokens": json.Number("120"),
jsonOutputTokensKey: json.Number("20"),
},
expected: model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 100, OutputTokens: 20},
expected: model.TokenUsage{
Breakdown: nil, ContextWindow: 0, ContextTokens: 0,
InputTokens: 100, OutputTokens: 20,
},
},
}

Expand All @@ -54,10 +63,17 @@ func TestUsageFromObjectParsesProviderShapes(t *testing.T) {
func TestMergeUsagePreservesEstimatedContextWindow(t *testing.T) {
t.Parallel()

estimated := model.TokenUsage{ContextWindow: 1000, ContextTokens: 200, InputTokens: 200, OutputTokens: 0}
reported := model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 150, OutputTokens: 25}
estimated := model.TokenUsage{
Breakdown: nil, ContextWindow: 1000, ContextTokens: 200,
InputTokens: 200, OutputTokens: 0,
}
reported := model.TokenUsage{
Breakdown: nil, ContextWindow: 0, ContextTokens: 0,
InputTokens: 150, OutputTokens: 25,
}

assert.Equal(t, model.TokenUsage{
Breakdown: nil,
ContextWindow: 1000,
ContextTokens: 200,
InputTokens: 150,
Expand All @@ -68,10 +84,17 @@ func TestMergeUsagePreservesEstimatedContextWindow(t *testing.T) {
func TestMergeUsageNeverShrinksEstimatedContext(t *testing.T) {
t.Parallel()

estimated := model.TokenUsage{ContextWindow: 100_000, ContextTokens: 14_000, InputTokens: 14_000, OutputTokens: 0}
reported := model.TokenUsage{ContextWindow: 0, ContextTokens: 12_000, InputTokens: 12_000, OutputTokens: 700}
estimated := model.TokenUsage{
Breakdown: nil, ContextWindow: 100_000, ContextTokens: 14_000,
InputTokens: 14_000, OutputTokens: 0,
}
reported := model.TokenUsage{
Breakdown: nil, ContextWindow: 0, ContextTokens: 12_000,
InputTokens: 12_000, OutputTokens: 700,
}

assert.Equal(t, model.TokenUsage{
Breakdown: nil,
ContextWindow: 100_000,
ContextTokens: 14_000,
InputTokens: 12_000,
Expand All @@ -82,10 +105,17 @@ func TestMergeUsageNeverShrinksEstimatedContext(t *testing.T) {
func TestMergeUsageDoesNotPromoteProviderTotalToContext(t *testing.T) {
t.Parallel()

estimated := model.TokenUsage{ContextWindow: 272_000, ContextTokens: 0, InputTokens: 0, OutputTokens: 0}
reported := model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 13_000_000, OutputTokens: 100}
estimated := model.TokenUsage{
Breakdown: nil, ContextWindow: 272_000, ContextTokens: 0,
InputTokens: 0, OutputTokens: 0,
}
reported := model.TokenUsage{
Breakdown: nil, ContextWindow: 0, ContextTokens: 0,
InputTokens: 13_000_000, OutputTokens: 100,
}

assert.Equal(t, model.TokenUsage{
Breakdown: nil,
ContextWindow: 272_000,
ContextTokens: 0,
InputTokens: 13_000_000,
Expand All @@ -106,6 +136,7 @@ func TestParseSSEResultPreservesUsageWhenItemsProvideText(t *testing.T) {
result, err := parseSSEResult(strings.NewReader(stream), nil)
require.NoError(t, err)
assert.Equal(t, model.TokenUsage{
Breakdown: nil,
ContextWindow: 0,
ContextTokens: 0,
InputTokens: 12,
Expand All @@ -128,6 +159,7 @@ func TestParseSSEResultPreservesUsageAcrossLaterResponseEvents(t *testing.T) {
result, err := parseSSEResult(strings.NewReader(stream), nil)
require.NoError(t, err)
assert.Equal(t, model.TokenUsage{
Breakdown: nil,
ContextWindow: 0,
ContextTokens: 0,
InputTokens: 12,
Expand Down
17 changes: 12 additions & 5 deletions internal/model/usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@ package model

// TokenUsage tracks model context and request/response token counts.
type TokenUsage struct {
ContextWindow int `json:"context_window,omitempty"`
ContextTokens int `json:"context_tokens,omitempty"`
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
Breakdown map[string]int `json:"breakdown,omitempty"`
ContextWindow int `json:"context_window,omitempty"`
ContextTokens int `json:"context_tokens,omitempty"`
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
}

// EmptyTokenUsage returns a zero-value token usage with explicit fields.
func EmptyTokenUsage() TokenUsage {
return TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 0, OutputTokens: 0}
return TokenUsage{
Breakdown: nil,
ContextWindow: 0,
ContextTokens: 0,
InputTokens: 0,
OutputTokens: 0,
}
}

// TotalTokens returns input plus output tokens reported for the turn.
Expand Down
Loading
Loading