Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/google/adk/artifacts/file_artifact_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ class FileArtifactVersion(ArtifactVersion):
file_name: str = Field(
description="Original filename supplied by the caller."
)
display_name: Optional[str] = Field(
default=None,
description="Original inline_data display name supplied by the caller.",
)


class FileArtifactService(BaseArtifactService):
Expand Down Expand Up @@ -415,6 +419,9 @@ def _save_artifact_sync(
_write_metadata(
version_dir / "metadata.json",
filename=filename,
display_name=(
artifact.inline_data.display_name if artifact.inline_data else None
),
mime_type=mime_type,
version=next_version,
canonical_uri=canonical_uri,
Expand Down Expand Up @@ -491,7 +498,12 @@ def _load_artifact_sync(
)
return None
data = content_path.read_bytes()
return types.Part(inline_data=types.Blob(mime_type=mime_type, data=data))
display_name = metadata.display_name if metadata else None
return types.Part(
inline_data=types.Blob(
mime_type=mime_type, data=data, display_name=display_name
)
)

if not content_path.exists():
logger.warning("Text artifact %s missing at %s", filename, content_path)
Expand Down Expand Up @@ -715,6 +727,7 @@ def _write_metadata(
path: Path,
*,
filename: str,
display_name: Optional[str],
mime_type: Optional[str],
version: int,
canonical_uri: str,
Expand All @@ -723,6 +736,7 @@ def _write_metadata(
"""Persists metadata describing an artifact version."""
metadata = FileArtifactVersion(
file_name=filename,
display_name=display_name,
mime_type=mime_type,
canonical_uri=canonical_uri,
version=version,
Expand Down
34 changes: 28 additions & 6 deletions src/google/adk/artifacts/gcs_artifact_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@

logger = logging.getLogger("google_adk." + __name__)

_DISPLAY_NAME_METADATA_KEY = "google_adk_display_name"


def _user_metadata(metadata: Optional[dict[str, Any]]) -> dict[str, Any]:
"""Returns blob metadata without ADK's internal storage keys."""
user_metadata = dict(metadata or {})
user_metadata.pop(_DISPLAY_NAME_METADATA_KEY, None)
return user_metadata


class GcsArtifactService(BaseArtifactService):
"""An artifact service implementation using Google Cloud Storage (GCS)."""
Expand Down Expand Up @@ -216,8 +225,16 @@ def _save_artifact(
app_name, user_id, filename, version, session_id
)
blob = self.bucket.blob(blob_name)
if custom_metadata:
blob.metadata = {k: str(v) for k, v in custom_metadata.items()}
metadata = (
{k: str(v) for k, v in custom_metadata.items()}
if custom_metadata
else {}
)

if artifact.inline_data and artifact.inline_data.display_name:
metadata[_DISPLAY_NAME_METADATA_KEY] = artifact.inline_data.display_name
if metadata:
blob.metadata = metadata

if artifact.inline_data:
blob.upload_from_string(
Expand Down Expand Up @@ -268,8 +285,13 @@ def _load_artifact(
artifact_bytes = blob.download_as_bytes()
if not artifact_bytes:
return None
artifact = types.Part.from_bytes(
data=artifact_bytes, mime_type=blob.content_type
display_name = (blob.metadata or {}).get(_DISPLAY_NAME_METADATA_KEY)
artifact = types.Part(
inline_data=types.Blob(
data=artifact_bytes,
mime_type=blob.content_type,
display_name=display_name,
)
)
return artifact

Expand Down Expand Up @@ -391,7 +413,7 @@ def _get_artifact_version_sync(
canonical_uri=canonical_uri,
create_time=blob.time_created.timestamp(),
mime_type=blob.content_type,
custom_metadata=blob.metadata if blob.metadata else {},
custom_metadata=_user_metadata(blob.metadata),
)

def _list_artifact_versions_sync(
Expand Down Expand Up @@ -421,7 +443,7 @@ def _list_artifact_versions_sync(
canonical_uri=canonical_uri,
create_time=blob.time_created.timestamp(),
mime_type=blob.content_type,
custom_metadata=blob.metadata if blob.metadata else {},
custom_metadata=_user_metadata(blob.metadata),
)
artifact_versions.append(av)

Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/artifacts/test_artifact_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,57 @@ async def test_save_load_delete(service_type, artifact_service_factory):
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_type",
[
ArtifactServiceType.IN_MEMORY,
ArtifactServiceType.GCS,
ArtifactServiceType.FILE,
],
)
async def test_save_load_preserves_inline_data_display_name(
service_type, artifact_service_factory
):
artifact_service = artifact_service_factory(service_type)
artifact = types.Part(
inline_data=types.Blob(
data=b"test_data",
mime_type="text/plain",
display_name="report.txt",
)
)

await artifact_service.save_artifact(
app_name="app0",
user_id="user0",
session_id="123",
filename="stored.bin",
artifact=artifact,
custom_metadata={"source": "unit-test"},
)

loaded = await artifact_service.load_artifact(
app_name="app0",
user_id="user0",
session_id="123",
filename="stored.bin",
)

assert loaded is not None
assert loaded.inline_data is not None
assert loaded.inline_data.display_name == "report.txt"

version = await artifact_service.get_artifact_version(
app_name="app0",
user_id="user0",
session_id="123",
filename="stored.bin",
)
assert version is not None
assert version.custom_metadata == {"source": "unit-test"}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_type",
Expand Down