Skip to content

Commit d7a4ad2

Browse files
committed
refactor: use pre-computed token counts in prune state instead of re-scanning messages
1 parent 5e9a617 commit d7a4ad2

File tree

17 files changed

+114
-142
lines changed

17 files changed

+114
-142
lines changed

lib/commands/context.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ function analyzeTokens(state: SessionState, messages: WithParts[]): TokenBreakdo
7474
tools: 0,
7575
toolCount: 0,
7676
prunedTokens: state.stats.totalPruneTokens,
77-
prunedToolCount: state.prune.toolIds.size,
78-
prunedMessageCount: state.prune.messageIds.size,
77+
prunedToolCount: state.prune.tools.size,
78+
prunedMessageCount: state.prune.messages.size,
7979
total: 0,
8080
}
8181

@@ -129,7 +129,7 @@ function analyzeTokens(state: SessionState, messages: WithParts[]): TokenBreakdo
129129
foundToolIds.add(toolPart.callID)
130130
}
131131

132-
const isPruned = toolPart.callID && state.prune.toolIds.has(toolPart.callID)
132+
const isPruned = toolPart.callID && state.prune.tools.has(toolPart.callID)
133133
if (!isCompacted && !isPruned) {
134134
if (toolPart.state?.input) {
135135
const inputStr =

lib/commands/stats.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ export async function handleStatsCommand(ctx: StatsCommandContext): Promise<void
5151

5252
// Session stats from in-memory state
5353
const sessionTokens = state.stats.totalPruneTokens
54-
const sessionTools = state.prune.toolIds.size
55-
const sessionMessages = state.prune.messageIds.size
54+
const sessionTools = state.prune.tools.size
55+
const sessionMessages = state.prune.messages.size
5656

5757
// All-time stats from storage files
5858
const allTime = await loadAllSessionStats(logger)

lib/commands/sweep.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import type { SessionState, WithParts, ToolParameterEntry } from "../state"
1212
import type { PluginConfig } from "../config"
1313
import { sendIgnoredMessage } from "../ui/notification"
1414
import { formatPrunedItemsList } from "../ui/utils"
15-
import { getCurrentParams, calculateTokensSaved } from "../strategies/utils"
15+
import { getCurrentParams, getTotalToolTokens } from "../strategies/utils"
1616
import { buildToolIdList, isIgnoredUserMessage } from "../messages/utils"
1717
import { saveSessionState } from "../state/persistence"
1818
import { isMessageCompacted } from "../shared-utils"
@@ -164,7 +164,7 @@ export async function handleSweepCommand(ctx: SweepCommandContext): Promise<void
164164

165165
// Filter out already-pruned tools, protected tools, and protected file paths
166166
const newToolIds = toolIdsToSweep.filter((id) => {
167-
if (state.prune.toolIds.has(id)) {
167+
if (state.prune.tools.has(id)) {
168168
return false
169169
}
170170
const entry = state.toolParameters.get(id)
@@ -214,13 +214,13 @@ export async function handleSweepCommand(ctx: SweepCommandContext): Promise<void
214214
return
215215
}
216216

217+
const tokensSaved = getTotalToolTokens(state, newToolIds)
218+
217219
// Add to prune list
218220
for (const id of newToolIds) {
219-
state.prune.toolIds.add(id)
221+
const entry = state.toolParameters.get(id)
222+
state.prune.tools.set(id, entry?.tokenCount ?? 0)
220223
}
221-
222-
// Calculate tokens saved
223-
const tokensSaved = calculateTokensSaved(state, messages, newToolIds)
224224
state.stats.pruneTokenCounter += tokensSaved
225225
state.stats.totalPruneTokens += state.stats.pruneTokenCounter
226226
state.stats.pruneTokenCounter = 0

lib/messages/inject.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ const buildPrunableToolsList = (
162162
const toolIdList = state.toolIdList
163163

164164
state.toolParameters.forEach((toolParameterEntry, toolCallId) => {
165-
if (state.prune.toolIds.has(toolCallId)) {
165+
if (state.prune.tools.has(toolCallId)) {
166166
return
167167
}
168168

lib/messages/prune.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ const pruneFullTool = (state: SessionState, logger: Logger, messages: WithParts[
3838
if (part.type !== "tool") {
3939
continue
4040
}
41-
if (!state.prune.toolIds.has(part.callID)) {
41+
if (!state.prune.tools.has(part.callID)) {
4242
continue
4343
}
4444
if (part.tool !== "edit" && part.tool !== "write") {
@@ -79,7 +79,7 @@ const pruneToolOutputs = (state: SessionState, logger: Logger, messages: WithPar
7979
if (part.type !== "tool") {
8080
continue
8181
}
82-
if (!state.prune.toolIds.has(part.callID)) {
82+
if (!state.prune.tools.has(part.callID)) {
8383
continue
8484
}
8585
if (part.state.status !== "completed") {
@@ -105,7 +105,7 @@ const pruneToolInputs = (state: SessionState, logger: Logger, messages: WithPart
105105
if (part.type !== "tool") {
106106
continue
107107
}
108-
if (!state.prune.toolIds.has(part.callID)) {
108+
if (!state.prune.tools.has(part.callID)) {
109109
continue
110110
}
111111
if (part.state.status !== "completed") {
@@ -133,7 +133,7 @@ const pruneToolErrors = (state: SessionState, logger: Logger, messages: WithPart
133133
if (part.type !== "tool") {
134134
continue
135135
}
136-
if (!state.prune.toolIds.has(part.callID)) {
136+
if (!state.prune.tools.has(part.callID)) {
137137
continue
138138
}
139139
if (part.state.status !== "error") {
@@ -158,7 +158,7 @@ const filterCompressedRanges = (
158158
logger: Logger,
159159
messages: WithParts[],
160160
): void => {
161-
if (!state.prune.messageIds?.size) {
161+
if (!state.prune.messages?.size) {
162162
return
163163
}
164164

@@ -193,7 +193,7 @@ const filterCompressedRanges = (
193193
}
194194

195195
// Skip messages that are in the prune list
196-
if (state.prune.messageIds.has(msgId)) {
196+
if (state.prune.messages.has(msgId)) {
197197
continue
198198
}
199199

lib/shared-utils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ export const isMessageCompacted = (state: SessionState, msg: WithParts): boolean
55
if (msg.info.time.created < state.lastCompaction) {
66
return true
77
}
8-
if (state.prune.messageIds.has(msg.info.id)) {
8+
if (state.prune.messages.has(msg.info.id)) {
99
return true
1010
}
1111
return false

lib/state/persistence.ts

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ import { join } from "path"
1111
import type { SessionState, SessionStats, CompressSummary } from "./types"
1212
import type { Logger } from "../logger"
1313

14-
/** Prune state as stored on disk (arrays for JSON compatibility) */
14+
/** Prune state as stored on disk */
1515
export interface PersistedPrune {
16-
toolIds: string[]
17-
messageIds: string[]
16+
// New format: tool/message IDs with token counts
17+
tools?: Record<string, number>
18+
messages?: Record<string, number>
19+
// Legacy format: plain ID arrays (backward compatibility)
20+
toolIds?: string[]
21+
messageIds?: string[]
1822
}
1923

2024
export interface PersistedSessionState {
@@ -58,8 +62,8 @@ export async function saveSessionState(
5862
const state: PersistedSessionState = {
5963
sessionName: sessionName,
6064
prune: {
61-
toolIds: [...sessionState.prune.toolIds],
62-
messageIds: [...sessionState.prune.messageIds],
65+
tools: Object.fromEntries(sessionState.prune.tools),
66+
messages: Object.fromEntries(sessionState.prune.messages),
6367
},
6468
compressSummaries: sessionState.compressSummaries,
6569
stats: sessionState.stats,
@@ -96,7 +100,9 @@ export async function loadSessionState(
96100
const content = await fs.readFile(filePath, "utf-8")
97101
const state = JSON.parse(content) as PersistedSessionState
98102

99-
if (!state || !state.prune || !Array.isArray(state.prune.toolIds) || !state.stats) {
103+
const hasNewFormat = state?.prune?.tools && typeof state.prune.tools === "object"
104+
const hasLegacyFormat = Array.isArray(state?.prune?.toolIds)
105+
if (!state || !state.prune || (!hasNewFormat && !hasLegacyFormat) || !state.stats) {
100106
logger.warn("Invalid session state file, ignoring", {
101107
sessionId: sessionId,
102108
})
@@ -166,10 +172,14 @@ export async function loadAllSessionStats(logger: Logger): Promise<AggregatedSta
166172
const content = await fs.readFile(filePath, "utf-8")
167173
const state = JSON.parse(content) as PersistedSessionState
168174

169-
if (state?.stats?.totalPruneTokens && state?.prune?.toolIds) {
175+
if (state?.stats?.totalPruneTokens && state?.prune) {
170176
result.totalTokens += state.stats.totalPruneTokens
171-
result.totalTools += state.prune.toolIds.length
172-
result.totalMessages += state.prune.messageIds?.length || 0
177+
result.totalTools += state.prune.tools
178+
? Object.keys(state.prune.tools).length
179+
: (state.prune.toolIds?.length ?? 0)
180+
result.totalMessages += state.prune.messages
181+
? Object.keys(state.prune.messages).length
182+
: (state.prune.messageIds?.length ?? 0)
173183
result.sessionCount++
174184
}
175185
} catch {

lib/state/state.ts

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {
66
findLastCompactionTimestamp,
77
countTurns,
88
resetOnCompaction,
9+
loadPruneMap,
910
} from "./utils"
1011
import { getLastUserMessage } from "../shared-utils"
1112

@@ -48,8 +49,8 @@ export function createSessionState(): SessionState {
4849
sessionId: null,
4950
isSubAgent: false,
5051
prune: {
51-
toolIds: new Set<string>(),
52-
messageIds: new Set<string>(),
52+
tools: new Map<string, number>(),
53+
messages: new Map<string, number>(),
5354
},
5455
compressSummaries: [],
5556
stats: {
@@ -71,8 +72,8 @@ export function resetSessionState(state: SessionState): void {
7172
state.sessionId = null
7273
state.isSubAgent = false
7374
state.prune = {
74-
toolIds: new Set<string>(),
75-
messageIds: new Set<string>(),
75+
tools: new Map<string, number>(),
76+
messages: new Map<string, number>(),
7677
}
7778
state.compressSummaries = []
7879
state.stats = {
@@ -118,10 +119,8 @@ export async function ensureSessionInitialized(
118119
return
119120
}
120121

121-
state.prune = {
122-
toolIds: new Set(persisted.prune.toolIds || []),
123-
messageIds: new Set(persisted.prune.messageIds || []),
124-
}
122+
state.prune.tools = loadPruneMap(persisted.prune.tools, persisted.prune.toolIds)
123+
state.prune.messages = loadPruneMap(persisted.prune.messages, persisted.prune.messageIds)
125124
state.compressSummaries = persisted.compressSummaries || []
126125
state.stats = {
127126
pruneTokenCounter: persisted.stats?.pruneTokenCounter || 0,

lib/state/types.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ export interface CompressSummary {
2727
}
2828

2929
export interface Prune {
30-
toolIds: Set<string>
31-
messageIds: Set<string>
30+
tools: Map<string, number>
31+
messages: Map<string, number>
3232
}
3333

3434
export interface SessionState {

lib/state/utils.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,19 @@ export function countTurns(state: SessionState, messages: WithParts[]): number {
3636
return turnCount
3737
}
3838

39+
export function loadPruneMap(
40+
obj?: Record<string, number>,
41+
legacyArr?: string[],
42+
): Map<string, number> {
43+
if (obj) return new Map(Object.entries(obj))
44+
if (legacyArr) return new Map(legacyArr.map((id) => [id, 0]))
45+
return new Map()
46+
}
47+
3948
export function resetOnCompaction(state: SessionState): void {
4049
state.toolParameters.clear()
41-
state.prune.toolIds = new Set<string>()
42-
state.prune.messageIds = new Set<string>()
50+
state.prune.tools = new Map<string, number>()
51+
state.prune.messages = new Map<string, number>()
4352
state.compressSummaries = []
4453
state.nudgeCounter = 0
4554
state.lastToolPrune = false

0 commit comments

Comments
 (0)