Skip to content

Commit 04d7823

Browse files
committed
pass in logger to main prompt
1 parent e6f4d75 commit 04d7823

File tree

6 files changed

+157
-162
lines changed

6 files changed

+157
-162
lines changed

backend/src/__tests__/cost-aggregation.integration.test.ts

Lines changed: 45 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
import { TEST_USER_ID } from '@codebuff/common/old-constants'
2-
import {
3-
clearMockedModules,
4-
mockModule,
5-
} from '@codebuff/common/testing/mock-modules'
62
import { getInitialSessionState } from '@codebuff/common/types/session-state'
73
import {
84
spyOn,
95
beforeEach,
106
afterEach,
11-
beforeAll,
12-
afterAll,
137
describe,
148
expect,
159
it,
@@ -24,6 +18,7 @@ import * as websocketAction from '../websockets/websocket-action'
2418

2519
import type { AgentTemplate } from '../templates/types'
2620
import type { ProjectFileContext } from '@codebuff/common/util/file'
21+
import type { Logger } from '@codebuff/types/logger'
2722
import type { WebSocket } from 'ws'
2823

2924
const mockFileContext: ProjectFileContext = {
@@ -104,19 +99,12 @@ class MockWebSocket {
10499
describe('Cost Aggregation Integration Tests', () => {
105100
let mockLocalAgentTemplates: Record<string, any>
106101
let mockWebSocket: MockWebSocket
107-
108-
beforeAll(() => {
109-
// Mock logger for backend
110-
mockModule('@codebuff/backend/util/logger', () => ({
111-
logger: {
112-
debug: () => {},
113-
error: () => {},
114-
info: () => {},
115-
warn: () => {},
116-
},
117-
withLoggerContext: async (context: any, fn: () => Promise<any>) => fn(),
118-
}))
119-
})
102+
const logger: Logger = {
103+
debug: () => {},
104+
error: () => {},
105+
info: () => {},
106+
warn: () => {},
107+
}
120108

121109
beforeEach(async () => {
122110
mockWebSocket = new MockWebSocket()
@@ -251,10 +239,6 @@ describe('Cost Aggregation Integration Tests', () => {
251239
mock.restore()
252240
})
253241

254-
afterAll(() => {
255-
clearMockedModules()
256-
})
257-
258242
it('should correctly aggregate costs across the entire main prompt flow', async () => {
259243
const sessionState = getInitialSessionState(mockFileContext)
260244
// Set the main agent to use the 'base' type which is defined in our mock templates
@@ -271,16 +255,15 @@ describe('Cost Aggregation Integration Tests', () => {
271255
toolResults: [],
272256
}
273257

274-
const result = await mainPrompt(
275-
mockWebSocket as unknown as WebSocket,
258+
const result = await mainPrompt({
259+
ws: mockWebSocket as unknown as WebSocket,
276260
action,
277-
{
278-
userId: TEST_USER_ID,
279-
clientSessionId: 'test-session',
280-
onResponseChunk: () => {},
281-
localAgentTemplates: mockLocalAgentTemplates,
282-
},
283-
)
261+
userId: TEST_USER_ID,
262+
clientSessionId: 'test-session',
263+
onResponseChunk: () => {},
264+
localAgentTemplates: mockLocalAgentTemplates,
265+
logger,
266+
})
284267

285268
// Verify the total cost includes both main agent and subagent costs
286269
const finalCreditsUsed = result.sessionState.mainAgentState.creditsUsed
@@ -307,15 +290,14 @@ describe('Cost Aggregation Integration Tests', () => {
307290
}
308291

309292
// Call through websocket action handler to test full integration
310-
await websocketAction.callMainPrompt(
311-
mockWebSocket as unknown as WebSocket,
293+
await websocketAction.callMainPrompt({
294+
ws: mockWebSocket as unknown as WebSocket,
312295
action,
313-
{
314-
userId: TEST_USER_ID,
315-
promptId: 'test-prompt',
316-
clientSessionId: 'test-session',
317-
},
318-
)
296+
userId: TEST_USER_ID,
297+
promptId: 'test-prompt',
298+
clientSessionId: 'test-session',
299+
logger,
300+
})
319301

320302
// Verify final cost is included in prompt response
321303
const promptResponse = mockWebSocket.sentActions.find(
@@ -378,16 +360,15 @@ describe('Cost Aggregation Integration Tests', () => {
378360
toolResults: [],
379361
}
380362

381-
const result = await mainPrompt(
382-
mockWebSocket as unknown as WebSocket,
363+
const result = await mainPrompt({
364+
ws: mockWebSocket as unknown as WebSocket,
383365
action,
384-
{
385-
userId: TEST_USER_ID,
386-
clientSessionId: 'test-session',
387-
onResponseChunk: () => {},
388-
localAgentTemplates: mockLocalAgentTemplates,
389-
},
390-
)
366+
userId: TEST_USER_ID,
367+
clientSessionId: 'test-session',
368+
onResponseChunk: () => {},
369+
localAgentTemplates: mockLocalAgentTemplates,
370+
logger,
371+
})
391372

392373
// Should aggregate costs from all levels: main + sub1 + sub2
393374
const finalCreditsUsed = result.sessionState.mainAgentState.creditsUsed
@@ -437,11 +418,14 @@ describe('Cost Aggregation Integration Tests', () => {
437418

438419
let result
439420
try {
440-
result = await mainPrompt(mockWebSocket as unknown as WebSocket, action, {
421+
result = await mainPrompt({
422+
ws: mockWebSocket as unknown as WebSocket,
423+
action,
441424
userId: TEST_USER_ID,
442425
clientSessionId: 'test-session',
443426
onResponseChunk: () => {},
444427
localAgentTemplates: mockLocalAgentTemplates,
428+
logger,
445429
})
446430
} catch (error) {
447431
// Expected to fail, but costs may still be tracked
@@ -483,11 +467,14 @@ describe('Cost Aggregation Integration Tests', () => {
483467
toolResults: [],
484468
}
485469

486-
await mainPrompt(mockWebSocket as unknown as WebSocket, action, {
470+
await mainPrompt({
471+
ws: mockWebSocket as unknown as WebSocket,
472+
action,
487473
userId: TEST_USER_ID,
488474
clientSessionId: 'test-session',
489475
onResponseChunk: () => {},
490476
localAgentTemplates: mockLocalAgentTemplates,
477+
logger,
491478
})
492479

493480
// Verify no duplicate message IDs (no double-counting)
@@ -520,15 +507,14 @@ describe('Cost Aggregation Integration Tests', () => {
520507
}
521508

522509
// Call through websocket action to test server-side reset
523-
await websocketAction.callMainPrompt(
524-
mockWebSocket as unknown as WebSocket,
510+
await websocketAction.callMainPrompt({
511+
ws: mockWebSocket as unknown as WebSocket,
525512
action,
526-
{
527-
userId: TEST_USER_ID,
528-
promptId: 'test-prompt',
529-
clientSessionId: 'test-session',
530-
},
531-
)
513+
userId: TEST_USER_ID,
514+
promptId: 'test-prompt',
515+
clientSessionId: 'test-session',
516+
logger,
517+
})
532518

533519
// Server should have reset the malicious value and calculated correct cost
534520
const promptResponse = mockWebSocket.sentActions.find(

backend/src/__tests__/main-prompt.integration.test.ts

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -384,21 +384,20 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) {
384384
toolResults: [],
385385
}
386386

387-
const { output, sessionState: finalSessionState } = await mainPrompt(
388-
new MockWebSocket() as unknown as WebSocket,
387+
const { output, sessionState: finalSessionState } = await mainPrompt({
388+
ws: new MockWebSocket() as unknown as WebSocket,
389389
action,
390-
{
391-
userId: TEST_USER_ID,
392-
clientSessionId: 'test-session-delete-function-integration',
393-
localAgentTemplates: mockLocalAgentTemplates,
394-
onResponseChunk: (chunk: string | PrintModeEvent) => {
395-
if (typeof chunk !== 'string') {
396-
return
397-
}
398-
process.stdout.write(chunk)
399-
},
390+
userId: TEST_USER_ID,
391+
clientSessionId: 'test-session-delete-function-integration',
392+
localAgentTemplates: mockLocalAgentTemplates,
393+
onResponseChunk: (chunk: string | PrintModeEvent) => {
394+
if (typeof chunk !== 'string') {
395+
return
396+
}
397+
process.stdout.write(chunk)
400398
},
401-
)
399+
logger,
400+
})
402401
const requestToolCallSpy = websocketAction.requestToolCall as any
403402

404403
// Find the write_file tool call
@@ -474,7 +473,9 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) {
474473
toolResults: [],
475474
}
476475

477-
await mainPrompt(new MockWebSocket() as unknown as WebSocket, action, {
476+
await mainPrompt({
477+
ws: new MockWebSocket() as unknown as WebSocket,
478+
action,
478479
userId: TEST_USER_ID,
479480
clientSessionId: 'test-session-delete-function-integration',
480481
localAgentTemplates: {
@@ -501,6 +502,7 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) {
501502
}
502503
process.stdout.write(chunk)
503504
},
505+
logger,
504506
})
505507

506508
const requestToolCallSpy = websocketAction.requestToolCall as any

0 commit comments

Comments
 (0)