Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
AgentResponseUpdate,
AgentSession,
BaseAgent,
BaseHistoryProvider,
Content,
ContinuationToken,
Message,
ResponseStream,
SessionContext,
normalize_messages,
prepend_agent_framework_to_user_agent,
)
Expand Down Expand Up @@ -284,17 +286,36 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
When stream=True: A ResponseStream of AgentResponseUpdate items.
"""
del function_invocation_kwargs, client_kwargs, kwargs
normalized_messages = normalize_messages(messages)

if continuation_token is not None:
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe(
TaskIdParams(id=continuation_token["task_id"])
)
else:
normalized_messages = normalize_messages(messages)
if not normalized_messages:
raise ValueError("At least one message is required when starting a new task (no continuation_token).")
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1])
a2a_stream = self.client.send_message(a2a_message)

provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()

session_context = SessionContext(
session_id=provider_session.session_id if provider_session else None,
service_session_id=provider_session.service_session_id if provider_session else None,
input_messages=normalized_messages or [],
options={},
)

response = ResponseStream(
self._map_a2a_stream(a2a_stream, background=background),
self._map_a2a_stream(
a2a_stream,
background=background,
session=provider_session,
session_context=session_context,
),
finalizer=AgentResponse.from_updates,
)
if stream:
Expand All @@ -306,6 +327,8 @@ async def _map_a2a_stream(
a2a_stream: AsyncIterable[A2AStreamItem],
*,
background: bool = False,
session: AgentSession | None = None,
session_context: SessionContext | None = None,
) -> AsyncIterable[AgentResponseUpdate]:
"""Map raw A2A protocol items to AgentResponseUpdates.

Expand All @@ -316,24 +339,52 @@ async def _map_a2a_stream(
background: When False, in-progress task updates are silently
consumed (the stream keeps iterating until a terminal state).
When True, they are yielded with a continuation token.
session: The agent session for context providers.
session_context: The session context for context providers.
"""
if session_context is None:
session_context = SessionContext(input_messages=[], options={})

# Run before_run providers (forward order)
for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
continue
if session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.before_run(
agent=self, # type: ignore[arg-type]
session=session,
context=session_context,
state=session.state.setdefault(provider.source_id, {}),
)

all_updates: list[AgentResponseUpdate] = []
async for item in a2a_stream:
if isinstance(item, A2AMessage):
# Process A2A Message
contents = self._parse_contents_from_a2a(item.parts)
yield AgentResponseUpdate(
update = AgentResponseUpdate(
contents=contents,
role="assistant" if item.role == A2ARole.agent else "user",
response_id=str(getattr(item, "message_id", uuid.uuid4())),
raw_representation=item,
)
all_updates.append(update)
yield update
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task):
task, _update_event = item
for update in self._updates_from_task(task, background=background):
all_updates.append(update)
yield update
else:
raise NotImplementedError("Only Message and Task responses are supported")

# Set the response on the context for after_run providers
if all_updates:
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]

await self._run_after_providers(session=session, context=session_context)

# ------------------------------------------------------------------
# Task helpers
# ------------------------------------------------------------------
Expand Down
190 changes: 189 additions & 1 deletion python/packages/a2a/tests/test_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
from agent_framework import (
AgentResponse,
AgentResponseUpdate,
AgentSession,
BaseContextProvider,
Content,
Message,
SessionContext,
)
from agent_framework.a2a import A2AAgent
from pytest import fixture, raises
from pytest import fixture, mark, raises

from agent_framework_a2a import A2AContinuationToken
from agent_framework_a2a._agent import _get_uri_data # type: ignore
Expand Down Expand Up @@ -851,3 +854,188 @@ async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2A


# endregion


# region Context Provider Tests


class TrackingContextProvider(BaseContextProvider):
"""A context provider that records when before_run and after_run are called."""

def __init__(self) -> None:
super().__init__(source_id="tracking-provider")
self.before_run_called = False
self.after_run_called = False
self.before_run_context: SessionContext | None = None
self.after_run_context: SessionContext | None = None

async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: SessionContext,
state: dict[str, Any],
) -> None:
self.before_run_called = True
self.before_run_context = context

async def after_run(
self,
*,
agent: Any,
session: AgentSession,
context: SessionContext,
state: dict[str, Any],
) -> None:
self.after_run_called = True
self.after_run_context = context


async def test_run_invokes_context_providers(mock_a2a_client: MockA2AClient) -> None:
"""Test that context providers are invoked during non-streaming run."""
provider = TrackingContextProvider()
agent = A2AAgent(
name="Test Agent",
client=mock_a2a_client,
context_providers=[provider],
http_client=None,
)
mock_a2a_client.add_message_response("msg-1", "Hello from A2A")
session = agent.create_session()

response = await agent.run("Hello", session=session)

assert provider.before_run_called
assert provider.after_run_called
assert response.text == "Hello from A2A"


async def test_run_streaming_invokes_context_providers(mock_a2a_client: MockA2AClient) -> None:
"""Test that context providers are invoked during streaming run."""
provider = TrackingContextProvider()
agent = A2AAgent(
name="Test Agent",
client=mock_a2a_client,
context_providers=[provider],
http_client=None,
)
mock_a2a_client.add_message_response("msg-1", "Streamed response")
session = agent.create_session()

stream = agent.run("Hello", stream=True, session=session)
updates = []
async for update in stream:
updates.append(update)

assert provider.before_run_called
assert provider.after_run_called
assert len(updates) == 1
assert updates[0].text == "Streamed response"


async def test_context_providers_receive_response(mock_a2a_client: MockA2AClient) -> None:
"""Test that after_run providers can access the response via session context."""
provider = TrackingContextProvider()
agent = A2AAgent(
name="Test Agent",
client=mock_a2a_client,
context_providers=[provider],
http_client=None,
)
mock_a2a_client.add_message_response("msg-1", "Response text")
session = agent.create_session()

await agent.run("Hello", session=session)

assert provider.after_run_context is not None
assert provider.after_run_context.response is not None
assert provider.after_run_context.response.text == "Response text"


async def test_context_providers_receive_input_messages(mock_a2a_client: MockA2AClient) -> None:
"""Test that before_run providers can access input messages via session context."""
provider = TrackingContextProvider()
agent = A2AAgent(
name="Test Agent",
client=mock_a2a_client,
context_providers=[provider],
http_client=None,
)
mock_a2a_client.add_message_response("msg-1", "Reply")
session = agent.create_session()

await agent.run("Hello world", session=session)

assert provider.before_run_context is not None
assert len(provider.before_run_context.input_messages) > 0
assert provider.before_run_context.input_messages[-1].text == "Hello world"


async def test_run_without_context_providers(mock_a2a_client: MockA2AClient) -> None:
"""Test that run works normally when no context providers are configured."""
agent = A2AAgent(
name="Test Agent",
client=mock_a2a_client,
http_client=None,
)
mock_a2a_client.add_message_response("msg-1", "Hello")

response = await agent.run("Hello")

assert response.text == "Hello"


async def test_run_creates_session_for_providers_when_none_provided(mock_a2a_client: MockA2AClient) -> None:
"""Test that a session is auto-created when context providers are configured but no session is passed."""
provider = TrackingContextProvider()
agent = A2AAgent(
name="Test Agent",
client=mock_a2a_client,
context_providers=[provider],
http_client=None,
)
mock_a2a_client.add_message_response("msg-1", "Hello")

await agent.run("Hello")

assert provider.before_run_called
assert provider.after_run_called


@mark.parametrize("messages", [None, []])
async def test_run_raises_when_no_messages_and_no_continuation_token(
mock_a2a_client: MockA2AClient, messages: list[str] | None
) -> None:
"""Test that run() raises ValueError when messages is None/empty and no continuation_token is provided."""
agent = A2AAgent(
name="Test Agent",
client=mock_a2a_client,
http_client=None,
)

with raises(ValueError, match="At least one message is required"):
await agent.run(messages)


async def test_run_with_continuation_token_does_not_require_messages(mock_a2a_client: MockA2AClient) -> None:
"""Test that run() does not raise when messages is None but a continuation_token is provided."""
task = Task(
id="task-cont",
context_id="ctx-cont",
status=TaskStatus(state=TaskState.completed, message=None),
)
mock_a2a_client.resubscribe_responses.append((task, None))

agent = A2AAgent(
name="Test Agent",
client=mock_a2a_client,
http_client=None,
)

token = A2AContinuationToken(task_id="task-cont", context_id="ctx-cont")
response = await agent.run(None, continuation_token=token)
assert response is not None


# endregion
Loading