Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/google/adk/cli/utils/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for local .adk folder persistence."""

from __future__ import annotations

import asyncio
import logging
from pathlib import Path
from typing import Any
from typing import Mapping
from typing import Optional

Expand All @@ -27,6 +29,7 @@
from ...events.event import Event
from ...sessions.base_session_service import BaseSessionService
from ...sessions.base_session_service import GetSessionConfig
from ...sessions.base_session_service import ListSessionsConfig
from ...sessions.base_session_service import ListSessionsResponse
from ...sessions.session import Session
from .dot_adk_folder import dot_adk_folder_for_agent
Expand Down Expand Up @@ -155,15 +158,19 @@ async def create_session(
*,
app_name: str,
user_id: str,
state: Optional[dict[str, object]] = None,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
labels: Optional[dict[str, str]] = None,
) -> Session:
service = await self._get_service(app_name)
return await service.create_session(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
display_name=display_name,
labels=labels,
)

@override
Expand All @@ -189,9 +196,12 @@ async def list_sessions(
*,
app_name: str,
user_id: Optional[str] = None,
config: Optional[ListSessionsConfig] = None,
) -> ListSessionsResponse:
service = await self._get_service(app_name)
return await service.list_sessions(app_name=app_name, user_id=user_id)
return await service.list_sessions(
app_name=app_name, user_id=user_id, config=config
)

@override
async def delete_session(
Expand Down
6 changes: 6 additions & 0 deletions src/google/adk/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsConfig
from .base_session_service import ListSessionsResponse
from .in_memory_session_service import InMemorySessionService
from .session import Session
from .state import State
Expand All @@ -20,7 +23,10 @@
__all__ = [
'BaseSessionService',
'DatabaseSessionService',
'GetSessionConfig',
'InMemorySessionService',
'ListSessionsConfig',
'ListSessionsResponse',
'Session',
'State',
'VertexAiSessionService',
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/sessions/_session_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for session service."""

from __future__ import annotations

from typing import Any
Expand Down
20 changes: 19 additions & 1 deletion src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ class GetSessionConfig(BaseModel):
after_timestamp: Optional[float] = None


class ListSessionsConfig(BaseModel):
"""The configuration of listing sessions."""

labels: Optional[dict[str, str]] = None
"""Filter sessions by labels. Only sessions that have all the specified
labels will be returned."""


class ListSessionsResponse(BaseModel):
"""The response of listing sessions.

Expand All @@ -56,6 +64,8 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
labels: Optional[dict[str, str]] = None,
) -> Session:
"""Creates a new session.

Expand All @@ -65,6 +75,9 @@ async def create_session(
state: the initial state of the session.
session_id: the client-provided id of the session. If not provided, a
generated ID will be used.
display_name: optional display name for the session.
labels: optional labels with user-defined metadata to organize sessions.
Label keys and values can be no longer than 64 characters.

Returns:
session: The newly created session instance.
Expand All @@ -83,14 +96,19 @@ async def get_session(

@abc.abstractmethod
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
config: Optional[ListSessionsConfig] = None,
) -> ListSessionsResponse:
"""Lists all the sessions for a user.

Args:
app_name: The name of the app.
user_id: The ID of the user. If not provided, lists all sessions for all
users.
config: Optional configuration for filtering sessions.

Returns:
A ListSessionsResponse containing the sessions.
Expand Down
38 changes: 35 additions & 3 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ..events.event import Event
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsConfig
from .base_session_service import ListSessionsResponse
from .migration import _schema_check_utils
from .schemas.v0 import Base as BaseV0
Expand Down Expand Up @@ -229,6 +230,8 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
labels: Optional[dict[str, str]] = None,
) -> Session:
# 1. Populate states.
# 2. Build storage session object
Expand Down Expand Up @@ -280,6 +283,8 @@ async def create_session(
user_id=user_id,
id=session_id,
state=session_state,
display_name=display_name,
labels=labels or {},
)
sql_session.add(storage_session)
await sql_session.commit()
Expand Down Expand Up @@ -355,7 +360,11 @@ async def get_session(

@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
config: Optional[ListSessionsConfig] = None,
) -> ListSessionsResponse:
await self._prepare_tables()
schema = self._get_schema_classes()
Expand All @@ -366,6 +375,21 @@ async def list_sessions(
if user_id is not None:
stmt = stmt.filter(schema.StorageSession.user_id == user_id)

labels_filter = config.labels if config else None

# Apply label filter at database level for backends with native JSON
# support (PostgreSQL JSONB). For other backends, filter in Python.
apply_python_filter = False
if labels_filter:
if self.db_engine.dialect.name == "postgresql":
# PostgreSQL JSONB supports efficient containment queries via @>
stmt = stmt.filter(
schema.StorageSession.labels.contains(labels_filter)
)
else:
# For other backends (SQLite, MySQL), filter in Python after fetching
apply_python_filter = True

result = await sql_session.execute(stmt)
results = result.scalars().all()

Expand Down Expand Up @@ -394,6 +418,14 @@ async def list_sessions(

sessions = []
for storage_session in results:
# Apply Python-level label filter for non-PostgreSQL backends
if apply_python_filter and labels_filter:
session_labels = storage_session.labels or {}
if not all(
session_labels.get(k) == v for k, v in labels_filter.items()
):
continue

session_state = storage_session.state
user_state = user_states_map.get(storage_session.user_id, {})
merged_state = _merge_state(app_state, user_state, session_state)
Expand Down Expand Up @@ -436,8 +468,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
if storage_session.update_timestamp_tz > session.last_update_time:
raise ValueError(
"The last_update_time provided in the session object"
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
" earlier than the update_time in the storage_session"
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'}"
" is earlier than the update_time in the storage_session"
f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}."
" Please check if it is a stale session."
)
Expand Down
52 changes: 47 additions & 5 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..events.event import Event
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsConfig
from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
Expand Down Expand Up @@ -58,12 +59,16 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
labels: Optional[dict[str, str]] = None,
) -> Session:
return self._create_session_impl(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
display_name=display_name,
labels=labels,
)

def create_session_sync(
Expand All @@ -73,13 +78,17 @@ def create_session_sync(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
labels: Optional[dict[str, str]] = None,
) -> Session:
logger.warning('Deprecated. Please migrate to the async method.')
return self._create_session_impl(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
display_name=display_name,
labels=labels,
)

def _create_session_impl(
Expand All @@ -89,6 +98,8 @@ def _create_session_impl(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
labels: Optional[dict[str, str]] = None,
) -> Session:
if session_id and self._get_session_impl(
app_name=app_name, user_id=user_id, session_id=session_id
Expand Down Expand Up @@ -116,6 +127,8 @@ def _create_session_impl(
id=session_id,
state=session_state or {},
last_update_time=time.time(),
display_name=display_name,
labels=labels or {},
)

if app_name not in self.sessions:
Expand Down Expand Up @@ -218,20 +231,44 @@ def _merge_state(
][key]
return copied_session

def _matches_labels(
self, session: Session, labels: Optional[dict[str, str]]
) -> bool:
"""Checks if a session has all the specified labels."""
if not labels:
return True
return all(session.labels.get(k) == v for k, v in labels.items())

@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
config: Optional[ListSessionsConfig] = None,
) -> ListSessionsResponse:
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
return self._list_sessions_impl(
app_name=app_name, user_id=user_id, config=config
)

def list_sessions_sync(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
config: Optional[ListSessionsConfig] = None,
) -> ListSessionsResponse:
logger.warning('Deprecated. Please migrate to the async method.')
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
return self._list_sessions_impl(
app_name=app_name, user_id=user_id, config=config
)

def _list_sessions_impl(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
config: Optional[ListSessionsConfig] = None,
) -> ListSessionsResponse:
empty_response = ListSessionsResponse()
if app_name not in self.sessions:
Expand All @@ -240,17 +277,22 @@ def _list_sessions_impl(
return empty_response

sessions_without_events = []
labels_filter = config.labels if config else None

if user_id is None:
for user_id in self.sessions[app_name]:
for session_id in self.sessions[app_name][user_id]:
session = self.sessions[app_name][user_id][session_id]
if not self._matches_labels(session, labels_filter):
continue
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session = self._merge_state(app_name, user_id, copied_session)
sessions_without_events.append(copied_session)
else:
for session in self.sessions[app_name][user_id].values():
if not self._matches_labels(session, labels_filter):
continue
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session = self._merge_state(app_name, user_id, copied_session)
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/sessions/migration/migration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Migration runner to upgrade schemas to the latest version."""

from __future__ import annotations

import logging
Expand Down
9 changes: 9 additions & 0 deletions src/google/adk/sessions/schemas/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ class StorageSession(Base):
PreciseTimestamp, default=func.now(), onupdate=func.now()
)

display_name: Mapped[Optional[str]] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
labels: Mapped[MutableDict[str, str]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)

storage_events: Mapped[list[StorageEvent]] = relationship(
"StorageEvent",
back_populates="storage_session",
Expand Down Expand Up @@ -164,6 +171,8 @@ def to_session(
state=state,
events=events,
last_update_time=self.update_timestamp_tz,
display_name=self.display_name,
labels=self.labels or {},
)


Expand Down
Loading
Loading