Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/cai/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,9 +1244,19 @@ def fix_message_list(messages): # pylint: disable=R0914,R0915,R0912

# If this isn't the first message, check if the previous message is a matching assistant message
if i > 0:
prev_msg = processed_messages[i - 1]

# Check if the previous message is an assistant message with matching tool_call_id
# Walk backward past sibling tool messages to find the nearest
# assistant. This avoids an infinite loop when an assistant has
# multiple tool_calls and their responses arrive out of order:
# the previous message may be a sibling tool response rather
# than the parent assistant message, which is still valid.
k = i - 1
while k >= 0 and processed_messages[k].get("role") == "tool":
k -= 1

prev_msg = processed_messages[k] if k >= 0 else {}

# Check if the nearest non-tool ancestor is an assistant message
# with a matching tool_call_id
is_valid_sequence = (
prev_msg.get("role") == "assistant"
and prev_msg.get("tool_calls")
Expand Down
45 changes: 45 additions & 0 deletions tests/test_fix_message_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import signal

import pytest

from cai.util import fix_message_list


def test_fix_message_list_handles_multiple_tool_results_without_loop():
if not hasattr(signal, "SIGALRM"):
pytest.skip("SIGALRM is required for this infinite-loop regression test")

messages = [
{"role": "user", "content": "Run both tools"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_one",
"type": "function",
"function": {"name": "first_tool", "arguments": "{}"},
},
{
"id": "call_two",
"type": "function",
"function": {"name": "second_tool", "arguments": "{}"},
},
],
},
{"role": "tool", "tool_call_id": "call_one", "content": "first result"},
{"role": "tool", "tool_call_id": "call_two", "content": "second result"},
]

def fail_on_timeout(_signum, _frame):
raise TimeoutError("fix_message_list did not terminate")

previous_handler = signal.signal(signal.SIGALRM, fail_on_timeout)
signal.alarm(2)
try:
fixed_messages = fix_message_list(messages)
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, previous_handler)

assert fixed_messages == messages