Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,13 @@ async def exchange_auth_token(
return exchange_result.credential

async def parse_and_store_auth_response(self, state: State) -> None:
credential_key = self.auth_config.credential_key
if not credential_key:
raise ValueError("credential_key is empty.")

credential_key = "temp:" + self.auth_config.credential_key
temp_credential_key = "temp:" + credential_key

state[credential_key] = self.auth_config.exchanged_auth_credential
state[temp_credential_key] = self.auth_config.exchanged_auth_credential
if not isinstance(
self.auth_config.auth_scheme, SecurityBase
) or self.auth_config.auth_scheme.type_ not in (
Expand All @@ -67,15 +70,54 @@ async def parse_and_store_auth_response(self, state: State) -> None:
):
return

state[credential_key] = await self.exchange_auth_token()
state[temp_credential_key] = await self.exchange_auth_token()

def _validate(self) -> None:
if not self.auth_scheme:
if not self.auth_config.auth_scheme:
raise ValueError("auth_scheme is empty.")

def get_auth_response(self, state: State) -> AuthCredential:
credential_key = "temp:" + self.auth_config.credential_key
return state.get(credential_key, None)
def get_auth_response(self, state: State) -> AuthCredential | None:
# 1. Try reading the temp credential key (standard ADK flow)
credential_key = self.auth_config.credential_key
if not credential_key:
return None

temp_credential_key = "temp:" + credential_key
val = state.get(temp_credential_key, None)
if val is not None:
if isinstance(val, AuthCredential):
return val
if isinstance(val, str) and val:
return self._build_oauth2_credential(val)

# 2. Try reading the credential key without the 'temp:' prefix
val = state.get(credential_key, None)
if val is not None:
if isinstance(val, AuthCredential):
return val
if isinstance(val, str) and val:
return self._build_oauth2_credential(val)

# 3. Fallback: scan the state for any active Google OAuth access token (ya29.*)
try:
state_dict = state.to_dict() if hasattr(state, "to_dict") else state
if isinstance(state_dict, dict):
for k, v in state_dict.items():
if isinstance(v, str) and v.startswith("ya29."):
return self._build_oauth2_credential(v)
except Exception: # pylint: disable=broad-except
pass

return None

def _build_oauth2_credential(self, token: str) -> AuthCredential:
from .auth_credential import AuthCredentialTypes
from .auth_credential import OAuth2Auth

return AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token=token),
)

def generate_auth_request(self) -> AuthConfig:
if not isinstance(
Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/auth/test_auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,57 @@ def test_get_auth_response_not_exists(self, auth_config):
result = handler.get_auth_response(state)
assert result is None

def test_get_auth_response_temp_prefix_str_token(self, auth_config):
"""Test retrieving a string token stored under temp prefix in state."""
handler = AuthHandler(auth_config)
state = MockState()
credential_key = auth_config.credential_key
state["temp:" + credential_key] = "ya29.mock_token"

result = handler.get_auth_response(state)

assert result is not None
assert result.auth_type == AuthCredentialTypes.OAUTH2
assert result.oauth2.access_token == "ya29.mock_token"

def test_get_auth_response_no_prefix_credential(
self, auth_config, oauth2_credentials_with_auth_uri
):
"""Test retrieving a credential stored under the key without prefix."""
handler = AuthHandler(auth_config)
state = MockState()
credential_key = auth_config.credential_key
state[credential_key] = oauth2_credentials_with_auth_uri

result = handler.get_auth_response(state)

assert result == oauth2_credentials_with_auth_uri

def test_get_auth_response_no_prefix_str_token(self, auth_config):
"""Test retrieving a string token stored under the key without prefix."""
handler = AuthHandler(auth_config)
state = MockState()
credential_key = auth_config.credential_key
state[credential_key] = "ya29.mock_token_no_prefix"

result = handler.get_auth_response(state)

assert result is not None
assert result.auth_type == AuthCredentialTypes.OAUTH2
assert result.oauth2.access_token == "ya29.mock_token_no_prefix"

def test_get_auth_response_fallback_google_token(self, auth_config):
"""Test retrieving fallback Google token from state via scanning."""
handler = AuthHandler(auth_config)
state = MockState()
state["some_other_key"] = "ya29.fallback_google_token"

result = handler.get_auth_response(state)

assert result is not None
assert result.auth_type == AuthCredentialTypes.OAUTH2
assert result.oauth2.access_token == "ya29.fallback_google_token"


class TestParseAndStoreAuthResponse:
"""Tests for the parse_and_store_auth_response method."""
Expand Down
Loading