1717from agentex .types .agent_rpc_params import ParamsSendEventRequest
1818from agentex .types .agent_rpc_result import StreamTaskMessageDone , StreamTaskMessageFull
1919from agentex .types .text_content_param import TextContentParam
20-
21-
22- def _task_message_poll_sort_key (message : TaskMessage ) -> tuple [int , datetime ]:
23- """Order messages within one poll so tool lifecycle rows precede other content.
24-
25- Streaming assistant ``text`` rows are often created before ``tool_request`` /
26- ``tool_response`` rows from the same model turn (earlier ``created_at``). Sorting
27- only by ``created_at`` makes consumers that ``break`` on DONE agent text exit the
28- poll generator before tool rows in the same ``list()`` batch are yielded.
29- """
30- ts = message .created_at if message .created_at else datetime .min .replace (tzinfo = timezone .utc )
31- ctype = getattr (message .content , "type" , None ) if message .content else None
32- phase = 0 if ctype in ("tool_request" , "tool_response" ) else 1
20+ from agentex .types .tool_request_content import ToolRequestContent
21+ from agentex .types .tool_response_content import ToolResponseContent
22+
23+
24+ def _is_tool_lifecycle_message (message : TaskMessage ) -> bool :
25+ """True for tool request/response rows (by model type or ``content.type`` string)."""
26+ c = message .content
27+ if c is None :
28+ return False
29+ if isinstance (c , (ToolRequestContent , ToolResponseContent )):
30+ return True
31+ ctype = getattr (c , "type" , None )
32+ return ctype in ("tool_request" , "tool_response" )
33+
34+
35+ def _pending_poll_sort_key (item : tuple [TaskMessage , int | None ]) -> tuple [int , datetime ]:
36+ """Yield tool lifecycle messages before other content in the same poll batch."""
37+ m , _ = item
38+ ts = m .created_at if m .created_at else datetime .min .replace (tzinfo = timezone .utc )
39+ phase = 0 if _is_tool_lifecycle_message (m ) else 1
3340 return (phase , ts )
3441
3542
@@ -106,12 +113,14 @@ async def poll_messages(
106113 Yields:
107114 TaskMessage objects as they are discovered or updated.
108115
109- Within each poll, ``tool_request`` and ``tool_response`` messages are yielded before
110- other types (when present in the same batch), so streaming tests can stop on DONE
111- agent text without missing tool lifecycle rows.
116+ Within each poll, messages to emit are collected first, then re-ordered so
117+ ``tool_request`` / ``tool_response`` rows are yielded before other content in that
118+ batch. ``tool_response`` can also appear on a later poll than final assistant text;
119+ callers that ``break`` on DONE text should keep polling until they have seen
120+ ``tool_response`` after a ``tool_request`` (see tutorial tests).
112121 """
113122 # Keep track of messages we've already yielded
114- seen_message_ids = set ()
123+ seen_message_ids : set [ str ] = set ()
115124 # Track message content hashes to detect updates (for streaming)
116125 message_content_hashes : dict [str , int ] = {}
117126 start_time = datetime .now ()
@@ -120,11 +129,15 @@ async def poll_messages(
120129 while (datetime .now () - start_time ).seconds < timeout :
121130 messages = await client .messages .list (task_id = task_id )
122131
123- # Sort so tool_request / tool_response appear before agent text in the same poll;
124- # then by created_at (see _task_message_poll_sort_key).
125- sorted_messages = sorted (messages , key = _task_message_poll_sort_key )
132+ sorted_messages = sorted (
133+ messages ,
134+ key = lambda m : m .created_at if m .created_at else datetime .min .replace (tzinfo = timezone .utc ),
135+ )
136+
137+ # Collect (message, hash) for this poll without mutating dedupe state yet, then
138+ # yield tool lifecycle rows before streaming text / other updates in the same batch.
139+ pending : list [tuple [TaskMessage , int | None ]] = []
126140
127- new_messages_found = 0
128141 for message in sorted_messages :
129142 # Check if message passes timestamp filter
130143 if messages_created_after and message .created_at :
@@ -156,16 +169,22 @@ async def poll_messages(
156169 is_updated = message .id in message_content_hashes and message_content_hashes [message .id ] != content_hash
157170
158171 if is_new_message or is_updated :
159- message_content_hashes [message .id ] = content_hash
160- seen_message_ids .add (message .id )
161- new_messages_found += 1
162- yield message
172+ pending .append ((message , content_hash ))
163173 else :
164174 # Original behavior: only yield each message ID once
165175 if is_new_message :
166- seen_message_ids .add (message .id )
167- new_messages_found += 1
168- yield message
176+ pending .append ((message , None ))
177+
178+ pending .sort (key = _pending_poll_sort_key )
179+
180+ for message , content_hash in pending :
181+ mid = message .id
182+ if not mid :
183+ continue
184+ if yield_updates and content_hash is not None :
185+ message_content_hashes [mid ] = content_hash
186+ seen_message_ids .add (mid )
187+ yield message
169188
170189 # Sleep before next poll
171190 await asyncio .sleep (sleep_interval )
0 commit comments