Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,11 @@ def generate(
tool_schemas = None
tool_call_turns = 0
total_tool_calls = 0
curr_num_correction_steps = 0
# Counts parse attempts within the current conversation (initial attempt + corrections).
# The first parse is attempt 1, so `parse_attempts <= max_correction_steps` permits exactly
# `max_correction_steps` corrections after the initial attempt before falling through to
# restart-or-raise. Reset to 0 on each conversation restart.
parse_attempts = 0
curr_num_restarts = 0

mcp_facade = self._get_mcp_facade(tool_alias)
Expand Down Expand Up @@ -367,7 +371,7 @@ def generate(
response = (completion_response.message.content or "").strip()
reasoning_trace = completion_response.message.reasoning_content
messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
curr_num_correction_steps += 1
parse_attempts += 1

try:
output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below
Expand All @@ -379,12 +383,12 @@ def generate(
exc,
) from exc

if curr_num_correction_steps <= max_correction_steps:
if parse_attempts <= max_correction_steps:
# Add user message with error for correction
messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc))))

elif curr_num_restarts < max_conversation_restarts:
curr_num_correction_steps = 0
parse_attempts = 0
curr_num_restarts += 1
messages = deepcopy(restart_checkpoint)
tool_call_turns = checkpoint_tool_call_turns
Expand Down Expand Up @@ -425,7 +429,8 @@ async def agenerate(
tool_schemas = None
tool_call_turns = 0
total_tool_calls = 0
curr_num_correction_steps = 0
# See `generate` for a description of the parse-attempts counter semantics.
parse_attempts = 0
curr_num_restarts = 0

mcp_facade = self._get_mcp_facade(tool_alias)
Expand Down Expand Up @@ -469,7 +474,7 @@ async def agenerate(
response = (completion_response.message.content or "").strip()
reasoning_trace = completion_response.message.reasoning_content
messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None))
curr_num_correction_steps += 1
parse_attempts += 1

try:
output_obj = parser(response)
Expand All @@ -481,11 +486,11 @@ async def agenerate(
exc,
) from exc

if curr_num_correction_steps <= max_correction_steps:
if parse_attempts <= max_correction_steps:
messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc))))

elif curr_num_restarts < max_conversation_restarts:
curr_num_correction_steps = 0
parse_attempts = 0
curr_num_restarts += 1
messages = deepcopy(restart_checkpoint)
tool_call_turns = checkpoint_tool_call_turns
Expand Down
Loading