diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 8e8f5d340b..8ccdcdeec3 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -55,10 +55,13 @@ async def exchange_auth_token( return exchange_result.credential async def parse_and_store_auth_response(self, state: State) -> None: + credential_key = self.auth_config.credential_key + if not credential_key: + raise ValueError("credential_key is empty.") - credential_key = "temp:" + self.auth_config.credential_key + temp_credential_key = "temp:" + credential_key - state[credential_key] = self.auth_config.exchanged_auth_credential + state[temp_credential_key] = self.auth_config.exchanged_auth_credential if not isinstance( self.auth_config.auth_scheme, SecurityBase ) or self.auth_config.auth_scheme.type_ not in ( @@ -67,15 +70,54 @@ async def parse_and_store_auth_response(self, state: State) -> None: ): return - state[credential_key] = await self.exchange_auth_token() + state[temp_credential_key] = await self.exchange_auth_token() def _validate(self) -> None: - if not self.auth_scheme: + if not self.auth_config.auth_scheme: raise ValueError("auth_scheme is empty.") - def get_auth_response(self, state: State) -> AuthCredential: - credential_key = "temp:" + self.auth_config.credential_key - return state.get(credential_key, None) + def get_auth_response(self, state: State) -> AuthCredential | None: + # 1. Try reading the temp credential key (standard ADK flow) + credential_key = self.auth_config.credential_key + if not credential_key: + return None + + temp_credential_key = "temp:" + credential_key + val = state.get(temp_credential_key, None) + if val is not None: + if isinstance(val, AuthCredential): + return val + if isinstance(val, str) and val: + return self._build_oauth2_credential(val) + + # 2. Try reading the credential key without the 'temp:' prefix + val = state.get(credential_key, None) + if val is not None: + if isinstance(val, AuthCredential): + return val + if isinstance(val, str) and val: + return self._build_oauth2_credential(val) + + # 3. Fallback: scan the state for any active Google OAuth access token (ya29.*) + try: + state_dict = state.to_dict() if hasattr(state, "to_dict") else state + if isinstance(state_dict, dict): + for k, v in state_dict.items(): + if isinstance(v, str) and v.startswith("ya29."): + return self._build_oauth2_credential(v) + except Exception: # pylint: disable=broad-except + pass + + return None + + def _build_oauth2_credential(self, token: str) -> AuthCredential: + from .auth_credential import AuthCredentialTypes + from .auth_credential import OAuth2Auth + + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(access_token=token), + ) def generate_auth_request(self) -> AuthConfig: if not isinstance( diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index c19a5d93fd..3b4c916f1b 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -503,6 +503,57 @@ def test_get_auth_response_not_exists(self, auth_config): result = handler.get_auth_response(state) assert result is None + def test_get_auth_response_temp_prefix_str_token(self, auth_config): + """Test retrieving a string token stored under temp prefix in state.""" + handler = AuthHandler(auth_config) + state = MockState() + credential_key = auth_config.credential_key + state["temp:" + credential_key] = "ya29.mock_token" + + result = handler.get_auth_response(state) + + assert result is not None + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.access_token == "ya29.mock_token" + + def test_get_auth_response_no_prefix_credential( + self, auth_config, oauth2_credentials_with_auth_uri + ): + """Test retrieving a credential stored under the key without prefix.""" + handler = AuthHandler(auth_config) + state = MockState() + credential_key = auth_config.credential_key + state[credential_key] = oauth2_credentials_with_auth_uri + + result = handler.get_auth_response(state) + + assert result == oauth2_credentials_with_auth_uri + + def test_get_auth_response_no_prefix_str_token(self, auth_config): + """Test retrieving a string token stored under the key without prefix.""" + handler = AuthHandler(auth_config) + state = MockState() + credential_key = auth_config.credential_key + state[credential_key] = "ya29.mock_token_no_prefix" + + result = handler.get_auth_response(state) + + assert result is not None + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.access_token == "ya29.mock_token_no_prefix" + + def test_get_auth_response_fallback_google_token(self, auth_config): + """Test retrieving fallback Google token from state via scanning.""" + handler = AuthHandler(auth_config) + state = MockState() + state["some_other_key"] = "ya29.fallback_google_token" + + result = handler.get_auth_response(state) + + assert result is not None + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.access_token == "ya29.fallback_google_token" + class TestParseAndStoreAuthResponse: """Tests for the parse_and_store_auth_response method."""