Skip to content
Open
22 changes: 20 additions & 2 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,15 +2205,33 @@ def migrate():
default="INFO",
help="Optional. Set the logging level",
)
@click.option( # type: ignore[untyped-decorator]
"--allow-unsafe-unpickling",
"--allow_unsafe_unpickling",
is_flag=True,
default=False,
help=(
"Optional. Allow unsafe pickle loading for trusted legacy session"
" databases."
),
)
def cli_migrate_session(
*, source_db_url: str, dest_db_url: str, log_level: str
*,
source_db_url: str,
dest_db_url: str,
log_level: str,
allow_unsafe_unpickling: bool,
):
"""Migrates a session database to the latest schema version."""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
try:
from ..sessions.migration import migration_runner

migration_runner.upgrade(source_db_url, dest_db_url)
migration_runner.upgrade(
source_db_url,
dest_db_url,
allow_unsafe_unpickling=allow_unsafe_unpickling,
)
click.secho("Migration check and upgrade process finished.", fg="green")
except Exception as e:
click.secho(f"Migration failed: {e}", fg="red", err=True)
Expand Down
150 changes: 139 additions & 11 deletions src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
from datetime import datetime
from datetime import timezone
import io
import json
import logging
import pickle
Expand All @@ -37,6 +38,92 @@

logger = logging.getLogger("google_adk." + __name__)

_ALLOWED_PICKLE_GLOBALS: set[tuple[str, str]] = {
# Builtin containers/primitives.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also allow datetime.datetime and datetime.timezone? It's quite common for legacy state_delta or other Any fields in EventActions to contain timestamp objects.

("builtins", "dict"),
("builtins", "list"),
("builtins", "set"),
("builtins", "tuple"),
("builtins", "str"),
("builtins", "bytes"),
("builtins", "bytearray"),
("builtins", "int"),
("builtins", "float"),
("builtins", "bool"),
("datetime", "datetime"),
("datetime", "timedelta"),
("datetime", "timezone"),
# Expected pickled payload for v0 session schema events.
("fastapi.openapi.models", "APIKey"),
("fastapi.openapi.models", "APIKeyIn"),
("fastapi.openapi.models", "HTTPBase"),
("fastapi.openapi.models", "HTTPBearer"),
("fastapi.openapi.models", "OAuth2"),
("fastapi.openapi.models", "OAuthFlow"),
("fastapi.openapi.models", "OAuthFlowAuthorizationCode"),
("fastapi.openapi.models", "OAuthFlowClientCredentials"),
("fastapi.openapi.models", "OAuthFlowImplicit"),
("fastapi.openapi.models", "OAuthFlowPassword"),
("fastapi.openapi.models", "OAuthFlows"),
("fastapi.openapi.models", "OpenIdConnect"),
("fastapi.openapi.models", "SecurityBase"),
("fastapi.openapi.models", "SecurityScheme"),
("fastapi.openapi.models", "SecuritySchemeType"),
("google.adk.auth.auth_credential", "AuthCredential"),
("google.adk.auth.auth_credential", "AuthCredentialTypes"),
("google.adk.auth.auth_credential", "HttpAuth"),
("google.adk.auth.auth_credential", "HttpCredentials"),
("google.adk.auth.auth_credential", "OAuth2Auth"),
("google.adk.auth.auth_credential", "ServiceAccountCredential"),
("google.adk.auth.auth_schemes", "CustomAuthScheme"),
("google.adk.auth.auth_schemes", "ExtendedOAuth2"),
("google.adk.auth.auth_schemes", "OAuthGrantType"),
("google.adk.auth.auth_schemes", "OpenIdConnectWithConfig"),
("google.adk.auth.auth_tool", "AuthConfig"),
("google.adk.events.event_actions", "EventActions"),
("google.adk.events.event_actions", "EventCompaction"),
("google.adk.tools.tool_confirmation", "ToolConfirmation"),
("google.genai.types", "Blob"),
("google.genai.types", "CodeExecutionResult"),
("google.genai.types", "Content"),
("google.genai.types", "ExecutableCode"),
("google.genai.types", "FileData"),
("google.genai.types", "FunctionCall"),
("google.genai.types", "FunctionResponse"),
("google.genai.types", "FunctionResponseBlob"),
("google.genai.types", "FunctionResponseFileData"),
("google.genai.types", "FunctionResponsePart"),
("google.genai.types", "Part"),
("google.genai.types", "PartMediaResolution"),
("google.genai.types", "VideoMetadata"),
}


class _RestrictedUnpickler(pickle.Unpickler):
"""Restricted unpickler for migrating legacy v0 schema actions.

The v0 session schema stored `EventActions` as a pickled blob. During
migration we treat the raw bytes read from the source DB as untrusted input
and only allow the minimum set of safe globals needed to reconstruct
`EventActions`.
"""

def find_class(self, module: str, name: str) -> Any: # noqa: ANN001
if (module, name) in _ALLOWED_PICKLE_GLOBALS:
return super().find_class(module, name)
raise pickle.UnpicklingError(
f"Blocked global during migration unpickle: {module}.{name}"
)


def _restricted_pickle_loads(
data: bytes, *, allow_unsafe_unpickling: bool = False
) -> Any:
"""Load a pickle payload using the restricted unpickler by default."""
if allow_unsafe_unpickling:
return pickle.loads(data)
return _RestrictedUnpickler(io.BytesIO(data)).load()


def _to_datetime_obj(val: Any) -> datetime | Any:
"""Converts string to datetime if needed."""
Expand All @@ -51,15 +138,19 @@ def _to_datetime_obj(val: Any) -> datetime | Any:
return val


def _row_to_event(row: dict) -> Event:
def _row_to_event(
row: dict[str, Any], *, allow_unsafe_unpickling: bool = False
) -> Event:
"""Converts event row (dict) to event object, handling missing columns and deserializing."""

actions_val = row.get("actions")
actions = None
if actions_val is not None:
try:
if isinstance(actions_val, bytes):
actions = pickle.loads(actions_val)
actions = _restricted_pickle_loads(
actions_val, allow_unsafe_unpickling=allow_unsafe_unpickling
)
else: # for spanner - it might return object directly
actions = actions_val
except Exception as e:
Expand All @@ -75,17 +166,25 @@ def _row_to_event(row: dict) -> Event:
else:
actions = EventActions()

def _safe_json_load(val):
data = None
def _safe_json_load(val: Any) -> dict[str, Any] | None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint suggests it returns a dict[str, Any], but json.loads could return a list. While it's unlikely for these specific columns, it might be safer to use dict[str, Any] | list[Any] | None or verify it's a dict. Also, the cast below might hide issues if it's actually a list.

if isinstance(val, str):
try:
data = json.loads(val)
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON for event {row.get('id')}")
return None
elif isinstance(val, dict):
data = val # for postgres JSONB
return data
return val # for postgres JSONB
else:
return None

if isinstance(data, dict):
return data
logger.warning(
f"Expected JSON object for event {row.get('id')}, got"
f" {type(data).__name__}."
)
return None

content_dict = _safe_json_load(row.get("content"))
grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata"))
Expand Down Expand Up @@ -147,23 +246,31 @@ def _safe_json_load(val):
)


def _get_state_dict(state_val: Any) -> dict:
def _get_state_dict(state_val: Any) -> dict[str, Any]:
"""Safely load dict from JSON string or return dict if already dict."""
if isinstance(state_val, dict):
return state_val
if isinstance(state_val, str):
try:
return json.loads(state_val)
data = json.loads(state_val)
except json.JSONDecodeError:
logger.warning(
"Failed to parse state JSON string, defaulting to empty dict."
)
return {}
if isinstance(data, dict):
return data
logger.warning("State JSON was not an object, defaulting to empty dict.")
return {}
return {}


# --- Migration Logic ---
def migrate(source_db_url: str, dest_db_url: str):
def migrate(
source_db_url: str,
dest_db_url: str,
allow_unsafe_unpickling: bool = False,
) -> None:
"""Migrates data from old pickle schema to new JSON schema."""
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
# This allows users to provide URLs like 'postgresql+asyncpg://...' and have
Expand All @@ -172,6 +279,11 @@ def migrate(source_db_url: str, dest_db_url: str):
dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url)

logger.info(f"Connecting to source database: {source_db_url}")
if allow_unsafe_unpickling:
logger.warning(
"Unsafe pickle migration mode is enabled. Only use this with a trusted"
" source database."
)
try:
source_engine = create_engine(source_sync_url)
SourceSession = sessionmaker(bind=source_engine)
Expand Down Expand Up @@ -265,7 +377,10 @@ def migrate(source_db_url: str, dest_db_url: str):
text("SELECT * FROM events")
).mappings():
try:
event_obj = _row_to_event(dict(row))
event_obj = _row_to_event(
dict(row),
allow_unsafe_unpickling=allow_unsafe_unpickling,
)
new_event = v1.StorageEvent(
id=event_obj.id,
app_name=row["app_name"],
Expand Down Expand Up @@ -309,9 +424,22 @@ def migrate(source_db_url: str, dest_db_url: str):
required=True,
help="SQLAlchemy URL of destination database",
)
parser.add_argument(
"--allow_unsafe_unpickling",
"--allow-unsafe-unpickling",
action="store_true",
help=(
"Allow legacy pickle payloads to use Python's unsafe pickle loader."
" Only use this with a trusted source database."
),
)
args = parser.parse_args()
try:
migrate(args.source_db_url, args.dest_db_url)
migrate(
args.source_db_url,
args.dest_db_url,
allow_unsafe_unpickling=args.allow_unsafe_unpickling,
)
except Exception as e:
logger.error(f"Migration failed: {e}")
sys.exit(1)
18 changes: 16 additions & 2 deletions src/google/adk/sessions/migration/migration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION


def upgrade(source_db_url: str, dest_db_url: str):
def upgrade(
source_db_url: str,
dest_db_url: str,
allow_unsafe_unpickling: bool = False,
) -> None:
"""Migrates a database from its current version to the latest version.

If the source database schema is older than the latest version, this
Expand All @@ -61,6 +65,9 @@ def upgrade(source_db_url: str, dest_db_url: str):
source_db_url: The SQLAlchemy URL of the database to migrate from.
dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be
different from source_db_url.
allow_unsafe_unpickling: If true, use Python's unsafe pickle loader for the
legacy pickle migration step. Only use this with a trusted source
database.

Raises:
RuntimeError: If source_db_url and dest_db_url are the same, or if no
Expand Down Expand Up @@ -113,7 +120,14 @@ def upgrade(source_db_url: str, dest_db_url: str):
logger.info(
f"Migrating from {in_url} to {out_url} (schema v{end_version})..."
)
migrate_func(in_url, out_url)
if migrate_func is migrate_from_sqlalchemy_pickle.migrate:
migrate_func(
in_url,
out_url,
allow_unsafe_unpickling=allow_unsafe_unpickling,
)
else:
migrate_func(in_url, out_url)
logger.info("Finished migration step to schema %s.", end_version)
# The output of this step becomes the input for the next step.
in_url = out_url
Expand Down
47 changes: 47 additions & 0 deletions tests/unittests/cli/utils/test_cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,53 @@ def test_cli_web_passes_service_uris(
assert called_kwargs.get("memory_service_uri") == "rag://mycorpus"


@pytest.mark.parametrize(
"flag",
["--allow-unsafe-unpickling", "--allow_unsafe_unpickling"],
)
def test_cli_migrate_session_allows_unsafe_unpickling_flag(
monkeypatch: pytest.MonkeyPatch, flag: str
) -> None:
calls: list[dict[str, Any]] = []

def fake_upgrade(
source_db_url: str,
dest_db_url: str,
*,
allow_unsafe_unpickling: bool = False,
) -> None:
calls.append({
"source_db_url": source_db_url,
"dest_db_url": dest_db_url,
"allow_unsafe_unpickling": allow_unsafe_unpickling,
})

monkeypatch.setattr(
"google.adk.sessions.migration.migration_runner.upgrade",
fake_upgrade,
)

result = CliRunner().invoke(
cli_tools_click.main,
[
"migrate",
"session",
"--source_db_url",
"sqlite:///source.db",
"--dest_db_url",
"sqlite:///dest.db",
flag,
],
)

assert result.exit_code == 0, (result.output, repr(result.exception))
assert calls == [{
"source_db_url": "sqlite:///source.db",
"dest_db_url": "sqlite:///dest.db",
"allow_unsafe_unpickling": True,
}]


def test_cli_eval_with_eval_set_file_path(
mock_load_eval_set_from_file,
mock_get_root_agent,
Expand Down
Loading