Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,28 @@
# for 'autogenerate' support
from google.adk.sessions.database_session_service import Base

# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# Only operate on tables defined in these metadata objects.
TARGET_METADATAS = (Base.metadata,)
target_metadata = TARGET_METADATAS[0]

_ALLOWED_TABLE_NAMES = frozenset(
table_name
for metadata in TARGET_METADATAS
for table_name in metadata.tables
)


def include_object(obj, name, type_, reflected, compare_to):
"""Only include tables (and their indexes) defined in TARGET_METADATAS."""
if type_ == "table":
return bool(_ALLOWED_TABLE_NAMES) and name in _ALLOWED_TABLE_NAMES
if type_ == "index":
try:
return obj.table.name in _ALLOWED_TABLE_NAMES
except AttributeError:
return False
return True


# other values from the config, defined by the needs of env.py,
# can be acquired:
Expand All @@ -58,6 +78,8 @@ def run_migrations_offline() -> None:
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
include_object=include_object,
version_table="alembic_version_adk",
)

with context.begin_transaction():
Expand All @@ -78,7 +100,12 @@ def run_migrations_online() -> None:
)

with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
context.configure(
connection=connection,
target_metadata=target_metadata,
include_object=include_object,
version_table="alembic_version_adk",
)

with context.begin_transaction():
context.run_migrations()
Expand Down
128 changes: 108 additions & 20 deletions scripts/db_migration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,119 @@ fi
echo " Set sqlalchemy.url"

# --- 3. Set target_metadata in alembic/env.py ---
echo "Configuring ${ENV_FILE}..."

# Edit 1: Uncomment and replace the model import line
sed -i.bak "s/# from myapp import mymodel/from ${MODEL_PATH} import Base/" "${ENV_FILE}"
echo "Writing safe ${ENV_FILE} (only operate on provided metadata tables)..."
cat > "${ENV_FILE}" <<EOF
from logging.config import fileConfig

from sqlalchemy import engine_from_config
from sqlalchemy import pool

from alembic import context

config = context.config

if config.config_file_name is not None:
fileConfig(config.config_file_name)

from ${MODEL_PATH} import Base

TARGET_METADATAS = (Base.metadata,)
target_metadata = TARGET_METADATAS[0]
_ALLOWED_TABLE_NAMES = frozenset(
table_name
for metadata in TARGET_METADATAS
for table_name in metadata.tables
)


def include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return bool(_ALLOWED_TABLE_NAMES) and name in _ALLOWED_TABLE_NAMES
if type_ == "index":
try:
return obj.table.name in _ALLOWED_TABLE_NAMES
except AttributeError:
return False
return True


def run_migrations_offline() -> None:
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
include_object=include_object,
version_table="alembic_version_adk",
)

with context.begin_transaction():
context.run_migrations()


def run_migrations_online() -> None:
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)

with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
include_object=include_object,
version_table="alembic_version_adk",
)

with context.begin_transaction():
context.run_migrations()


if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
EOF
if [ $? -ne 0 ]; then
echo "Error: Failed to set model import in ${ENV_FILE}."
echo "Error: Failed to write ${ENV_FILE}."
exit 1
fi
echo " Set target_metadata and include_object filter"
echo ""

# Edit 2: Set the target_metadata to use the imported Base
sed -i.bak "s/target_metadata = None/target_metadata = Base.metadata/" "${ENV_FILE}"
# --- 4. Clean up backup files ---
echo "Cleaning up backup files..."
rm -f "${INI_FILE}.bak"
rm -f "${ENV_FILE}.bak"

# --- 5. Reset stale alembic_version (if any) ---
echo "Resetting any existing alembic_version entry..."
python - <<'PY'
import configparser
import pathlib

from sqlalchemy import create_engine, text

ini = pathlib.Path("alembic.ini")
parser = configparser.ConfigParser()
parser["DEFAULT"] = {"here": str(ini.parent)}
parser.read(ini)
db_url = parser.get("alembic", "sqlalchemy.url")

engine = create_engine(db_url)
with engine.begin() as conn:
conn.execute(text("DROP TABLE IF EXISTS alembic_version_adk"))
PY
if [ $? -ne 0 ]; then
echo "Error: Failed to set target_metadata in ${ENV_FILE}."
echo "Error: Failed to reset alembic_version table."
exit 1
fi

echo " Set target_metadata"
echo " alembic_version reset (if it existed)."
echo ""

# --- 4. Clean up backup files ---
echo "Cleaning up backup files..."
rm "${INI_FILE}.bak"
rm "${ENV_FILE}.bak"

# --- 5. Run alembic stamp head ---
# --- 6. Run alembic stamp head ---
echo "Running 'alembic stamp head'..."
alembic stamp head
if [ $? -ne 0 ]; then
Expand All @@ -101,7 +189,7 @@ fi
echo "stamping complete."
echo ""

# --- 6. Run alembic upgrade ---
# --- 7. Run alembic revision ---
echo "Running 'alembic revision --autogenerate'..."
alembic revision --autogenerate -m "ADK session DB upgrade"
if [ $? -ne 0 ]; then
Expand All @@ -111,7 +199,7 @@ fi
echo "revision complete."
echo ""

# --- 7. Add import statement to version files ---
# --- 8. Add import statement to version files ---
echo "Adding import statement to version files..."
for f in ${ALEMBIC_DIR}/versions/*.py; do
if [ -f "$f" ]; then
Expand All @@ -130,7 +218,7 @@ done
echo "Import statements added."
echo ""

# --- 8. Run alembic upgrade ---
# --- 9. Run alembic upgrade ---
echo "running 'alembic upgrade'..."
alembic upgrade head
if [ $? -ne 0 ]; then
Expand All @@ -141,4 +229,4 @@ echo "upgrade complete."
echo ""

echo "---"
echo "✅ ADK session DB is Updated!"
echo "✅ ADK session DB is Updated!"