Skip to content

Commit b992969

Browse files
committed
Retry generating output schema if errors
1 parent 04689f2 commit b992969

File tree

2 files changed

+333
-0
lines changed

2 files changed

+333
-0
lines changed

backend/src/__tests__/loop-agent-steps.test.ts

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import type { AgentTemplate } from '../templates/types'
2828
import type { StepGenerator } from '@codebuff/common/types/agent-template'
2929
import type { AgentState } from '@codebuff/common/types/session-state'
3030
import type { WebSocket } from 'ws'
31+
import { z } from 'zod/v4'
3132

3233
describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () => {
3334
let mockTemplate: AgentTemplate
@@ -680,4 +681,300 @@ describe('loopAgentSteps - runAgentStep vs runProgrammaticStep behavior', () =>
680681
// Second call should have stepsComplete: true (after end_turn tool was called)
681682
expect(runProgrammaticStepCalls[1].options.stepsComplete).toBe(true)
682683
})
684+
685+
it('should restart loop when agent finishes without setting required output', async () => {
686+
// Test that when an agent has outputSchema but finishes without calling set_output,
687+
// the loop restarts with a system message
688+
689+
const outputSchema = z.object({
690+
result: z.string(),
691+
status: z.string(),
692+
})
693+
694+
const templateWithOutputSchema = {
695+
...mockTemplate,
696+
outputSchema,
697+
toolNames: ['set_output', 'end_turn'], // Add set_output to available tools
698+
handleSteps: undefined, // LLM-only agent
699+
}
700+
701+
const localAgentTemplates = {
702+
'test-agent': templateWithOutputSchema,
703+
}
704+
705+
let llmCallNumber = 0
706+
let capturedAgentState: AgentState | null = null
707+
708+
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) {
709+
llmCallNumber++
710+
if (llmCallNumber === 1) {
711+
// First call: agent tries to end turn without setting output
712+
yield {
713+
type: 'text' as const,
714+
text: `First response without output\n\n${getToolCallString('end_turn', {})}`,
715+
}
716+
} else if (llmCallNumber === 2) {
717+
// Second call: agent sets output after being reminded
718+
// Manually set the output to simulate the set_output tool execution
719+
if (capturedAgentState) {
720+
capturedAgentState.output = {
721+
result: 'test result',
722+
status: 'success',
723+
}
724+
}
725+
yield {
726+
type: 'text' as const,
727+
text: `Setting output now\n\n${getToolCallString('set_output', { result: 'test result', status: 'success' })}\n\n${getToolCallString('end_turn', {})}`,
728+
}
729+
} else {
730+
// Safety: if called more than twice, just end
731+
yield {
732+
type: 'text' as const,
733+
text: `Ending\n\n${getToolCallString('end_turn', {})}`,
734+
}
735+
}
736+
return 'mock-message-id'
737+
})
738+
739+
const mockCheckLiveUserInput = require('@codebuff/backend/live-user-inputs')
740+
let checkCount = 0
741+
spyOn(mockCheckLiveUserInput, 'checkLiveUserInput').mockImplementation(
742+
() => {
743+
checkCount++
744+
return checkCount < 10 // Limit to prevent infinite loop
745+
},
746+
)
747+
748+
// Capture the agent state during execution
749+
mockAgentState.output = undefined
750+
capturedAgentState = mockAgentState
751+
752+
const result = await loopAgentSteps(
753+
new MockWebSocket() as unknown as WebSocket,
754+
{
755+
userInputId: 'test-user-input',
756+
agentType: 'test-agent',
757+
agentState: mockAgentState,
758+
prompt: 'Test output schema validation',
759+
params: undefined,
760+
fingerprintId: 'test-fingerprint',
761+
fileContext: mockFileContext,
762+
localAgentTemplates,
763+
userId: TEST_USER_ID,
764+
clientSessionId: 'test-session',
765+
onResponseChunk: () => {},
766+
},
767+
)
768+
769+
// Should call LLM twice: once to try ending without output, once after reminder
770+
expect(llmCallNumber).toBe(2)
771+
772+
// Should have output set after the second attempt
773+
expect(result.agentState.output).toEqual({
774+
result: 'test result',
775+
status: 'success',
776+
})
777+
778+
// Check that a system message was added to message history
779+
const systemMessages = result.agentState.messageHistory.filter(
780+
(msg) =>
781+
msg.role === 'user' &&
782+
typeof msg.content === 'string' &&
783+
msg.content.includes('set_output'),
784+
)
785+
expect(systemMessages.length).toBeGreaterThan(0)
786+
})
787+
788+
it('should not restart loop if output is set correctly', async () => {
789+
// Test that when an agent has outputSchema and sets output correctly,
790+
// the loop ends normally without restarting
791+
792+
const outputSchema = z.object({
793+
result: z.string(),
794+
})
795+
796+
const templateWithOutputSchema = {
797+
...mockTemplate,
798+
outputSchema,
799+
toolNames: ['set_output', 'end_turn'],
800+
handleSteps: undefined,
801+
}
802+
803+
const localAgentTemplates = {
804+
'test-agent': templateWithOutputSchema,
805+
}
806+
807+
let llmCallNumber = 0
808+
let capturedAgentState: AgentState | null = null
809+
810+
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) {
811+
llmCallNumber++
812+
// Agent sets output correctly on first call
813+
if (capturedAgentState) {
814+
capturedAgentState.output = { result: 'success' }
815+
}
816+
yield {
817+
type: 'text' as const,
818+
text: `Setting output\n\n${getToolCallString('set_output', { result: 'success' })}\n\n${getToolCallString('end_turn', {})}`,
819+
}
820+
return 'mock-message-id'
821+
})
822+
823+
const mockCheckLiveUserInput = require('@codebuff/backend/live-user-inputs')
824+
spyOn(mockCheckLiveUserInput, 'checkLiveUserInput').mockImplementation(
825+
() => true,
826+
)
827+
828+
mockAgentState.output = undefined
829+
capturedAgentState = mockAgentState
830+
831+
const result = await loopAgentSteps(
832+
new MockWebSocket() as unknown as WebSocket,
833+
{
834+
userInputId: 'test-user-input',
835+
agentType: 'test-agent',
836+
agentState: mockAgentState,
837+
prompt: 'Test with correct output',
838+
params: undefined,
839+
fingerprintId: 'test-fingerprint',
840+
fileContext: mockFileContext,
841+
localAgentTemplates,
842+
userId: TEST_USER_ID,
843+
clientSessionId: 'test-session',
844+
onResponseChunk: () => {},
845+
},
846+
)
847+
848+
// Should only call LLM once since output was set correctly
849+
expect(llmCallNumber).toBe(1)
850+
851+
// Should have output set
852+
expect(result.agentState.output).toEqual({ result: 'success' })
853+
})
854+
855+
it('should allow agents without outputSchema to end normally', async () => {
856+
// Test that agents without outputSchema can end without setting output
857+
858+
const templateWithoutOutputSchema = {
859+
...mockTemplate,
860+
outputSchema: undefined,
861+
handleSteps: undefined,
862+
}
863+
864+
const localAgentTemplates = {
865+
'test-agent': templateWithoutOutputSchema,
866+
}
867+
868+
let llmCallNumber = 0
869+
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) {
870+
llmCallNumber++
871+
yield {
872+
type: 'text' as const,
873+
text: `Response without output\n\n${getToolCallString('end_turn', {})}`,
874+
}
875+
return 'mock-message-id'
876+
})
877+
878+
const mockCheckLiveUserInput = require('@codebuff/backend/live-user-inputs')
879+
spyOn(mockCheckLiveUserInput, 'checkLiveUserInput').mockImplementation(
880+
() => true,
881+
)
882+
883+
const result = await loopAgentSteps(
884+
new MockWebSocket() as unknown as WebSocket,
885+
{
886+
userInputId: 'test-user-input',
887+
agentType: 'test-agent',
888+
agentState: mockAgentState,
889+
prompt: 'Test without output schema',
890+
params: undefined,
891+
fingerprintId: 'test-fingerprint',
892+
fileContext: mockFileContext,
893+
localAgentTemplates,
894+
userId: TEST_USER_ID,
895+
clientSessionId: 'test-session',
896+
onResponseChunk: () => {},
897+
},
898+
)
899+
900+
// Should only call LLM once and end normally
901+
expect(llmCallNumber).toBe(1)
902+
903+
// Output should be undefined since no outputSchema required
904+
expect(result.agentState.output).toBeUndefined()
905+
})
906+
907+
it('should continue loop if agent does not end turn (has more work)', async () => {
908+
// Test that validation only triggers when shouldEndTurn is true
909+
910+
const outputSchema = z.object({
911+
result: z.string(),
912+
})
913+
914+
const templateWithOutputSchema = {
915+
...mockTemplate,
916+
outputSchema,
917+
toolNames: ['read_files', 'set_output', 'end_turn'],
918+
handleSteps: undefined,
919+
}
920+
921+
const localAgentTemplates = {
922+
'test-agent': templateWithOutputSchema,
923+
}
924+
925+
let llmCallNumber = 0
926+
let capturedAgentState: AgentState | null = null
927+
928+
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(async function* ({}) {
929+
llmCallNumber++
930+
if (llmCallNumber === 1) {
931+
// First call: agent does some work but doesn't end turn
932+
yield {
933+
type: 'text' as const,
934+
text: `Doing work\n\n${getToolCallString('read_files', { paths: ['test.txt'] })}`,
935+
}
936+
} else {
937+
// Second call: agent sets output and ends
938+
if (capturedAgentState) {
939+
capturedAgentState.output = { result: 'done' }
940+
}
941+
yield {
942+
type: 'text' as const,
943+
text: `Finishing\n\n${getToolCallString('set_output', { result: 'done' })}\n\n${getToolCallString('end_turn', {})}`,
944+
}
945+
}
946+
return 'mock-message-id'
947+
})
948+
949+
const mockCheckLiveUserInput = require('@codebuff/backend/live-user-inputs')
950+
spyOn(mockCheckLiveUserInput, 'checkLiveUserInput').mockImplementation(
951+
() => true,
952+
)
953+
954+
mockAgentState.output = undefined
955+
capturedAgentState = mockAgentState
956+
957+
const result = await loopAgentSteps(
958+
new MockWebSocket() as unknown as WebSocket,
959+
{
960+
userInputId: 'test-user-input',
961+
agentType: 'test-agent',
962+
agentState: mockAgentState,
963+
prompt: 'Test loop continues',
964+
params: undefined,
965+
fingerprintId: 'test-fingerprint',
966+
fileContext: mockFileContext,
967+
localAgentTemplates,
968+
userId: TEST_USER_ID,
969+
clientSessionId: 'test-session',
970+
onResponseChunk: () => {},
971+
},
972+
)
973+
974+
// Should call LLM twice: once for work, once to set output and end
975+
expect(llmCallNumber).toBe(2)
976+
977+
// Should have output set
978+
expect(result.agentState.output).toEqual({ result: 'done' })
979+
})
683980
})

backend/src/run-agent-step.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ export const loopAgentSteps = async (
555555
messageHistory: initialMessages,
556556
}
557557
let shouldEndTurn = false
558+
let hasRetriedOutputSchema = false
558559
let currentPrompt = prompt
559560
let currentParams = params
560561
let totalSteps = 0
@@ -616,6 +617,41 @@ export const loopAgentSteps = async (
616617
}
617618
}
618619

620+
// Check if output is required but missing
621+
if (
622+
agentTemplate.outputSchema &&
623+
currentAgentState.output === undefined &&
624+
shouldEndTurn &&
625+
!hasRetriedOutputSchema
626+
) {
627+
hasRetriedOutputSchema = true
628+
logger.warn(
629+
{
630+
agentType,
631+
agentId: currentAgentState.agentId,
632+
runId,
633+
},
634+
'Agent finished without setting required output, restarting loop',
635+
)
636+
637+
// Add system message instructing to use set_output
638+
const outputSchemaMessage = asSystemMessage(
639+
`You must use the "set_output" tool to provide a result that matches the output schema before ending your turn. The output schema is required for this agent.`,
640+
)
641+
642+
currentAgentState.messageHistory = [
643+
...currentAgentState.messageHistory,
644+
{
645+
role: 'user',
646+
content: outputSchemaMessage,
647+
keepDuringTruncation: true,
648+
},
649+
]
650+
651+
// Reset shouldEndTurn to continue the loop
652+
shouldEndTurn = false
653+
}
654+
619655
// End turn if programmatic step ended turn, or if the previous runAgentStep ended turn
620656
if (shouldEndTurn) {
621657
break

0 commit comments

Comments
 (0)