Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/google/adk/agents/remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,10 @@ async def _run_async_impl(
message_id=str(uuid.uuid4()),
parts=message_parts,
role="user",
context_id=context_id,
# Use existing context_id if available (for conversation continuity),
# otherwise use the local session ID to maintain session identity
# across local and remote agents.
context_id=context_id if context_id else ctx.session.id,
)

logger.debug(build_a2a_request_log(a2a_request))
Expand Down
97 changes: 97 additions & 0 deletions tests/unittests/agents/test_remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,103 @@ async def test_run_async_impl_successful_request(self):
in mock_event.custom_metadata
)

async def _run_context_id_test(
self, mock_context_id: str | None, expected_context_id: str
):
"""Helper to test context_id handling in _run_async_impl.

Args:
mock_context_id: The context_id to return from
_construct_message_parts_from_session.
expected_context_id: The expected context_id in the A2AMessage.
"""
from a2a.client import Client as A2AClient
from a2a.types import TextPart

with patch.object(self.agent, "_ensure_resolved"):
with patch.object(
self.agent, "_create_a2a_request_for_user_function_response"
) as mock_create_func:
mock_create_func.return_value = None

with patch.object(
self.agent, "_construct_message_parts_from_session"
) as mock_construct:
mock_a2a_part = Mock(spec=TextPart)
mock_construct.return_value = ([mock_a2a_part], mock_context_id)

# Mock A2A client
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
mock_response = Mock()
mock_send_message = AsyncMock()
mock_send_message.__aiter__.return_value = [mock_response]
mock_a2a_client.send_message.return_value = mock_send_message
self.agent._a2a_client = mock_a2a_client

mock_event = Event(
author=self.agent.name,
invocation_id=self.mock_context.invocation_id,
branch=self.mock_context.branch,
)

with patch.object(self.agent, "_handle_a2a_response") as mock_handle:
mock_handle.return_value = mock_event

with patch(
"google.adk.agents.remote_a2a_agent.build_a2a_request_log"
) as mock_req_log:
with patch(
"google.adk.agents.remote_a2a_agent.build_a2a_response_log"
) as mock_resp_log:
mock_req_log.return_value = "Mock request log"
mock_resp_log.return_value = "Mock response log"

with patch(
"google.adk.agents.remote_a2a_agent.A2AMessage"
) as mock_message_class:
mock_message = Mock(spec=A2AMessage)
mock_message_class.return_value = mock_message
mock_response.model_dump.return_value = {"test": "response"}

# Execute
events = []
async for event in self.agent._run_async_impl(
self.mock_context
):
events.append(event)

# Verify A2AMessage was called with expected context_id
mock_message_class.assert_called_once()
call_kwargs = mock_message_class.call_args[1]
assert call_kwargs["context_id"] == expected_context_id

@pytest.mark.asyncio
async def test_run_async_impl_uses_session_id_when_no_context_id(self):
"""Test that session ID is used as context_id when no existing context.

When _construct_message_parts_from_session returns None for context_id,
the agent should use ctx.session.id to maintain session identity across
local and remote agents.
"""
await self._run_context_id_test(
mock_context_id=None,
expected_context_id=self.mock_session.id,
)

@pytest.mark.asyncio
async def test_run_async_impl_preserves_existing_context_id(self):
"""Test that existing context_id is preserved when available.

When _construct_message_parts_from_session returns a context_id from
a previous remote agent response, that context_id should be used
for conversation continuity.
"""
existing_context_id = "existing-context-456"
await self._run_context_id_test(
mock_context_id=existing_context_id,
expected_context_id=existing_context_id,
)

@pytest.mark.asyncio
async def test_run_async_impl_a2a_client_error(self):
"""Test _run_async_impl when A2A send_message fails."""
Expand Down