Skip to content
Open
81 changes: 71 additions & 10 deletions src/google/adk/integrations/agent_registry/agent_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Generator
from enum import Enum
import logging
import os
import re
from typing import Any
from typing import Callable
Expand All @@ -39,9 +40,11 @@
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
import google.auth
import google.auth.transport.requests
from google.auth.transport import mtls
from google.auth.transport import requests as requests_auth
import httpx
from mcp import StdioServerParameters
import requests
from typing_extensions import override

# pylint: disable=g-import-not-at-top
Expand All @@ -61,6 +64,9 @@
logger = logging.getLogger("google_adk." + __name__)

AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha"
AGENT_REGISTRY_MTLS_BASE_URL = (
"https://agentregistry.mtls.googleapis.com/v1alpha"
)

_TRANSPORT_MAPPING = {
"HTTP_JSON": A2ATransport.http_json,
Expand Down Expand Up @@ -120,6 +126,14 @@ async def get_tools(
return tools


class _MtlsEndpoint(Enum):
"""The mTLS endpoint setting."""

AUTO = "auto"
ALWAYS = "always"
NEVER = "never"


class _ProtocolType(str, Enum):
"""Supported agent protocol types."""

Expand Down Expand Up @@ -224,23 +238,40 @@ def _make_request(
self, path: str, params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
"""Helper function to make GET requests to the Agent Registry API."""
# Determine if mTLS should be used
session = requests_auth.AuthorizedSession(credentials=self._credentials)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

creating a new AuthorizedSession on every request could be a performance bottleneck - it would be better to instantiate and configure the session once during initialization (in the init) and reuse it


use_client_cert = _use_client_cert_effective()
client_cert_source = None

if use_client_cert:
client_cert_source = (
mtls.default_client_cert_source()
if mtls.has_default_client_cert_source()
else None
)
session.configure_mtls_channel(client_cert_source)

base_url = _get_agent_registry_base_url(client_cert_source)

if path.startswith("projects/"):
url = f"{AGENT_REGISTRY_BASE_URL}/{path}"
url = f"{base_url}/{path}"
else:
url = f"{AGENT_REGISTRY_BASE_URL}/{self._base_path}/{path}"
url = f"{base_url}/{self._base_path}/{path}"

try:
headers = self._get_auth_headers()
with httpx.Client() as client:
response = client.get(url, headers=headers, params=params)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
# Using AuthorizedSession for internal API calls to handle mTLS/Auth.
response = session.get(
url, headers=self._get_auth_headers(), params=params
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
raise RuntimeError(
f"API request failed with status {e.response.status_code}:"
f" {e.response.text}"
) from e
except httpx.RequestError as e:
except requests.exceptions.RequestException as e:
raise RuntimeError(f"API request failed (network error): {e}") from e
except Exception as e:
raise RuntimeError(f"API request failed: {e}") from e
Expand Down Expand Up @@ -520,3 +551,33 @@ def get_remote_a2a_agent(
description=description,
httpx_client=httpx_client,
)


def _use_client_cert_effective() -> bool:
"""Returns whether client certificate should be used for mTLS."""
try:
return bool(mtls.should_use_client_cert())
except (ImportError, AttributeError):
use_client_cert_str = os.getenv(
"GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"
).lower()
return use_client_cert_str == "true"


def _get_agent_registry_base_url(client_cert_source: Any | None = None) -> str:
"""Returns the base URL based on mTLS configuration and cert availability."""
use_mtls_endpoint_str = os.getenv(
"GOOGLE_API_USE_MTLS_ENDPOINT", _MtlsEndpoint.AUTO.value
).lower()

try:
use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str)
except ValueError:
use_mtls_endpoint = _MtlsEndpoint.AUTO

if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or (
use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source
):
return AGENT_REGISTRY_MTLS_BASE_URL

return AGENT_REGISTRY_BASE_URL
Loading