Skip to content
Merged

DI #337

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 75 additions & 79 deletions backend/src/__tests__/cost-aggregation.integration.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { TEST_USER_ID } from '@codebuff/common/old-constants'
import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime'
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
import { getInitialSessionState } from '@codebuff/common/types/session-state'
import {
spyOn,
Expand All @@ -12,12 +12,12 @@ import {
} from 'bun:test'

import * as messageCostTracker from '../llm-apis/message-cost-tracker'
import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk'
import { mainPrompt } from '../main-prompt'
import * as agentRegistry from '../templates/agent-registry'
import * as websocketAction from '../websockets/websocket-action'

import type { AgentTemplate } from '../templates/types'
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
import type { ProjectFileContext } from '@codebuff/common/util/file'
import type { WebSocket } from 'ws'

Expand Down Expand Up @@ -99,8 +99,10 @@ class MockWebSocket {
describe('Cost Aggregation Integration Tests', () => {
let mockLocalAgentTemplates: Record<string, any>
let mockWebSocket: MockWebSocket
let agentRuntimeImpl: AgentRuntimeDeps

beforeEach(async () => {
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
mockWebSocket = new MockWebSocket()

// Setup mock agent templates
Expand Down Expand Up @@ -150,33 +152,31 @@ describe('Cost Aggregation Integration Tests', () => {
// Mock LLM streaming
let callCount = 0
const creditHistory: number[] = []
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
async function* (options) {
callCount++
const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs
creditHistory.push(credits)

if (options.onCostCalculated) {
await options.onCostCalculated(credits)
}
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
callCount++
const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs
creditHistory.push(credits)

// Simulate different responses based on call
if (callCount === 1) {
// Main agent spawns a subagent
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n</codebuff_tool_call>',
}
} else {
// Subagent writes a file
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n</codebuff_tool_call>',
}
if (options.onCostCalculated) {
await options.onCostCalculated(credits)
}

// Simulate different responses based on call
if (callCount === 1) {
// Main agent spawns a subagent
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n</codebuff_tool_call>',
}
return 'mock-message-id'
},
)
} else {
// Subagent writes a file
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n</codebuff_tool_call>',
}
}
return 'mock-message-id'
}

// Mock tool call execution
spyOn(websocketAction, 'requestToolCall').mockImplementation(
Expand Down Expand Up @@ -250,7 +250,7 @@ describe('Cost Aggregation Integration Tests', () => {
}

const result = await mainPrompt({
...testAgentRuntimeImpl,
...agentRuntimeImpl,
ws: mockWebSocket as unknown as WebSocket,
action,
userId: TEST_USER_ID,
Expand Down Expand Up @@ -285,7 +285,7 @@ describe('Cost Aggregation Integration Tests', () => {

// Call through websocket action handler to test full integration
await websocketAction.callMainPrompt({
...testAgentRuntimeImpl,
...agentRuntimeImpl,
ws: mockWebSocket as unknown as WebSocket,
action,
userId: TEST_USER_ID,
Expand All @@ -308,37 +308,35 @@ describe('Cost Aggregation Integration Tests', () => {
it('should handle multi-level subagent hierarchies correctly', async () => {
// Mock a more complex scenario with nested subagents
let callCount = 0
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
async function* (options) {
callCount++
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
callCount++

if (options.onCostCalculated) {
await options.onCostCalculated(5) // Each call costs 5 credits
}
if (options.onCostCalculated) {
await options.onCostCalculated(5) // Each call costs 5 credits
}

if (callCount === 1) {
// Main agent spawns first-level subagent
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n</codebuff_tool_call>',
}
} else if (callCount === 2) {
// First-level subagent spawns second-level subagent
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n</codebuff_tool_call>',
}
} else {
// Second-level subagent does actual work
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n</codebuff_tool_call>',
}
if (callCount === 1) {
// Main agent spawns first-level subagent
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n</codebuff_tool_call>',
}
} else if (callCount === 2) {
// First-level subagent spawns second-level subagent
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n</codebuff_tool_call>',
}
} else {
// Second-level subagent does actual work
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n</codebuff_tool_call>',
}
}

return 'mock-message-id'
},
)
return 'mock-message-id'
}

const sessionState = getInitialSessionState(mockFileContext)
sessionState.mainAgentState.stepsRemaining = 10
Expand All @@ -355,7 +353,7 @@ describe('Cost Aggregation Integration Tests', () => {
}

const result = await mainPrompt({
...testAgentRuntimeImpl,
...agentRuntimeImpl,
ws: mockWebSocket as unknown as WebSocket,
action,
userId: TEST_USER_ID,
Expand All @@ -373,29 +371,27 @@ describe('Cost Aggregation Integration Tests', () => {
it('should maintain cost integrity when subagents fail', async () => {
// Mock scenario where subagent fails after incurring partial costs
let callCount = 0
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
async function* (options) {
callCount++
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
callCount++

if (options.onCostCalculated) {
await options.onCostCalculated(6) // Each call costs 6 credits
}
if (options.onCostCalculated) {
await options.onCostCalculated(6) // Each call costs 6 credits
}

if (callCount === 1) {
// Main agent spawns subagent
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n</codebuff_tool_call>',
}
} else {
// Subagent fails after incurring cost
yield { type: 'text' as const, text: 'Some response' }
throw new Error('Subagent execution failed')
if (callCount === 1) {
// Main agent spawns subagent
yield {
type: 'text' as const,
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n</codebuff_tool_call>',
}
} else {
// Subagent fails after incurring cost
yield { type: 'text' as const, text: 'Some response' }
throw new Error('Subagent execution failed')
}

return 'mock-message-id'
},
)
return 'mock-message-id'
}

const sessionState = getInitialSessionState(mockFileContext)
sessionState.mainAgentState.agentType = 'base'
Expand All @@ -413,7 +409,7 @@ describe('Cost Aggregation Integration Tests', () => {
let result
try {
result = await mainPrompt({
...testAgentRuntimeImpl,
...agentRuntimeImpl,
ws: mockWebSocket as unknown as WebSocket,
action,
userId: TEST_USER_ID,
Expand Down Expand Up @@ -462,7 +458,7 @@ describe('Cost Aggregation Integration Tests', () => {
}

await mainPrompt({
...testAgentRuntimeImpl,
...agentRuntimeImpl,
ws: mockWebSocket as unknown as WebSocket,
action,
userId: TEST_USER_ID,
Expand Down Expand Up @@ -502,7 +498,7 @@ describe('Cost Aggregation Integration Tests', () => {

// Call through websocket action to test server-side reset
await websocketAction.callMainPrompt({
...testAgentRuntimeImpl,
...agentRuntimeImpl,
ws: mockWebSocket as unknown as WebSocket,
action,
userId: TEST_USER_ID,
Expand Down
8 changes: 4 additions & 4 deletions backend/src/__tests__/cost-aggregation.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime'
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
import {
getInitialAgentState,
getInitialSessionState,
Expand Down Expand Up @@ -180,7 +180,7 @@ describe('Cost Aggregation System', () => {
}

const result = handleSpawnAgents({
...testAgentRuntimeImpl,
...TEST_AGENT_RUNTIME_IMPL,
previousToolCallFinished: Promise.resolve(),
toolCall: mockToolCall,
fileContext: mockFileContext,
Expand Down Expand Up @@ -260,7 +260,7 @@ describe('Cost Aggregation System', () => {
}

const result = handleSpawnAgents({
...testAgentRuntimeImpl,
...TEST_AGENT_RUNTIME_IMPL,
previousToolCallFinished: Promise.resolve(),
toolCall: mockToolCall,
fileContext: mockFileContext,
Expand Down Expand Up @@ -417,7 +417,7 @@ describe('Cost Aggregation System', () => {
}

const result = handleSpawnAgents({
...testAgentRuntimeImpl,
...TEST_AGENT_RUNTIME_IMPL,
previousToolCallFinished: Promise.resolve(),
toolCall: mockToolCall,
fileContext: mockFileContext,
Expand Down
20 changes: 9 additions & 11 deletions backend/src/__tests__/fast-rewrite.test.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import path from 'path'

import { TEST_USER_ID } from '@codebuff/common/old-constants'
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
import {
clearMockedModules,
mockModule,
} from '@codebuff/common/testing/mock-modules'
import { afterAll, beforeAll, describe, expect, it } from 'bun:test'
import { afterAll, beforeAll, beforeEach, describe, expect, it } from 'bun:test'
import { createPatch } from 'diff'

import { rewriteWithOpenAI } from '../fast-rewrite'

import type { Logger } from '@codebuff/common/types/contracts/logger'
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'

const logger: Logger = {
debug: () => {},
info: () => {},
warn: () => {},
error: () => {},
}
let agentRuntimeImpl: AgentRuntimeDeps

describe.skip('rewriteWithOpenAI', () => {
beforeAll(() => {
Expand All @@ -42,6 +38,10 @@ describe.skip('rewriteWithOpenAI', () => {
}))
})

beforeEach(() => {
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
})

afterAll(() => {
clearMockedModules()
})
Expand All @@ -53,15 +53,13 @@ describe.skip('rewriteWithOpenAI', () => {
const expectedResult = await Bun.file(`${testDataDir}/expected.go`).text()

const result = await rewriteWithOpenAI({
...agentRuntimeImpl,
oldContent: originalContent,
editSnippet,
filePath: 'taskruntoolcall.go',
clientSessionId: 'clientSessionId',
fingerprintId: 'fingerprintId',
userInputId: 'userInputId',
userId: TEST_USER_ID,
userMessage: undefined,
logger,
})

const patch = createPatch('test.ts', expectedResult, result)
Expand Down
Loading