From b62b7b31b6cc9f47916521a5a8a8f68209e8547f Mon Sep 17 00:00:00 2001 From: Tony Coconate Date: Fri, 29 May 2026 17:04:56 -0500 Subject: [PATCH] fix(auth): Support fallback OAuth token and prefixless credential lookups in session state Session state might store authentication responses as raw string tokens instead of AuthCredential objects, or under custom credential keys without the standard "temp:" prefix. Add robust fallback handling to resolve raw token strings, check for prefixless keys, and scan state values for any Google OAuth access tokens starting with "ya29." --- src/google/adk/auth/auth_handler.py | 56 ++++++++++++++++++++--- tests/unittests/auth/test_auth_handler.py | 51 +++++++++++++++++++++ 2 files changed, 100 insertions(+), 7 deletions(-) diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 8e8f5d340b..8ccdcdeec3 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -55,10 +55,13 @@ async def exchange_auth_token( return exchange_result.credential async def parse_and_store_auth_response(self, state: State) -> None: + credential_key = self.auth_config.credential_key + if not credential_key: + raise ValueError("credential_key is empty.") - credential_key = "temp:" + self.auth_config.credential_key + temp_credential_key = "temp:" + credential_key - state[credential_key] = self.auth_config.exchanged_auth_credential + state[temp_credential_key] = self.auth_config.exchanged_auth_credential if not isinstance( self.auth_config.auth_scheme, SecurityBase ) or self.auth_config.auth_scheme.type_ not in ( @@ -67,15 +70,54 @@ async def parse_and_store_auth_response(self, state: State) -> None: ): return - state[credential_key] = await self.exchange_auth_token() + state[temp_credential_key] = await self.exchange_auth_token() def _validate(self) -> None: - if not self.auth_scheme: + if not self.auth_config.auth_scheme: raise ValueError("auth_scheme is empty.") - def get_auth_response(self, state: State) -> AuthCredential: - credential_key = "temp:" + self.auth_config.credential_key - return state.get(credential_key, None) + def get_auth_response(self, state: State) -> AuthCredential | None: + # 1. Try reading the temp credential key (standard ADK flow) + credential_key = self.auth_config.credential_key + if not credential_key: + return None + + temp_credential_key = "temp:" + credential_key + val = state.get(temp_credential_key, None) + if val is not None: + if isinstance(val, AuthCredential): + return val + if isinstance(val, str) and val: + return self._build_oauth2_credential(val) + + # 2. Try reading the credential key without the 'temp:' prefix + val = state.get(credential_key, None) + if val is not None: + if isinstance(val, AuthCredential): + return val + if isinstance(val, str) and val: + return self._build_oauth2_credential(val) + + # 3. Fallback: scan the state for any active Google OAuth access token (ya29.*) + try: + state_dict = state.to_dict() if hasattr(state, "to_dict") else state + if isinstance(state_dict, dict): + for k, v in state_dict.items(): + if isinstance(v, str) and v.startswith("ya29."): + return self._build_oauth2_credential(v) + except Exception: # pylint: disable=broad-except + pass + + return None + + def _build_oauth2_credential(self, token: str) -> AuthCredential: + from .auth_credential import AuthCredentialTypes + from .auth_credential import OAuth2Auth + + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(access_token=token), + ) def generate_auth_request(self) -> AuthConfig: if not isinstance( diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index c19a5d93fd..3b4c916f1b 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -503,6 +503,57 @@ def test_get_auth_response_not_exists(self, auth_config): result = handler.get_auth_response(state) assert result is None + def test_get_auth_response_temp_prefix_str_token(self, auth_config): + """Test retrieving a string token stored under temp prefix in state.""" + handler = AuthHandler(auth_config) + state = MockState() + credential_key = auth_config.credential_key + state["temp:" + credential_key] = "ya29.mock_token" + + result = handler.get_auth_response(state) + + assert result is not None + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.access_token == "ya29.mock_token" + + def test_get_auth_response_no_prefix_credential( + self, auth_config, oauth2_credentials_with_auth_uri + ): + """Test retrieving a credential stored under the key without prefix.""" + handler = AuthHandler(auth_config) + state = MockState() + credential_key = auth_config.credential_key + state[credential_key] = oauth2_credentials_with_auth_uri + + result = handler.get_auth_response(state) + + assert result == oauth2_credentials_with_auth_uri + + def test_get_auth_response_no_prefix_str_token(self, auth_config): + """Test retrieving a string token stored under the key without prefix.""" + handler = AuthHandler(auth_config) + state = MockState() + credential_key = auth_config.credential_key + state[credential_key] = "ya29.mock_token_no_prefix" + + result = handler.get_auth_response(state) + + assert result is not None + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.access_token == "ya29.mock_token_no_prefix" + + def test_get_auth_response_fallback_google_token(self, auth_config): + """Test retrieving fallback Google token from state via scanning.""" + handler = AuthHandler(auth_config) + state = MockState() + state["some_other_key"] = "ya29.fallback_google_token" + + result = handler.get_auth_response(state) + + assert result is not None + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.access_token == "ya29.fallback_google_token" + class TestParseAndStoreAuthResponse: """Tests for the parse_and_store_auth_response method."""