diff --git a/src/google/adk/auth/exchanger/__init__.py b/src/google/adk/auth/exchanger/__init__.py index ce5c464c4..4226ae715 100644 --- a/src/google/adk/auth/exchanger/__init__.py +++ b/src/google/adk/auth/exchanger/__init__.py @@ -15,11 +15,9 @@ """Credential exchanger module.""" from .base_credential_exchanger import BaseCredentialExchanger -from .credential_exchanger_registry import CredentialExchangerRegistry from .service_account_credential_exchanger import ServiceAccountCredentialExchanger __all__ = [ "BaseCredentialExchanger", - "CredentialExchangerRegistry", "ServiceAccountCredentialExchanger", ] diff --git a/src/google/adk/auth/service_account_credential_exchanger.py b/src/google/adk/auth/exchanger/service_account_credential_exchanger.py similarity index 57% rename from src/google/adk/auth/service_account_credential_exchanger.py rename to src/google/adk/auth/exchanger/service_account_credential_exchanger.py index 644501ee6..415081ca5 100644 --- a/src/google/adk/auth/service_account_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/service_account_credential_exchanger.py @@ -16,19 +16,22 @@ from __future__ import annotations +from typing import Optional + import google.auth from google.auth.transport.requests import Request from google.oauth2 import service_account +from typing_extensions import override -from ..utils.feature_decorator import experimental -from .auth_credential import AuthCredential -from .auth_credential import AuthCredentialTypes -from .auth_credential import HttpAuth -from .auth_credential import HttpCredentials +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_credential import AuthCredentialTypes +from ..auth_schemes import AuthScheme +from .base_credential_exchanger import BaseCredentialExchanger @experimental -class ServiceAccountCredentialExchanger: +class ServiceAccountCredentialExchanger(BaseCredentialExchanger): """Exchanges Google Service Account credentials for an access token. Uses the default service credential if `use_default_credential = True`. @@ -36,44 +39,56 @@ class ServiceAccountCredentialExchanger: credential. """ - def __init__(self, credential: AuthCredential): - if credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: - raise ValueError("Credential is not a service account credential.") - self._credential = credential - - def exchange(self) -> AuthCredential: + @override + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: """Exchanges the service account auth credential for an access token. If the AuthCredential contains a service account credential, it will be used to exchange for an access token. Otherwise, if use_default_credential is True, the default application credential will be used for exchanging an access token. + Args: + auth_scheme: The authentication scheme. + auth_credential: The credential to exchange. + Returns: - An AuthCredential in HTTP Bearer format, containing the access token. + An AuthCredential in OAUTH2 format, containing the exchanged credential JSON. Raises: ValueError: If service account credentials are missing or invalid. Exception: If credential exchange or refresh fails. """ + if auth_credential is None: + raise ValueError("Credential cannot be None.") + + if auth_credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: + raise ValueError("Credential is not a service account credential.") + + if auth_credential.service_account is None: + raise ValueError( + "Service account credentials are missing. Please provide them." + ) + if ( - self._credential is None - or self._credential.service_account is None - or ( - self._credential.service_account.service_account_credential is None - and not self._credential.service_account.use_default_credential - ) + auth_credential.service_account.service_account_credential is None + and not auth_credential.service_account.use_default_credential ): raise ValueError( - "Service account credentials are missing. Please provide them, or set" - " `use_default_credential = True` to use application default" - " credential in a hosted service like Google Cloud Run." + "Service account credentials are invalid. Please set the" + " service_account_credential field or set `use_default_credential =" + " True` to use application default credential in a hosted service" + " like Google Cloud Run." ) try: - if self._credential.service_account.use_default_credential: + if auth_credential.service_account.use_default_credential: credentials, _ = google.auth.default() else: - config = self._credential.service_account + config = auth_credential.service_account credentials = service_account.Credentials.from_service_account_info( config.service_account_credential.model_dump(), scopes=config.scopes ) @@ -82,11 +97,8 @@ def exchange(self) -> AuthCredential: credentials.refresh(Request()) return AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", - credentials=HttpCredentials(token=credentials.token), - ), + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json=credentials.to_json(), ) except Exception as e: raise ValueError(f"Failed to exchange service account token: {e}") from e diff --git a/tests/unittests/auth/exchanger/__init__.py b/tests/unittests/auth/exchanger/__init__.py new file mode 100644 index 000000000..5fb8a262b --- /dev/null +++ b/tests/unittests/auth/exchanger/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for credential exchanger.""" diff --git a/tests/unittests/auth/test_service_account_credential_exchanger.py b/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py similarity index 61% rename from tests/unittests/auth/test_service_account_credential_exchanger.py rename to tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py index a5c668436..195e143d3 100644 --- a/tests/unittests/auth/test_service_account_credential_exchanger.py +++ b/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py @@ -17,19 +17,20 @@ from unittest.mock import MagicMock from unittest.mock import patch +from fastapi.openapi.models import HTTPBearer from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import ServiceAccount from google.adk.auth.auth_credential import ServiceAccountCredential -from google.adk.auth.service_account_credential_exchanger import ServiceAccountCredentialExchanger +from google.adk.auth.exchanger.service_account_credential_exchanger import ServiceAccountCredentialExchanger import pytest class TestServiceAccountCredentialExchanger: """Test cases for ServiceAccountCredentialExchanger.""" - def test_init_valid_credential(self): - """Test successful initialization with valid service account credential.""" + def test_exchange_with_valid_credential(self): + """Test successful exchange with valid service account credential.""" credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( @@ -55,26 +56,36 @@ def test_init_valid_credential(self): ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - assert exchanger._credential == credential + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() - def test_init_invalid_credential_type(self): - """Test initialization with invalid credential type raises ValueError.""" + # This should not raise an exception + assert exchanger is not None + + @pytest.mark.asyncio + async def test_exchange_invalid_credential_type(self): + """Test exchange with invalid credential type raises ValueError.""" credential = AuthCredential( auth_type=AuthCredentialTypes.API_KEY, api_key="test-key", ) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + with pytest.raises( ValueError, match="Credential is not a service account credential" ): - ServiceAccountCredentialExchanger(credential) + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) @patch( - "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" ) - @patch("google.adk.auth.service_account_credential_exchanger.Request") - def test_exchange_with_explicit_credentials_success( + async def test_exchange_with_explicit_credentials_success( self, mock_request_class, mock_from_service_account_info ): """Test successful exchange with explicit service account credentials.""" @@ -84,6 +95,9 @@ def test_exchange_with_explicit_credentials_success( mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" + mock_credentials.to_json.return_value = ( + '{"token": "mock_access_token", "type": "authorized_user"}' + ) mock_from_service_account_info.return_value = mock_credentials # Create test credential @@ -113,13 +127,20 @@ def test_exchange_with_explicit_credentials_success( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - result = exchanger.exchange() + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + result = await exchanger.exchange(credential, auth_scheme) # Verify the result - assert result.auth_type == AuthCredentialTypes.HTTP - assert result.http.scheme == "bearer" - assert result.http.credentials.token == "mock_access_token" + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.google_oauth2_json is not None + # Verify that google_oauth2_json contains the token + import json + + exchanged_creds = json.loads(result.google_oauth2_json) + assert exchanged_creds.get( + "token" + ) == "mock_access_token" or "mock_access_token" in str(exchanged_creds) # Verify mocks were called correctly mock_from_service_account_info.assert_called_once_with( @@ -128,11 +149,14 @@ def test_exchange_with_explicit_credentials_success( ) mock_credentials.refresh.assert_called_once_with(mock_request) + @pytest.mark.asyncio @patch( - "google.adk.auth.service_account_credential_exchanger.google.auth.default" + "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" ) - @patch("google.adk.auth.service_account_credential_exchanger.Request") - def test_exchange_with_default_credentials_success( + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" + ) + async def test_exchange_with_default_credentials_success( self, mock_request_class, mock_google_auth_default ): """Test successful exchange with default application credentials.""" @@ -142,6 +166,9 @@ def test_exchange_with_default_credentials_success( mock_credentials = MagicMock() mock_credentials.token = "default_access_token" + mock_credentials.to_json.return_value = ( + '{"token": "default_access_token", "type": "authorized_user"}' + ) mock_google_auth_default.return_value = (mock_credentials, "test-project") # Create test credential with use_default_credential=True @@ -153,33 +180,45 @@ def test_exchange_with_default_credentials_success( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - result = exchanger.exchange() + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + result = await exchanger.exchange(credential, auth_scheme) # Verify the result - assert result.auth_type == AuthCredentialTypes.HTTP - assert result.http.scheme == "bearer" - assert result.http.credentials.token == "default_access_token" + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.google_oauth2_json is not None + # Verify that google_oauth2_json contains the token + import json + + exchanged_creds = json.loads(result.google_oauth2_json) + assert exchanged_creds.get( + "token" + ) == "default_access_token" or "default_access_token" in str( + exchanged_creds + ) # Verify mocks were called correctly mock_google_auth_default.assert_called_once() mock_credentials.refresh.assert_called_once_with(mock_request) - def test_exchange_missing_service_account(self): + @pytest.mark.asyncio + async def test_exchange_missing_service_account(self): """Test exchange fails when service_account is None.""" credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=None, ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Service account credentials are missing" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) - def test_exchange_missing_credentials_and_not_default(self): + @pytest.mark.asyncio + async def test_exchange_missing_credentials_and_not_default(self): """Test exchange fails when credentials are missing and use_default_credential is False.""" credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, @@ -190,17 +229,19 @@ def test_exchange_missing_credentials_and_not_default(self): ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( - ValueError, match="Service account credentials are missing" + ValueError, match="Service account credentials are invalid" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio @patch( - "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" ) - def test_exchange_credential_creation_failure( + async def test_exchange_credential_creation_failure( self, mock_from_service_account_info ): """Test exchange handles credential creation failure gracefully.""" @@ -234,17 +275,21 @@ def test_exchange_credential_creation_failure( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Failed to exchange service account token" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio @patch( - "google.adk.auth.service_account_credential_exchanger.google.auth.default" + "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" ) - def test_exchange_default_credential_failure(self, mock_google_auth_default): + async def test_exchange_default_credential_failure( + self, mock_google_auth_default + ): """Test exchange handles default credential failure gracefully.""" # Setup mock to raise exception mock_google_auth_default.side_effect = Exception( @@ -260,18 +305,22 @@ def test_exchange_default_credential_failure(self, mock_google_auth_default): ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Failed to exchange service account token" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) @patch( - "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" ) - @patch("google.adk.auth.service_account_credential_exchanger.Request") - def test_exchange_refresh_failure( + async def test_exchange_refresh_failure( self, mock_request_class, mock_from_service_account_info ): """Test exchange handles credential refresh failure gracefully.""" @@ -312,30 +361,73 @@ def test_exchange_refresh_failure( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Failed to exchange service account token" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + + @pytest.mark.asyncio + async def test_exchange_none_credential_in_constructor(self): + """Test that passing None credential raises appropriate error during exchange.""" + # This test verifies behavior when credential is None + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + + with pytest.raises(ValueError, match="Credential cannot be None"): + await exchanger.exchange(None, auth_scheme) - def test_exchange_none_credential_in_constructor(self): - """Test that passing None credential raises appropriate error during construction.""" - # This test verifies behavior when _credential is None, though this shouldn't - # happen in normal usage due to constructor validation + @pytest.mark.asyncio + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" + ) + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" + ) + async def test_exchange_with_service_account_no_explicit_credentials( + self, mock_request_class, mock_google_auth_default + ): + """Test exchange with service account that has no explicit credentials uses default.""" + # Setup mocks + mock_request = MagicMock() + mock_request_class.return_value = mock_request + + mock_credentials = MagicMock() + mock_credentials.token = "default_access_token" + mock_credentials.to_json.return_value = ( + '{"token": "default_access_token", "type": "authorized_user"}' + ) + mock_google_auth_default.return_value = (mock_credentials, "test-project") + + # Create test credential with no explicit credentials but use_default_credential=True credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( + service_account_credential=None, use_default_credential=True, scopes=["https://www.googleapis.com/auth/cloud-platform"], ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - # Manually set to None to test the validation logic - exchanger._credential = None + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + result = await exchanger.exchange(credential, auth_scheme) - with pytest.raises( - ValueError, match="Service account credentials are missing" - ): - exchanger.exchange() + # Verify the result + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.google_oauth2_json is not None + # Verify that google_oauth2_json contains the token + import json + + exchanged_creds = json.loads(result.google_oauth2_json) + assert exchanged_creds.get( + "token" + ) == "default_access_token" or "default_access_token" in str( + exchanged_creds + ) + + # Verify mocks were called correctly + mock_google_auth_default.assert_called_once() + mock_credentials.refresh.assert_called_once_with(mock_request)