Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 58 additions & 30 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import select
from sqlalchemy import update # Import update
from sqlalchemy.engine import make_url
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.asyncio import async_sessionmaker
Expand Down Expand Up @@ -143,13 +144,12 @@ def __init__(self, db_url: str, **kwargs: Any):
try:
engine_kwargs = dict(kwargs)
url = make_url(db_url)
if (
url.get_backend_name() == _SQLITE_DIALECT
and url.database == ":memory:"
):
if url.get_backend_name() == _SQLITE_DIALECT:
engine_kwargs.setdefault("poolclass", StaticPool)
connect_args = dict(engine_kwargs.get("connect_args", {}))
connect_args.setdefault("check_same_thread", False)
# Enforce SERIALIZABLE isolation for SQLite to prevent race conditions.
connect_args.setdefault("isolation_level", "SERIALIZABLE")
engine_kwargs["connect_args"] = connect_args
elif url.get_backend_name() != _SQLITE_DIALECT:
engine_kwargs.setdefault("pool_pre_ping", True)
Expand Down Expand Up @@ -205,7 +205,7 @@ async def _rollback_on_exception_session(
On normal exit the caller is responsible for committing; on any exception
the transaction is explicitly rolled back before the error propagates,
preventing connection-pool exhaustion from lingering invalid transactions.
"""
"""<
async with self.database_session_factory() as sql_session:
try:
yield sql_session
Expand Down Expand Up @@ -369,6 +369,9 @@ async def create_session(
if is_sqlite or is_postgresql:
now = now.replace(tzinfo=None)

# Initialize event_sequence for new sessions
session_state["event_sequence"] = 0

storage_session = schema.StorageSession(
app_name=app_name,
user_id=user_id,
Expand All @@ -385,7 +388,7 @@ async def create_session(
storage_app_state.state, storage_user_state.state, session_state
)
session = storage_session.to_session(
state=merged_state, is_sqlite=is_sqlite
state=merged_state, is_sqlite=is_sqlite, event_sequence=0
)
return session

Expand Down Expand Up @@ -440,6 +443,7 @@ async def get_session(
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
session_state = storage_session.state
event_sequence = session_state.get("event_sequence", 0)

# Merge states
merged_state = _merge_state(app_state, user_state, session_state)
Expand All @@ -448,7 +452,7 @@ async def get_session(
events = [e.to_event() for e in reversed(storage_events)]
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
session = storage_session.to_session(
state=merged_state, events=events, is_sqlite=is_sqlite
state=merged_state, events=events, is_sqlite=is_sqlite, event_sequence=event_sequence
)
return session

Expand Down Expand Up @@ -497,8 +501,10 @@ async def list_sessions(
session_state = storage_session.state
user_state = user_states_map.get(storage_session.user_id, {})
merged_state = _merge_state(app_state, user_state, session_state)
# Pass event_sequence to to_session if it exists in state
event_sequence = session_state.get("event_sequence", 0)
sessions.append(
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite)
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite, event_sequence=event_sequence)
)
return ListSessionsResponse(sessions=sessions)

Expand Down Expand Up @@ -564,6 +570,41 @@ async def append_event(self, session: Session, event: Event) -> Event:
if storage_session is None:
raise ValueError(f"Session {session.id} not found.")

# Get the current event_sequence from storage.
stored_event_sequence = storage_session.state.get("event_sequence", 0)
next_event_sequence = stored_event_sequence + 1

# Atomically update the event_sequence in the database.
# This update includes a WHERE clause that checks the current
# event_sequence, ensuring optimistic concurrency control.
# If another writer has already updated the session, the rowcount
# will be 0, indicating a conflict.
update_session_stmt = (
update(schema.StorageSession)
.where(
schema.StorageSession.app_name == session.app_name,
schema.StorageSession.user_id == session.user_id,
schema.StorageSession.id == session.id,
schema.StorageSession.state["event_sequence"]
== stored_event_sequence,
)
.values(
state=storage_session.state | {"event_sequence": next_event_sequence}
)
)
update_result = await sql_session.execute(update_session_stmt)

if update_result.rowcount == 0:
raise ValueError(
"The session has been modified by another writer. "
"Please reload the session before appending events."
)

# Update the in-memory storage_session to reflect the database change
# This is important for subsequent ORM operations within the same session
# and for the final session.event_sequence update.
storage_session.state["event_sequence"] = next_event_sequence

storage_app_state = await _select_required_state(
sql_session=sql_session,
state_model=schema.StorageAppState,
Expand Down Expand Up @@ -591,27 +632,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
),
)

if (
storage_session.get_update_timestamp(is_sqlite)
> session.last_update_time
):
# Reload the session from storage if it has been updated since it was
# loaded.
app_state = storage_app_state.state
user_state = storage_user_state.state
session_state = storage_session.state
session.state = _merge_state(app_state, user_state, session_state)

stmt = (
select(schema.StorageEvent)
.filter(schema.StorageEvent.app_name == session.app_name)
.filter(schema.StorageEvent.session_id == session.id)
.filter(schema.StorageEvent.user_id == session.user_id)
.order_by(schema.StorageEvent.timestamp.asc())
)
result = await sql_session.stream_scalars(stmt)
storage_events = [e async for e in result]
session.events = [e.to_event() for e in storage_events]
# Removed the old timestamp-based stale session check and reload logic.

# Merge pre-extracted state deltas into storage.
if has_app_delta:
Expand All @@ -623,6 +644,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
storage_user_state.state | state_deltas["user"]
)
if state_deltas["session"]:
# Note: storage_session.state["event_sequence"] was already updated
# by the explicit UPDATE statement above.
# We need to ensure other session state changes are merged.
# The | operator for dictionaries correctly merges.
storage_session.state = (
storage_session.state | state_deltas["session"]
)
Expand All @@ -642,8 +667,11 @@ async def append_event(self, session: Session, event: Event) -> Event:
session.last_update_time = storage_session.get_update_timestamp(
is_sqlite
)
# Update in-memory session's event_sequence after successful commit
session.event_sequence = next_event_sequence

# Also update the in-memory session
# Also update the in-memory session (handled by super().append_event)
# This call might also update last_update_time, but event_sequence is explicitly handled now.
await super().append_event(session=session, event=event)
return event

Expand Down
2 changes: 2 additions & 0 deletions src/google/adk/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ class Session(BaseModel):
call/response, etc."""
last_update_time: float = 0.0
"""The last update time of the session."""
event_sequence: int = 0
"""A monotonically increasing integer for optimistic concurrency control."""