Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 46 additions & 18 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
memory_service: Optional[BaseMemoryService] = None,
credential_service: Optional[BaseCredentialService] = None,
plugin_close_timeout: float = 5.0,
auto_create_session: bool = False,
):
"""Initializes the Runner.

Expand All @@ -175,6 +176,9 @@ def __init__(
memory_service: The memory service for the runner.
credential_service: The credential service for the runner.
plugin_close_timeout: The timeout in seconds for plugin close methods.
auto_create_session: Whether to automatically create a session when
not found. Defaults to False. If False, a missing session raises
ValueError with a helpful message.

Raises:
ValueError: If `app` is provided along with `agent` or `plugins`, or if
Expand All @@ -195,6 +199,7 @@ def __init__(
self.plugin_manager = PluginManager(
plugins=plugins, close_timeout=plugin_close_timeout
)
self.auto_create_session = auto_create_session
(
self._agent_origin_app_name,
self._agent_origin_dir,
Expand Down Expand Up @@ -343,9 +348,43 @@ def _format_session_not_found_message(self, session_id: str) -> str:
return message
return (
f'{message}. {self._app_name_alignment_hint} '
'The mismatch prevents the runner from locating the session.'
'The mismatch prevents the runner from locating the session. '
'To automatically create a session when missing, set '
'auto_create_session=True when constructing the runner.'
)

async def _get_or_create_session(
self, *, user_id: str, session_id: str
) -> Session:
"""Gets the session or creates it if auto-creation is enabled.

This helper first attempts to retrieve the session. If not found and
auto_create_session is True, it creates a new session with the provided
identifiers. Otherwise, it raises a ValueError with a helpful message.

Args:
user_id: The user ID of the session.
session_id: The session ID of the session.

Returns:
The existing or newly created `Session`.

Raises:
ValueError: If the session is not found and auto_create_session is False.
"""
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if not session:
if self.auto_create_session:
session = await self.session_service.create_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
else:
message = self._format_session_not_found_message(session_id)
raise ValueError(message)
return session

def run(
self,
*,
Expand Down Expand Up @@ -455,12 +494,9 @@ async def _run_with_trace(
invocation_id: Optional[str] = None,
) -> AsyncGenerator[Event, None]:
with tracer.start_as_current_span('invocation'):
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
)
if not session:
message = self._format_session_not_found_message(session_id)
raise ValueError(message)
if not invocation_id and not new_message:
raise ValueError(
'Running an agent requires either a new_message or an '
Expand Down Expand Up @@ -534,12 +570,9 @@ async def rewind_async(
rewind_before_invocation_id: str,
) -> None:
"""Rewinds the session to before the specified invocation."""
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
)
if not session:
raise ValueError(f'Session not found: {session_id}')

rewind_event_index = -1
for i, event in enumerate(session.events):
if event.invocation_id == rewind_before_invocation_id:
Expand Down Expand Up @@ -967,14 +1000,9 @@ async def run_live(
stacklevel=2,
)
if not session:
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
)
if not session:
raise ValueError(
f'Session not found for user id: {user_id} and session id:'
f' {session_id}'
)
invocation_context = self._new_invocation_context_for_live(
session,
live_request_queue=live_request_queue,
Expand Down
121 changes: 121 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ async def _run_async_impl(
)


class MockLiveAgent(BaseAgent):
"""Mock live agent for unit testing."""

def __init__(self, name: str):
super().__init__(name=name, sub_agents=[])

async def _run_live_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
content=types.Content(
role="model", parts=[types.Part(text="live hello")]
),
)


class MockLlmAgent(LlmAgent):
"""Mock LLM agent for unit testing."""

Expand Down Expand Up @@ -237,6 +255,109 @@ def _infer_agent_origin(
assert "Ensure the runner app_name matches" in message


@pytest.mark.asyncio
async def test_session_auto_creation():

class RunnerWithMismatch(Runner):

def _infer_agent_origin(
self, agent: BaseAgent
) -> tuple[Optional[str], Optional[Path]]:
del agent
return "expected_app", Path("/workspace/agents/expected_app")

session_service = InMemorySessionService()
runner = RunnerWithMismatch(
app_name="expected_app",
agent=MockLlmAgent("test_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
auto_create_session=True,
)

agen = runner.run_async(
user_id="user",
session_id="missing",
new_message=types.Content(role="user", parts=[types.Part(text="hi")]),
)

event = await agen.__anext__()
await agen.aclose()

# Verify that session_id="missing" doesn't error out - session is auto-created
assert event.author == "test_agent"
assert event.content.parts[0].text == "Test LLM response"


@pytest.mark.asyncio
async def test_rewind_auto_create_session_on_missing_session():
"""When auto_create_session=True, rewind should create session if missing.

The newly created session won't contain the target invocation, so
`rewind_async` should raise an Invocation ID not found error (rather than
a session not found error), demonstrating auto-creation occurred.
"""
session_service = InMemorySessionService()
runner = Runner(
app_name="auto_create_app",
agent=MockLlmAgent("agent_for_rewind"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
auto_create_session=True,
)

with pytest.raises(ValueError, match=r"Invocation ID not found: inv_missing"):
await runner.rewind_async(
user_id="user",
session_id="missing",
rewind_before_invocation_id="inv_missing",
)

# Verify the session actually exists now due to auto-creation.
session = await session_service.get_session(
app_name="auto_create_app", user_id="user", session_id="missing"
)
assert session is not None
assert session.app_name == "auto_create_app"


@pytest.mark.asyncio
async def test_run_live_auto_create_session():
"""run_live should auto-create session when missing and yield events."""
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name="live_app",
agent=MockLiveAgent("live_agent"),
session_service=session_service,
artifact_service=artifact_service,
auto_create_session=True,
)

# An empty LiveRequestQueue is sufficient for our mock agent.
from google.adk.agents.live_request_queue import LiveRequestQueue

live_queue = LiveRequestQueue()

agen = runner.run_live(
user_id="user",
session_id="missing",
live_request_queue=live_queue,
)

event = await agen.__anext__()
await agen.aclose()

assert event.author == "live_agent"
assert event.content.parts[0].text == "live hello"

# Session should have been created automatically.
session = await session_service.get_session(
app_name="live_app", user_id="user", session_id="missing"
)
assert session is not None


@pytest.mark.asyncio
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
project_root = tmp_path / "workspace"
Expand Down
Loading