Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,15 +2205,32 @@ def migrate():
default="INFO",
help="Optional. Set the logging level",
)
@click.option(
"--allow_unsafe_unpickling",
is_flag=True,
default=False,
help=(
"Optional. Allow unsafe pickle loading for trusted legacy session"
" databases."
),
)
def cli_migrate_session(
*, source_db_url: str, dest_db_url: str, log_level: str
*,
source_db_url: str,
dest_db_url: str,
log_level: str,
allow_unsafe_unpickling: bool,
):
"""Migrates a session database to the latest schema version."""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
try:
from ..sessions.migration import migration_runner

migration_runner.upgrade(source_db_url, dest_db_url)
migration_runner.upgrade(
source_db_url,
dest_db_url,
allow_unsafe_unpickling=allow_unsafe_unpickling,
)
click.secho("Migration check and upgrade process finished.", fg="green")
except Exception as e:
click.secho(f"Migration failed: {e}", fg="red", err=True)
Expand Down
132 changes: 123 additions & 9 deletions src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import argparse
from datetime import datetime
from datetime import timezone
import io
import json
import logging
import pickle
import sys
from typing import Any
from typing import cast

from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
Expand All @@ -37,6 +39,89 @@

logger = logging.getLogger("google_adk." + __name__)

_ALLOWED_PICKLE_GLOBALS: set[tuple[str, str]] = {
# Builtin containers/primitives.
("builtins", "dict"),
("builtins", "list"),
("builtins", "set"),
("builtins", "tuple"),
("builtins", "str"),
("builtins", "bytes"),
("builtins", "bytearray"),
("builtins", "int"),
("builtins", "float"),
("builtins", "bool"),
# Expected pickled payload for v0 session schema events.
("fastapi.openapi.models", "APIKey"),
("fastapi.openapi.models", "APIKeyIn"),
("fastapi.openapi.models", "HTTPBase"),
("fastapi.openapi.models", "HTTPBearer"),
("fastapi.openapi.models", "OAuth2"),
("fastapi.openapi.models", "OAuthFlow"),
("fastapi.openapi.models", "OAuthFlowAuthorizationCode"),
("fastapi.openapi.models", "OAuthFlowClientCredentials"),
("fastapi.openapi.models", "OAuthFlowImplicit"),
("fastapi.openapi.models", "OAuthFlowPassword"),
("fastapi.openapi.models", "OAuthFlows"),
("fastapi.openapi.models", "OpenIdConnect"),
("fastapi.openapi.models", "SecurityBase"),
("fastapi.openapi.models", "SecurityScheme"),
("fastapi.openapi.models", "SecuritySchemeType"),
("google.adk.auth.auth_credential", "AuthCredential"),
("google.adk.auth.auth_credential", "AuthCredentialTypes"),
("google.adk.auth.auth_credential", "HttpAuth"),
("google.adk.auth.auth_credential", "HttpCredentials"),
("google.adk.auth.auth_credential", "OAuth2Auth"),
("google.adk.auth.auth_credential", "ServiceAccountCredential"),
("google.adk.auth.auth_schemes", "CustomAuthScheme"),
("google.adk.auth.auth_schemes", "ExtendedOAuth2"),
("google.adk.auth.auth_schemes", "OAuthGrantType"),
("google.adk.auth.auth_schemes", "OpenIdConnectWithConfig"),
("google.adk.auth.auth_tool", "AuthConfig"),
("google.adk.events.event_actions", "EventActions"),
("google.adk.events.event_actions", "EventCompaction"),
("google.adk.tools.tool_confirmation", "ToolConfirmation"),
("google.genai.types", "Blob"),
("google.genai.types", "CodeExecutionResult"),
("google.genai.types", "Content"),
("google.genai.types", "ExecutableCode"),
("google.genai.types", "FileData"),
("google.genai.types", "FunctionCall"),
("google.genai.types", "FunctionResponse"),
("google.genai.types", "FunctionResponseBlob"),
("google.genai.types", "FunctionResponseFileData"),
("google.genai.types", "FunctionResponsePart"),
("google.genai.types", "Part"),
("google.genai.types", "PartMediaResolution"),
("google.genai.types", "VideoMetadata"),
}


class _RestrictedUnpickler(pickle.Unpickler):
"""Restricted unpickler for migrating legacy v0 schema actions.

The v0 session schema stored `EventActions` as a pickled blob. During
migration we treat the raw bytes read from the source DB as untrusted input
and only allow the minimum set of safe globals needed to reconstruct
`EventActions`.
"""

def find_class(self, module: str, name: str) -> Any: # noqa: ANN001
if (module, name) in _ALLOWED_PICKLE_GLOBALS:
return super().find_class(module, name)
raise pickle.UnpicklingError(
f"Blocked global during migration unpickle: {module}.{name}"
)


def _restricted_pickle_loads(
data: bytes, *, allow_unsafe_unpickling: bool = False
) -> Any:
"""Load a pickle payload using the restricted unpickler by default."""
if allow_unsafe_unpickling:
return pickle.loads(data)
return _RestrictedUnpickler(io.BytesIO(data)).load()


def _to_datetime_obj(val: Any) -> datetime | Any:
"""Converts string to datetime if needed."""
Expand All @@ -51,15 +136,19 @@ def _to_datetime_obj(val: Any) -> datetime | Any:
return val


def _row_to_event(row: dict) -> Event:
def _row_to_event(
row: dict[str, Any], *, allow_unsafe_unpickling: bool = False
) -> Event:
"""Converts event row (dict) to event object, handling missing columns and deserializing."""

actions_val = row.get("actions")
actions = None
if actions_val is not None:
try:
if isinstance(actions_val, bytes):
actions = pickle.loads(actions_val)
actions = _restricted_pickle_loads(
actions_val, allow_unsafe_unpickling=allow_unsafe_unpickling
)
else: # for spanner - it might return object directly
actions = actions_val
except Exception as e:
Expand All @@ -75,7 +164,7 @@ def _row_to_event(row: dict) -> Event:
else:
actions = EventActions()

def _safe_json_load(val):
def _safe_json_load(val: Any) -> dict[str, Any] | None:
data = None
if isinstance(val, str):
try:
Expand All @@ -85,7 +174,7 @@ def _safe_json_load(val):
return None
elif isinstance(val, dict):
data = val # for postgres JSONB
return data
return cast(dict[str, Any] | None, data)

content_dict = _safe_json_load(row.get("content"))
grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata"))
Expand Down Expand Up @@ -147,13 +236,13 @@ def _safe_json_load(val):
)


def _get_state_dict(state_val: Any) -> dict:
def _get_state_dict(state_val: Any) -> dict[str, Any]:
"""Safely load dict from JSON string or return dict if already dict."""
if isinstance(state_val, dict):
return state_val
if isinstance(state_val, str):
try:
return json.loads(state_val)
return cast(dict[str, Any], json.loads(state_val))
except json.JSONDecodeError:
logger.warning(
"Failed to parse state JSON string, defaulting to empty dict."
Expand All @@ -163,7 +252,11 @@ def _get_state_dict(state_val: Any) -> dict:


# --- Migration Logic ---
def migrate(source_db_url: str, dest_db_url: str):
def migrate(
source_db_url: str,
dest_db_url: str,
allow_unsafe_unpickling: bool = False,
) -> None:
"""Migrates data from old pickle schema to new JSON schema."""
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
# This allows users to provide URLs like 'postgresql+asyncpg://...' and have
Expand All @@ -172,6 +265,11 @@ def migrate(source_db_url: str, dest_db_url: str):
dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url)

logger.info(f"Connecting to source database: {source_db_url}")
if allow_unsafe_unpickling:
logger.warning(
"Unsafe pickle migration mode is enabled. Only use this with a trusted"
" source database."
)
try:
source_engine = create_engine(source_sync_url)
SourceSession = sessionmaker(bind=source_engine)
Expand Down Expand Up @@ -265,7 +363,10 @@ def migrate(source_db_url: str, dest_db_url: str):
text("SELECT * FROM events")
).mappings():
try:
event_obj = _row_to_event(dict(row))
event_obj = _row_to_event(
dict(row),
allow_unsafe_unpickling=allow_unsafe_unpickling,
)
new_event = v1.StorageEvent(
id=event_obj.id,
app_name=row["app_name"],
Expand Down Expand Up @@ -309,9 +410,22 @@ def migrate(source_db_url: str, dest_db_url: str):
required=True,
help="SQLAlchemy URL of destination database",
)
parser.add_argument(
"--allow_unsafe_unpickling",
"--allow-unsafe-unpickling",
action="store_true",
help=(
"Allow legacy pickle payloads to use Python's unsafe pickle loader."
" Only use this with a trusted source database."
),
)
args = parser.parse_args()
try:
migrate(args.source_db_url, args.dest_db_url)
migrate(
args.source_db_url,
args.dest_db_url,
allow_unsafe_unpickling=args.allow_unsafe_unpickling,
)
except Exception as e:
logger.error(f"Migration failed: {e}")
sys.exit(1)
18 changes: 16 additions & 2 deletions src/google/adk/sessions/migration/migration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION


def upgrade(source_db_url: str, dest_db_url: str):
def upgrade(
source_db_url: str,
dest_db_url: str,
allow_unsafe_unpickling: bool = False,
) -> None:
"""Migrates a database from its current version to the latest version.

If the source database schema is older than the latest version, this
Expand All @@ -61,6 +65,9 @@ def upgrade(source_db_url: str, dest_db_url: str):
source_db_url: The SQLAlchemy URL of the database to migrate from.
dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be
different from source_db_url.
allow_unsafe_unpickling: If true, use Python's unsafe pickle loader for the
legacy pickle migration step. Only use this with a trusted source
database.

Raises:
RuntimeError: If source_db_url and dest_db_url are the same, or if no
Expand Down Expand Up @@ -113,7 +120,14 @@ def upgrade(source_db_url: str, dest_db_url: str):
logger.info(
f"Migrating from {in_url} to {out_url} (schema v{end_version})..."
)
migrate_func(in_url, out_url)
if migrate_func is migrate_from_sqlalchemy_pickle.migrate:
migrate_func(
in_url,
out_url,
allow_unsafe_unpickling=allow_unsafe_unpickling,
)
else:
migrate_func(in_url, out_url)
logger.info("Finished migration step to schema %s.", end_version)
# The output of this step becomes the input for the next step.
in_url = out_url
Expand Down
Loading