Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions src/cai/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,13 +1244,39 @@ def fix_message_list(messages): # pylint: disable=R0914,R0915,R0912

# If this isn't the first message, check if the previous message is a matching assistant message
if i > 0:
prev_msg = processed_messages[i - 1]
# Previous logic only checked the immediately preceding message.
# That fails when one assistant message calls multiple tools:
#
# prev_msg = processed_messages[i - 1]
# is_valid_sequence = (
# prev_msg.get("role") == "assistant"
# and prev_msg.get("tool_calls")
# and any(tc.get("id") == tool_id for tc in prev_msg.get("tool_calls", []))
# )
#
# Valid multi-tool sequences look like:
# assistant(tool_calls=[a, b, c]), tool(a), tool(b), tool(c).
# So walk back over contiguous tool results and validate against
# the nearest preceding assistant tool_calls block.
sequence_assistant_idx = i - 1
while (
sequence_assistant_idx >= 0
and processed_messages[sequence_assistant_idx].get("role") == "tool"
):
sequence_assistant_idx -= 1

# Check if the previous message is an assistant message with matching tool_call_id
sequence_assistant = (
processed_messages[sequence_assistant_idx]
if sequence_assistant_idx >= 0
else {}
)
is_valid_sequence = (
prev_msg.get("role") == "assistant"
and prev_msg.get("tool_calls")
and any(tc.get("id") == tool_id for tc in prev_msg.get("tool_calls", []))
sequence_assistant.get("role") == "assistant"
and sequence_assistant.get("tool_calls")
and any(
tc.get("id") == tool_id
for tc in sequence_assistant.get("tool_calls", [])
)
)

if not is_valid_sequence:
Expand All @@ -1273,18 +1299,33 @@ def fix_message_list(messages): # pylint: disable=R0914,R0915,R0912
# Remember to save the tool message
tool_msg = processed_messages.pop(i)

# Insert right after the assistant message
processed_messages.insert(assistant_idx + 1, tool_msg)
# If the assistant was after the tool message, its index
# shifts left after pop(i).
if assistant_idx > i:
assistant_idx -= 1

# Adjust i to account for the move
if assistant_idx < i:
# We moved the message backward, so i should point to the next message
# which is now at position i (since we removed a message before it)
continue
else:
# We moved the message forward, so i should now point to the message
# that is now at position i
continue
assistant_tool_ids = {
tc.get("id")
for tc in processed_messages[assistant_idx].get("tool_calls", [])
}
insert_idx = assistant_idx + 1
while (
insert_idx < len(processed_messages)
and processed_messages[insert_idx].get("role") == "tool"
and processed_messages[insert_idx].get("tool_call_id")
in assistant_tool_ids
):
insert_idx += 1

# Insert after the assistant's existing tool-result block
# instead of always at assistant_idx + 1, which would reorder
# multiple tool results forever.
processed_messages.insert(insert_idx, tool_msg)

# Move past the current position to avoid reprocessing
# the same slot after mutating the list.
i += 1
continue
else:
# No matching assistant message found - create one
assistant_msg = {
Expand Down
47 changes: 47 additions & 0 deletions tests/cli/test_cli_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,53 @@ def test_fix_message_list_with_interrupted_tools(self):
# No need to clean up _Converter state since it's instance-based
return False

def test_fix_message_list_allows_multiple_tool_results(self):
"""Test one assistant message can be followed by multiple tool results."""
from cai.util import fix_message_list

messages = [
{"role": "user", "content": "Run three checks"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_a",
"type": "function",
"function": {"name": "tool_a", "arguments": "{}"},
},
{
"id": "call_b",
"type": "function",
"function": {"name": "tool_b", "arguments": "{}"},
},
{
"id": "call_c",
"type": "function",
"function": {"name": "tool_c", "arguments": "{}"},
},
],
},
{"role": "tool", "tool_call_id": "call_a", "content": "result a"},
{"role": "tool", "tool_call_id": "call_b", "content": "result b"},
{"role": "tool", "tool_call_id": "call_c", "content": "result c"},
]

fixed_messages = fix_message_list(messages)

assert [msg["role"] for msg in fixed_messages] == [
"user",
"assistant",
"tool",
"tool",
"tool",
]
assert [
msg.get("tool_call_id")
for msg in fixed_messages
if msg.get("role") == "tool"
] == ["call_a", "call_b", "call_c"]

def test_generic_linux_command_interrupt_simulation(self):
"""Test generic_linux_command behavior during interruption."""

Expand Down