Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
105 changes: 76 additions & 29 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@
import logging
from typing import Any
from typing import Optional
from typing import overload

from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.inspection import inspect
from sqlalchemy.pool import StaticPool
from typing_extensions import override
from tzlocal import get_localzone
Expand Down Expand Up @@ -99,37 +98,85 @@ def __init__(self, version: str):
class DatabaseSessionService(BaseSessionService):
"""A session service that uses a database for storage."""

def __init__(self, db_url: str, **kwargs: Any):
"""Initializes the database session service with a database URL."""
# 1. Create DB engine for db connection
@overload
def __init__(
self,
db_url: str,
**kwargs: Any,
) -> None:
"""Initializes the database session service with a database URL.

Args:
db_url: Database URL string for creating a new engine.
**kwargs: Additional keyword arguments passed to create_async_engine.
"""

@overload
def __init__(
self,
*,
db_engine: AsyncEngine,
) -> None:
"""Initializes the database session service with an existing SQLAlchemy AsyncEngine.

Args:
db_engine: Existing SQLAlchemy AsyncEngine instance to use.
"""

def __init__(
self,
db_url: Optional[str] = None,
db_engine: Optional[AsyncEngine] = None,
**kwargs: Any,
) -> None:
"""Initializes the database session service.

Args:
db_url: Database URL string for creating a new engine. Mutually exclusive
with db_engine.
db_engine: Existing AsyncEngine instance. Mutually exclusive with db_url.
**kwargs: Additional keyword arguments passed to create_async_engine when
db_url is provided. Ignored when db_engine is provided.

Raises:
ValueError: If neither or both db_url and db_engine are provided, or if
engine creation fails.
"""
if (db_url is None) == (db_engine is None):
raise ValueError(
"Exactly one of 'db_url' or 'db_engine' must be provided."
)

# 1. Create or use provided DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properties
try:
engine_kwargs = dict(kwargs)
url = make_url(db_url)
if url.get_backend_name() == "sqlite" and url.database == ":memory:":
engine_kwargs.setdefault("poolclass", StaticPool)
connect_args = dict(engine_kwargs.get("connect_args", {}))
connect_args.setdefault("check_same_thread", False)
engine_kwargs["connect_args"] = connect_args

db_engine = create_async_engine(db_url, **engine_kwargs)
if db_engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)

except Exception as e:
if isinstance(e, ArgumentError):
raise ValueError(
f"Invalid database URL format or argument '{db_url}'."
) from e
if isinstance(e, ImportError):
if db_engine is None:
try:
engine_kwargs = dict(kwargs)
url = make_url(db_url)
if url.get_backend_name() == "sqlite" and url.database == ":memory:":
engine_kwargs.setdefault("poolclass", StaticPool)
connect_args = dict(engine_kwargs.get("connect_args", {}))
connect_args.setdefault("check_same_thread", False)
engine_kwargs["connect_args"] = connect_args

db_engine = create_async_engine(db_url, **engine_kwargs)
if db_engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)

except Exception as e:
if isinstance(e, ArgumentError):
raise ValueError(
f"Invalid database URL format or argument '{db_url}'."
) from e
if isinstance(e, ImportError):
raise ValueError(
f"Database related module not found for URL '{db_url}'."
) from e
raise ValueError(
f"Database related module not found for URL '{db_url}'."
f"Failed to create database engine for URL '{db_url}'"
) from e
raise ValueError(
f"Failed to create database engine for URL '{db_url}'"
) from e

# Get the local timezone
local_timezone = get_localzone()
Expand Down
Loading