diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index 9c3870b6e3..b843bee0ff 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -210,6 +210,10 @@ class FileArtifactVersion(ArtifactVersion): file_name: str = Field( description="Original filename supplied by the caller." ) + display_name: Optional[str] = Field( + default=None, + description="Original inline_data display name supplied by the caller.", + ) class FileArtifactService(BaseArtifactService): @@ -415,6 +419,9 @@ def _save_artifact_sync( _write_metadata( version_dir / "metadata.json", filename=filename, + display_name=( + artifact.inline_data.display_name if artifact.inline_data else None + ), mime_type=mime_type, version=next_version, canonical_uri=canonical_uri, @@ -491,7 +498,12 @@ def _load_artifact_sync( ) return None data = content_path.read_bytes() - return types.Part(inline_data=types.Blob(mime_type=mime_type, data=data)) + display_name = metadata.display_name if metadata else None + return types.Part( + inline_data=types.Blob( + mime_type=mime_type, data=data, display_name=display_name + ) + ) if not content_path.exists(): logger.warning("Text artifact %s missing at %s", filename, content_path) @@ -715,6 +727,7 @@ def _write_metadata( path: Path, *, filename: str, + display_name: Optional[str], mime_type: Optional[str], version: int, canonical_uri: str, @@ -723,6 +736,7 @@ def _write_metadata( """Persists metadata describing an artifact version.""" metadata = FileArtifactVersion( file_name=filename, + display_name=display_name, mime_type=mime_type, canonical_uri=canonical_uri, version=version, diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index f8706dedbd..c5d1d6e043 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -39,6 +39,15 @@ logger = logging.getLogger("google_adk." + __name__) +_DISPLAY_NAME_METADATA_KEY = "google_adk_display_name" + + +def _user_metadata(metadata: Optional[dict[str, Any]]) -> dict[str, Any]: + """Returns blob metadata without ADK's internal storage keys.""" + user_metadata = dict(metadata or {}) + user_metadata.pop(_DISPLAY_NAME_METADATA_KEY, None) + return user_metadata + class GcsArtifactService(BaseArtifactService): """An artifact service implementation using Google Cloud Storage (GCS).""" @@ -216,8 +225,16 @@ def _save_artifact( app_name, user_id, filename, version, session_id ) blob = self.bucket.blob(blob_name) - if custom_metadata: - blob.metadata = {k: str(v) for k, v in custom_metadata.items()} + metadata = ( + {k: str(v) for k, v in custom_metadata.items()} + if custom_metadata + else {} + ) + + if artifact.inline_data and artifact.inline_data.display_name: + metadata[_DISPLAY_NAME_METADATA_KEY] = artifact.inline_data.display_name + if metadata: + blob.metadata = metadata if artifact.inline_data: blob.upload_from_string( @@ -268,8 +285,13 @@ def _load_artifact( artifact_bytes = blob.download_as_bytes() if not artifact_bytes: return None - artifact = types.Part.from_bytes( - data=artifact_bytes, mime_type=blob.content_type + display_name = (blob.metadata or {}).get(_DISPLAY_NAME_METADATA_KEY) + artifact = types.Part( + inline_data=types.Blob( + data=artifact_bytes, + mime_type=blob.content_type, + display_name=display_name, + ) ) return artifact @@ -391,7 +413,7 @@ def _get_artifact_version_sync( canonical_uri=canonical_uri, create_time=blob.time_created.timestamp(), mime_type=blob.content_type, - custom_metadata=blob.metadata if blob.metadata else {}, + custom_metadata=_user_metadata(blob.metadata), ) def _list_artifact_versions_sync( @@ -421,7 +443,7 @@ def _list_artifact_versions_sync( canonical_uri=canonical_uri, create_time=blob.time_created.timestamp(), mime_type=blob.content_type, - custom_metadata=blob.metadata if blob.metadata else {}, + custom_metadata=_user_metadata(blob.metadata), ) artifact_versions.append(av) diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index 8b82397097..1d38958587 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -263,6 +263,57 @@ async def test_save_load_delete(service_type, artifact_service_factory): ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.FILE, + ], +) +async def test_save_load_preserves_inline_data_display_name( + service_type, artifact_service_factory +): + artifact_service = artifact_service_factory(service_type) + artifact = types.Part( + inline_data=types.Blob( + data=b"test_data", + mime_type="text/plain", + display_name="report.txt", + ) + ) + + await artifact_service.save_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="stored.bin", + artifact=artifact, + custom_metadata={"source": "unit-test"}, + ) + + loaded = await artifact_service.load_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="stored.bin", + ) + + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.display_name == "report.txt" + + version = await artifact_service.get_artifact_version( + app_name="app0", + user_id="user0", + session_id="123", + filename="stored.bin", + ) + assert version is not None + assert version.custom_metadata == {"source": "unit-test"} + + @pytest.mark.asyncio @pytest.mark.parametrize( "service_type",