Skip to content

Commit 5b14d4c

Browse files
committed
fix(tutorial_test): allow tool_request & tool_retrieves to be read before final message in message order II
1 parent f1b2773 commit 5b14d4c

3 files changed

Lines changed: 55 additions & 30 deletions

File tree

examples/tutorials/10_async/10_temporal/070_open_ai_agents_sdk_tools/tests/test_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,11 @@ async def test_send_event_and_poll(self, client: AsyncAgentex, agent_id: str):
118118
content_length = len(str(agent_text))
119119
final_message = message
120120

121-
# Stop when we get DONE status
121+
# Stop when we get DONE status (after tool_response if a tool was used;
122+
# tool rows can appear on a later poll than final text).
122123
if message.streaming_status == "DONE" and content_length > 0:
124+
if seen_tool_request and not seen_tool_response:
125+
continue
123126
break
124127

125128
# Verify we got all the expected pieces

examples/tutorials/10_async/10_temporal/080_open_ai_agents_sdk_human_in_the_loop/tests/test_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,12 @@ async def test_send_event_and_poll_with_human_approval(self, client: AsyncAgente
148148
if message.content and message.content.type == "text" and message.content.author == "agent":
149149
content_length = len(message.content.content) if message.content.content else 0
150150

151-
# Stop when we get DONE status with actual content
151+
# Stop when we get DONE with content (after tool_response if a tool ran;
152+
# tool rows can appear on a later poll than final text).
152153
if message.streaming_status == "DONE" and content_length > 0:
153154
found_final_response = True
155+
if seen_tool_request and not seen_tool_response:
156+
continue
154157
break
155158

156159
# Verify that we saw the complete flow: tool_request -> human approval -> tool_response -> final answer

examples/tutorials/test_utils/async_utils.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,26 @@
1717
from agentex.types.agent_rpc_params import ParamsSendEventRequest
1818
from agentex.types.agent_rpc_result import StreamTaskMessageDone, StreamTaskMessageFull
1919
from agentex.types.text_content_param import TextContentParam
20-
21-
22-
def _task_message_poll_sort_key(message: TaskMessage) -> tuple[int, datetime]:
23-
"""Order messages within one poll so tool lifecycle rows precede other content.
24-
25-
Streaming assistant ``text`` rows are often created before ``tool_request`` /
26-
``tool_response`` rows from the same model turn (earlier ``created_at``). Sorting
27-
only by ``created_at`` makes consumers that ``break`` on DONE agent text exit the
28-
poll generator before tool rows in the same ``list()`` batch are yielded.
29-
"""
30-
ts = message.created_at if message.created_at else datetime.min.replace(tzinfo=timezone.utc)
31-
ctype = getattr(message.content, "type", None) if message.content else None
32-
phase = 0 if ctype in ("tool_request", "tool_response") else 1
20+
from agentex.types.tool_request_content import ToolRequestContent
21+
from agentex.types.tool_response_content import ToolResponseContent
22+
23+
24+
def _is_tool_lifecycle_message(message: TaskMessage) -> bool:
25+
"""True for tool request/response rows (by model type or ``content.type`` string)."""
26+
c = message.content
27+
if c is None:
28+
return False
29+
if isinstance(c, (ToolRequestContent, ToolResponseContent)):
30+
return True
31+
ctype = getattr(c, "type", None)
32+
return ctype in ("tool_request", "tool_response")
33+
34+
35+
def _pending_poll_sort_key(item: tuple[TaskMessage, int | None]) -> tuple[int, datetime]:
36+
"""Yield tool lifecycle messages before other content in the same poll batch."""
37+
m, _ = item
38+
ts = m.created_at if m.created_at else datetime.min.replace(tzinfo=timezone.utc)
39+
phase = 0 if _is_tool_lifecycle_message(m) else 1
3340
return (phase, ts)
3441

3542

@@ -106,12 +113,14 @@ async def poll_messages(
106113
Yields:
107114
TaskMessage objects as they are discovered or updated.
108115
109-
Within each poll, ``tool_request`` and ``tool_response`` messages are yielded before
110-
other types (when present in the same batch), so streaming tests can stop on DONE
111-
agent text without missing tool lifecycle rows.
116+
Within each poll, messages to emit are collected first, then re-ordered so
117+
``tool_request`` / ``tool_response`` rows are yielded before other content in that
118+
batch. ``tool_response`` can also appear on a later poll than final assistant text;
119+
callers that ``break`` on DONE text should keep polling until they have seen
120+
``tool_response`` after a ``tool_request`` (see tutorial tests).
112121
"""
113122
# Keep track of messages we've already yielded
114-
seen_message_ids = set()
123+
seen_message_ids: set[str] = set()
115124
# Track message content hashes to detect updates (for streaming)
116125
message_content_hashes: dict[str, int] = {}
117126
start_time = datetime.now()
@@ -120,11 +129,15 @@ async def poll_messages(
120129
while (datetime.now() - start_time).seconds < timeout:
121130
messages = await client.messages.list(task_id=task_id)
122131

123-
# Sort so tool_request / tool_response appear before agent text in the same poll;
124-
# then by created_at (see _task_message_poll_sort_key).
125-
sorted_messages = sorted(messages, key=_task_message_poll_sort_key)
132+
sorted_messages = sorted(
133+
messages,
134+
key=lambda m: m.created_at if m.created_at else datetime.min.replace(tzinfo=timezone.utc),
135+
)
136+
137+
# Collect (message, hash) for this poll without mutating dedupe state yet, then
138+
# yield tool lifecycle rows before streaming text / other updates in the same batch.
139+
pending: list[tuple[TaskMessage, int | None]] = []
126140

127-
new_messages_found = 0
128141
for message in sorted_messages:
129142
# Check if message passes timestamp filter
130143
if messages_created_after and message.created_at:
@@ -156,16 +169,22 @@ async def poll_messages(
156169
is_updated = message.id in message_content_hashes and message_content_hashes[message.id] != content_hash
157170

158171
if is_new_message or is_updated:
159-
message_content_hashes[message.id] = content_hash
160-
seen_message_ids.add(message.id)
161-
new_messages_found += 1
162-
yield message
172+
pending.append((message, content_hash))
163173
else:
164174
# Original behavior: only yield each message ID once
165175
if is_new_message:
166-
seen_message_ids.add(message.id)
167-
new_messages_found += 1
168-
yield message
176+
pending.append((message, None))
177+
178+
pending.sort(key=_pending_poll_sort_key)
179+
180+
for message, content_hash in pending:
181+
mid = message.id
182+
if not mid:
183+
continue
184+
if yield_updates and content_hash is not None:
185+
message_content_hashes[mid] = content_hash
186+
seen_message_ids.add(mid)
187+
yield message
169188

170189
# Sleep before next poll
171190
await asyncio.sleep(sleep_interval)

0 commit comments

Comments
 (0)