From b2fbf30eb123ab35ff33048f67493e861e527b07 Mon Sep 17 00:00:00 2001 From: agrawalradhika-cell Date: Tue, 26 May 2026 13:24:34 -0700 Subject: [PATCH 1/6] feat: Integrate mTLS support into Agent Registry API calls --- .../agent_registry/agent_registry.py | 81 ++++++++++++++++--- 1 file changed, 71 insertions(+), 10 deletions(-) diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index a486215151..236465a623 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -19,6 +19,7 @@ from collections.abc import Generator from enum import Enum import logging +import os import re from typing import Any from typing import Callable @@ -39,9 +40,11 @@ from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams from google.adk.tools.mcp_tool.mcp_toolset import McpToolset import google.auth -import google.auth.transport.requests +from google.auth.transport import mtls +from google.auth.transport import requests as requests_auth import httpx from mcp import StdioServerParameters +import requests from typing_extensions import override # pylint: disable=g-import-not-at-top @@ -61,6 +64,9 @@ logger = logging.getLogger("google_adk." + __name__) AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha" +AGENT_REGISTRY_MTLS_BASE_URL = ( + "https://agentregistry.mtls.googleapis.com/v1alpha" +) _TRANSPORT_MAPPING = { "HTTP_JSON": A2ATransport.http_json, @@ -120,6 +126,14 @@ async def get_tools( return tools +class _MtlsEndpoint(Enum): + """The mTLS endpoint setting.""" + + AUTO = "auto" + ALWAYS = "always" + NEVER = "never" + + class _ProtocolType(str, Enum): """Supported agent protocol types.""" @@ -224,23 +238,40 @@ def _make_request( self, path: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: """Helper function to make GET requests to the Agent Registry API.""" + # Determine if mTLS should be used + session = requests_auth.AuthorizedSession(credentials=self._credentials) + + use_client_cert = _use_client_cert_effective() + client_cert_source = None + + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + + base_url = _get_agent_registry_base_url(client_cert_source) + if path.startswith("projects/"): - url = f"{AGENT_REGISTRY_BASE_URL}/{path}" + url = f"{base_url}/{path}" else: - url = f"{AGENT_REGISTRY_BASE_URL}/{self._base_path}/{path}" + url = f"{base_url}/{self._base_path}/{path}" try: - headers = self._get_auth_headers() - with httpx.Client() as client: - response = client.get(url, headers=headers, params=params) - response.raise_for_status() - return response.json() - except httpx.HTTPStatusError as e: + # Using AuthorizedSession for internal API calls to handle mTLS/Auth. + response = session.get( + url, headers=self._get_auth_headers(), params=params + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: raise RuntimeError( f"API request failed with status {e.response.status_code}:" f" {e.response.text}" ) from e - except httpx.RequestError as e: + except requests.exceptions.RequestException as e: raise RuntimeError(f"API request failed (network error): {e}") from e except Exception as e: raise RuntimeError(f"API request failed: {e}") from e @@ -520,3 +551,33 @@ def get_remote_a2a_agent( description=description, httpx_client=httpx_client, ) + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS.""" + try: + return bool(mtls.should_use_client_cert()) + except (ImportError, AttributeError): + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + return use_client_cert_str == "true" + + +def _get_agent_registry_base_url(client_cert_source: Any | None = None) -> str: + """Returns the base URL based on mTLS configuration and cert availability.""" + use_mtls_endpoint_str = os.getenv( + "GOOGLE_API_USE_MTLS_ENDPOINT", _MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + use_mtls_endpoint = _MtlsEndpoint.AUTO + + if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source + ): + return AGENT_REGISTRY_MTLS_BASE_URL + + return AGENT_REGISTRY_BASE_URL From 46218f3c9433937e537784c3445af604af0fbba2 Mon Sep 17 00:00:00 2001 From: agrawalradhika-cell Date: Tue, 26 May 2026 13:25:58 -0700 Subject: [PATCH 2/6] chore: Add unit tests for mTLS support for AgentRegistry --- .../agent_registry/test_agent_registry.py | 263 ++++++++++++------ 1 file changed, 178 insertions(+), 85 deletions(-) diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index f4ba47cf25..bd8288c483 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -13,6 +13,7 @@ # limitations under the License. +import os from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -26,28 +27,32 @@ from google.adk.integrations.agent_registry.agent_registry import _ProtocolType from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +from google.auth.transport import requests as requests_auth import httpx from mcp import ClientSession from mcp.types import ListToolsResult from mcp.types import Tool import pytest +import requests class TestAgentRegistry: @pytest.fixture def registry(self): - with patch("google.auth.default", return_value=(MagicMock(), "project-id")): + mock_creds = MagicMock() + mock_creds.quota_project_id = None + with patch("google.auth.default", return_value=(mock_creds, "project-id")): return AgentRegistry(project_id="test-project", location="global") @pytest.mark.asyncio - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") @patch( "google.adk.tools.mcp_tool.mcp_session_manager.MCPSessionManager.create_session", new_callable=AsyncMock, ) async def test_get_mcp_toolset_adds_destination_id( - self, mock_create_session, mock_httpx, registry + self, mock_create_session, mock_session_class, registry ): """Test that tools from get_mcp_toolset have the destination ID.""" # Arrange @@ -63,9 +68,7 @@ async def test_get_mcp_toolset_adds_destination_id( "protocolBinding": "JSONRPC", }], } - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_api_response - ) + mock_session_class.return_value.get.return_value = mock_api_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -109,13 +112,13 @@ async def test_get_mcp_toolset_adds_destination_id( ) @pytest.mark.asyncio - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") @patch( "google.adk.tools.mcp_tool.mcp_session_manager.MCPSessionManager.create_session", new_callable=AsyncMock, ) async def test_get_mcp_toolset_handles_missing_destination_id( - self, mock_create_session, mock_httpx, registry + self, mock_create_session, mock_session_class, registry ): """Test get_mcp_toolset when the destination ID is missing.""" # Arrange @@ -129,9 +132,7 @@ async def test_get_mcp_toolset_handles_missing_destination_id( "protocolBinding": "JSONRPC", }], } - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_api_response - ) + mock_session_class.return_value.get.return_value = mock_api_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -258,14 +259,12 @@ def test_get_connection_uri_returns_none_if_no_url_in_interfaces( assert version is None assert binding is None - @patch("httpx.Client") - def test_list_agents(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_list_agents(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = {"agents": []} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response # Mock auth refresh registry._credentials.token = "token" @@ -274,14 +273,12 @@ def test_list_agents(self, mock_httpx, registry): agents = registry.list_agents() assert agents == {"agents": []} - @patch("httpx.Client") - def test_get_mcp_server(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_mcp_server(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = {"name": "test-mcp"} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -289,14 +286,12 @@ def test_get_mcp_server(self, mock_httpx, registry): server = registry.get_mcp_server("test-mcp") assert server == {"name": "test-mcp"} - @patch("httpx.Client") - def test_list_endpoints(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_list_endpoints(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = {"endpoints": []} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response # Mock auth refresh registry._credentials.token = "token" @@ -305,14 +300,12 @@ def test_list_endpoints(self, mock_httpx, registry): endpoints = registry.list_endpoints() assert endpoints == {"endpoints": []} - @patch("httpx.Client") - def test_get_endpoint(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_endpoint(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = {"name": "test-endpoint"} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -329,9 +322,14 @@ def test_get_endpoint(self, mock_httpx, registry): ("https://mcp.googleapis.com/v1", True, True), ], ) - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") def test_get_mcp_toolset_auth_headers( - self, mock_httpx, registry, url, expected_auth, use_custom_provider + self, + mock_session_class, + registry, + url, + expected_auth, + use_custom_provider, ): mock_response = MagicMock() mock_response.json.return_value = { @@ -342,16 +340,17 @@ def test_get_mcp_toolset_auth_headers( }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response + + mock_creds = MagicMock() + mock_creds.quota_project_id = None if use_custom_provider: custom_header_provider = lambda context: { "Authorization": "Bearer custom_token" } with patch( - "google.auth.default", return_value=(MagicMock(), "project-id") + "google.auth.default", return_value=(mock_creds, "project-id") ): registry = AgentRegistry( project_id="test-project", @@ -375,8 +374,8 @@ def test_get_mcp_toolset_auth_headers( else: assert "Authorization" not in headers - @patch("httpx.Client") - def test_get_mcp_toolset_with_auth(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_mcp_toolset_with_auth(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestPrefix", @@ -386,9 +385,7 @@ def test_get_mcp_toolset_with_auth(self, mock_httpx, registry): }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -408,9 +405,9 @@ def test_get_mcp_toolset_with_auth(self, mock_httpx, registry): assert auth_config.auth_scheme == auth_scheme assert auth_config.raw_auth_credential == auth_credential - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") def test_get_mcp_toolset_with_auth_blocks_gcp_headers( - self, mock_httpx, registry + self, mock_session_class, registry ): mock_response = MagicMock() mock_response.json.return_value = { @@ -421,9 +418,7 @@ def test_get_mcp_toolset_with_auth_blocks_gcp_headers( }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -442,8 +437,8 @@ def test_get_mcp_toolset_with_auth_blocks_gcp_headers( headers = toolset._header_provider(MagicMock()) assert "Authorization" not in headers - @patch("httpx.Client") - def test_get_remote_a2a_agent(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_remote_a2a_agent(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestAgent", @@ -460,9 +455,7 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry): "skills": [{"id": "s1", "name": "Skill 1", "description": "Desc 1"}], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -478,8 +471,8 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry): assert agent._agent_card.preferred_transport == A2ATransport.http_json assert agent._agent_card.protocol_version == "0.4.0" - @patch("httpx.Client") - def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_remote_a2a_agent_defaults(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestAgent", @@ -493,9 +486,7 @@ def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry): }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -505,8 +496,8 @@ def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry): assert agent._agent_card.preferred_transport == A2ATransport.http_json assert agent._agent_card.protocol_version == "0.3.0" - @patch("httpx.Client") - def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_remote_a2a_agent_with_card(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = { "name": "projects/p/locations/l/agents/a", @@ -530,9 +521,7 @@ def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): }, } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -547,8 +536,10 @@ def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): assert len(agent._agent_card.skills) == 1 assert agent._agent_card.skills[0].name == "S1" - @patch("httpx.Client") - def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_remote_a2a_agent_with_httpx_client( + self, mock_session_class, registry + ): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestAgent", @@ -562,9 +553,7 @@ def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry): }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response custom_client = httpx.AsyncClient() agent = registry.get_remote_a2a_agent( @@ -572,9 +561,9 @@ def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry): ) assert agent._httpx_client is custom_client - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") def test_get_remote_a2a_agent_configures_transports( - self, mock_httpx, registry + self, mock_session_class, registry ): mock_response = MagicMock() mock_response.json.return_value = { @@ -588,9 +577,7 @@ def test_get_remote_a2a_agent_configures_transports( }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -616,15 +603,17 @@ def test_get_auth_headers_fallback_to_project_id(self, registry): assert headers["Authorization"] == "Bearer fake-token" assert headers["x-goog-user-project"] == "test-project" - @patch("httpx.Client") - def test_make_request_raises_http_status_error(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_make_request_raises_http_status_error( + self, mock_session_class, registry + ): mock_response = MagicMock() mock_response.status_code = 404 mock_response.text = "Not Found" - error = httpx.HTTPStatusError( + error = requests.exceptions.HTTPError( "Error", request=MagicMock(), response=mock_response ) - mock_httpx.return_value.__enter__.return_value.get.side_effect = error + mock_session_class.return_value.get.side_effect = error registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -634,10 +623,14 @@ def test_make_request_raises_http_status_error(self, mock_httpx, registry): ): registry._make_request("test-path") - @patch("httpx.Client") - def test_make_request_raises_request_error(self, mock_httpx, registry): - error = httpx.RequestError("Connection failed", request=MagicMock()) - mock_httpx.return_value.__enter__.return_value.get.side_effect = error + @patch("google.auth.transport.requests.AuthorizedSession") + def test_make_request_raises_request_error( + self, mock_session_class, registry + ): + error = requests.exceptions.RequestException( + "Connection failed", request=MagicMock() + ) + mock_session_class.return_value.get.side_effect = error registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -647,11 +640,11 @@ def test_make_request_raises_request_error(self, mock_httpx, registry): ): registry._make_request("test-path") - @patch("httpx.Client") - def test_make_request_raises_generic_exception(self, mock_httpx, registry): - mock_httpx.return_value.__enter__.return_value.get.side_effect = Exception( - "Generic error" - ) + @patch("google.auth.transport.requests.AuthorizedSession") + def test_make_request_raises_generic_exception( + self, mock_session_class, registry + ): + mock_session_class.return_value.get.side_effect = Exception("Generic error") registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -741,3 +734,103 @@ def side_effect(*args, **kwargs): == "projects/123/locations/l/authProviders/ap-789" ) assert toolset._auth_scheme.continue_uri == "https://override.com/continue" + + +class TestAgentRegistryMtls: + + @pytest.fixture + def registry(self): + with patch( + "google.auth.default", return_value=(MagicMock(), "test-project") + ): + return AgentRegistry(project_id="test-project", location="global") + + @patch("google.auth.transport.requests.AuthorizedSession") + @patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ) + def test_make_request_uses_authorized_session_no_mtls( + self, mock_has_cert, mock_session_class, registry + ): + """Verifies that AuthorizedSession is used for standard requests.""" + mock_session = mock_session_class.return_value + mock_response = MagicMock() + mock_response.json.return_value = {"key": "value"} + mock_session.get.return_value = mock_response + + result = registry._make_request("test-path") + + # Assert session initialization and usage + mock_session_class.assert_called_once_with( + credentials=registry._credentials + ) + mock_session.get.assert_called_once() + assert mock_session.configure_mtls_channel.call_count == 0 + assert result == {"key": "value"} + + @patch("google.auth.transport.requests.AuthorizedSession") + @patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ) + @patch("google.auth.transport.mtls.default_client_cert_source") + @patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + def test_make_request_configures_mtls( + self, mock_cert_source, mock_has_cert, mock_session_class, registry + ): + """Verifies that mTLS is configured when supported and enabled.""" + mock_session = mock_session_class.return_value + mock_cert_source.return_value = lambda: (b"cert", b"key") + + registry._make_request("test-path") + + # Verify mTLS configuration and endpoint + mock_session.configure_mtls_channel.assert_called_once() + args, kwargs = mock_session.get.call_args + assert "agentregistry.mtls.googleapis.com" in args[0] + + @pytest.mark.parametrize( + "env_val, has_cert, expected", + [ + ("true", True, True), + ("true", False, True), + ("false", True, False), + ("false", False, False), + ], + ) + def test_use_client_cert_effective( + self, env_val, has_cert, expected, registry + ): + """Tests the logic for enabling mTLS based on env vars and cert availability.""" + with patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": env_val}): + with patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=has_cert, + ): + from google.adk.integrations.agent_registry.agent_registry import _use_client_cert_effective + + assert _use_client_cert_effective() == expected + + def test_get_agent_registry_base_url(self, registry): + """Verifies correct base URL selection for mTLS vs non-mTLS.""" + from google.adk.integrations.agent_registry.agent_registry import _get_agent_registry_base_url + + # Non-mTLS + assert "agentregistry.googleapis.com" in _get_agent_registry_base_url(None) + + # mTLS + assert "agentregistry.mtls.googleapis.com" in _get_agent_registry_base_url( + lambda: True + ) + + @patch("google.auth.transport.requests.AuthorizedSession") + def test_make_request_error_handling(self, mock_session_class, registry): + """Ensures exceptions from AuthorizedSession are handled gracefully.""" + mock_session = mock_session_class.return_value + mock_session.get.side_effect = Exception("Connection error") + + with pytest.raises( + RuntimeError, match="API request failed: Connection error" + ): + registry._make_request("test-path") From 0bb75602c493d28eb420534d1999b4771b0490c9 Mon Sep 17 00:00:00 2001 From: agrawalradhika-cell Date: Thu, 28 May 2026 11:04:28 -0700 Subject: [PATCH 3/6] fix: add client_cert_source in configure_mtls_channel fix: add client_cert_source in configure_mtls_channel --- src/google/adk/integrations/agent_registry/agent_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index 236465a623..d500b86ccc 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -250,7 +250,7 @@ def _make_request( if mtls.has_default_client_cert_source() else None ) - session.configure_mtls_channel() + session.configure_mtls_channel(client_cert_source) base_url = _get_agent_registry_base_url(client_cert_source) From 953304d57abac701fe2edfddc20579201204f0af Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Fri, 29 May 2026 05:52:38 +0800 Subject: [PATCH 4/6] fix(eval): include intermediate text in final response match (#5698) --- .../adk/evaluation/final_response_match_v2.py | 17 ++++- .../adk/evaluation/llm_as_judge_utils.py | 7 ++ .../test_final_response_match_v2.py | 75 +++++++++++++++++++ .../evaluation/test_llm_as_judge_utils.py | 30 ++++++++ 4 files changed, 125 insertions(+), 4 deletions(-) diff --git a/src/google/adk/evaluation/final_response_match_v2.py b/src/google/adk/evaluation/final_response_match_v2.py index 713b421e3d..445d65c13d 100644 --- a/src/google/adk/evaluation/final_response_match_v2.py +++ b/src/google/adk/evaluation/final_response_match_v2.py @@ -159,13 +159,22 @@ def format_auto_rater_prompt( if expected_invocation is None: raise ValueError("expected_invocation is required for this metric.") - reference = get_text_from_content(expected_invocation.final_response) - response = get_text_from_content(actual_invocation.final_response) + include_intermediate = ( + self._criterion.include_intermediate_responses_in_final + ) + reference = get_text_from_content( + expected_invocation, + include_intermediate_responses_in_final=include_intermediate, + ) + response = get_text_from_content( + actual_invocation, + include_intermediate_responses_in_final=include_intermediate, + ) user_prompt = get_text_from_content(expected_invocation.user_content) return self._auto_rater_prompt_template.format( prompt=user_prompt, - response=response, - golden_response=reference, + response=response or "", + golden_response=reference or "", ) @override diff --git a/src/google/adk/evaluation/llm_as_judge_utils.py b/src/google/adk/evaluation/llm_as_judge_utils.py index 0986f2bed0..edc057be7c 100644 --- a/src/google/adk/evaluation/llm_as_judge_utils.py +++ b/src/google/adk/evaluation/llm_as_judge_utils.py @@ -25,6 +25,7 @@ from .app_details import AppDetails from .common import EvalBaseModel from .eval_case import get_all_tool_calls_with_responses +from .eval_case import IntermediateData from .eval_case import IntermediateDataType from .eval_case import Invocation from .eval_case import InvocationEvents @@ -71,6 +72,12 @@ def get_text_from_content( text = get_text_from_content(event.content) if text: parts.append(text) + elif isinstance(content.intermediate_data, IntermediateData): + for _, response_parts in content.intermediate_data.intermediate_responses: + text = get_text_from_content(genai_types.Content(parts=response_parts)) + if text: + parts.append(text) + # Then fetch the final response text and append it to the end. final_text = get_text_from_content(content.final_response) if final_text: diff --git a/tests/unittests/evaluation/test_final_response_match_v2.py b/tests/unittests/evaluation/test_final_response_match_v2.py index ce44901ab5..4a609420b2 100644 --- a/tests/unittests/evaluation/test_final_response_match_v2.py +++ b/tests/unittests/evaluation/test_final_response_match_v2.py @@ -15,6 +15,8 @@ from __future__ import annotations from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_case import InvocationEvent +from google.adk.evaluation.eval_case import InvocationEvents from google.adk.evaluation.eval_metrics import BaseCriterion from google.adk.evaluation.eval_metrics import EvalMetric from google.adk.evaluation.eval_metrics import EvalStatus @@ -127,6 +129,8 @@ def create_test_template() -> str: def _create_test_evaluator_gemini( threshold: float, + *, + include_intermediate_responses_in_final: bool = False, ) -> FinalResponseMatchV2Evaluator: evaluator = FinalResponseMatchV2Evaluator( EvalMetric( @@ -134,6 +138,9 @@ def _create_test_evaluator_gemini( threshold=threshold, criterion=BaseCriterion( threshold=0.5, + include_intermediate_responses_in_final=( + include_intermediate_responses_in_final + ), ), ), ) @@ -168,6 +175,21 @@ def _create_test_invocations( return actual_invocation, expected_invocation +def _add_intermediate_text(invocation: Invocation, text: str) -> Invocation: + invocation.intermediate_data = InvocationEvents( + invocation_events=[ + InvocationEvent( + author="agent", + content=genai_types.Content( + parts=[genai_types.Part(text=text)], + role="model", + ), + ), + ] + ) + return invocation + + def test_format_auto_rater_prompt(): evaluator = _create_test_evaluator_gemini(threshold=0.8) actual_invocation, expected_invocation = _create_test_invocations( @@ -193,6 +215,59 @@ def test_format_auto_rater_prompt(): """ +def test_format_auto_rater_prompt_uses_empty_text_for_missing_final_response(): + evaluator = _create_test_evaluator_gemini(threshold=0.8) + actual_invocation, expected_invocation = _create_test_invocations( + "candidate text", "reference text" + ) + actual_invocation.final_response = None + expected_invocation.final_response = None + + prompt = evaluator.format_auto_rater_prompt( + actual_invocation, expected_invocation + ) + + assert "None" not in prompt + assert '"Agent response": ,' in prompt + assert '"Reference response": ,' in prompt + + +def test_format_auto_rater_prompt_ignores_intermediate_by_default(): + evaluator = _create_test_evaluator_gemini(threshold=0.8) + actual_invocation, expected_invocation = _create_test_invocations( + "candidate final", "reference final" + ) + _add_intermediate_text(actual_invocation, "candidate intro") + _add_intermediate_text(expected_invocation, "reference intro") + + prompt = evaluator.format_auto_rater_prompt( + actual_invocation, expected_invocation + ) + + assert "candidate final" in prompt + assert "reference final" in prompt + assert "candidate intro" not in prompt + assert "reference intro" not in prompt + + +def test_format_auto_rater_prompt_includes_intermediate_when_enabled(): + evaluator = _create_test_evaluator_gemini( + threshold=0.8, include_intermediate_responses_in_final=True + ) + actual_invocation, expected_invocation = _create_test_invocations( + "candidate final", "reference final" + ) + _add_intermediate_text(actual_invocation, "candidate intro") + _add_intermediate_text(expected_invocation, "reference intro") + + prompt = evaluator.format_auto_rater_prompt( + actual_invocation, expected_invocation + ) + + assert "candidate intro\ncandidate final" in prompt + assert "reference intro\nreference final" in prompt + + def test_convert_auto_rater_response_to_score_valid(): evaluator = _create_test_evaluator_gemini(threshold=0.8) auto_rater_response = """```json diff --git a/tests/unittests/evaluation/test_llm_as_judge_utils.py b/tests/unittests/evaluation/test_llm_as_judge_utils.py index 4b53a2dc43..c7cd5ff569 100644 --- a/tests/unittests/evaluation/test_llm_as_judge_utils.py +++ b/tests/unittests/evaluation/test_llm_as_judge_utils.py @@ -132,6 +132,36 @@ def test_get_text_from_content_with_invocation_include_intermediate_responses_in ) +def test_get_text_from_content_with_intermediate_data_full_response(): + invocation = Invocation( + user_content=genai_types.Content(parts=[genai_types.Part(text="user")]), + intermediate_data=IntermediateData( + intermediate_responses=[ + ("agent", [genai_types.Part(text="legacy intro")]), + ( + "tool", + [ + genai_types.Part( + function_call=genai_types.FunctionCall(name="lookup") + ) + ], + ), + ] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text="final answer")] + ), + ) + + assert get_text_from_content(invocation) == "final answer" + assert ( + get_text_from_content( + invocation, include_intermediate_responses_in_final=True + ) + == "legacy intro\nfinal answer" + ) + + def test_get_eval_status_with_none_score(): """Tests get_eval_status returns NOT_EVALUATED for a None score.""" assert get_eval_status(score=None, threshold=0.5) == EvalStatus.NOT_EVALUATED From 72d1866550004212572928c4d05444beceb50444 Mon Sep 17 00:00:00 2001 From: agrawalradhika-cell Date: Mon, 1 Jun 2026 11:52:09 -0700 Subject: [PATCH 5/6] feat: Update for initializing session in init feat: Update for initializing session in init --- .../agent_registry/agent_registry.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index d500b86ccc..2b468e20eb 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -212,6 +212,22 @@ def __init__( raise RuntimeError( f"Failed to get default Google Cloud credentials: {e}" ) from e + # Instantiate and configure AuthorizedSession once during initialization + self._session = requests_auth.AuthorizedSession( + credentials=self._credentials + ) + + use_client_cert = _use_client_cert_effective() + client_cert_source = None + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + self._session.configure_mtls_channel(client_cert_source) + + self._base_url = _get_agent_registry_base_url(client_cert_source) def _get_auth_headers(self) -> Dict[str, str]: """Refreshes credentials and returns authorization headers.""" @@ -255,15 +271,19 @@ def _make_request( base_url = _get_agent_registry_base_url(client_cert_source) if path.startswith("projects/"): - url = f"{base_url}/{path}" + url = f"{self._base_url}/{path}" else: - url = f"{base_url}/{self._base_path}/{path}" + url = f"{self._base_url}/{self._base_path}/{path}" + headers = {} + quota_project_id = ( + getattr(self._credentials, "quota_project_id", None) or self.project_id + ) + if quota_project_id: + headers["x-goog-user-project"] = quota_project_id try: # Using AuthorizedSession for internal API calls to handle mTLS/Auth. - response = session.get( - url, headers=self._get_auth_headers(), params=params - ) + response = self._session.get(url, headers=headers, params=params) response.raise_for_status() return response.json() except requests.exceptions.HTTPError as e: From 761f0a7372d5658f1209dce078003ea25b904283 Mon Sep 17 00:00:00 2001 From: agrawalradhika-cell Date: Mon, 1 Jun 2026 11:57:33 -0700 Subject: [PATCH 6/6] fix: Update registry fixture with AuthorizedSession mock fix: Update registry fixture with AuthorizedSession mock --- .../integrations/agent_registry/test_agent_registry.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index bd8288c483..d22dedaae9 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -42,8 +42,14 @@ class TestAgentRegistry: def registry(self): mock_creds = MagicMock() mock_creds.quota_project_id = None - with patch("google.auth.default", return_value=(mock_creds, "project-id")): - return AgentRegistry(project_id="test-project", location="global") + with patch( + "google.auth.default", return_value=(mock_creds, "project-id") + ), patch( + "google.auth.transport.requests.AuthorizedSession" + ) as mock_session_class: + registry = AgentRegistry(project_id="test-project", location="global") + registry._mock_session = mock_session_class.return_value + return registry @pytest.mark.asyncio @patch("google.auth.transport.requests.AuthorizedSession")