Skip to content

Commit 3f83814

Browse files
giles17Copilot
andcommitted
Fix: gate intermediate status updates behind emit_intermediate flag and add missing test coverage
- Add emit_intermediate parameter to _updates_from_task and _map_a2a_stream - Thread stream flag from run() so only streaming callers see intermediate updates - Add IN_PROGRESS_TASK_STATES guard to emit_intermediate condition - Add role parameter to test helper add_in_progress_task_response - Add clarifying comment on MockA2AClient.send_message batch semantics - Add tests for user role mapping, background precedence, non-streaming behavior, terminal task with no artifacts, and empty parts edge case Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5acf3cb commit 3f83814

2 files changed

Lines changed: 126 additions & 12 deletions

File tree

python/packages/a2a/agent_framework_a2a/_agent.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
313313
self._map_a2a_stream(
314314
a2a_stream,
315315
background=background,
316+
emit_intermediate=stream,
316317
session=provider_session,
317318
session_context=session_context,
318319
),
@@ -327,6 +328,7 @@ async def _map_a2a_stream(
327328
a2a_stream: AsyncIterable[A2AStreamItem],
328329
*,
329330
background: bool = False,
331+
emit_intermediate: bool = False,
330332
session: AgentSession | None = None,
331333
session_context: SessionContext | None = None,
332334
) -> AsyncIterable[AgentResponseUpdate]:
@@ -339,6 +341,10 @@ async def _map_a2a_stream(
339341
background: When False, in-progress task updates are silently
340342
consumed (the stream keeps iterating until a terminal state).
341343
When True, they are yielded with a continuation token.
344+
emit_intermediate: When True, in-progress status updates that
345+
carry message content are yielded to the caller. Typically
346+
set for streaming callers so non-streaming consumers only
347+
receive terminal task outputs.
342348
session: The agent session for context providers.
343349
session_context: The session context for context providers.
344350
"""
@@ -373,7 +379,11 @@ async def _map_a2a_stream(
373379
yield update
374380
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task):
375381
task, _update_event = item
376-
for update in self._updates_from_task(task, background=background):
382+
for update in self._updates_from_task(
383+
task,
384+
background=background,
385+
emit_intermediate=emit_intermediate,
386+
):
377387
all_updates.append(update)
378388
yield update
379389
else:
@@ -389,17 +399,26 @@ async def _map_a2a_stream(
389399
# Task helpers
390400
# ------------------------------------------------------------------
391401

392-
def _updates_from_task(self, task: Task, *, background: bool = False) -> list[AgentResponseUpdate]:
402+
def _updates_from_task(
403+
self,
404+
task: Task,
405+
*,
406+
background: bool = False,
407+
emit_intermediate: bool = False,
408+
) -> list[AgentResponseUpdate]:
393409
"""Convert an A2A Task into AgentResponseUpdate(s).
394410
395411
Terminal tasks produce updates from their artifacts/history.
396412
In-progress tasks produce a continuation token update when
397-
``background=True``. When ``background=False``, any message
398-
content attached to the status update is surfaced; otherwise
399-
the update is silently skipped so the caller keeps consuming
400-
the stream until completion.
413+
``background=True``. When ``emit_intermediate=True`` (typically
414+
set for streaming callers), any message content attached to an
415+
in-progress status update is surfaced; otherwise the update is
416+
silently skipped so the caller keeps consuming the stream until
417+
completion.
401418
"""
402-
if task.status.state in TERMINAL_TASK_STATES:
419+
status = task.status
420+
421+
if status.state in TERMINAL_TASK_STATES:
403422
task_messages = self._parse_messages_from_task(task)
404423
if task_messages:
405424
return [
@@ -414,7 +433,7 @@ def _updates_from_task(self, task: Task, *, background: bool = False) -> list[Ag
414433
]
415434
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]
416435

417-
if background and task.status.state in IN_PROGRESS_TASK_STATES:
436+
if background and status.state in IN_PROGRESS_TASK_STATES:
418437
token = self._build_continuation_token(task)
419438
return [
420439
AgentResponseUpdate(
@@ -427,13 +446,20 @@ def _updates_from_task(self, task: Task, *, background: bool = False) -> list[Ag
427446
]
428447

429448
# Surface message content from in-progress status updates (e.g. working state)
430-
if task.status.message is not None and task.status.message.parts:
431-
contents = self._parse_contents_from_a2a(task.status.message.parts)
449+
# Only emitted when the caller opts in (streaming), so non-streaming
450+
# consumers keep receiving only terminal task outputs.
451+
if (
452+
emit_intermediate
453+
and status.state in IN_PROGRESS_TASK_STATES
454+
and status.message is not None
455+
and status.message.parts
456+
):
457+
contents = self._parse_contents_from_a2a(status.message.parts)
432458
if contents:
433459
return [
434460
AgentResponseUpdate(
435461
contents=contents,
436-
role="assistant" if task.status.message.role == A2ARole.agent else "user",
462+
role="assistant" if status.message.role == A2ARole.agent else "user",
437463
response_id=task.id,
438464
raw_representation=task,
439465
)

python/packages/a2a/tests/test_a2a_agent.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,14 @@ def add_in_progress_task_response(
9292
context_id: str = "test-context",
9393
state: TaskState = TaskState.working,
9494
text: str | None = None,
95+
role: A2ARole = A2ARole.agent,
9596
) -> None:
9697
"""Add a mock in-progress Task response (non-terminal)."""
9798
message = None
9899
if text is not None:
99100
message = A2AMessage(
100101
message_id=str(uuid4()),
101-
role=A2ARole.agent,
102+
role=role,
102103
parts=[Part(root=TextPart(text=text))],
103104
)
104105
status = TaskStatus(state=state, message=message)
@@ -110,6 +111,7 @@ async def send_message(self, message: Any) -> AsyncIterator[Any]:
110111
"""Mock send_message method that yields responses."""
111112
self.call_count += 1
112113

114+
# All queued responses are delivered as a single streaming batch per call.
113115
for response in self.responses:
114116
yield response
115117
self.responses.clear()
@@ -1101,4 +1103,90 @@ async def test_streaming_working_update_without_message_is_skipped(
11011103
assert updates[0].contents[0].text == "Result"
11021104

11031105

1106+
async def test_streaming_working_update_user_role_mapping(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
1107+
"""Test that A2ARole.user in status message maps to role='user'."""
1108+
mock_a2a_client.add_in_progress_task_response("task-u", context_id="ctx-u", text="User echo", role=A2ARole.user)
1109+
mock_a2a_client.add_task_response("task-u", [{"id": "art-u", "content": "Done"}])
1110+
1111+
updates: list[AgentResponseUpdate] = []
1112+
async for update in a2a_agent.run("Hello", stream=True):
1113+
updates.append(update)
1114+
1115+
assert len(updates) == 2
1116+
assert updates[0].contents[0].text == "User echo"
1117+
assert updates[0].role == "user"
1118+
1119+
1120+
async def test_background_with_status_message_yields_continuation_token(
1121+
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
1122+
) -> None:
1123+
"""Test that background=True takes precedence over status message content."""
1124+
mock_a2a_client.add_in_progress_task_response("task-bg", context_id="ctx-bg", text="Should be ignored")
1125+
1126+
updates: list[AgentResponseUpdate] = []
1127+
async for update in a2a_agent.run("Hello", stream=True, background=True):
1128+
updates.append(update)
1129+
1130+
assert len(updates) == 1
1131+
assert updates[0].continuation_token is not None
1132+
assert updates[0].continuation_token["task_id"] == "task-bg"
1133+
assert updates[0].contents == []
1134+
1135+
1136+
async def test_non_streaming_does_not_surface_intermediate_messages(
1137+
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
1138+
) -> None:
1139+
"""Test that run(stream=False) does not include intermediate status messages."""
1140+
mock_a2a_client.add_in_progress_task_response("task-ns", context_id="ctx-ns", text="Intermediate")
1141+
mock_a2a_client.add_task_response("task-ns", [{"id": "art-ns", "content": "Final"}])
1142+
1143+
response = await a2a_agent.run("Hello")
1144+
1145+
assert len(response.messages) == 1
1146+
assert response.messages[0].text == "Final"
1147+
1148+
1149+
async def test_terminal_no_artifacts_after_working_with_content(
1150+
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
1151+
) -> None:
1152+
"""Test that a terminal task with no artifacts after working-state messages does not re-emit the working content."""
1153+
mock_a2a_client.add_in_progress_task_response("task-t", context_id="ctx-t", text="Working on it...")
1154+
# Terminal task with no artifacts and no history
1155+
status = TaskStatus(state=TaskState.completed, message=None)
1156+
task = Task(id="task-t", context_id="ctx-t", status=status)
1157+
mock_a2a_client.responses.append((task, None))
1158+
1159+
updates: list[AgentResponseUpdate] = []
1160+
async for update in a2a_agent.run("Hello", stream=True):
1161+
updates.append(update)
1162+
1163+
assert len(updates) == 2
1164+
assert updates[0].contents[0].text == "Working on it..."
1165+
# Terminal task with no artifacts yields an empty-contents update
1166+
assert updates[1].contents == []
1167+
1168+
1169+
async def test_streaming_working_update_with_empty_parts_is_skipped(
1170+
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
1171+
) -> None:
1172+
"""Test that a working update with status.message but empty parts list is skipped."""
1173+
# Construct a message with an empty parts list (distinct from message=None)
1174+
message = A2AMessage(
1175+
message_id=str(uuid4()),
1176+
role=A2ARole.agent,
1177+
parts=[],
1178+
)
1179+
status = TaskStatus(state=TaskState.working, message=message)
1180+
task = Task(id="task-ep", context_id="ctx-ep", status=status)
1181+
mock_a2a_client.responses.append((task, None))
1182+
mock_a2a_client.add_task_response("task-ep", [{"id": "art-ep", "content": "Result"}])
1183+
1184+
updates: list[AgentResponseUpdate] = []
1185+
async for update in a2a_agent.run("Hello", stream=True):
1186+
updates.append(update)
1187+
1188+
assert len(updates) == 1
1189+
assert updates[0].contents[0].text == "Result"
1190+
1191+
11041192
# endregion

0 commit comments

Comments
 (0)