Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
self._map_a2a_stream(
a2a_stream,
background=background,
emit_intermediate=stream,
session=provider_session,
session_context=session_context,
),
Expand All @@ -327,6 +328,7 @@ async def _map_a2a_stream(
a2a_stream: AsyncIterable[A2AStreamItem],
*,
background: bool = False,
emit_intermediate: bool = False,
session: AgentSession | None = None,
session_context: SessionContext | None = None,
) -> AsyncIterable[AgentResponseUpdate]:
Expand All @@ -339,6 +341,10 @@ async def _map_a2a_stream(
background: When False, in-progress task updates are silently
consumed (the stream keeps iterating until a terminal state).
When True, they are yielded with a continuation token.
emit_intermediate: When True, in-progress status updates that
carry message content are yielded to the caller. Typically
set for streaming callers so non-streaming consumers only
receive terminal task outputs.
session: The agent session for context providers.
session_context: The session context for context providers.
"""
Expand Down Expand Up @@ -373,7 +379,11 @@ async def _map_a2a_stream(
yield update
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task):
task, _update_event = item
for update in self._updates_from_task(task, background=background):
for update in self._updates_from_task(
task,
background=background,
emit_intermediate=emit_intermediate,
):
all_updates.append(update)
yield update
else:
Expand All @@ -389,15 +399,26 @@ async def _map_a2a_stream(
# Task helpers
# ------------------------------------------------------------------

def _updates_from_task(self, task: Task, *, background: bool = False) -> list[AgentResponseUpdate]:
def _updates_from_task(
self,
task: Task,
*,
background: bool = False,
emit_intermediate: bool = False,
) -> list[AgentResponseUpdate]:
"""Convert an A2A Task into AgentResponseUpdate(s).

Terminal tasks produce updates from their artifacts/history.
In-progress tasks produce a continuation token update only when
``background=True``; otherwise they are silently skipped so the
caller keeps consuming the stream until completion.
In-progress tasks produce a continuation token update when
``background=True``. When ``emit_intermediate=True`` (typically
set for streaming callers), any message content attached to an
in-progress status update is surfaced; otherwise the update is
silently skipped so the caller keeps consuming the stream until
completion.
"""
if task.status.state in TERMINAL_TASK_STATES:
status = task.status

if status.state in TERMINAL_TASK_STATES:
task_messages = self._parse_messages_from_task(task)
if task_messages:
return [
Expand All @@ -412,7 +433,7 @@ def _updates_from_task(self, task: Task, *, background: bool = False) -> list[Ag
]
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]

if background and task.status.state in IN_PROGRESS_TASK_STATES:
if background and status.state in IN_PROGRESS_TASK_STATES:
token = self._build_continuation_token(task)
return [
AgentResponseUpdate(
Expand All @@ -424,6 +445,26 @@ def _updates_from_task(self, task: Task, *, background: bool = False) -> list[Ag
)
]

# Surface message content from in-progress status updates (e.g. working state)
# Only emitted when the caller opts in (streaming), so non-streaming
# consumers keep receiving only terminal task outputs.
if (
emit_intermediate
and status.state in IN_PROGRESS_TASK_STATES
and status.message is not None
and status.message.parts
):
contents = self._parse_contents_from_a2a(status.message.parts)
if contents:
return [
AgentResponseUpdate(
contents=contents,
role="assistant" if status.message.role == A2ARole.agent else "user",
response_id=task.id,
raw_representation=task,
)
]

return []

@staticmethod
Expand Down
157 changes: 154 additions & 3 deletions python/packages/a2a/tests/test_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,18 @@ def add_in_progress_task_response(
task_id: str,
context_id: str = "test-context",
state: TaskState = TaskState.working,
text: str | None = None,
role: A2ARole = A2ARole.agent,
) -> None:
"""Add a mock in-progress Task response (non-terminal)."""
status = TaskStatus(state=state, message=None)
message = None
if text is not None:
message = A2AMessage(
message_id=str(uuid4()),
role=role,
parts=[Part(root=TextPart(text=text))],
)
status = TaskStatus(state=state, message=message)
task = Task(id=task_id, context_id=context_id, status=status)
client_event = (task, None)
self.responses.append(client_event)
Expand All @@ -102,9 +111,10 @@ async def send_message(self, message: Any) -> AsyncIterator[Any]:
"""Mock send_message method that yields responses."""
self.call_count += 1

if self.responses:
response = self.responses.pop(0)
# All queued responses are delivered as a single streaming batch per call.
for response in self.responses:
yield response
self.responses.clear()

async def resubscribe(self, request: Any) -> AsyncIterator[Any]:
"""Mock resubscribe method that yields responses."""
Expand Down Expand Up @@ -1039,3 +1049,144 @@ async def test_run_with_continuation_token_does_not_require_messages(mock_a2a_cl


# endregion

# region Streaming with in-progress message content


async def test_streaming_working_updates_yield_message_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that streaming working updates with status.message yield content."""
mock_a2a_client.add_in_progress_task_response("task-w", context_id="ctx-w", text="Processing step 1...")
mock_a2a_client.add_in_progress_task_response("task-w", context_id="ctx-w", text="Processing step 2...")
mock_a2a_client.add_task_response("task-w", [{"id": "art-w", "content": "Final result"}])

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 3
assert updates[0].contents[0].text == "Processing step 1..."
assert updates[1].contents[0].text == "Processing step 2..."
assert updates[2].contents[0].text == "Final result"


async def test_streaming_single_working_update_with_message(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that a single working update with message content is not dropped."""
mock_a2a_client.add_in_progress_task_response("task-s", context_id="ctx-s", text="Thinking...")
mock_a2a_client.add_task_response("task-s", [{"id": "art-s", "content": "Done"}])

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 2
assert updates[0].contents[0].text == "Thinking..."
assert updates[0].role == "assistant"
assert updates[1].contents[0].text == "Done"


async def test_streaming_working_update_without_message_is_skipped(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that working updates without status.message are still silently skipped."""
mock_a2a_client.add_in_progress_task_response("task-n", context_id="ctx-n")
mock_a2a_client.add_task_response("task-n", [{"id": "art-n", "content": "Result"}])

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 1
assert updates[0].contents[0].text == "Result"


async def test_streaming_working_update_user_role_mapping(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
"""Test that A2ARole.user in status message maps to role='user'."""
mock_a2a_client.add_in_progress_task_response("task-u", context_id="ctx-u", text="User echo", role=A2ARole.user)
mock_a2a_client.add_task_response("task-u", [{"id": "art-u", "content": "Done"}])

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 2
assert updates[0].contents[0].text == "User echo"
assert updates[0].role == "user"


async def test_background_with_status_message_yields_continuation_token(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that background=True takes precedence over status message content."""
mock_a2a_client.add_in_progress_task_response("task-bg", context_id="ctx-bg", text="Should be ignored")

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True, background=True):
updates.append(update)

assert len(updates) == 1
assert updates[0].continuation_token is not None
assert updates[0].continuation_token["task_id"] == "task-bg"
assert updates[0].contents == []


async def test_non_streaming_does_not_surface_intermediate_messages(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that run(stream=False) does not include intermediate status messages."""
mock_a2a_client.add_in_progress_task_response("task-ns", context_id="ctx-ns", text="Intermediate")
mock_a2a_client.add_task_response("task-ns", [{"id": "art-ns", "content": "Final"}])

response = await a2a_agent.run("Hello")

assert len(response.messages) == 1
assert response.messages[0].text == "Final"


async def test_terminal_no_artifacts_after_working_with_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that a terminal task with no artifacts after working-state messages does not re-emit the working content."""
mock_a2a_client.add_in_progress_task_response("task-t", context_id="ctx-t", text="Working on it...")
# Terminal task with no artifacts and no history
status = TaskStatus(state=TaskState.completed, message=None)
task = Task(id="task-t", context_id="ctx-t", status=status)
mock_a2a_client.responses.append((task, None))

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 2
assert updates[0].contents[0].text == "Working on it..."
# Terminal task with no artifacts yields an empty-contents update
assert updates[1].contents == []


async def test_streaming_working_update_with_empty_parts_is_skipped(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that a working update with status.message but empty parts list is skipped."""
# Construct a message with an empty parts list (distinct from message=None)
message = A2AMessage(
message_id=str(uuid4()),
role=A2ARole.agent,
parts=[],
)
status = TaskStatus(state=TaskState.working, message=message)
task = Task(id="task-ep", context_id="ctx-ep", status=status)
mock_a2a_client.responses.append((task, None))
mock_a2a_client.add_task_response("task-ep", [{"id": "art-ep", "content": "Result"}])

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 1
assert updates[0].contents[0].text == "Result"


# endregion
Loading