Skip to content

Commit 74935ba

Browse files
committed
add promptAiSdk to AgentRuntimeDeps
1 parent 1f0554b commit 74935ba

File tree

14 files changed

+154
-112
lines changed

14 files changed

+154
-112
lines changed

backend/src/__tests__/main-prompt.integration.test.ts

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@ import {
1515
} from 'bun:test'
1616

1717
import * as checkTerminalCommandModule from '../check-terminal-command'
18-
import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk'
1918
import { mainPrompt } from '../main-prompt'
2019
import * as websocketAction from '../websockets/websocket-action'
2120

21+
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
2222
import type { PrintModeEvent } from '@codebuff/common/types/print-mode'
2323
import type { ProjectFileContext } from '@codebuff/common/util/file'
2424
import type { WebSocket } from 'ws'
2525

2626
// --- Shared Mocks & Helpers ---
2727

28+
let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL }
29+
2830
class MockWebSocket {
2931
send(msg: string) {}
3032
close() {}
@@ -103,6 +105,7 @@ describe.skip('mainPrompt (Integration)', () => {
103105

104106
afterEach(() => {
105107
mock.restore()
108+
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
106109
})
107110

108111
it('should delete a specified function while preserving other code', async () => {
@@ -337,7 +340,9 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) {
337340
)
338341

339342
// Mock LLM calls
340-
spyOn(aisdk, 'promptAiSdk').mockResolvedValue('Mocked non-stream AiSdk')
343+
agentRuntimeImpl.promptAiSdk = async function () {
344+
return 'Mocked non-stream AiSdk'
345+
}
341346

342347
const sessionState = getInitialSessionState(mockFileContext)
343348
sessionState.mainAgentState.messageHistory.push(
@@ -377,7 +382,7 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) {
377382
}
378383

379384
const { output, sessionState: finalSessionState } = await mainPrompt({
380-
...TEST_AGENT_RUNTIME_IMPL,
385+
...agentRuntimeImpl,
381386
ws: new MockWebSocket() as unknown as WebSocket,
382387
action,
383388
userId: TEST_USER_ID,
@@ -419,7 +424,9 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) {
419424
).mockResolvedValue(null)
420425

421426
// Mock LLM calls
422-
spyOn(aisdk, 'promptAiSdk').mockResolvedValue('Mocked non-stream AiSdk')
427+
agentRuntimeImpl.promptAiSdk = async function () {
428+
return 'Mocked non-stream AiSdk'
429+
}
423430

424431
const sessionState = getInitialSessionState(mockFileContext)
425432
sessionState.mainAgentState.messageHistory.push(
@@ -459,7 +466,7 @@ export function getMessagesSubset(messages: Message[], otherTokens: number) {
459466
}
460467

461468
await mainPrompt({
462-
...TEST_AGENT_RUNTIME_IMPL,
469+
...agentRuntimeImpl,
463470
ws: new MockWebSocket() as unknown as WebSocket,
464471
action,
465472
userId: TEST_USER_ID,

backend/src/__tests__/malformed-tool-call.test.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@ import {
1616
} from 'bun:test'
1717

1818
import { MockWebSocket, mockFileContext } from './test-utils'
19-
import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk'
2019
import { processStreamWithTools } from '../tools/stream-parser'
2120
import * as websocketAction from '../websockets/websocket-action'
2221

2322
import type { AgentTemplate } from '../templates/types'
23+
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
2424
import type {
2525
Message,
2626
ToolMessage,
2727
} from '@codebuff/common/types/messages/codebuff-message'
2828
import type { WebSocket } from 'ws'
2929

30+
let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL }
31+
3032
describe('malformed tool call error handling', () => {
3133
let testAgent: AgentTemplate
3234
let mockWs: MockWebSocket
@@ -72,16 +74,17 @@ describe('malformed tool call error handling', () => {
7274
}))
7375

7476
// Mock LLM APIs
75-
spyOn(aisdk, 'promptAiSdk').mockImplementation(() =>
76-
Promise.resolve('Test response'),
77-
)
77+
agentRuntimeImpl.promptAiSdk = async function () {
78+
return 'Test response'
79+
}
7880

7981
// Mock generateCompactId for consistent test results
8082
spyOn(stringUtils, 'generateCompactId').mockReturnValue('test-tool-call-id')
8183
})
8284

8385
afterEach(() => {
8486
mock.restore()
87+
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
8588
})
8689

8790
function createMockStream(chunks: string[]) {

backend/src/__tests__/process-file-block.test.ts

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { TEST_USER_ID } from '@codebuff/common/old-constants'
2+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
23
import {
34
clearMockedModules,
45
mockModule,
@@ -9,14 +10,9 @@ import { applyPatch } from 'diff'
910

1011
import { processFileBlock } from '../process-file-block'
1112

12-
import type { Logger } from '@codebuff/common/types/contracts/logger'
13+
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
1314

14-
const logger: Logger = {
15-
debug: () => {},
16-
info: () => {},
17-
warn: () => {},
18-
error: () => {},
19-
}
15+
let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL }
2016

2117
describe('processFileBlockModule', () => {
2218
beforeAll(() => {
@@ -74,6 +70,7 @@ describe('processFileBlockModule', () => {
7470
const expectedContent = 'function test() {\n return true;\n}'
7571

7672
const result = await processFileBlock({
73+
...agentRuntimeImpl,
7774
path: 'test.ts',
7875
instructions: undefined,
7976
initialContentPromise: Promise.resolve(null),
@@ -85,7 +82,6 @@ describe('processFileBlockModule', () => {
8582
fingerprintId: 'fingerprintId',
8683
userInputId: 'userInputId',
8784
userId: TEST_USER_ID,
88-
logger,
8985
})
9086

9187
expect(result).not.toBeNull()
@@ -111,6 +107,7 @@ describe('processFileBlockModule', () => {
111107
'}\r\n'
112108

113109
const result = await processFileBlock({
110+
...agentRuntimeImpl,
114111
path: 'test.ts',
115112
instructions: undefined,
116113
initialContentPromise: Promise.resolve(oldContent),
@@ -122,7 +119,6 @@ describe('processFileBlockModule', () => {
122119
fingerprintId: 'fingerprintId',
123120
userInputId: 'userInputId',
124121
userId: TEST_USER_ID,
125-
logger,
126122
})
127123

128124
expect(result).not.toBeNull()
@@ -144,6 +140,7 @@ describe('processFileBlockModule', () => {
144140
const newContent = 'function test() {\n return true;\n}\n'
145141

146142
const result = await processFileBlock({
143+
...agentRuntimeImpl,
147144
path: 'test.ts',
148145
instructions: undefined,
149146
initialContentPromise: Promise.resolve(oldContent),
@@ -155,7 +152,6 @@ describe('processFileBlockModule', () => {
155152
fingerprintId: 'fingerprintId',
156153
userInputId: 'userInputId',
157154
userId: TEST_USER_ID,
158-
logger,
159155
})
160156

161157
expect(result).not.toBeNull()
@@ -170,6 +166,7 @@ describe('processFileBlockModule', () => {
170166
const newContent = 'const x = 1;\r\nconst z = 3;\r\n'
171167

172168
const result = await processFileBlock({
169+
...agentRuntimeImpl,
173170
path: 'test.ts',
174171
instructions: undefined,
175172
initialContentPromise: Promise.resolve(oldContent),
@@ -181,7 +178,6 @@ describe('processFileBlockModule', () => {
181178
fingerprintId: 'fingerprintId',
182179
userInputId: 'userInputId',
183180
userId: TEST_USER_ID,
184-
logger,
185181
})
186182

187183
expect(result).not.toBeNull()
@@ -217,6 +213,7 @@ describe('processFileBlockModule', () => {
217213
'// ... existing code ...\nconst x = 1;\n// ... existing code ...'
218214

219215
const result = await processFileBlock({
216+
...agentRuntimeImpl,
220217
path: 'test.ts',
221218
instructions: undefined,
222219
initialContentPromise: Promise.resolve(null),
@@ -228,7 +225,6 @@ describe('processFileBlockModule', () => {
228225
fingerprintId: 'fingerprintId',
229226
userInputId: 'userInputId',
230227
userId: TEST_USER_ID,
231-
logger,
232228
})
233229

234230
expect(result).not.toBeNull()

backend/src/__tests__/run-agent-step-tools.test.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import {
1818

1919
// Mock imports
2020
import * as liveUserInputs from '../live-user-inputs'
21-
import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk'
2221
import { runAgentStep } from '../run-agent-step'
2322
import { clearAgentGeneratorCache } from '../run-programmatic-step'
2423
import { asUserMessage } from '../util/messages'
@@ -106,9 +105,9 @@ describe('runAgentStep - set_output tool', () => {
106105
// Don't mock requestToolCall for integration test - let real tool execution happen
107106

108107
// Mock LLM APIs
109-
spyOn(aisdk, 'promptAiSdk').mockImplementation(() =>
110-
Promise.resolve('Test response'),
111-
)
108+
agentRuntimeImpl.promptAiSdk = async function () {
109+
return 'Test response'
110+
}
112111
clearAgentGeneratorCache(agentRuntimeImpl)
113112
})
114113

backend/src/__tests__/web-search-tool.test.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import * as requestFilesPrompt from '../find-files/request-files-prompt'
2323
import * as liveUserInputs from '../live-user-inputs'
2424
import { MockWebSocket, mockFileContext } from './test-utils'
2525
import * as linkupApi from '../llm-apis/linkup-api'
26-
import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk'
2726
import { runAgentStep } from '../run-agent-step'
2827
import { assembleLocalAgentTemplates } from '../templates/agent-registry'
2928
import * as websocketAction from '../websockets/websocket-action'
@@ -67,9 +66,9 @@ describe('web_search tool with researcher agent', () => {
6766
}))
6867

6968
// Mock LLM APIs
70-
spyOn(aisdk, 'promptAiSdk').mockImplementation(() =>
71-
Promise.resolve('Test response'),
72-
)
69+
agentRuntimeImpl.promptAiSdk = async function () {
70+
return 'Test response'
71+
}
7372

7473
// Mock other required modules
7574
spyOn(requestFilesPrompt, 'requestRelevantFiles').mockImplementation(

backend/src/impl/agent-runtime.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import { addAgentStep, finishAgentRun, startAgentRun } from '../agent-run'
2-
import { promptAiSdkStream } from '../llm-apis/vercel-ai-sdk/ai-sdk'
2+
import {
3+
promptAiSdk,
4+
promptAiSdkStream,
5+
} from '../llm-apis/vercel-ai-sdk/ai-sdk'
36
import { logger } from '../util/logger'
47

58
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
@@ -12,4 +15,5 @@ export const BACKEND_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({
1215
addAgentStep,
1316

1417
promptAiSdkStream,
18+
promptAiSdk,
1519
})

backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ import { openRouterLanguageModel } from '../openrouter'
1717
import { vertexFinetuned } from './vertex-finetuned'
1818

1919
import type { Model, OpenAIModel } from '@codebuff/common/old-constants'
20-
import type { PromptAiSdkStreamFn } from '@codebuff/common/types/contracts/llm'
21-
import type { Logger } from '@codebuff/common/types/contracts/logger'
2220
import type {
23-
ParamsExcluding,
24-
ParamsOf,
25-
} from '@codebuff/common/types/function-params'
21+
PromptAiSdkFn,
22+
PromptAiSdkStreamFn,
23+
} from '@codebuff/common/types/contracts/llm'
24+
import type { Logger } from '@codebuff/common/types/contracts/logger'
25+
import type { ParamsOf } from '@codebuff/common/types/function-params'
2626
import type { Message } from '@codebuff/common/types/messages/codebuff-message'
2727
import type {
2828
OpenRouterProviderOptions,
@@ -231,21 +231,8 @@ export const promptAiSdkStream = async function* (
231231

232232
// TODO: figure out a nice way to unify stream & non-stream versions maybe?
233233
export const promptAiSdk = async function (
234-
params: {
235-
messages: Message[]
236-
clientSessionId: string
237-
fingerprintId: string
238-
userInputId: string
239-
model: Model
240-
userId: string | undefined
241-
chargeUser?: boolean
242-
agentId?: string
243-
onCostCalculated?: (credits: number) => Promise<void>
244-
includeCacheControl?: boolean
245-
maxRetries?: number
246-
logger: Logger
247-
} & ParamsExcluding<typeof generateText, 'model' | 'messages'>,
248-
): Promise<string> {
234+
params: ParamsOf<PromptAiSdkFn>,
235+
): ReturnType<PromptAiSdkFn> {
249236
const { logger } = params
250237

251238
if (!checkLiveUserInput(params)) {

backend/src/process-file-block.ts

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,32 @@ import {
88
parseAndGetDiffBlocksSingleFile,
99
retryDiffBlocksPrompt,
1010
} from './generate-diffs-prompt'
11-
import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk'
1211
import { countTokens } from './util/token-counter'
1312

14-
import type { Message } from '@codebuff/common/types/messages/codebuff-message'
13+
import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm'
1514
import type { Logger } from '@codebuff/common/types/contracts/logger'
15+
import type { ParamsExcluding } from '@codebuff/common/types/function-params'
16+
import type { Message } from '@codebuff/common/types/messages/codebuff-message'
1617

17-
export async function processFileBlock(params: {
18-
path: string
19-
instructions: string | undefined
20-
initialContentPromise: Promise<string | null>
21-
newContent: string
22-
messages: Message[]
23-
fullResponse: string
24-
lastUserPrompt: string | undefined
25-
clientSessionId: string
26-
fingerprintId: string
27-
userInputId: string
28-
userId: string | undefined
29-
logger: Logger
30-
}): Promise<
18+
export async function processFileBlock(
19+
params: {
20+
path: string
21+
instructions: string | undefined
22+
initialContentPromise: Promise<string | null>
23+
newContent: string
24+
messages: Message[]
25+
fullResponse: string
26+
lastUserPrompt: string | undefined
27+
clientSessionId: string
28+
fingerprintId: string
29+
userInputId: string
30+
userId: string | undefined
31+
logger: Logger
32+
} & ParamsExcluding<
33+
typeof handleLargeFile,
34+
'oldContent' | 'editSnippet' | 'filePath'
35+
>,
36+
): Promise<
3137
| {
3238
tool: 'write_file'
3339
path: string
@@ -113,14 +119,10 @@ export async function processFileBlock(params: {
113119
)
114120
if (tokenCount > LARGE_FILE_TOKEN_LIMIT) {
115121
const largeFileContent = await handleLargeFile({
122+
...params,
116123
oldContent: normalizedInitialContent,
117124
editSnippet: normalizedEditSnippet,
118-
clientSessionId,
119-
fingerprintId,
120-
userInputId,
121-
userId,
122125
filePath: path,
123-
logger,
124126
})
125127

126128
if (!largeFileContent) {
@@ -239,6 +241,7 @@ export async function handleLargeFile(params: {
239241
userId: string | undefined
240242
filePath: string
241243
logger: Logger
244+
promptAiSdk: PromptAiSdkFn
242245
}): Promise<string | null> {
243246
const {
244247
oldContent,
@@ -248,6 +251,7 @@ export async function handleLargeFile(params: {
248251
userInputId,
249252
userId,
250253
filePath,
254+
promptAiSdk,
251255
logger,
252256
} = params
253257
const startTime = Date.now()

0 commit comments

Comments
 (0)