Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/google/adk/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from pydantic import field_validator
from pydantic import model_validator

from ..sessions.base_session_service import GetSessionConfig

logger = logging.getLogger('google_adk.' + __name__)


Expand Down Expand Up @@ -254,6 +256,9 @@ class RunConfig(BaseModel):
custom_metadata: Optional[dict[str, Any]] = None
"""Custom metadata for the current invocation."""

get_session_config: Optional[GetSessionConfig] = None
"""Configuration for controlling which events are fetched from session storage."""

@model_validator(mode='before')
@classmethod
def check_for_deprecated_save_live_audio(cls, data: Any) -> Any:
Expand Down
19 changes: 14 additions & 5 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,11 @@ async def _run_with_trace(
invocation_id: Optional[str] = None,
) -> AsyncGenerator[Event, None]:
with tracer.start_as_current_span('invocation'):
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
session = await self.session_service.get_session(
app_name=self.app_name,
user_id=user_id,
session_id=session_id,
config=run_config.get_session_config,
)
if not invocation_id and not new_message:
raise ValueError(
Expand Down Expand Up @@ -1000,8 +1003,11 @@ async def run_live(
stacklevel=2,
)
if not session:
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
session = await self.session_service.get_session(
app_name=self.app_name,
user_id=user_id,
session_id=session_id,
config=run_config.get_session_config,
)
invocation_context = self._new_invocation_context_for_live(
session,
Expand Down Expand Up @@ -1219,7 +1225,10 @@ async def run_debug(
Please use run_async() with proper configuration.
"""
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
app_name=self.app_name,
user_id=user_id,
session_id=session_id,
config=run_config.get_session_config if run_config else None,
Copy link

@thomasErich135 thomasErich135 Dec 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To maintain the same logic, I recommend adding this code at the beginning of the run_debug method and removing the ternary check from the config parameter assignment

run_config = run_config or RunConfig()

)
if not session:
session = await self.session_service.create_session(
Expand Down
131 changes: 131 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from google.adk.events.event import Event
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.runners import Runner
from google.adk.sessions.base_session_service import BaseSessionService
from google.adk.sessions.base_session_service import GetSessionConfig
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.session import Session
from google.adk.tools.function_tool import FunctionTool
Expand Down Expand Up @@ -1321,5 +1323,134 @@ def test_infer_agent_origin_detects_mismatch_for_user_agent(
assert "actual_name" in runner._app_name_alignment_hint


class TestRunnerGetSessionConfig:
"""Tests for Runner get_session_config passing to session service."""

def setup_method(self):
"""Set up test fixtures."""
self.mock_session_service = AsyncMock(spec=BaseSessionService)
self.artifact_service = InMemoryArtifactService()
self.root_agent = MockLlmAgent("root_agent")

# Create a mock session to return
self.mock_session = Session(
id=TEST_SESSION_ID,
app_name=TEST_APP_ID,
user_id=TEST_USER_ID,
events=[],
)

# Configure the mock to return the session
self.mock_session_service.get_session = AsyncMock(
return_value=self.mock_session
)

self.runner = Runner(
app_name=TEST_APP_ID,
agent=self.root_agent,
session_service=self.mock_session_service,
artifact_service=self.artifact_service,
)

@pytest.mark.asyncio
async def test_run_async_passes_get_session_config(self):
"""Test that run_async passes get_session_config to session service."""
config = GetSessionConfig(num_recent_events=5)
run_config = RunConfig(get_session_config=config)

agen = self.runner.run_async(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
new_message=types.Content(role="user", parts=[types.Part(text="test")]),
run_config=run_config,
)

# Consume first event to trigger get_session call
try:
await agen.__anext__()
except StopAsyncIteration:
pass
finally:
await agen.aclose()

# Verify get_session was called with the config
self.mock_session_service.get_session.assert_called_once()
call_kwargs = self.mock_session_service.get_session.call_args.kwargs
assert call_kwargs["config"] == config
assert call_kwargs["app_name"] == TEST_APP_ID
assert call_kwargs["user_id"] == TEST_USER_ID
assert call_kwargs["session_id"] == TEST_SESSION_ID

@pytest.mark.asyncio
async def test_run_async_passes_none_when_no_config(self):
"""Test that run_async passes None when get_session_config is not set."""
agen = self.runner.run_async(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
new_message=types.Content(role="user", parts=[types.Part(text="test")]),
)

# Consume first event to trigger get_session call
try:
await agen.__anext__()
except StopAsyncIteration:
pass
finally:
await agen.aclose()

# Verify get_session was called with config=None
self.mock_session_service.get_session.assert_called_once()
call_kwargs = self.mock_session_service.get_session.call_args.kwargs
assert call_kwargs["config"] is None

@pytest.mark.asyncio
async def test_run_debug_passes_get_session_config(self):
"""Test that run_debug passes get_session_config to session service."""
# Mock create_session as well since run_debug creates session if not found
self.mock_session_service.create_session = AsyncMock(
return_value=self.mock_session
)

config = GetSessionConfig(num_recent_events=10)
run_config = RunConfig(get_session_config=config)

await self.runner.run_debug(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
user_messages="test",
run_config=run_config,
quiet=True,
)

# Verify get_session was called with the config
# Note: get_session is called twice - once in run_debug, once in run_async
assert self.mock_session_service.get_session.call_count == 2
# Check both calls had the config
for call in self.mock_session_service.get_session.call_args_list:
assert call.kwargs["config"] == config

@pytest.mark.asyncio
async def test_run_debug_passes_none_when_no_config(self):
"""Test that run_debug passes None when run_config is not provided."""
# Mock create_session
self.mock_session_service.create_session = AsyncMock(
return_value=self.mock_session
)

await self.runner.run_debug(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
user_messages="test",
quiet=True,
)

# Verify get_session was called with config=None
# Note: get_session is called twice - once in run_debug, once in run_async
assert self.mock_session_service.get_session.call_count == 2
# Check both calls had config=None
for call in self.mock_session_service.get_session.call_args_list:
assert call.kwargs["config"] is None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend adding unit tests for rewind_async and run_live methods


if __name__ == "__main__":
pytest.main([__file__])