Skip to content

Commit 43cac13

Browse files
committed
move callMainPrompt to agent-runtime
1 parent 1b3e22a commit 43cac13

File tree

3 files changed

+108
-128
lines changed

3 files changed

+108
-128
lines changed

backend/src/__tests__/cost-aggregation.integration.test.ts

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { disableLiveUserInputCheck } from '@codebuff/agent-runtime/live-user-inputs'
2-
import { mainPrompt } from '@codebuff/agent-runtime/main-prompt'
2+
import { callMainPrompt, mainPrompt } from '@codebuff/agent-runtime/main-prompt'
33
import * as agentRegistry from '@codebuff/agent-runtime/templates/agent-registry'
44
import { TEST_USER_ID } from '@codebuff/common/old-constants'
55
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
@@ -15,9 +15,6 @@ import {
1515
beforeAll,
1616
} from 'bun:test'
1717

18-
import * as messageCostTracker from '../llm-apis/message-cost-tracker'
19-
import * as websocketAction from '../websockets/websocket-action'
20-
2118
import type { AgentTemplate } from '@codebuff/agent-runtime/templates/types'
2219
import type { ServerAction } from '@codebuff/common/actions'
2320
import type {
@@ -135,14 +132,6 @@ describe('Cost Aggregation Integration Tests', () => {
135132
} satisfies AgentTemplate,
136133
}
137134

138-
// Mock cost tracking to return 0 so only onCostCalculated contributes
139-
spyOn(messageCostTracker, 'saveMessage').mockImplementation(
140-
async (value) => {
141-
// Return 0 so we can control costs only via onCostCalculated
142-
return 0
143-
},
144-
)
145-
146135
// Mock LLM streaming
147136
let callCount = 0
148137
const creditHistory: number[] = []
@@ -271,7 +260,7 @@ describe('Cost Aggregation Integration Tests', () => {
271260
}
272261

273262
// Call through websocket action handler to test full integration
274-
await websocketAction.callMainPrompt({
263+
await callMainPrompt({
275264
...agentRuntimeImpl,
276265
repoId: undefined,
277266
repoUrl: undefined,
@@ -424,17 +413,6 @@ describe('Cost Aggregation Integration Tests', () => {
424413
it('should not double-count costs in complex scenarios', async () => {
425414
// Track all saveMessage calls to ensure no duplication
426415
const saveMessageCalls: any[] = []
427-
spyOn(messageCostTracker, 'saveMessage').mockImplementation(
428-
async (value) => {
429-
saveMessageCalls.push({
430-
messageId: value.messageId,
431-
model: value.model,
432-
inputTokens: value.inputTokens,
433-
outputTokens: value.outputTokens,
434-
})
435-
return 8 // Each LLM call costs 8 credits
436-
},
437-
)
438416

439417
const sessionState = getInitialSessionState(mockFileContext)
440418
sessionState.mainAgentState.agentType = 'base'
@@ -490,7 +468,7 @@ describe('Cost Aggregation Integration Tests', () => {
490468
}
491469

492470
// Call through websocket action to test server-side reset
493-
await websocketAction.callMainPrompt({
471+
await callMainPrompt({
494472
...agentRuntimeImpl,
495473
repoId: undefined,
496474
repoUrl: undefined,

backend/src/websockets/websocket-action.ts

Lines changed: 1 addition & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import {
22
cancelUserInput,
3-
checkLiveUserInput,
43
startUserInput,
54
} from '@codebuff/agent-runtime/live-user-inputs'
6-
import { mainPrompt } from '@codebuff/agent-runtime/main-prompt'
7-
import { assembleLocalAgentTemplates } from '@codebuff/agent-runtime/templates/agent-registry'
85
import { calculateUsageAndBalance } from '@codebuff/billing'
96
import { trackEvent } from '@codebuff/common/analytics'
107
import { AnalyticsEvent } from '@codebuff/common/constants/analytics-events'
@@ -19,13 +16,13 @@ import { getRequestContext } from './request-context'
1916
import { withLoggerContext } from '../util/logger'
2017

2118
import type { ClientAction, UsageResponse } from '@codebuff/common/actions'
22-
import type { SendActionFn } from '@codebuff/common/types/contracts/client'
2319
import type { GetUserInfoFromApiKeyFn } from '@codebuff/common/types/contracts/database'
2420
import type { UserInputRecord } from '@codebuff/common/types/contracts/live-user-input'
2521
import type { Logger } from '@codebuff/common/types/contracts/logger'
2622
import type { ParamsExcluding } from '@codebuff/common/types/function-params'
2723
import type { ClientMessage } from '@codebuff/common/websockets/websocket-schema'
2824
import type { WebSocket } from 'ws'
25+
import { callMainPrompt } from '@codebuff/agent-runtime/main-prompt'
2926

3027
/**
3128
* Generates a usage response object for the client
@@ -181,101 +178,6 @@ const onPrompt = async (
181178
)
182179
}
183180

184-
export const callMainPrompt = async (
185-
params: {
186-
action: ClientAction<'prompt'>
187-
userId: string
188-
promptId: string
189-
clientSessionId: string
190-
sendAction: SendActionFn
191-
liveUserInputRecord: UserInputRecord
192-
logger: Logger
193-
} & ParamsExcluding<
194-
typeof mainPrompt,
195-
'localAgentTemplates' | 'onResponseChunk'
196-
>,
197-
) => {
198-
const { action, userId, promptId, clientSessionId, sendAction, logger } =
199-
params
200-
const { fileContext } = action.sessionState
201-
202-
// Enforce server-side state authority: reset creditsUsed to 0
203-
// The server controls cost tracking, clients cannot manipulate this value
204-
action.sessionState.mainAgentState.creditsUsed = 0
205-
action.sessionState.mainAgentState.directCreditsUsed = 0
206-
207-
// Assemble local agent templates from fileContext
208-
const { agentTemplates: localAgentTemplates, validationErrors } =
209-
assembleLocalAgentTemplates({ fileContext, logger })
210-
211-
if (validationErrors.length > 0) {
212-
sendAction({
213-
action: {
214-
type: 'prompt-error',
215-
message: `Invalid agent config: ${validationErrors.map((err) => err.message).join('\n')}`,
216-
userInputId: promptId,
217-
},
218-
})
219-
}
220-
221-
sendAction({
222-
action: {
223-
type: 'response-chunk',
224-
userInputId: promptId,
225-
chunk: {
226-
type: 'start',
227-
agentId: action.sessionState.mainAgentState.agentType ?? undefined,
228-
messageHistoryLength:
229-
action.sessionState.mainAgentState.messageHistory.length,
230-
},
231-
},
232-
})
233-
234-
const result = await mainPrompt({
235-
...params,
236-
localAgentTemplates,
237-
onResponseChunk: (chunk) => {
238-
if (checkLiveUserInput({ ...params, userInputId: promptId })) {
239-
sendAction({
240-
action: {
241-
type: 'response-chunk',
242-
userInputId: promptId,
243-
chunk,
244-
},
245-
})
246-
}
247-
},
248-
})
249-
250-
const { sessionState, output } = result
251-
252-
sendAction({
253-
action: {
254-
type: 'response-chunk',
255-
userInputId: promptId,
256-
chunk: {
257-
type: 'finish',
258-
agentId: sessionState.mainAgentState.agentType ?? undefined,
259-
totalCost: sessionState.mainAgentState.creditsUsed,
260-
},
261-
},
262-
})
263-
264-
// Send prompt data back
265-
sendAction({
266-
action: {
267-
type: 'prompt-response',
268-
promptId,
269-
sessionState,
270-
toolCalls: [],
271-
toolResults: [],
272-
output,
273-
},
274-
})
275-
276-
return result
277-
}
278-
279181
/**
280182
* Handles initialization actions from the client
281183
* @param fileContext - The file context information

packages/agent-runtime/src/main-prompt.ts

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,22 @@ import { generateCompactId } from '@codebuff/common/util/string'
33
import { uniq } from 'lodash'
44

55
import { checkTerminalCommand } from './check-terminal-command'
6+
import { checkLiveUserInput } from './live-user-inputs'
67
import { loopAgentSteps } from './run-agent-step'
7-
import { getAgentTemplate } from './templates/agent-registry'
8+
import {
9+
assembleLocalAgentTemplates,
10+
getAgentTemplate,
11+
} from './templates/agent-registry'
812
import { expireMessages } from './util/messages'
913

1014
import type { AgentTemplate } from './templates/types'
1115
import type { ClientAction } from '@codebuff/common/actions'
1216
import type { CostMode } from '@codebuff/common/old-constants'
13-
import type { RequestToolCallFn } from '@codebuff/common/types/contracts/client'
17+
import type {
18+
RequestToolCallFn,
19+
SendActionFn,
20+
} from '@codebuff/common/types/contracts/client'
21+
import type { UserInputRecord } from '@codebuff/common/types/contracts/live-user-input'
1422
import type { Logger } from '@codebuff/common/types/contracts/logger'
1523
import type { ParamsExcluding } from '@codebuff/common/types/function-params'
1624
import type { PrintModeEvent } from '@codebuff/common/types/print-mode'
@@ -20,7 +28,7 @@ import type {
2028
AgentOutput,
2129
} from '@codebuff/common/types/session-state'
2230

23-
export const mainPrompt = async (
31+
export async function mainPrompt(
2432
params: {
2533
action: ClientAction<'prompt'>
2634

@@ -48,7 +56,7 @@ export const mainPrompt = async (
4856
): Promise<{
4957
sessionState: SessionState
5058
output: AgentOutput
51-
}> => {
59+
}> {
5260
const { action, localAgentTemplates, requestToolCall, logger } = params
5361

5462
const {
@@ -239,3 +247,95 @@ export const mainPrompt = async (
239247
},
240248
}
241249
}
250+
251+
export async function callMainPrompt(
252+
params: {
253+
action: ClientAction<'prompt'>
254+
promptId: string
255+
sendAction: SendActionFn
256+
liveUserInputRecord: UserInputRecord
257+
logger: Logger
258+
} & ParamsExcluding<
259+
typeof mainPrompt,
260+
'localAgentTemplates' | 'onResponseChunk'
261+
>,
262+
) {
263+
const { action, promptId, sendAction, logger } = params
264+
const { fileContext } = action.sessionState
265+
266+
// Enforce server-side state authority: reset creditsUsed to 0
267+
// The server controls cost tracking, clients cannot manipulate this value
268+
action.sessionState.mainAgentState.creditsUsed = 0
269+
action.sessionState.mainAgentState.directCreditsUsed = 0
270+
271+
// Assemble local agent templates from fileContext
272+
const { agentTemplates: localAgentTemplates, validationErrors } =
273+
assembleLocalAgentTemplates({ fileContext, logger })
274+
275+
if (validationErrors.length > 0) {
276+
sendAction({
277+
action: {
278+
type: 'prompt-error',
279+
message: `Invalid agent config: ${validationErrors.map((err) => err.message).join('\n')}`,
280+
userInputId: promptId,
281+
},
282+
})
283+
}
284+
285+
sendAction({
286+
action: {
287+
type: 'response-chunk',
288+
userInputId: promptId,
289+
chunk: {
290+
type: 'start',
291+
agentId: action.sessionState.mainAgentState.agentType ?? undefined,
292+
messageHistoryLength:
293+
action.sessionState.mainAgentState.messageHistory.length,
294+
},
295+
},
296+
})
297+
298+
const result = await mainPrompt({
299+
...params,
300+
localAgentTemplates,
301+
onResponseChunk: (chunk) => {
302+
if (checkLiveUserInput({ ...params, userInputId: promptId })) {
303+
sendAction({
304+
action: {
305+
type: 'response-chunk',
306+
userInputId: promptId,
307+
chunk,
308+
},
309+
})
310+
}
311+
},
312+
})
313+
314+
const { sessionState, output } = result
315+
316+
sendAction({
317+
action: {
318+
type: 'response-chunk',
319+
userInputId: promptId,
320+
chunk: {
321+
type: 'finish',
322+
agentId: sessionState.mainAgentState.agentType ?? undefined,
323+
totalCost: sessionState.mainAgentState.creditsUsed,
324+
},
325+
},
326+
})
327+
328+
// Send prompt data back
329+
sendAction({
330+
action: {
331+
type: 'prompt-response',
332+
promptId,
333+
sessionState,
334+
toolCalls: [],
335+
toolResults: [],
336+
output,
337+
},
338+
})
339+
340+
return result
341+
}

0 commit comments

Comments
 (0)