From d1bda9d946581461df6065b620929a1588b1f64b Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Sat, 14 Jun 2025 12:55:27 -0700 Subject: [PATCH 01/79] chore: Allow working_in_progress feature for unittests PiperOrigin-RevId: 771500394 --- tests/unittests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index ad204005e..2b93226db 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -23,6 +23,7 @@ 'GOOGLE_API_KEY': 'fake_google_api_key', 'GOOGLE_CLOUD_PROJECT': 'fake_google_cloud_project', 'GOOGLE_CLOUD_LOCATION': 'fake_google_cloud_location', + 'ADK_ALLOW_WIP_FEATURES': 'true', } ENV_SETUPS = { From a4d432a9e62a5e759e94168463795cd0e7d7ab72 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Sat, 14 Jun 2025 13:39:14 -0700 Subject: [PATCH 02/79] chore: Add Service Account Credential Exchanger (Experimental) PiperOrigin-RevId: 771507089 --- .../service_account_credential_exchanger.py | 92 +++++ ...st_service_account_credential_exchanger.py | 341 ++++++++++++++++++ 2 files changed, 433 insertions(+) create mode 100644 src/google/adk/auth/service_account_credential_exchanger.py create mode 100644 tests/unittests/auth/test_service_account_credential_exchanger.py diff --git a/src/google/adk/auth/service_account_credential_exchanger.py b/src/google/adk/auth/service_account_credential_exchanger.py new file mode 100644 index 000000000..644501ee6 --- /dev/null +++ b/src/google/adk/auth/service_account_credential_exchanger.py @@ -0,0 +1,92 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential fetcher for Google Service Account.""" + +from __future__ import annotations + +import google.auth +from google.auth.transport.requests import Request +from google.oauth2 import service_account + +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_credential import AuthCredentialTypes +from .auth_credential import HttpAuth +from .auth_credential import HttpCredentials + + +@experimental +class ServiceAccountCredentialExchanger: + """Exchanges Google Service Account credentials for an access token. + + Uses the default service credential if `use_default_credential = True`. + Otherwise, uses the service account credential provided in the auth + credential. + """ + + def __init__(self, credential: AuthCredential): + if credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: + raise ValueError("Credential is not a service account credential.") + self._credential = credential + + def exchange(self) -> AuthCredential: + """Exchanges the service account auth credential for an access token. + + If the AuthCredential contains a service account credential, it will be used + to exchange for an access token. Otherwise, if use_default_credential is True, + the default application credential will be used for exchanging an access token. + + Returns: + An AuthCredential in HTTP Bearer format, containing the access token. + + Raises: + ValueError: If service account credentials are missing or invalid. + Exception: If credential exchange or refresh fails. + """ + if ( + self._credential is None + or self._credential.service_account is None + or ( + self._credential.service_account.service_account_credential is None + and not self._credential.service_account.use_default_credential + ) + ): + raise ValueError( + "Service account credentials are missing. Please provide them, or set" + " `use_default_credential = True` to use application default" + " credential in a hosted service like Google Cloud Run." + ) + + try: + if self._credential.service_account.use_default_credential: + credentials, _ = google.auth.default() + else: + config = self._credential.service_account + credentials = service_account.Credentials.from_service_account_info( + config.service_account_credential.model_dump(), scopes=config.scopes + ) + + # Refresh credentials to ensure we have a valid access token + credentials.refresh(Request()) + + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token=credentials.token), + ), + ) + except Exception as e: + raise ValueError(f"Failed to exchange service account token: {e}") from e diff --git a/tests/unittests/auth/test_service_account_credential_exchanger.py b/tests/unittests/auth/test_service_account_credential_exchanger.py new file mode 100644 index 000000000..a5c668436 --- /dev/null +++ b/tests/unittests/auth/test_service_account_credential_exchanger.py @@ -0,0 +1,341 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the ServiceAccountCredentialExchanger.""" + +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_credential import ServiceAccountCredential +from google.adk.auth.service_account_credential_exchanger import ServiceAccountCredentialExchanger +import pytest + + +class TestServiceAccountCredentialExchanger: + """Test cases for ServiceAccountCredentialExchanger.""" + + def test_init_valid_credential(self): + """Test successful initialization with valid service account credential.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=ServiceAccountCredential( + type_="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE" + " KEY-----" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="12345", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + universe_domain="googleapis.com", + ), + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + assert exchanger._credential == credential + + def test_init_invalid_credential_type(self): + """Test initialization with invalid credential type raises ValueError.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key="test-key", + ) + + with pytest.raises( + ValueError, match="Credential is not a service account credential" + ): + ServiceAccountCredentialExchanger(credential) + + @patch( + "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) + @patch("google.adk.auth.service_account_credential_exchanger.Request") + def test_exchange_with_explicit_credentials_success( + self, mock_request_class, mock_from_service_account_info + ): + """Test successful exchange with explicit service account credentials.""" + # Setup mocks + mock_request = MagicMock() + mock_request_class.return_value = mock_request + + mock_credentials = MagicMock() + mock_credentials.token = "mock_access_token" + mock_from_service_account_info.return_value = mock_credentials + + # Create test credential + service_account_cred = ServiceAccountCredential( + type_="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE KEY-----" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="12345", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + universe_domain="googleapis.com", + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=service_account_cred, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + result = exchanger.exchange() + + # Verify the result + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_access_token" + + # Verify mocks were called correctly + mock_from_service_account_info.assert_called_once_with( + service_account_cred.model_dump(), + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + mock_credentials.refresh.assert_called_once_with(mock_request) + + @patch( + "google.adk.auth.service_account_credential_exchanger.google.auth.default" + ) + @patch("google.adk.auth.service_account_credential_exchanger.Request") + def test_exchange_with_default_credentials_success( + self, mock_request_class, mock_google_auth_default + ): + """Test successful exchange with default application credentials.""" + # Setup mocks + mock_request = MagicMock() + mock_request_class.return_value = mock_request + + mock_credentials = MagicMock() + mock_credentials.token = "default_access_token" + mock_google_auth_default.return_value = (mock_credentials, "test-project") + + # Create test credential with use_default_credential=True + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + result = exchanger.exchange() + + # Verify the result + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "default_access_token" + + # Verify mocks were called correctly + mock_google_auth_default.assert_called_once() + mock_credentials.refresh.assert_called_once_with(mock_request) + + def test_exchange_missing_service_account(self): + """Test exchange fails when service_account is None.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=None, + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Service account credentials are missing" + ): + exchanger.exchange() + + def test_exchange_missing_credentials_and_not_default(self): + """Test exchange fails when credentials are missing and use_default_credential is False.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=None, + use_default_credential=False, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Service account credentials are missing" + ): + exchanger.exchange() + + @patch( + "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) + def test_exchange_credential_creation_failure( + self, mock_from_service_account_info + ): + """Test exchange handles credential creation failure gracefully.""" + # Setup mock to raise exception + mock_from_service_account_info.side_effect = Exception( + "Invalid private key" + ) + + # Create test credential + service_account_cred = ServiceAccountCredential( + type_="service_account", + project_id="test-project", + private_key_id="key-id", + private_key="invalid-key", + client_email="test@test-project.iam.gserviceaccount.com", + client_id="12345", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + universe_domain="googleapis.com", + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=service_account_cred, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Failed to exchange service account token" + ): + exchanger.exchange() + + @patch( + "google.adk.auth.service_account_credential_exchanger.google.auth.default" + ) + def test_exchange_default_credential_failure(self, mock_google_auth_default): + """Test exchange handles default credential failure gracefully.""" + # Setup mock to raise exception + mock_google_auth_default.side_effect = Exception( + "No default credentials found" + ) + + # Create test credential with use_default_credential=True + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Failed to exchange service account token" + ): + exchanger.exchange() + + @patch( + "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) + @patch("google.adk.auth.service_account_credential_exchanger.Request") + def test_exchange_refresh_failure( + self, mock_request_class, mock_from_service_account_info + ): + """Test exchange handles credential refresh failure gracefully.""" + # Setup mocks + mock_request = MagicMock() + mock_request_class.return_value = mock_request + + mock_credentials = MagicMock() + mock_credentials.refresh.side_effect = Exception( + "Network error during refresh" + ) + mock_from_service_account_info.return_value = mock_credentials + + # Create test credential + service_account_cred = ServiceAccountCredential( + type_="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE KEY-----" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="12345", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + universe_domain="googleapis.com", + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=service_account_cred, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Failed to exchange service account token" + ): + exchanger.exchange() + + def test_exchange_none_credential_in_constructor(self): + """Test that passing None credential raises appropriate error during construction.""" + # This test verifies behavior when _credential is None, though this shouldn't + # happen in normal usage due to constructor validation + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + # Manually set to None to test the validation logic + exchanger._credential = None + + with pytest.raises( + ValueError, match="Service account credentials are missing" + ): + exchanger.exchange() From 675faefc670b5cd41991939fe0fc604df331111a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Jun 2025 08:34:53 -0700 Subject: [PATCH 03/79] feat: Allow data_store_specs pass into ADK VAIS built-in tool PiperOrigin-RevId: 772039465 --- src/google/adk/tools/vertex_ai_search_tool.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index 5449f5090..b00cd0329 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -39,6 +39,9 @@ def __init__( self, *, data_store_id: Optional[str] = None, + data_store_specs: Optional[ + list[types.VertexAISearchDataStoreSpec] + ] = None, search_engine_id: Optional[str] = None, filter: Optional[str] = None, max_results: Optional[int] = None, @@ -49,6 +52,8 @@ def __init__( data_store_id: The Vertex AI search data store resource ID in the format of "projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}". + data_store_specs: Specifications that define the specific DataStores to be + searched. It should only be set if engine is used. search_engine_id: The Vertex AI search engine resource ID in the format of "projects/{project}/locations/{location}/collections/{collection}/engines/{engine}". @@ -64,7 +69,12 @@ def __init__( raise ValueError( 'Either data_store_id or search_engine_id must be specified.' ) + if data_store_specs is not None and search_engine_id is None: + raise ValueError( + 'search_engine_id must be specified if data_store_specs is specified.' + ) self.data_store_id = data_store_id + self.data_store_specs = data_store_specs self.search_engine_id = search_engine_id self.filter = filter self.max_results = max_results @@ -89,6 +99,7 @@ async def process_llm_request( retrieval=types.Retrieval( vertex_ai_search=types.VertexAISearch( datastore=self.data_store_id, + data_store_specs=self.data_store_specs, engine=self.search_engine_id, filter=self.filter, max_results=self.max_results, From badbcbd7a464e6b323cf3164d2bcd4e27cbc057f Mon Sep 17 00:00:00 2001 From: SimonWei <119845914+simonwei97@users.noreply.github.com> Date: Tue, 17 Jun 2025 00:40:41 +0800 Subject: [PATCH 04/79] fix: agent generate config err (#1305) * fix: agent generate config err * fix: resovle comment --------- Co-authored-by: Hangfei Lin Co-authored-by: genquan9 <49327371+genquan9@users.noreply.github.com> --- src/google/adk/models/lite_llm.py | 65 +++++++++++++++++++++----- tests/unittests/models/test_litellm.py | 33 ++++++++++++- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index ed54faecf..c954711ad 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,6 +23,7 @@ from typing import Dict from typing import Generator from typing import Iterable +from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -481,16 +482,22 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> tuple[Iterable[Message], Iterable[dict]]: - """Converts an LlmRequest to litellm inputs. +) -> Tuple[ + List[Message], + Optional[List[Dict]], + Optional[types.SchemaUnion], + Optional[Dict], +]: + """Converts an LlmRequest to litellm inputs and extracts generation params. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary and response format). + The litellm inputs (message list, tool dictionary, response format, and generation params). """ - messages = [] + # 1. Construct messages + messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -507,7 +514,8 @@ def _get_completion_inputs( ), ) - tools = None + # 2. Convert tool declarations + tools: Optional[List[Dict]] = None if ( llm_request.config and llm_request.config.tools @@ -518,12 +526,39 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - response_format = None + # 3. Handle response format + response_format: Optional[types.SchemaUnion] = ( + llm_request.config.response_schema if llm_request.config else None + ) + + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: + config_dict = llm_request.config.model_dump(exclude_none=True) + # Generate LiteLlm parameters here, + # Following https://docs.litellm.ai/docs/completion/input. + generation_params = {} + param_mapping = { + "max_output_tokens": "max_completion_tokens", + "stop_sequences": "stop", + } + for key in ( + "temperature", + "max_output_tokens", + "top_p", + "top_k", + "stop_sequences", + "presence_penalty", + "frequency_penalty", + ): + if key in config_dict: + mapped_key = param_mapping.get(key, key) + generation_params[mapped_key] = config_dict[key] - if llm_request.config.response_schema: - response_format = llm_request.config.response_schema + if not generation_params: + generation_params = None - return messages, tools, response_format + return messages, tools, response_format, generation_params def _build_function_declaration_log( @@ -660,7 +695,9 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format = _get_completion_inputs(llm_request) + messages, tools, response_format, generation_params = ( + _get_completion_inputs(llm_request) + ) completion_args = { "model": self.model, @@ -668,7 +705,13 @@ async def generate_content_async( "tools": tools, "response_format": response_format, } - completion_args.update(self._additional_args) + + # Merge additional arguments and generation parameters safely + if hasattr(self, "_additional_args") and self._additional_args: + completion_args.update(self._additional_args) + + if generation_params: + completion_args.update(generation_params) if stream: text = "" diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index f316e83ae..e600ee7f0 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,7 +13,6 @@ # limitations under the License. -import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -1430,3 +1429,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +def test_get_completion_inputs_generation_params(): + # Test that generation_params are extracted and mapped correctly + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]), + ], + config=types.GenerateContentConfig( + temperature=0.33, + max_output_tokens=123, + top_p=0.88, + top_k=7, + stop_sequences=["foo", "bar"], + presence_penalty=0.1, + frequency_penalty=0.2, + ), + ) + from google.adk.models.lite_llm import _get_completion_inputs + + _, _, _, generation_params = _get_completion_inputs(req) + assert generation_params["temperature"] == 0.33 + assert generation_params["max_completion_tokens"] == 123 + assert generation_params["top_p"] == 0.88 + assert generation_params["top_k"] == 7 + assert generation_params["stop"] == ["foo", "bar"] + assert generation_params["presence_penalty"] == 0.1 + assert generation_params["frequency_penalty"] == 0.2 + # Should not include max_output_tokens + assert "max_output_tokens" not in generation_params + assert "stop_sequences" not in generation_params From 1cfc555e70a45bdd23d0741176b3f17300f4b4ab Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Mon, 16 Jun 2025 10:19:37 -0700 Subject: [PATCH 05/79] ADK changes PiperOrigin-RevId: 772078053 --- src/google/adk/models/lite_llm.py | 65 +++++--------------------- tests/unittests/models/test_litellm.py | 33 +------------ 2 files changed, 12 insertions(+), 86 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index c954711ad..ed54faecf 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,7 +23,6 @@ from typing import Dict from typing import Generator from typing import Iterable -from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -482,22 +481,16 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> Tuple[ - List[Message], - Optional[List[Dict]], - Optional[types.SchemaUnion], - Optional[Dict], -]: - """Converts an LlmRequest to litellm inputs and extracts generation params. +) -> tuple[Iterable[Message], Iterable[dict]]: + """Converts an LlmRequest to litellm inputs. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary, response format, and generation params). + The litellm inputs (message list, tool dictionary and response format). """ - # 1. Construct messages - messages: List[Message] = [] + messages = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -514,8 +507,7 @@ def _get_completion_inputs( ), ) - # 2. Convert tool declarations - tools: Optional[List[Dict]] = None + tools = None if ( llm_request.config and llm_request.config.tools @@ -526,39 +518,12 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - # 3. Handle response format - response_format: Optional[types.SchemaUnion] = ( - llm_request.config.response_schema if llm_request.config else None - ) - - # 4. Extract generation parameters - generation_params: Optional[Dict] = None - if llm_request.config: - config_dict = llm_request.config.model_dump(exclude_none=True) - # Generate LiteLlm parameters here, - # Following https://docs.litellm.ai/docs/completion/input. - generation_params = {} - param_mapping = { - "max_output_tokens": "max_completion_tokens", - "stop_sequences": "stop", - } - for key in ( - "temperature", - "max_output_tokens", - "top_p", - "top_k", - "stop_sequences", - "presence_penalty", - "frequency_penalty", - ): - if key in config_dict: - mapped_key = param_mapping.get(key, key) - generation_params[mapped_key] = config_dict[key] + response_format = None - if not generation_params: - generation_params = None + if llm_request.config.response_schema: + response_format = llm_request.config.response_schema - return messages, tools, response_format, generation_params + return messages, tools, response_format def _build_function_declaration_log( @@ -695,9 +660,7 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format, generation_params = ( - _get_completion_inputs(llm_request) - ) + messages, tools, response_format = _get_completion_inputs(llm_request) completion_args = { "model": self.model, @@ -705,13 +668,7 @@ async def generate_content_async( "tools": tools, "response_format": response_format, } - - # Merge additional arguments and generation parameters safely - if hasattr(self, "_additional_args") and self._additional_args: - completion_args.update(self._additional_args) - - if generation_params: - completion_args.update(generation_params) + completion_args.update(self._additional_args) if stream: text = "" diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index e600ee7f0..f316e83ae 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,6 +13,7 @@ # limitations under the License. +import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -1429,35 +1430,3 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} - - -@pytest.mark.asyncio -def test_get_completion_inputs_generation_params(): - # Test that generation_params are extracted and mapped correctly - req = LlmRequest( - contents=[ - types.Content(role="user", parts=[types.Part.from_text(text="hi")]), - ], - config=types.GenerateContentConfig( - temperature=0.33, - max_output_tokens=123, - top_p=0.88, - top_k=7, - stop_sequences=["foo", "bar"], - presence_penalty=0.1, - frequency_penalty=0.2, - ), - ) - from google.adk.models.lite_llm import _get_completion_inputs - - _, _, _, generation_params = _get_completion_inputs(req) - assert generation_params["temperature"] == 0.33 - assert generation_params["max_completion_tokens"] == 123 - assert generation_params["top_p"] == 0.88 - assert generation_params["top_k"] == 7 - assert generation_params["stop"] == ["foo", "bar"] - assert generation_params["presence_penalty"] == 0.1 - assert generation_params["frequency_penalty"] == 0.2 - # Should not include max_output_tokens - assert "max_output_tokens" not in generation_params - assert "stop_sequences" not in generation_params From 8201f9aebd62ab4cf1ab36e08e475c9aba3ffb57 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Mon, 16 Jun 2025 12:06:30 -0700 Subject: [PATCH 06/79] chore: Added live-streaming sample agent Also added a readme. PiperOrigin-RevId: 772120698 --- .../live_bidi_streaming_agent/__init__.py | 15 +++ .../live_bidi_streaming_agent/agent.py | 104 ++++++++++++++++++ .../live_bidi_streaming_agent/readme.md | 37 +++++++ 3 files changed, 156 insertions(+) create mode 100755 contributing/samples/live_bidi_streaming_agent/__init__.py create mode 100755 contributing/samples/live_bidi_streaming_agent/agent.py create mode 100644 contributing/samples/live_bidi_streaming_agent/readme.md diff --git a/contributing/samples/live_bidi_streaming_agent/__init__.py b/contributing/samples/live_bidi_streaming_agent/__init__.py new file mode 100755 index 000000000..c48963cdc --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/live_bidi_streaming_agent/agent.py b/contributing/samples/live_bidi_streaming_agent/agent.py new file mode 100755 index 000000000..2896bd70f --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/agent.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk import Agent +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +def roll_die(sides: int, tool_context: ToolContext) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if not 'rolls' in tool_context.state: + tool_context.state['rolls'] = [] + + tool_context.state['rolls'] = tool_context.state['rolls'] + [result] + return result + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + 'No prime numbers found.' + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model='gemini-2.0-flash-live-preview-04-09', # for Vertex project + # model='gemini-2.0-flash-live-001', # for AI studio key + name='hello_world_agent', + description=( + 'hello world agent that can roll a dice of 8 sides and check prime' + ' numbers.' + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) diff --git a/contributing/samples/live_bidi_streaming_agent/readme.md b/contributing/samples/live_bidi_streaming_agent/readme.md new file mode 100644 index 000000000..6a9258f3e --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/readme.md @@ -0,0 +1,37 @@ +# Simplistic Live (Bidi-Streaming) Agent +This project provides a basic example of a live, bidirectional streaming agent +designed for testing and experimentation. + +You can see full documentation [here](https://google.github.io/adk-docs/streaming/). + +## Getting Started + +Follow these steps to get the agent up and running: + +1. **Start the ADK Web Server** + Open your terminal, navigate to the root directory that contains the + `live_bidi_streaming_agent` folder, and execute the following command: + ```bash + adk web + ``` + +2. **Access the ADK Web UI** + Once the server is running, open your web browser and navigate to the URL + provided in the terminal (it will typically be `http://localhost:8000`). + +3. **Select the Agent** + In the top-left corner of the ADK Web UI, use the dropdown menu to select + this agent. + +4. **Start Streaming** + Click on either the **Audio** or **Video** icon located near the chat input + box to begin the streaming session. + +5. **Interact with the Agent** + You can now begin talking to the agent, and it will respond in real-time. + +## Usage Notes + +* You only need to click the **Audio** or **Video** button once to initiate the + stream. The current version does not support stopping and restarting the stream + by clicking the button again during a session. From fe1d5aa439cc56b89d248a52556c0a9b4cbd15e4 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Mon, 16 Jun 2025 14:35:16 -0700 Subject: [PATCH 07/79] feat: add enable_affective_dialog and proactivity to run_config and llm_request PiperOrigin-RevId: 772175206 --- src/google/adk/agents/run_config.py | 6 ++++++ src/google/adk/flows/llm_flows/basic.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index 5679f04e9..c9a50a0ae 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -73,6 +73,12 @@ class RunConfig(BaseModel): realtime_input_config: Optional[types.RealtimeInputConfig] = None """Realtime input config for live agents with audio input from user.""" + enable_affective_dialog: Optional[bool] = None + """If enabled, the model will detect emotions and adapt its responses accordingly.""" + + proactivity: Optional[types.ProactivityConfig] = None + """Configures the proactivity of the model. This allows the model to respond proactively to the input and to ignore irrelevant input.""" + max_llm_calls: int = 500 """ A limit on the total number of llm calls for a given run. diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 7efadd97e..ee5c83da1 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -68,6 +68,12 @@ async def run_async( llm_request.live_connect_config.realtime_input_config = ( invocation_context.run_config.realtime_input_config ) + llm_request.live_connect_config.enable_affective_dialog = ( + invocation_context.run_config.enable_affective_dialog + ) + llm_request.live_connect_config.proactivity = ( + invocation_context.run_config.proactivity + ) # TODO: handle tool append here, instead of in BaseTool.process_llm_request. From fef87784297b806914de307f48c51d83f977298f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Jun 2025 16:19:47 -0700 Subject: [PATCH 08/79] fix: liteLLM test failures Fix liteLLM test failures for function call responses. PiperOrigin-RevId: 772212629 --- tests/unittests/models/test_litellm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index f316e83ae..8b43cc48b 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1194,11 +1194,11 @@ async def test_generate_content_async_stream( assert responses[2].content.role == "model" assert responses[2].content.parts[0].text == "two:" assert responses[3].content.role == "model" - assert responses[3].content.parts[0].function_call.name == "test_function" - assert responses[3].content.parts[0].function_call.args == { + assert responses[3].content.parts[-1].function_call.name == "test_function" + assert responses[3].content.parts[-1].function_call.args == { "test_arg": "test_value" } - assert responses[3].content.parts[0].function_call.id == "test_tool_call_id" + assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" mock_completion.assert_called_once() _, kwargs = mock_completion.call_args @@ -1257,11 +1257,11 @@ async def test_generate_content_async_stream_with_usage_metadata( assert responses[2].content.role == "model" assert responses[2].content.parts[0].text == "two:" assert responses[3].content.role == "model" - assert responses[3].content.parts[0].function_call.name == "test_function" - assert responses[3].content.parts[0].function_call.args == { + assert responses[3].content.parts[-1].function_call.name == "test_function" + assert responses[3].content.parts[-1].function_call.args == { "test_arg": "test_value" } - assert responses[3].content.parts[0].function_call.id == "test_tool_call_id" + assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" assert responses[3].usage_metadata.prompt_token_count == 10 assert responses[3].usage_metadata.candidates_token_count == 5 From 31b81a342d3438b1efb7557e362b9288810033d5 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 16:21:32 -0700 Subject: [PATCH 09/79] chore: Update streamable http mcp example agent PiperOrigin-RevId: 772213323 --- contributing/samples/mcp_streamablehttp_agent/README.md | 3 +-- contributing/samples/mcp_streamablehttp_agent/agent.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/contributing/samples/mcp_streamablehttp_agent/README.md b/contributing/samples/mcp_streamablehttp_agent/README.md index 1c211dd71..547a0788d 100644 --- a/contributing/samples/mcp_streamablehttp_agent/README.md +++ b/contributing/samples/mcp_streamablehttp_agent/README.md @@ -1,8 +1,7 @@ -This agent connects to a local MCP server via sse. +This agent connects to a local MCP server via Streamable HTTP. To run this agent, start the local MCP server first by : ```bash uv run filesystem_server.py ``` - diff --git a/contributing/samples/mcp_streamablehttp_agent/agent.py b/contributing/samples/mcp_streamablehttp_agent/agent.py index 61d59e051..f165c4c1b 100644 --- a/contributing/samples/mcp_streamablehttp_agent/agent.py +++ b/contributing/samples/mcp_streamablehttp_agent/agent.py @@ -18,7 +18,6 @@ from google.adk.agents.llm_agent import LlmAgent from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPServerParams from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset -from google.adk.tools.mcp_tool.mcp_toolset import SseServerParams _allowed_path = os.path.dirname(os.path.abspath(__file__)) From 4bda24517163a52dc227b525f26d3d83ce36f1ec Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 16:25:56 -0700 Subject: [PATCH 10/79] chore: fix oauth_calendar_agent example PiperOrigin-RevId: 772214855 --- .../samples/oauth_calendar_agent/agent.py | 120 ++++++++++++------ 1 file changed, 79 insertions(+), 41 deletions(-) diff --git a/contributing/samples/oauth_calendar_agent/agent.py b/contributing/samples/oauth_calendar_agent/agent.py index 9d56d3ff8..a1b1dea87 100644 --- a/contributing/samples/oauth_calendar_agent/agent.py +++ b/contributing/samples/oauth_calendar_agent/agent.py @@ -27,8 +27,6 @@ from google.adk.auth import AuthCredentialTypes from google.adk.auth import OAuth2Auth from google.adk.tools import ToolContext -from google.adk.tools.authenticated_tool.base_authenticated_tool import AuthenticatedFunctionTool -from google.adk.tools.authenticated_tool.credentials_store import ToolContextCredentialsStore from google.adk.tools.google_api_tool import CalendarToolset from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials @@ -58,7 +56,6 @@ def list_calendar_events( end_time: str, limit: int, tool_context: ToolContext, - credential: AuthCredential, ) -> list[dict]: """Search for calendar events. @@ -83,11 +80,84 @@ def list_calendar_events( Returns: list[dict]: A list of events that match the search criteria. """ - - creds = Credentials( - token=credential.oauth2.access_token, - refresh_token=credential.oauth2.refresh_token, - ) + creds = None + + # Check if the tokes were already in the session state, which means the user + # has already gone through the OAuth flow and successfully authenticated and + # authorized the tool to access their calendar. + if "calendar_tool_tokens" in tool_context.state: + creds = Credentials.from_authorized_user_info( + tool_context.state["calendar_tool_tokens"], SCOPES + ) + if not creds or not creds.valid: + # If the access token is expired, refresh it with the refresh token. + if creds and creds.expired and creds.refresh_token: + creds.refresh(Request()) + else: + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://accounts.google.com/o/oauth2/auth", + tokenUrl="https://oauth2.googleapis.com/token", + scopes={ + "https://www.googleapis.com/auth/calendar": ( + "See, edit, share, and permanently delete all the" + " calendars you can access using Google Calendar" + ) + }, + ) + ) + ) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=oauth_client_id, client_secret=oauth_client_secret + ), + ) + # If the user has not gone through the OAuth flow before, or the refresh + # token also expired, we need to ask users to go through the OAuth flow. + # First we check whether the user has just gone through the OAuth flow and + # Oauth response is just passed back. + auth_response = tool_context.get_auth_response( + AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential + ) + ) + if auth_response: + # ADK exchanged the access token already for us + access_token = auth_response.oauth2.access_token + refresh_token = auth_response.oauth2.refresh_token + + creds = Credentials( + token=access_token, + refresh_token=refresh_token, + token_uri=auth_scheme.flows.authorizationCode.tokenUrl, + client_id=oauth_client_id, + client_secret=oauth_client_secret, + scopes=list(auth_scheme.flows.authorizationCode.scopes.keys()), + ) + else: + # If there are no auth response which means the user has not gone + # through the OAuth flow yet, we need to ask users to go through the + # OAuth flow. + tool_context.request_credential( + AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=auth_credential, + ) + ) + # The return value is optional and could be any dict object. It will be + # wrapped in a dict with key as 'result' and value as the return value + # if the object returned is not a dict. This response will be passed + # to LLM to generate a user friendly message. e.g. LLM will tell user: + # "I need your authorization to access your calendar. Please authorize + # me so I can check your meetings for today." + return "Need User Authorization to access their calendar." + # We store the access token and refresh token in the session state for the + # next runs. This is just an example. On production, a tool should store + # those credentials in some secure store or properly encrypt it before store + # it in the session state. + tool_context.state["calendar_tool_tokens"] = json.loads(creds.to_json()) service = build("calendar", "v3", credentials=creds) events_result = ( @@ -138,38 +208,6 @@ def update_time(callback_context: CallbackContext): Currnet time: {_time} """, - tools=[ - AuthenticatedFunctionTool( - func=list_calendar_events, - auth_config=AuthConfig( - auth_scheme=OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl=( - "https://accounts.google.com/o/oauth2/auth" - ), - tokenUrl="https://oauth2.googleapis.com/token", - scopes={ - "https://www.googleapis.com/auth/calendar": ( - "See, edit, share, and permanently delete" - " all the calendars you can access using" - " Google Calendar" - ) - }, - ) - ) - ), - raw_auth_credential=AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id=oauth_client_id, - client_secret=oauth_client_secret, - ), - ), - ), - credential_store=ToolContextCredentialsStore(), - ), - calendar_toolset, - ], + tools=[list_calendar_events, calendar_toolset], before_agent_callback=update_time, ) From aafa80bd85a49fb1c1a255ac797587cffd3fa567 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Jun 2025 16:36:34 -0700 Subject: [PATCH 11/79] fix: stream in litellm + adk and add corresponding integration tests Fixes https://github.com/google/adk-python/issues/1368 PiperOrigin-RevId: 772218385 --- src/google/adk/models/lite_llm.py | 3 +- .../models/test_litellm_no_function.py | 109 +++++++++++++++++- .../models/test_litellm_with_function.py | 25 ++-- 3 files changed, 125 insertions(+), 12 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index ed54faecf..dce5ed7c4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -739,11 +739,12 @@ async def generate_content_async( _message_to_generate_content_response( ChatCompletionAssistantMessage( role="assistant", - content="", + content=text, tool_calls=tool_calls, ) ) ) + text = "" function_calls.clear() elif finish_reason == "stop" and text: aggregated_llm_response = _message_to_generate_content_response( diff --git a/tests/integration/models/test_litellm_no_function.py b/tests/integration/models/test_litellm_no_function.py index e662384ce..ff5d3bb82 100644 --- a/tests/integration/models/test_litellm_no_function.py +++ b/tests/integration/models/test_litellm_no_function.py @@ -20,12 +20,26 @@ from google.genai.types import Part import pytest -_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """You are a helpful assistant.""" +def get_weather(city: str) -> str: + """Simulates a web search. Use it get information on weather. + + Args: + city: A string containing the location to get weather information for. + + Returns: + A string with the simulated weather information for the queried city. + """ + if "sf" in city.lower() or "san francisco" in city.lower(): + return "It's 70 degrees and foggy." + return "It's 80 degrees and sunny." + + @pytest.fixture def oss_llm(): return LiteLlm(model=_TEST_MODEL_NAME) @@ -44,6 +58,48 @@ def llm_request(): ) +@pytest.fixture +def llm_request_with_tools(): + return LlmRequest( + model=_TEST_MODEL_NAME, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="What is the weather in San Francisco?") + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction=_SYSTEM_PROMPT, + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="get_weather", + description="Get the weather in a given location", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "city": types.Schema( + type=types.Type.STRING, + description=( + "The city to get the weather for." + ), + ), + }, + required=["city"], + ), + ) + ] + ) + ], + ), + ) + + @pytest.mark.asyncio async def test_generate_content_async(oss_llm, llm_request): async for response in oss_llm.generate_content_async(llm_request): @@ -51,10 +107,8 @@ async def test_generate_content_async(oss_llm, llm_request): assert response.content.parts[0].text -# Note that, this test disabled streaming because streaming is not supported -# properly in the current test model for now. @pytest.mark.asyncio -async def test_generate_content_async_stream(oss_llm, llm_request): +async def test_generate_content_async(oss_llm, llm_request): responses = [ resp async for resp in oss_llm.generate_content_async( @@ -63,3 +117,50 @@ async def test_generate_content_async_stream(oss_llm, llm_request): ] part = responses[0].content.parts[0] assert len(part.text) > 0 + + +@pytest.mark.asyncio +async def test_generate_content_async_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=False + ) + ] + function_call = responses[0].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_async_stream(oss_llm, llm_request): + responses = [ + resp + async for resp in oss_llm.generate_content_async(llm_request, stream=True) + ] + text = "" + for i in range(len(responses) - 1): + assert responses[i].partial is True + assert responses[i].content.parts[0].text + text += responses[i].content.parts[0].text + + # Last message should be accumulated text + assert responses[-1].content.parts[0].text == text + assert not responses[-1].partial + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" diff --git a/tests/integration/models/test_litellm_with_function.py b/tests/integration/models/test_litellm_with_function.py index a2ceb540a..799c55e5c 100644 --- a/tests/integration/models/test_litellm_with_function.py +++ b/tests/integration/models/test_litellm_with_function.py @@ -13,7 +13,6 @@ # limitations under the License. from google.adk.models import LlmRequest -from google.adk.models import LlmResponse from google.adk.models.lite_llm import LiteLlm from google.genai import types from google.genai.types import Content @@ -23,12 +22,11 @@ litellm.add_function_to_prompt = True -_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" - +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """ You are a helpful assistant, and call tools optionally. -If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs. +If call tools, the tool format should be in json body, and the tool argument values should be parsed from users inputs. """ @@ -40,7 +38,7 @@ "properties": { "city": { "type": "string", - "description": "The city, e.g. San Francisco", + "description": "The city to get the weather for.", }, }, "required": ["city"], @@ -87,8 +85,6 @@ def llm_request(): ) -# Note that, this test disabled streaming because streaming is not supported -# properly in the current test model for now. @pytest.mark.asyncio async def test_generate_content_asyn_with_function( oss_llm_with_function, llm_request @@ -102,3 +98,18 @@ async def test_generate_content_asyn_with_function( function_call = responses[0].content.parts[0].function_call assert function_call.name == "get_weather" assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_asyn_stream_with_function( + oss_llm_with_function, llm_request +): + responses = [ + resp + async for resp in oss_llm_with_function.generate_content_async( + llm_request, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" From e384fa4ad76114fa942a3be8bd51bbeb5225e00e Mon Sep 17 00:00:00 2001 From: Liang Wu Date: Mon, 16 Jun 2025 16:57:05 -0700 Subject: [PATCH 12/79] chore: fix previously skipped isort issue PiperOrigin-RevId: 772224853 --- tests/integration/models/test_litellm_no_function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/models/test_litellm_no_function.py b/tests/integration/models/test_litellm_no_function.py index ff5d3bb82..05072b899 100644 --- a/tests/integration/models/test_litellm_no_function.py +++ b/tests/integration/models/test_litellm_no_function.py @@ -20,7 +20,6 @@ from google.genai.types import Part import pytest - _TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """You are a helpful assistant.""" From a6b1baa61b5dbf4168a035609077094307171135 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 17:04:27 -0700 Subject: [PATCH 13/79] chore: Add base credential exchanger (Experimental) PiperOrigin-RevId: 772227201 --- src/google/adk/auth/exchanger/__init__.py | 25 ++++++++++ .../exchanger/base_credential_exchanger.py | 49 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 src/google/adk/auth/exchanger/__init__.py create mode 100644 src/google/adk/auth/exchanger/base_credential_exchanger.py diff --git a/src/google/adk/auth/exchanger/__init__.py b/src/google/adk/auth/exchanger/__init__.py new file mode 100644 index 000000000..ce5c464c4 --- /dev/null +++ b/src/google/adk/auth/exchanger/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential exchanger module.""" + +from .base_credential_exchanger import BaseCredentialExchanger +from .credential_exchanger_registry import CredentialExchangerRegistry +from .service_account_credential_exchanger import ServiceAccountCredentialExchanger + +__all__ = [ + "BaseCredentialExchanger", + "CredentialExchangerRegistry", + "ServiceAccountCredentialExchanger", +] diff --git a/src/google/adk/auth/exchanger/base_credential_exchanger.py b/src/google/adk/auth/exchanger/base_credential_exchanger.py new file mode 100644 index 000000000..1d7417cd0 --- /dev/null +++ b/src/google/adk/auth/exchanger/base_credential_exchanger.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base credential exchanger interface.""" + +from __future__ import annotations + +import abc +from typing import Optional + +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_schemes import AuthScheme + + +@experimental +class BaseCredentialExchanger(abc.ABC): + """Base interface for credential exchangers.""" + + @abc.abstractmethod + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Exchange credential if needed. + + Args: + auth_credential: The credential to exchange. + auth_scheme: The authentication scheme (optional, some exchangers don't need it). + + Returns: + The exchanged credential. + + Raises: + ValueError: If credential exchange fails. + """ + pass From 28dfcd25128e4cab34764abd1451f15529c4626d Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 17:36:12 -0700 Subject: [PATCH 14/79] chore: Add experimental decorator to Oauth2 credential fethcer PiperOrigin-RevId: 772236406 --- src/google/adk/auth/oauth2_credential_fetcher.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/google/adk/auth/oauth2_credential_fetcher.py b/src/google/adk/auth/oauth2_credential_fetcher.py index 1a8692417..cbed70762 100644 --- a/src/google/adk/auth/oauth2_credential_fetcher.py +++ b/src/google/adk/auth/oauth2_credential_fetcher.py @@ -20,6 +20,7 @@ from fastapi.openapi.models import OAuth2 +from ..utils.feature_decorator import experimental from .auth_credential import AuthCredential from .auth_schemes import AuthScheme from .auth_schemes import OAuthGrantType @@ -37,8 +38,9 @@ logger = logging.getLogger("google_adk." + __name__) +@experimental class OAuth2CredentialFetcher: - """Exchanges and refreshes an OAuth2 access token.""" + """Exchanges and refreshes an OAuth2 access token. (Experimental)""" def __init__( self, From e2a81365ec18cb4ed1a5422513992ccc21962937 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 18:21:03 -0700 Subject: [PATCH 15/79] chore: Add a base credential refresher interface PiperOrigin-RevId: 772248299 --- .../exchanger/base_credential_exchanger.py | 12 ++- src/google/adk/auth/refresher/__init__.py | 21 ++++++ .../refresher/base_credential_refresher.py | 74 +++++++++++++++++++ 3 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 src/google/adk/auth/refresher/__init__.py create mode 100644 src/google/adk/auth/refresher/base_credential_refresher.py diff --git a/src/google/adk/auth/exchanger/base_credential_exchanger.py b/src/google/adk/auth/exchanger/base_credential_exchanger.py index 1d7417cd0..b09adb80a 100644 --- a/src/google/adk/auth/exchanger/base_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/base_credential_exchanger.py @@ -24,9 +24,17 @@ from ..auth_schemes import AuthScheme +class CredentialExchangError(Exception): + """Base exception for credential exchange errors.""" + + @experimental class BaseCredentialExchanger(abc.ABC): - """Base interface for credential exchangers.""" + """Base interface for credential exchangers. + + Credential exchangers are responsible for exchanging credentials from + one format or scheme to another. + """ @abc.abstractmethod async def exchange( @@ -44,6 +52,6 @@ async def exchange( The exchanged credential. Raises: - ValueError: If credential exchange fails. + CredentialExchangError: If credential exchange fails. """ pass diff --git a/src/google/adk/auth/refresher/__init__.py b/src/google/adk/auth/refresher/__init__.py new file mode 100644 index 000000000..27d7245dc --- /dev/null +++ b/src/google/adk/auth/refresher/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential refresher module.""" + +from .base_credential_refresher import BaseCredentialRefresher + +__all__ = [ + "BaseCredentialRefresher", +] diff --git a/src/google/adk/auth/refresher/base_credential_refresher.py b/src/google/adk/auth/refresher/base_credential_refresher.py new file mode 100644 index 000000000..230b07d09 --- /dev/null +++ b/src/google/adk/auth/refresher/base_credential_refresher.py @@ -0,0 +1,74 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base credential refresher interface.""" + +from __future__ import annotations + +import abc +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.utils.feature_decorator import experimental + + +class CredentialRefresherError(Exception): + """Base exception for credential refresh errors.""" + + +@experimental +class BaseCredentialRefresher(abc.ABC): + """Base interface for credential refreshers. + + Credential refreshers are responsible for checking if a credential is expired + or needs to be refreshed, and for refreshing it if necessary. + """ + + @abc.abstractmethod + async def is_refresh_needed( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> bool: + """Checks if a credential needs to be refreshed. + + Args: + auth_credential: The credential to check. + auth_scheme: The authentication scheme (optional, some refreshers don't need it). + + Returns: + True if the credential needs to be refreshed, False otherwise. + """ + pass + + @abc.abstractmethod + async def refresh( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Refreshes a credential if needed. + + Args: + auth_credential: The credential to refresh. + auth_scheme: The authentication scheme (optional, some refreshers don't need it). + + Returns: + The refreshed credential. + + Raises: + CredentialRefresherError: If credential refresh fails. + """ + pass From 476805d5b9e6d598ca8bb71488a4923c162cfdbc Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 18:33:41 -0700 Subject: [PATCH 16/79] chore: Add a2a extra dependency for github UT workflows PiperOrigin-RevId: 772251530 --- .github/workflows/python-unit-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index a504fde0d..0d77402f9 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -43,7 +43,7 @@ jobs: run: | uv venv .venv source .venv/bin/activate - uv sync --extra test --extra eval + uv sync --extra test --extra eval --extra a2a - name: Run unit tests with pytest run: | From 94caccc148833c135b9b60af3c4c54986b10406c Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 19:02:31 -0700 Subject: [PATCH 17/79] refactor: Extract util method from OAuth2 credential fetcher for reuse Context: we'd like to separate fetcher into exchanger and refresher later. This cl help to extract the common utility that will be used by both exchanger and refresher. PiperOrigin-RevId: 772257995 --- .../adk/auth/oauth2_credential_fetcher.py | 59 +--- src/google/adk/auth/oauth2_credential_util.py | 107 ++++++ tests/unittests/auth/test_auth_handler.py | 2 +- .../auth/test_oauth2_credential_fetcher.py | 332 +----------------- .../auth/test_oauth2_credential_util.py | 147 ++++++++ 5 files changed, 284 insertions(+), 363 deletions(-) create mode 100644 src/google/adk/auth/oauth2_credential_util.py create mode 100644 tests/unittests/auth/test_oauth2_credential_util.py diff --git a/src/google/adk/auth/oauth2_credential_fetcher.py b/src/google/adk/auth/oauth2_credential_fetcher.py index cbed70762..c9e838b25 100644 --- a/src/google/adk/auth/oauth2_credential_fetcher.py +++ b/src/google/adk/auth/oauth2_credential_fetcher.py @@ -15,19 +15,15 @@ from __future__ import annotations import logging -from typing import Optional -from typing import Tuple - -from fastapi.openapi.models import OAuth2 from ..utils.feature_decorator import experimental from .auth_credential import AuthCredential from .auth_schemes import AuthScheme from .auth_schemes import OAuthGrantType -from .auth_schemes import OpenIdConnectWithConfig +from .oauth2_credential_util import create_oauth2_session +from .oauth2_credential_util import update_credential_with_tokens try: - from authlib.integrations.requests_client import OAuth2Session from authlib.oauth2.rfc6749 import OAuth2Token AUTHLIB_AVIALABLE = True @@ -50,45 +46,6 @@ def __init__( self._auth_scheme = auth_scheme self._auth_credential = auth_credential - def _oauth2_session(self) -> Tuple[Optional[OAuth2Session], Optional[str]]: - auth_scheme = self._auth_scheme - auth_credential = self._auth_credential - - if isinstance(auth_scheme, OpenIdConnectWithConfig): - if not hasattr(auth_scheme, "token_endpoint"): - return None, None - token_endpoint = auth_scheme.token_endpoint - scopes = auth_scheme.scopes - elif isinstance(auth_scheme, OAuth2): - if ( - not auth_scheme.flows.authorizationCode - or not auth_scheme.flows.authorizationCode.tokenUrl - ): - return None, None - token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl - scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) - else: - return None, None - - if ( - not auth_credential - or not auth_credential.oauth2 - or not auth_credential.oauth2.client_id - or not auth_credential.oauth2.client_secret - ): - return None, None - - return ( - OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, - scope=" ".join(scopes), - redirect_uri=auth_credential.oauth2.redirect_uri, - state=auth_credential.oauth2.state, - ), - token_endpoint, - ) - def _update_credential(self, tokens: OAuth2Token) -> None: self._auth_credential.oauth2.access_token = tokens.get("access_token") self._auth_credential.oauth2.refresh_token = tokens.get("refresh_token") @@ -114,7 +71,9 @@ def exchange(self) -> AuthCredential: ): return self._auth_credential - client, token_endpoint = self._oauth2_session() + client, token_endpoint = create_oauth2_session( + self._auth_scheme, self._auth_credential + ) if not client: logger.warning("Could not create OAuth2 session for token exchange") return self._auth_credential @@ -126,7 +85,7 @@ def exchange(self) -> AuthCredential: code=self._auth_credential.oauth2.auth_code, grant_type=OAuthGrantType.AUTHORIZATION_CODE, ) - self._update_credential(tokens) + update_credential_with_tokens(self._auth_credential, tokens) logger.info("Successfully exchanged OAuth2 tokens") except Exception as e: logger.error("Failed to exchange OAuth2 tokens: %s", e) @@ -151,7 +110,9 @@ def refresh(self) -> AuthCredential: "expires_at": credential.oauth2.expires_at, "expires_in": credential.oauth2.expires_in, }).is_expired(): - client, token_endpoint = self._oauth2_session() + client, token_endpoint = create_oauth2_session( + self._auth_scheme, self._auth_credential + ) if not client: logger.warning("Could not create OAuth2 session for token refresh") return credential @@ -161,7 +122,7 @@ def refresh(self) -> AuthCredential: url=token_endpoint, refresh_token=credential.oauth2.refresh_token, ) - self._update_credential(tokens) + update_credential_with_tokens(self._auth_credential, tokens) logger.info("Successfully refreshed OAuth2 tokens") except Exception as e: logger.error("Failed to refresh OAuth2 tokens: %s", e) diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py new file mode 100644 index 000000000..51ed4d29f --- /dev/null +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional +from typing import Tuple + +from fastapi.openapi.models import OAuth2 + +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_schemes import AuthScheme +from .auth_schemes import OpenIdConnectWithConfig + +try: + from authlib.integrations.requests_client import OAuth2Session + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +def create_oauth2_session( + auth_scheme: AuthScheme, + auth_credential: AuthCredential, +) -> Tuple[Optional[OAuth2Session], Optional[str]]: + """Create an OAuth2 session for token operations. + + Args: + auth_scheme: The authentication scheme configuration. + auth_credential: The authentication credential. + + Returns: + Tuple of (OAuth2Session, token_endpoint) or (None, None) if cannot create session. + """ + if isinstance(auth_scheme, OpenIdConnectWithConfig): + if not hasattr(auth_scheme, "token_endpoint"): + return None, None + token_endpoint = auth_scheme.token_endpoint + scopes = auth_scheme.scopes + elif isinstance(auth_scheme, OAuth2): + if ( + not auth_scheme.flows.authorizationCode + or not auth_scheme.flows.authorizationCode.tokenUrl + ): + return None, None + token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl + scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) + else: + return None, None + + if ( + not auth_credential + or not auth_credential.oauth2 + or not auth_credential.oauth2.client_id + or not auth_credential.oauth2.client_secret + ): + return None, None + + return ( + OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + scope=" ".join(scopes), + redirect_uri=auth_credential.oauth2.redirect_uri, + state=auth_credential.oauth2.state, + ), + token_endpoint, + ) + + +@experimental +def update_credential_with_tokens( + auth_credential: AuthCredential, tokens: OAuth2Token +) -> None: + """Update the credential with new tokens. + + Args: + auth_credential: The authentication credential to update. + tokens: The OAuth2Token object containing new token information. + """ + auth_credential.oauth2.access_token = tokens.get("access_token") + auth_credential.oauth2.refresh_token = tokens.get("refresh_token") + auth_credential.oauth2.expires_at = ( + int(tokens.get("expires_at")) if tokens.get("expires_at") else None + ) + auth_credential.oauth2.expires_in = ( + int(tokens.get("expires_in")) if tokens.get("expires_in") else None + ) diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index aaed35a19..2bfc7d4c9 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -538,7 +538,7 @@ def test_credentials_with_token( assert result == oauth2_credentials_with_token @patch( - "google.adk.auth.oauth2_credential_fetcher.OAuth2Session", + "google.adk.auth.oauth2_credential_util.OAuth2Session", MockOAuth2Session, ) def test_successful_token_exchange(self, auth_config_with_auth_code): diff --git a/tests/unittests/auth/test_oauth2_credential_fetcher.py b/tests/unittests/auth/test_oauth2_credential_fetcher.py index 0b9b5a3c1..aba6a9923 100644 --- a/tests/unittests/auth/test_oauth2_credential_fetcher.py +++ b/tests/unittests/auth/test_oauth2_credential_fetcher.py @@ -14,7 +14,6 @@ import time from unittest.mock import Mock -from unittest.mock import patch from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import OAuth2 @@ -24,38 +23,15 @@ from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_schemes import OpenIdConnectWithConfig -from google.adk.auth.oauth2_credential_fetcher import OAuth2CredentialFetcher +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens -class TestOAuth2CredentialFetcher: - """Test suite for OAuth2CredentialFetcher.""" +class TestOAuth2CredentialUtil: + """Test suite for OAuth2 credential utility functions.""" - def test_init(self): - """Test OAuth2CredentialFetcher initialization.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - assert fetcher._auth_scheme == scheme - assert fetcher._auth_credential == credential - - def test_oauth2_session_openid_connect(self): - """Test _oauth2_session with OpenID Connect scheme.""" + def test_create_oauth2_session_openid_connect(self): + """Test create_oauth2_session with OpenID Connect scheme.""" scheme = OpenIdConnectWithConfig( type_="openIdConnect", openId_connect_url=( @@ -75,16 +51,15 @@ def test_oauth2_session_openid_connect(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is not None assert token_endpoint == "https://example.com/token" assert client.client_id == "test_client_id" assert client.client_secret == "test_client_secret" - def test_oauth2_session_oauth2_scheme(self): - """Test _oauth2_session with OAuth2 scheme.""" + def test_create_oauth2_session_oauth2_scheme(self): + """Test create_oauth2_session with OAuth2 scheme.""" flows = OAuthFlows( authorizationCode=OAuthFlowAuthorizationCode( authorizationUrl="https://example.com/auth", @@ -102,14 +77,13 @@ def test_oauth2_session_oauth2_scheme(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is not None assert token_endpoint == "https://example.com/token" - def test_oauth2_session_invalid_scheme(self): - """Test _oauth2_session with invalid scheme.""" + def test_create_oauth2_session_invalid_scheme(self): + """Test create_oauth2_session with invalid scheme.""" scheme = Mock() # Invalid scheme type credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, @@ -119,14 +93,13 @@ def test_oauth2_session_invalid_scheme(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is None assert token_endpoint is None - def test_oauth2_session_missing_credentials(self): - """Test _oauth2_session with missing credentials.""" + def test_create_oauth2_session_missing_credentials(self): + """Test create_oauth2_session with missing credentials.""" scheme = OpenIdConnectWithConfig( type_="openIdConnect", openId_connect_url=( @@ -144,23 +117,13 @@ def test_oauth2_session_missing_credentials(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is None assert token_endpoint is None - def test_update_credential(self): - """Test _update_credential method.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) + def test_update_credential_with_tokens(self): + """Test update_credential_with_tokens function.""" credential = AuthCredential( auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, oauth2=OAuth2Auth( @@ -169,7 +132,6 @@ def test_update_credential(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) tokens = OAuth2Token({ "access_token": "new_access_token", "refresh_token": "new_refresh_token", @@ -177,265 +139,9 @@ def test_update_credential(self): "expires_in": 3600, }) - fetcher._update_credential(tokens) + update_credential_with_tokens(credential, tokens) assert credential.oauth2.access_token == "new_access_token" assert credential.oauth2.refresh_token == "new_refresh_token" assert credential.oauth2.expires_at == int(time.time()) + 3600 assert credential.oauth2.expires_in == 3600 - - def test_exchange_with_existing_token(self): - """Test exchange method when access token already exists.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="existing_token", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result == credential - assert result.oauth2.access_token == "existing_token" - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_exchange_success(self, mock_oauth2_session): - """Test successful token exchange.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_response_uri=( - "https://example.com/callback?code=auth_code&state=test_state" - ), - ), - ) - - # Mock the OAuth2Session - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - } - mock_client.fetch_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result.oauth2.access_token == "new_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" - mock_client.fetch_token.assert_called_once() - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_exchange_with_auth_code(self, mock_oauth2_session): - """Test token exchange with auth code.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_code="test_auth_code", - ), - ) - - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - } - mock_client.fetch_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result.oauth2.access_token == "new_access_token" - mock_client.fetch_token.assert_called_once() - - def test_exchange_no_session(self): - """Test exchange when OAuth2Session cannot be created.""" - scheme = Mock() # Invalid scheme - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_response_uri="https://example.com/callback?code=auth_code", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result == credential - assert result.oauth2.access_token is None - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_refresh_token_not_expired( - self, mock_oauth2_session, mock_oauth2_token - ): - """Test refresh when token is not expired.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="current_token", - refresh_token="refresh_token", - expires_at=int(time.time()) + 3600, - expires_in=3600, - ), - ) - - # Mock token not expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = False - mock_oauth2_token.return_value = mock_token_instance - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - assert result.oauth2.access_token == "current_token" - mock_oauth2_session.assert_not_called() - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_refresh_token_expired_success( - self, mock_oauth2_session, mock_oauth2_token - ): - """Test successful token refresh when token is expired.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="expired_token", - refresh_token="refresh_token", - expires_at=int(time.time()) - 3600, # Expired - expires_in=3600, - ), - ) - - # Mock token expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = True - mock_oauth2_token.return_value = mock_token_instance - - # Mock refresh token response - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "refreshed_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - } - mock_client.refresh_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result.oauth2.access_token == "refreshed_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" - mock_client.refresh_token.assert_called_once_with( - url="https://example.com/token", - refresh_token="refresh_token", - ) - - def test_refresh_no_oauth2_credential(self): - """Test refresh when oauth2 credential is missing.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP) # No oauth2 - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - def test_refresh_no_session(self, mock_oauth2_token): - """Test refresh when OAuth2Session cannot be created.""" - scheme = Mock() # Invalid scheme - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="expired_token", - refresh_token="refresh_token", - expires_at=int(time.time()) - 3600, - ), - ) - - # Mock token expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = True - mock_oauth2_token.return_value = mock_token_instance - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - assert result.oauth2.access_token == "expired_token" # Unchanged diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py new file mode 100644 index 000000000..aba6a9923 --- /dev/null +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock + +from authlib.oauth2.rfc6749 import OAuth2Token +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens + + +class TestOAuth2CredentialUtil: + """Test suite for OAuth2 credential utility functions.""" + + def test_create_oauth2_session_openid_connect(self): + """Test create_oauth2_session with OpenID Connect scheme.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + + def test_create_oauth2_session_oauth2_scheme(self): + """Test create_oauth2_session with OAuth2 scheme.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + scheme = OAuth2(type_="oauth2", flows=flows) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + + def test_create_oauth2_session_invalid_scheme(self): + """Test create_oauth2_session with invalid scheme.""" + scheme = Mock() # Invalid scheme type + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_create_oauth2_session_missing_credentials(self): + """Test create_oauth2_session with missing credentials.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_update_credential_with_tokens(self): + """Test update_credential_with_tokens function.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + + update_credential_with_tokens(credential, tokens) + + assert credential.oauth2.access_token == "new_access_token" + assert credential.oauth2.refresh_token == "new_refresh_token" + assert credential.oauth2.expires_at == int(time.time()) + 3600 + assert credential.oauth2.expires_in == 3600 From c755cf23c555a2173b0eafd774cc0cc027b5f3da Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 19:53:08 -0700 Subject: [PATCH 18/79] chore: Ignore a2a ut tests for python 3.9 given a2a-sdk only supports 3.10+ PiperOrigin-RevId: 772270172 --- .github/workflows/python-unit-tests.yml | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 0d77402f9..d4af7b13a 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -48,6 +48,13 @@ jobs: - name: Run unit tests with pytest run: | source .venv/bin/activate - pytest tests/unittests \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + if [[ "${{ matrix.python-version }}" == "3.9" ]]; then + pytest tests/unittests \ + --ignore=tests/unittests/a2a \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + else + pytest tests/unittests \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + fi \ No newline at end of file From e1812797ad499a2503275e41d28b07338ca951f9 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 20:34:37 -0700 Subject: [PATCH 19/79] chore: Add A2A Part converter (WIP) PiperOrigin-RevId: 772282116 --- src/google/adk/a2a/converters/__init__.py | 13 + .../adk/a2a/converters/part_converter.py | 166 +++++++ tests/unittests/a2a/__init__.py | 13 + tests/unittests/a2a/converters/__init__.py | 13 + .../a2a/converters/test_part_converter.py | 443 ++++++++++++++++++ 5 files changed, 648 insertions(+) create mode 100644 src/google/adk/a2a/converters/__init__.py create mode 100644 src/google/adk/a2a/converters/part_converter.py create mode 100644 tests/unittests/a2a/__init__.py create mode 100644 tests/unittests/a2a/converters/__init__.py create mode 100644 tests/unittests/a2a/converters/test_part_converter.py diff --git a/src/google/adk/a2a/converters/__init__.py b/src/google/adk/a2a/converters/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/a2a/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py new file mode 100644 index 000000000..1c51fd7c1 --- /dev/null +++ b/src/google/adk/a2a/converters/part_converter.py @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +module containing utilities for conversion betwen A2A Part and Google GenAI Part +""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from a2a import types as a2a_types +from google.genai import types as genai_types + +from ...utils.feature_decorator import working_in_progress + +logger = logging.getLogger('google_adk.' + __name__) + +A2A_DATA_PART_METADATA_TYPE_KEY = 'type' +A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = 'function_call' +A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response' + + +@working_in_progress +def convert_a2a_part_to_genai_part( + a2a_part: a2a_types.Part, +) -> Optional[genai_types.Part]: + """Convert an A2A Part to a Google GenAI Part.""" + part = a2a_part.root + if isinstance(part, a2a_types.TextPart): + return genai_types.Part(text=part.text) + + if isinstance(part, a2a_types.FilePart): + if isinstance(part.file, a2a_types.FileWithUri): + return genai_types.Part( + file_data=genai_types.FileData( + file_uri=part.file.uri, mime_type=part.file.mimeType + ) + ) + + elif isinstance(part.file, a2a_types.FileWithBytes): + return genai_types.Part( + inline_data=genai_types.Blob( + data=part.file.bytes.encode('utf-8'), mime_type=part.file.mimeType + ) + ) + else: + logger.warning( + 'Cannot convert unsupported file type: %s for A2A part: %s', + type(part.file), + a2a_part, + ) + return None + + if isinstance(part, a2a_types.DataPart): + # Conver the Data Part to funcall and function reponse. + # This is mainly for converting human in the loop and auth request and + # response. + # TODO once A2A defined how to suervice such information, migrate below + # logic accordinlgy + if part.metadata and A2A_DATA_PART_METADATA_TYPE_KEY in part.metadata: + if ( + part.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ): + return genai_types.Part( + function_call=genai_types.FunctionCall.model_validate( + part.data, by_alias=True + ) + ) + if ( + part.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ): + return genai_types.Part( + function_response=genai_types.FunctionResponse.model_validate( + part.data, by_alias=True + ) + ) + return genai_types.Part(text=json.dumps(part.data)) + + logger.warning( + 'Cannot convert unsupported part type: %s for A2A part: %s', + type(part), + a2a_part, + ) + return None + + +@working_in_progress +def convert_genai_part_to_a2a_part( + part: genai_types.Part, +) -> Optional[a2a_types.Part]: + """Convert a Google GenAI Part to an A2A Part.""" + if part.text: + return a2a_types.TextPart(text=part.text) + + if part.file_data: + return a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=part.file_data.file_uri, + mimeType=part.file_data.mime_type, + ) + ) + + if part.inline_data: + return a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=part.inline_data.data, + mimeType=part.inline_data.mime_type, + ) + ) + ) + + # Conver the funcall and function reponse to A2A DataPart. + # This is mainly for converting human in the loop and auth request and + # response. + # TODO once A2A defined how to suervice such information, migrate below + # logic accordinlgy + if part.function_call: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.function_call.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + }, + ) + ) + + if part.function_response: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.function_response.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + }, + ) + ) + + logger.warning( + 'Cannot convert unsupported part for Google GenAI part: %s', + part, + ) + return None diff --git a/tests/unittests/a2a/__init__.py b/tests/unittests/a2a/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/converters/__init__.py b/tests/unittests/a2a/converters/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py new file mode 100644 index 000000000..5ad6cd62d --- /dev/null +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -0,0 +1,443 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import Mock +from unittest.mock import patch + +from a2a import types as a2a_types +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part +from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part +from google.genai import types as genai_types +import pytest + + +class TestConvertA2aPartToGenaiPart: + """Test cases for convert_a2a_part_to_genai_part function.""" + + def test_convert_text_part(self): + """Test conversion of A2A TextPart to GenAI Part.""" + # Arrange + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="Hello, world!")) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == "Hello, world!" + + def test_convert_file_part_with_uri(self): + """Test conversion of A2A FilePart with URI to GenAI Part.""" + # Arrange + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri="gs://bucket/file.txt", mimeType="text/plain" + ) + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.file_data is not None + assert result.file_data.file_uri == "gs://bucket/file.txt" + assert result.file_data.mime_type == "text/plain" + + def test_convert_file_part_with_bytes(self): + """Test conversion of A2A FilePart with bytes to GenAI Part.""" + # Arrange + test_bytes = b"test file content" + # Note: A2A FileWithBytes converts bytes to string automatically + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=test_bytes, mimeType="text/plain" + ) + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.inline_data is not None + # Source code now properly converts A2A string back to bytes for GenAI Blob + assert result.inline_data.data == test_bytes + assert result.inline_data.mime_type == "text/plain" + + def test_convert_data_part_function_call(self): + """Test conversion of A2A DataPart with function call metadata.""" + # Arrange + function_call_data = { + "name": "test_function", + "args": {"param1": "value1", "param2": 42}, + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=function_call_data, + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.function_call is not None + assert result.function_call.name == "test_function" + assert result.function_call.args == {"param1": "value1", "param2": 42} + + def test_convert_data_part_function_response(self): + """Test conversion of A2A DataPart with function response metadata.""" + # Arrange + function_response_data = { + "name": "test_function", + "response": {"result": "success", "data": [1, 2, 3]}, + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=function_response_data, + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.function_response is not None + assert result.function_response.name == "test_function" + assert result.function_response.response == { + "result": "success", + "data": [1, 2, 3], + } + + def test_convert_data_part_without_special_metadata(self): + """Test conversion of A2A DataPart without special metadata to text.""" + # Arrange + data = {"key": "value", "number": 123} + a2a_part = a2a_types.Part( + root=a2a_types.DataPart(data=data, metadata={"other": "metadata"}) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == json.dumps(data) + + def test_convert_data_part_no_metadata(self): + """Test conversion of A2A DataPart with no metadata to text.""" + # Arrange + data = {"key": "value", "array": [1, 2, 3]} + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data)) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == json.dumps(data) + + def test_convert_unsupported_file_type(self): + """Test handling of unsupported file types.""" + + # Arrange - Create a mock unsupported file type + class UnsupportedFileType: + pass + + # Create a part manually since FilePart validation might reject it + mock_file_part = Mock() + mock_file_part.file = UnsupportedFileType() + a2a_part = Mock() + a2a_part.root = mock_file_part + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + def test_convert_unsupported_part_type(self): + """Test handling of unsupported part types.""" + + # Arrange - Create a mock unsupported part type + class UnsupportedPartType: + pass + + mock_part = Mock() + mock_part.root = UnsupportedPartType() + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_a2a_part_to_genai_part(mock_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + +class TestConvertGenaiPartToA2aPart: + """Test cases for convert_genai_part_to_a2a_part function.""" + + def test_convert_text_part(self): + """Test conversion of GenAI text Part to A2A Part.""" + # Arrange + genai_part = genai_types.Part(text="Hello, world!") + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.TextPart) + assert result.text == "Hello, world!" + + def test_convert_file_data_part(self): + """Test conversion of GenAI file_data Part to A2A Part.""" + # Arrange + genai_part = genai_types.Part( + file_data=genai_types.FileData( + file_uri="gs://bucket/file.txt", mime_type="text/plain" + ) + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.FilePart) + assert isinstance(result.file, a2a_types.FileWithUri) + assert result.file.uri == "gs://bucket/file.txt" + assert result.file.mimeType == "text/plain" + + def test_convert_inline_data_part(self): + """Test conversion of GenAI inline_data Part to A2A Part.""" + # Arrange + test_bytes = b"test file content" + genai_part = genai_types.Part( + inline_data=genai_types.Blob(data=test_bytes, mime_type="text/plain") + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithBytes) + # A2A FileWithBytes stores bytes as strings + assert result.root.file.bytes == test_bytes.decode("utf-8") + assert result.root.file.mimeType == "text/plain" + + def test_convert_function_call_part(self): + """Test conversion of GenAI function_call Part to A2A Part.""" + # Arrange + function_call = genai_types.FunctionCall( + name="test_function", args={"param1": "value1", "param2": 42} + ) + genai_part = genai_types.Part(function_call=function_call) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = function_call.model_dump(by_alias=True, exclude_none=True) + assert result.root.data == expected_data + assert ( + result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + + def test_convert_function_response_part(self): + """Test conversion of GenAI function_response Part to A2A Part.""" + # Arrange + function_response = genai_types.FunctionResponse( + name="test_function", response={"result": "success", "data": [1, 2, 3]} + ) + genai_part = genai_types.Part(function_response=function_response) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = function_response.model_dump( + by_alias=True, exclude_none=True + ) + assert result.root.data == expected_data + assert ( + result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + + def test_convert_unsupported_part(self): + """Test handling of unsupported GenAI Part types.""" + # Arrange - Create a GenAI Part with no recognized fields + genai_part = genai_types.Part() + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + +class TestRoundTripConversions: + """Test cases for round-trip conversions to ensure consistency.""" + + def test_text_part_round_trip(self): + """Test round-trip conversion for text parts.""" + # Arrange + original_text = "Hello, world!" + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text=original_text)) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.TextPart) + assert result_a2a_part.text == original_text + + def test_file_uri_round_trip(self): + """Test round-trip conversion for file parts with URI.""" + # Arrange + original_uri = "gs://bucket/file.txt" + original_mime_type = "text/plain" + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=original_uri, mimeType=original_mime_type + ) + ) + ) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.FilePart) + assert isinstance(result_a2a_part.file, a2a_types.FileWithUri) + assert result_a2a_part.file.uri == original_uri + assert result_a2a_part.file.mimeType == original_mime_type + + +class TestEdgeCases: + """Test cases for edge cases and error conditions.""" + + def test_empty_text_part(self): + """Test conversion of empty text part.""" + # Arrange + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="")) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == "" + + def test_none_input_a2a_to_genai(self): + """Test handling of None input for A2A to GenAI conversion.""" + # This test depends on how the function handles None input + # If it should raise an exception, we test for that + with pytest.raises(AttributeError): + convert_a2a_part_to_genai_part(None) + + def test_none_input_genai_to_a2a(self): + """Test handling of None input for GenAI to A2A conversion.""" + # This test depends on how the function handles None input + # If it should raise an exception, we test for that + with pytest.raises(AttributeError): + convert_genai_part_to_a2a_part(None) + + def test_data_part_with_complex_data(self): + """Test conversion of DataPart with complex nested data.""" + # Arrange + complex_data = { + "nested": { + "array": [1, 2, {"inner": "value"}], + "boolean": True, + "null_value": None, + }, + "unicode": "Hello 世界 🌍", + } + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=complex_data)) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == json.dumps(complex_data) + + def test_data_part_with_empty_metadata(self): + """Test conversion of DataPart with empty metadata dict.""" + # Arrange + data = {"key": "value"} + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data, metadata={})) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == json.dumps(data) From 694b71256c631d44bb4c4488279ea91d82f43e26 Mon Sep 17 00:00:00 2001 From: SimonWei <119845914+simonwei97@users.noreply.github.com> Date: Tue, 17 Jun 2025 23:48:00 +0800 Subject: [PATCH 20/79] fix: agent generate config error (#1450) --- src/google/adk/models/lite_llm.py | 60 +++++++++++++++++++++----- tests/unittests/models/test_litellm.py | 33 +++++++++++++- 2 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index dce5ed7c4..e34299f6f 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,6 +23,7 @@ from typing import Dict from typing import Generator from typing import Iterable +from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -481,16 +482,22 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> tuple[Iterable[Message], Iterable[dict]]: - """Converts an LlmRequest to litellm inputs. +) -> Tuple[ + List[Message], + Optional[List[dict]], + Optional[types.SchemaUnion], + Optional[Dict], +]: + """Converts an LlmRequest to litellm inputs and extracts generation params. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary and response format). + The litellm inputs (message list, tool dictionary, response format and generation params). """ - messages = [] + # 1. Construct messages + messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -507,7 +514,8 @@ def _get_completion_inputs( ), ) - tools = None + # 2. Convert tool declarations + tools: Optional[List[Dict]] = None if ( llm_request.config and llm_request.config.tools @@ -518,12 +526,39 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - response_format = None - - if llm_request.config.response_schema: + # 3. Handle response format + response_format: Optional[types.SchemaUnion] = None + if llm_request.config and llm_request.config.response_schema: response_format = llm_request.config.response_schema - return messages, tools, response_format + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: + config_dict = llm_request.config.model_dump(exclude_none=True) + # Generate LiteLlm parameters here, + # Following https://docs.litellm.ai/docs/completion/input. + generation_params = {} + param_mapping = { + "max_output_tokens": "max_completion_tokens", + "stop_sequences": "stop", + } + for key in ( + "temperature", + "max_output_tokens", + "top_p", + "top_k", + "stop_sequences", + "presence_penalty", + "frequency_penalty", + ): + if key in config_dict: + mapped_key = param_mapping.get(key, key) + generation_params[mapped_key] = config_dict[key] + + if not generation_params: + generation_params = None + + return messages, tools, response_format, generation_params def _build_function_declaration_log( @@ -660,7 +695,9 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format = _get_completion_inputs(llm_request) + messages, tools, response_format, generation_params = ( + _get_completion_inputs(llm_request) + ) completion_args = { "model": self.model, @@ -670,6 +707,9 @@ async def generate_content_async( } completion_args.update(self._additional_args) + if generation_params: + completion_args.update(generation_params) + if stream: text = "" # Track function calls by index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 8b43cc48b..0125872fd 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,7 +13,6 @@ # limitations under the License. -import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -1430,3 +1429,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +def test_get_completion_inputs_generation_params(): + # Test that generation_params are extracted and mapped correctly + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]), + ], + config=types.GenerateContentConfig( + temperature=0.33, + max_output_tokens=123, + top_p=0.88, + top_k=7, + stop_sequences=["foo", "bar"], + presence_penalty=0.1, + frequency_penalty=0.2, + ), + ) + from google.adk.models.lite_llm import _get_completion_inputs + + _, _, _, generation_params = _get_completion_inputs(req) + assert generation_params["temperature"] == 0.33 + assert generation_params["max_completion_tokens"] == 123 + assert generation_params["top_p"] == 0.88 + assert generation_params["top_k"] == 7 + assert generation_params["stop"] == ["foo", "bar"] + assert generation_params["presence_penalty"] == 0.1 + assert generation_params["frequency_penalty"] == 0.2 + # Should not include max_output_tokens + assert "max_output_tokens" not in generation_params + assert "stop_sequences" not in generation_params From 1ae176ad2fa2b691714ac979aec21f1cf7d35e45 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Jun 2025 10:30:58 -0700 Subject: [PATCH 21/79] fix: update conversion between Celsius and Fahrenheit #non-breaking The correct conversion from 25 degrees Celsius is 77 degrees Fahrenheit. The previous value of 41 was wrong. PiperOrigin-RevId: 772528757 --- contributing/samples/quickstart/agent.py | 2 +- src/google/adk/models/lite_llm.py | 60 ++++-------------------- tests/unittests/models/test_litellm.py | 33 +------------ 3 files changed, 12 insertions(+), 83 deletions(-) diff --git a/contributing/samples/quickstart/agent.py b/contributing/samples/quickstart/agent.py index fdd6b7f9d..b251069ad 100644 --- a/contributing/samples/quickstart/agent.py +++ b/contributing/samples/quickstart/agent.py @@ -29,7 +29,7 @@ def get_weather(city: str) -> dict: "status": "success", "report": ( "The weather in New York is sunny with a temperature of 25 degrees" - " Celsius (41 degrees Fahrenheit)." + " Celsius (77 degrees Fahrenheit)." ), } else: diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index e34299f6f..dce5ed7c4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,7 +23,6 @@ from typing import Dict from typing import Generator from typing import Iterable -from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -482,22 +481,16 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> Tuple[ - List[Message], - Optional[List[dict]], - Optional[types.SchemaUnion], - Optional[Dict], -]: - """Converts an LlmRequest to litellm inputs and extracts generation params. +) -> tuple[Iterable[Message], Iterable[dict]]: + """Converts an LlmRequest to litellm inputs. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary, response format and generation params). + The litellm inputs (message list, tool dictionary and response format). """ - # 1. Construct messages - messages: List[Message] = [] + messages = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -514,8 +507,7 @@ def _get_completion_inputs( ), ) - # 2. Convert tool declarations - tools: Optional[List[Dict]] = None + tools = None if ( llm_request.config and llm_request.config.tools @@ -526,39 +518,12 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - # 3. Handle response format - response_format: Optional[types.SchemaUnion] = None - if llm_request.config and llm_request.config.response_schema: - response_format = llm_request.config.response_schema - - # 4. Extract generation parameters - generation_params: Optional[Dict] = None - if llm_request.config: - config_dict = llm_request.config.model_dump(exclude_none=True) - # Generate LiteLlm parameters here, - # Following https://docs.litellm.ai/docs/completion/input. - generation_params = {} - param_mapping = { - "max_output_tokens": "max_completion_tokens", - "stop_sequences": "stop", - } - for key in ( - "temperature", - "max_output_tokens", - "top_p", - "top_k", - "stop_sequences", - "presence_penalty", - "frequency_penalty", - ): - if key in config_dict: - mapped_key = param_mapping.get(key, key) - generation_params[mapped_key] = config_dict[key] + response_format = None - if not generation_params: - generation_params = None + if llm_request.config.response_schema: + response_format = llm_request.config.response_schema - return messages, tools, response_format, generation_params + return messages, tools, response_format def _build_function_declaration_log( @@ -695,9 +660,7 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format, generation_params = ( - _get_completion_inputs(llm_request) - ) + messages, tools, response_format = _get_completion_inputs(llm_request) completion_args = { "model": self.model, @@ -707,9 +670,6 @@ async def generate_content_async( } completion_args.update(self._additional_args) - if generation_params: - completion_args.update(generation_params) - if stream: text = "" # Track function calls by index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 0125872fd..8b43cc48b 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,6 +13,7 @@ # limitations under the License. +import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -1429,35 +1430,3 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} - - -@pytest.mark.asyncio -def test_get_completion_inputs_generation_params(): - # Test that generation_params are extracted and mapped correctly - req = LlmRequest( - contents=[ - types.Content(role="user", parts=[types.Part.from_text(text="hi")]), - ], - config=types.GenerateContentConfig( - temperature=0.33, - max_output_tokens=123, - top_p=0.88, - top_k=7, - stop_sequences=["foo", "bar"], - presence_penalty=0.1, - frequency_penalty=0.2, - ), - ) - from google.adk.models.lite_llm import _get_completion_inputs - - _, _, _, generation_params = _get_completion_inputs(req) - assert generation_params["temperature"] == 0.33 - assert generation_params["max_completion_tokens"] == 123 - assert generation_params["top_p"] == 0.88 - assert generation_params["top_k"] == 7 - assert generation_params["stop"] == ["foo", "bar"] - assert generation_params["presence_penalty"] == 0.1 - assert generation_params["frequency_penalty"] == 0.2 - # Should not include max_output_tokens - assert "max_output_tokens" not in generation_params - assert "stop_sequences" not in generation_params From c04adaade118be242fcf110e24e96253ac6550ab Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 10:45:10 -0700 Subject: [PATCH 22/79] chore: Add in memory credential service (Experimental) PiperOrigin-RevId: 772534962 --- .../in_memory_credential_service.py | 64 ++++ .../test_in_memory_credential_service.py | 323 ++++++++++++++++++ 2 files changed, 387 insertions(+) create mode 100644 src/google/adk/auth/credential_service/in_memory_credential_service.py create mode 100644 tests/unittests/auth/credential_service/test_in_memory_credential_service.py diff --git a/src/google/adk/auth/credential_service/in_memory_credential_service.py b/src/google/adk/auth/credential_service/in_memory_credential_service.py new file mode 100644 index 000000000..f6f51b35a --- /dev/null +++ b/src/google/adk/auth/credential_service/in_memory_credential_service.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from typing_extensions import override + +from ...tools.tool_context import ToolContext +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_tool import AuthConfig +from .base_credential_service import BaseCredentialService + + +@experimental +class InMemoryCredentialService(BaseCredentialService): + """Class for in memory implementation of credential service(Experimental)""" + + def __init__(self): + super().__init__() + self._credentials = {} + + @override + async def load_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> Optional[AuthCredential]: + credential_bucket = self._get_bucket_for_current_context(tool_context) + return credential_bucket.get(auth_config.credential_key) + + @override + async def save_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> None: + credential_bucket = self._get_bucket_for_current_context(tool_context) + credential_bucket[auth_config.credential_key] = ( + auth_config.exchanged_auth_credential + ) + + def _get_bucket_for_current_context(self, tool_context: ToolContext) -> str: + app_name = tool_context._invocation_context.app_name + user_id = tool_context._invocation_context.user_id + + if app_name not in self._credentials: + self._credentials[app_name] = {} + if user_id not in self._credentials[app_name]: + self._credentials[app_name][user_id] = {} + return self._credentials[app_name][user_id] diff --git a/tests/unittests/auth/credential_service/test_in_memory_credential_service.py b/tests/unittests/auth/credential_service/test_in_memory_credential_service.py new file mode 100644 index 000000000..9312f72a3 --- /dev/null +++ b/tests/unittests/auth/credential_service/test_in_memory_credential_service.py @@ -0,0 +1,323 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from google.adk.tools.tool_context import ToolContext +import pytest + + +class TestInMemoryCredentialService: + """Tests for the InMemoryCredentialService class.""" + + @pytest.fixture + def credential_service(self): + """Create an InMemoryCredentialService instance for testing.""" + return InMemoryCredentialService() + + @pytest.fixture + def oauth2_auth_scheme(self): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + @pytest.fixture + def oauth2_credentials(self): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + @pytest.fixture + def auth_config(self, oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + exchanged_credential = oauth2_credentials.model_copy(deep=True) + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged_credential, + ) + + @pytest.fixture + def tool_context(self): + """Create a mock ToolContext for testing.""" + mock_context = Mock(spec=ToolContext) + mock_invocation_context = Mock() + mock_invocation_context.app_name = "test_app" + mock_invocation_context.user_id = "test_user" + mock_context._invocation_context = mock_invocation_context + return mock_context + + @pytest.fixture + def another_tool_context(self): + """Create another mock ToolContext with different app/user for testing isolation.""" + mock_context = Mock(spec=ToolContext) + mock_invocation_context = Mock() + mock_invocation_context.app_name = "another_app" + mock_invocation_context.user_id = "another_user" + mock_context._invocation_context = mock_invocation_context + return mock_context + + def test_init(self, credential_service): + """Test that the service initializes with an empty store.""" + assert isinstance(credential_service._credentials, dict) + assert len(credential_service._credentials) == 0 + + @pytest.mark.asyncio + async def test_load_credential_not_found( + self, credential_service, auth_config, tool_context + ): + """Test loading a credential that doesn't exist returns None.""" + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_save_and_load_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving and then loading a credential.""" + # Save the credential + await credential_service.save_credential(auth_config, tool_context) + + # Load the credential + result = await credential_service.load_credential(auth_config, tool_context) + + # Verify the credential was saved and loaded correctly + assert result is not None + assert result == auth_config.exchanged_auth_credential + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.client_id == "mock_client_id" + + @pytest.mark.asyncio + async def test_save_credential_updates_existing( + self, credential_service, auth_config, tool_context, oauth2_credentials + ): + """Test that saving a credential updates an existing one.""" + # Save initial credential + await credential_service.save_credential(auth_config, tool_context) + + # Create a new credential and update the auth_config + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="updated_client_id", + client_secret="updated_client_secret", + redirect_uri="https://updated.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + + # Save the updated credential + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify the credential was updated + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "updated_client_id" + assert result.oauth2.client_secret == "updated_client_secret" + + @pytest.mark.asyncio + async def test_credentials_isolated_by_context( + self, credential_service, auth_config, tool_context, another_tool_context + ): + """Test that credentials are isolated between different app/user contexts.""" + # Save credential in first context + await credential_service.save_credential(auth_config, tool_context) + + # Try to load from another context + result = await credential_service.load_credential( + auth_config, another_tool_context + ) + assert result is None + + # Verify original context still has the credential + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + + @pytest.mark.asyncio + async def test_multiple_credentials_same_context( + self, credential_service, tool_context, oauth2_auth_scheme + ): + """Test storing multiple credentials in the same context with different keys.""" + # Create two different auth configs with different credential keys + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client1", + client_secret="secret1", + redirect_uri="https://example1.com/callback", + ), + ) + + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client2", + client_secret="secret2", + redirect_uri="https://example2.com/callback", + ), + ) + + auth_config1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred1, + exchanged_auth_credential=cred1, + credential_key="key1", + ) + + auth_config2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred2, + exchanged_auth_credential=cred2, + credential_key="key2", + ) + + # Save both credentials + await credential_service.save_credential(auth_config1, tool_context) + await credential_service.save_credential(auth_config2, tool_context) + + # Load and verify both credentials + result1 = await credential_service.load_credential( + auth_config1, tool_context + ) + result2 = await credential_service.load_credential( + auth_config2, tool_context + ) + + assert result1 is not None + assert result2 is not None + assert result1.oauth2.client_id == "client1" + assert result2.oauth2.client_id == "client2" + + def test_get_bucket_for_current_context_creates_nested_structure( + self, credential_service, tool_context + ): + """Test that _get_bucket_for_current_context creates the proper nested structure.""" + storage = credential_service._get_bucket_for_current_context(tool_context) + + # Verify the nested structure was created + assert "test_app" in credential_service._credentials + assert "test_user" in credential_service._credentials["test_app"] + assert isinstance(storage, dict) + assert storage is credential_service._credentials["test_app"]["test_user"] + + def test_get_bucket_for_current_context_reuses_existing( + self, credential_service, tool_context + ): + """Test that _get_bucket_for_current_context reuses existing structure.""" + # Create initial structure + storage1 = credential_service._get_bucket_for_current_context(tool_context) + storage1["test_key"] = "test_value" + + # Get storage again + storage2 = credential_service._get_bucket_for_current_context(tool_context) + + # Verify it's the same storage instance + assert storage1 is storage2 + assert storage2["test_key"] == "test_value" + + def test_get_storage_different_apps( + self, credential_service, tool_context, another_tool_context + ): + """Test that different apps get different storage instances.""" + storage1 = credential_service._get_bucket_for_current_context(tool_context) + storage2 = credential_service._get_bucket_for_current_context( + another_tool_context + ) + + # Verify they are different storage instances + assert storage1 is not storage2 + + # Verify the structure + assert "test_app" in credential_service._credentials + assert "another_app" in credential_service._credentials + assert "test_user" in credential_service._credentials["test_app"] + assert "another_user" in credential_service._credentials["another_app"] + + @pytest.mark.asyncio + async def test_same_user_different_apps( + self, credential_service, auth_config + ): + """Test that the same user in different apps get isolated storage.""" + # Create two contexts with same user but different apps + context1 = Mock(spec=ToolContext) + mock_invocation_context1 = Mock() + mock_invocation_context1.app_name = "app1" + mock_invocation_context1.user_id = "same_user" + context1._invocation_context = mock_invocation_context1 + + context2 = Mock(spec=ToolContext) + mock_invocation_context2 = Mock() + mock_invocation_context2.app_name = "app2" + mock_invocation_context2.user_id = "same_user" + context2._invocation_context = mock_invocation_context2 + + # Save credential in app1 + await credential_service.save_credential(auth_config, context1) + + # Try to load from app2 (should not find it) + result = await credential_service.load_credential(auth_config, context2) + assert result is None + + # Verify app1 still has the credential + result = await credential_service.load_credential(auth_config, context1) + assert result is not None + + @pytest.mark.asyncio + async def test_same_app_different_users( + self, credential_service, auth_config + ): + """Test that different users in the same app get isolated storage.""" + # Create two contexts with same app but different users + context1 = Mock(spec=ToolContext) + mock_invocation_context1 = Mock() + mock_invocation_context1.app_name = "same_app" + mock_invocation_context1.user_id = "user1" + context1._invocation_context = mock_invocation_context1 + + context2 = Mock(spec=ToolContext) + mock_invocation_context2 = Mock() + mock_invocation_context2.app_name = "same_app" + mock_invocation_context2.user_id = "user2" + context2._invocation_context = mock_invocation_context2 + + # Save credential for user1 + await credential_service.save_credential(auth_config, context1) + + # Try to load for user2 (should not find it) + result = await credential_service.load_credential(auth_config, context2) + assert result is None + + # Verify user1 still has the credential + result = await credential_service.load_credential(auth_config, context1) + assert result is not None From 6d174eba305a51fcf2122c0fd481378752d690ef Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Jun 2025 18:03:45 -0700 Subject: [PATCH 23/79] fix: Set explicit project in the BigQuery client This change sets an explicit project id in the BigQuery client from the conversation context. Without this the client was trying to set a project from the environment's application default credentials and running into issues where application default credentials is not available. PiperOrigin-RevId: 772695883 --- src/google/adk/tools/bigquery/client.py | 6 +- .../adk/tools/bigquery/metadata_tool.py | 16 ++- src/google/adk/tools/bigquery/query_tool.py | 4 +- .../tools/bigquery/test_bigquery_client.py | 125 ++++++++++++++++++ .../bigquery/test_bigquery_metadata_tool.py | 122 +++++++++++++++++ .../bigquery/test_bigquery_query_tool.py | 58 ++++++-- .../tools/bigquery/test_bigquery_toolset.py | 4 +- 7 files changed, 315 insertions(+), 20 deletions(-) create mode 100644 tests/unittests/tools/bigquery/test_bigquery_client.py create mode 100644 tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index d72761b2d..ea1bebc7a 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -21,13 +21,15 @@ USER_AGENT = "adk-bigquery-tool" -def get_bigquery_client(*, credentials: Credentials) -> bigquery.Client: +def get_bigquery_client( + *, project: str, credentials: Credentials +) -> bigquery.Client: """Get a BigQuery client.""" client_info = google.api_core.client_info.ClientInfo(user_agent=USER_AGENT) bigquery_client = bigquery.Client( - credentials=credentials, client_info=client_info + project=project, credentials=credentials, client_info=client_info ) return bigquery_client diff --git a/src/google/adk/tools/bigquery/metadata_tool.py b/src/google/adk/tools/bigquery/metadata_tool.py index 6e279d59e..4f5400611 100644 --- a/src/google/adk/tools/bigquery/metadata_tool.py +++ b/src/google/adk/tools/bigquery/metadata_tool.py @@ -42,7 +42,9 @@ def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]: 'bbc_news'] """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) datasets = [] for dataset in bq_client.list_datasets(project_id): @@ -106,7 +108,9 @@ def get_dataset_info( } """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) dataset = bq_client.get_dataset( bigquery.DatasetReference(project_id, dataset_id) ) @@ -137,7 +141,9 @@ def list_table_ids( 'local_data_for_better_health_county_data'] """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) tables = [] for table in bq_client.list_tables( @@ -251,7 +257,9 @@ def get_table_info( } """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) return bq_client.get_table( bigquery.TableReference( bigquery.DatasetReference(project_id, dataset_id), table_id diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 80b56aad3..d3a94fda7 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -72,7 +72,9 @@ def execute_sql( """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) if not config or config.write_mode == WriteMode.BLOCKED: query_job = bq_client.query( query, diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py new file mode 100644 index 000000000..612dddd6e --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -0,0 +1,125 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock + +from google.adk.tools.bigquery.client import get_bigquery_client +from google.auth.exceptions import DefaultCredentialsError +from google.oauth2.credentials import Credentials +import pytest + + +def test_bigquery_client_project(): + """Test BigQuery client project.""" + # Trigger the BigQuery client creation + client = get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the client has the desired project set + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_explicit(): + """Test BigQuery client creation does not invoke default auth.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_with_default_auth(): + """Test BigQuery client creation invokes default auth to set the project.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate credentials + mock_creds = mock.create_autospec(Credentials, instance=True) + + # Simulate output of the default auth + mock_default_auth.return_value = (mock_creds, "test-gcp-project") + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project=None, + credentials=mock_creds, + ) + + # Verify that default auth was called once to set the client project + mock_default_auth.assert_called_once() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_with_env(): + """Test BigQuery client creation sets the project from environment variable.""" + # Let's simulate the project set in environment variables + with mock.patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True + ): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project=None, + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_user_agent(): + """Test BigQuery client user agent.""" + with mock.patch( + "google.cloud.bigquery.client.Connection", autospec=True + ) as mock_connection: + # Trigger the BigQuery client creation + get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the tracking user agent was set + client_info_arg = mock_connection.call_args[1].get("client_info") + assert client_info_arg is not None + assert client_info_arg.user_agent == "adk-bigquery-tool" diff --git a/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py new file mode 100644 index 000000000..14ecea558 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py @@ -0,0 +1,122 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock + +from google.adk.tools.bigquery import metadata_tool +from google.auth.exceptions import DefaultCredentialsError +from google.cloud import bigquery +from google.oauth2.credentials import Credentials +import pytest + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.list_datasets", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_list_dataset_ids(mock_default_auth, mock_list_datasets): + """Test list_dataset_ids tool invocation.""" + project = "my_project_id" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_list_datasets.return_value = [ + bigquery.DatasetReference(project, "dataset1"), + bigquery.DatasetReference(project, "dataset2"), + ] + result = metadata_tool.list_dataset_ids(project, mock_credentials) + assert result == ["dataset1", "dataset2"] + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.get_dataset", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_get_dataset_info(mock_default_auth, mock_get_dataset): + """Test get_dataset_info tool invocation.""" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_get_dataset.return_value = mock.create_autospec( + Credentials, instance=True + ) + result = metadata_tool.get_dataset_info( + "my_project_id", "my_dataset_id", mock_credentials + ) + assert result != { + "status": "ERROR", + "error_details": "Your default credentials were not found", + } + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.list_tables", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_list_table_ids(mock_default_auth, mock_list_tables): + """Test list_table_ids tool invocation.""" + project = "my_project_id" + dataset = "my_dataset_id" + dataset_ref = bigquery.DatasetReference(project, dataset) + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_list_tables.return_value = [ + bigquery.TableReference(dataset_ref, "table1"), + bigquery.TableReference(dataset_ref, "table2"), + ] + result = metadata_tool.list_table_ids(project, dataset, mock_credentials) + assert result == ["table1", "table2"] + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.get_table", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_get_table_info(mock_default_auth, mock_get_table): + """Test get_table_info tool invocation.""" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_get_table.return_value = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_table_info( + "my_project_id", "my_dataset_id", "my_table_id", mock_credentials + ) + assert result != { + "status": "ERROR", + "error_details": "Your default credentials were not found", + } + mock_default_auth.assert_not_called() diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index 35d44ef81..3cb8c3c4a 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import os import textwrap from typing import Optional from unittest import mock @@ -24,6 +25,7 @@ from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode from google.adk.tools.bigquery.query_tool import execute_sql +from google.auth.exceptions import DefaultCredentialsError from google.cloud import bigquery from google.oauth2.credentials import Credentials import pytest @@ -227,14 +229,8 @@ async def test_execute_sql_declaration_write(tool_config): @pytest.mark.parametrize( ("write_mode",), [ - pytest.param( - WriteMode.BLOCKED, - id="blocked", - ), - pytest.param( - WriteMode.ALLOWED, - id="allowed", - ), + pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.ALLOWED, id="allowed"), ], ) def test_execute_sql_select_stmt(write_mode): @@ -279,7 +275,7 @@ def test_execute_sql_select_stmt(write_mode): ], ) def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): - """Test execute_sql tool for SELECT query when writes are blocked.""" + """Test execute_sql tool for non-SELECT query when writes are blocked.""" project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) @@ -318,7 +314,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): ], ) def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): - """Test execute_sql tool for SELECT query when writes are blocked.""" + """Test execute_sql tool for non-SELECT query when writes are blocked.""" project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) @@ -342,3 +338,45 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): "status": "ERROR", "error_details": "Read-only mode only supports SELECT statements.", } + + +@pytest.mark.parametrize( + ("write_mode",), + [ + pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.ALLOWED, id="allowed"), + ], +) +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True) +@mock.patch("google.cloud.bigquery.Client.query", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_execute_sql_no_default_auth( + mock_default_auth, mock_query, mock_query_and_wait, write_mode +): + """Test execute_sql tool invocation does not involve calling default auth.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + query_result = [{"num": 123}] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=write_mode) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + mock_query.return_value = query_job + + # Simulate the result of query_and_wait API + mock_query_and_wait.return_value = query_result + + # Test the tool worked without invoking default auth + result = execute_sql(project, query, credentials, tool_config) + assert result == {"status": "SUCCESS", "rows": query_result} + mock_default_auth.assert_not_called() diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index ea9990b9f..4129dc512 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -96,9 +96,7 @@ async def test_bigquery_toolset_tools_selective(selected_tools): ], ) @pytest.mark.asyncio -async def test_bigquery_toolset_unknown_tool_raises( - selected_tools, returned_tools -): +async def test_bigquery_toolset_unknown_tool(selected_tools, returned_tools): """Test BigQuery toolset with filter. This test verifies the behavior of the BigQuery toolset when filter is From 5f89a469ec6a9bad5ab8625e71d6b4d54046e2cd Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 18:08:05 -0700 Subject: [PATCH 24/79] chore: Add credential service to runner and invocation context PiperOrigin-RevId: 772697298 --- src/google/adk/agents/invocation_context.py | 2 ++ .../credential_service/base_credential_service.py | 4 ++-- src/google/adk/cli/cli.py | 10 ++++++++++ src/google/adk/cli/fast_api.py | 5 +++++ src/google/adk/runners.py | 7 ++++++- tests/unittests/cli/utils/test_cli.py | 11 ++++++++--- 6 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index f70371535..765f22a2c 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -22,6 +22,7 @@ from pydantic import ConfigDict from ..artifacts.base_artifact_service import BaseArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService from ..memory.base_memory_service import BaseMemoryService from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session @@ -115,6 +116,7 @@ class InvocationContext(BaseModel): artifact_service: Optional[BaseArtifactService] = None session_service: BaseSessionService memory_service: Optional[BaseMemoryService] = None + credential_service: Optional[BaseCredentialService] = None invocation_id: str """The id of this invocation context. Readonly.""" diff --git a/src/google/adk/auth/credential_service/base_credential_service.py b/src/google/adk/auth/credential_service/base_credential_service.py index 7416ccc65..fc6cd500d 100644 --- a/src/google/adk/auth/credential_service/base_credential_service.py +++ b/src/google/adk/auth/credential_service/base_credential_service.py @@ -19,12 +19,12 @@ from typing import Optional from ...tools.tool_context import ToolContext -from ...utils.feature_decorator import working_in_progress +from ...utils.feature_decorator import experimental from ..auth_credential import AuthCredential from ..auth_tool import AuthConfig -@working_in_progress("Implementation are in progress. Don't use it for now.") +@experimental class BaseCredentialService(ABC): """Abstract class for Service that loads / saves tool credentials from / to the backend credential store.""" diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index aceb3fcce..79d0bfe65 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -24,6 +24,8 @@ from ..agents.llm_agent import LlmAgent from ..artifacts import BaseArtifactService from ..artifacts import InMemoryArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService +from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService @@ -43,6 +45,7 @@ async def run_input_file( root_agent: LlmAgent, artifact_service: BaseArtifactService, session_service: BaseSessionService, + credential_service: BaseCredentialService, input_path: str, ) -> Session: runner = Runner( @@ -50,6 +53,7 @@ async def run_input_file( agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, ) with open(input_path, 'r', encoding='utf-8') as f: input_file = InputFile.model_validate_json(f.read()) @@ -75,12 +79,14 @@ async def run_interactively( artifact_service: BaseArtifactService, session: Session, session_service: BaseSessionService, + credential_service: BaseCredentialService, ) -> None: runner = Runner( app_name=session.app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, ) while True: query = input('[user]: ') @@ -125,6 +131,7 @@ async def run_cli( artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() + credential_service = InMemoryCredentialService() user_id = 'test_user' session = await session_service.create_session( @@ -141,6 +148,7 @@ async def run_cli( root_agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, input_path=input_file, ) elif saved_session_file: @@ -163,6 +171,7 @@ async def run_cli( artifact_service, session, session_service, + credential_service, ) else: click.echo(f'Running agent {root_agent.name}, type exit to exit.') @@ -171,6 +180,7 @@ async def run_cli( artifact_service, session, session_service, + credential_service, ) if save_session: diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4512174c5..46e008655 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -57,6 +57,7 @@ from ..agents.run_config import StreamingMode from ..artifacts.gcs_artifact_service import GcsArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService +from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..errors.not_found_error import NotFoundError from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import SessionInput @@ -305,6 +306,9 @@ async def internal_lifespan(app: FastAPI): else: artifact_service = InMemoryArtifactService() + # Build the Credential service + credential_service = InMemoryCredentialService() + # initialize Agent Loader agent_loader = AgentLoader(agents_dir) @@ -929,6 +933,7 @@ async def _get_runner_async(app_name: str) -> Runner: artifact_service=artifact_service, session_service=session_service, memory_service=memory_service, + credential_service=credential_service, ) runner_dict[app_name] = runner return runner diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index c4fcdfb9e..01412a2b3 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -17,7 +17,6 @@ import asyncio import logging import queue -import threading from typing import AsyncGenerator from typing import Generator from typing import Optional @@ -34,6 +33,7 @@ from .agents.run_config import RunConfig from .artifacts.base_artifact_service import BaseArtifactService from .artifacts.in_memory_artifact_service import InMemoryArtifactService +from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .events.event import Event from .memory.base_memory_service import BaseMemoryService @@ -73,6 +73,8 @@ class Runner: """The session service for the runner.""" memory_service: Optional[BaseMemoryService] = None """The memory service for the runner.""" + credential_service: Optional[BaseCredentialService] = None + """The credential service for the runner.""" def __init__( self, @@ -82,6 +84,7 @@ def __init__( artifact_service: Optional[BaseArtifactService] = None, session_service: BaseSessionService, memory_service: Optional[BaseMemoryService] = None, + credential_service: Optional[BaseCredentialService] = None, ): """Initializes the Runner. @@ -97,6 +100,7 @@ def __init__( self.artifact_service = artifact_service self.session_service = session_service self.memory_service = memory_service + self.credential_service = credential_service def run( self, @@ -418,6 +422,7 @@ def _new_invocation_context( artifact_service=self.artifact_service, session_service=self.session_service, memory_service=self.memory_service, + credential_service=self.credential_service, invocation_id=invocation_id, agent=self.agent, session=session, diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 1721885f3..2139a8c20 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -129,6 +129,7 @@ def _echo(msg: str) -> None: artifact_service = cli.InMemoryArtifactService() session_service = cli.InMemorySessionService() + credential_service = cli.InMemoryCredentialService() dummy_root = types.SimpleNamespace(name="root") session = await cli.run_input_file( @@ -137,6 +138,7 @@ def _echo(msg: str) -> None: root_agent=dummy_root, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, input_path=str(input_path), ) @@ -199,9 +201,10 @@ async def test_run_interactively_whitespace_and_exit( ) -> None: """run_interactively should skip blank input, echo once, then exit.""" # make a session that belongs to dummy agent - svc = cli.InMemorySessionService() - sess = await svc.create_session(app_name="dummy", user_id="u") + session_service = cli.InMemorySessionService() + sess = await session_service.create_session(app_name="dummy", user_id="u") artifact_service = cli.InMemoryArtifactService() + credential_service = cli.InMemoryCredentialService() root_agent = types.SimpleNamespace(name="root") # fake user input: blank -> 'hello' -> 'exit' @@ -212,7 +215,9 @@ async def test_run_interactively_whitespace_and_exit( echoed: list[str] = [] monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg)) - await cli.run_interactively(root_agent, artifact_service, sess, svc) + await cli.run_interactively( + root_agent, artifact_service, sess, session_service, credential_service + ) # verify: assistant echoed once with 'echo:hello' assert any("echo:hello" in m for m in echoed) From f9fa7841df81bcfc38d11a3d059c3d02f8ec3794 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Jun 2025 18:30:51 -0700 Subject: [PATCH 25/79] chore: add google-adk/{version} to bigquery user agent PiperOrigin-RevId: 772703504 --- src/google/adk/tools/bigquery/client.py | 4 +++- tests/unittests/tools/bigquery/test_bigquery_client.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index ea1bebc7a..23f1befc5 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -18,7 +18,9 @@ from google.cloud import bigquery from google.oauth2.credentials import Credentials -USER_AGENT = "adk-bigquery-tool" +from ... import version + +USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" def get_bigquery_client( diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py index 612dddd6e..e8b373416 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_client.py +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -14,6 +14,7 @@ from __future__ import annotations import os +import re from unittest import mock from google.adk.tools.bigquery.client import get_bigquery_client @@ -122,4 +123,7 @@ def test_bigquery_client_user_agent(): # Verify that the tracking user agent was set client_info_arg = mock_connection.call_args[1].get("client_info") assert client_info_arg is not None - assert client_info_arg.user_agent == "adk-bigquery-tool" + assert re.search( + r"adk-bigquery-tool google-adk/([0-9A-Za-z._\-+/]+)", + client_info_arg.user_agent, + ) From 0a9625317a7a511cae39fd566625e98dfab24486 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 18:55:11 -0700 Subject: [PATCH 26/79] refactor: Adapt service account credential exchanger to base credential exchanger interface PiperOrigin-RevId: 772710438 --- src/google/adk/auth/exchanger/__init__.py | 2 - .../service_account_credential_exchanger.py | 70 +++--- tests/unittests/auth/exchanger/__init__.py | 15 ++ ...st_service_account_credential_exchanger.py | 202 +++++++++++++----- 4 files changed, 203 insertions(+), 86 deletions(-) rename src/google/adk/auth/{ => exchanger}/service_account_credential_exchanger.py (57%) create mode 100644 tests/unittests/auth/exchanger/__init__.py rename tests/unittests/auth/{ => exchanger}/test_service_account_credential_exchanger.py (61%) diff --git a/src/google/adk/auth/exchanger/__init__.py b/src/google/adk/auth/exchanger/__init__.py index ce5c464c4..4226ae715 100644 --- a/src/google/adk/auth/exchanger/__init__.py +++ b/src/google/adk/auth/exchanger/__init__.py @@ -15,11 +15,9 @@ """Credential exchanger module.""" from .base_credential_exchanger import BaseCredentialExchanger -from .credential_exchanger_registry import CredentialExchangerRegistry from .service_account_credential_exchanger import ServiceAccountCredentialExchanger __all__ = [ "BaseCredentialExchanger", - "CredentialExchangerRegistry", "ServiceAccountCredentialExchanger", ] diff --git a/src/google/adk/auth/service_account_credential_exchanger.py b/src/google/adk/auth/exchanger/service_account_credential_exchanger.py similarity index 57% rename from src/google/adk/auth/service_account_credential_exchanger.py rename to src/google/adk/auth/exchanger/service_account_credential_exchanger.py index 644501ee6..415081ca5 100644 --- a/src/google/adk/auth/service_account_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/service_account_credential_exchanger.py @@ -16,19 +16,22 @@ from __future__ import annotations +from typing import Optional + import google.auth from google.auth.transport.requests import Request from google.oauth2 import service_account +from typing_extensions import override -from ..utils.feature_decorator import experimental -from .auth_credential import AuthCredential -from .auth_credential import AuthCredentialTypes -from .auth_credential import HttpAuth -from .auth_credential import HttpCredentials +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_credential import AuthCredentialTypes +from ..auth_schemes import AuthScheme +from .base_credential_exchanger import BaseCredentialExchanger @experimental -class ServiceAccountCredentialExchanger: +class ServiceAccountCredentialExchanger(BaseCredentialExchanger): """Exchanges Google Service Account credentials for an access token. Uses the default service credential if `use_default_credential = True`. @@ -36,44 +39,56 @@ class ServiceAccountCredentialExchanger: credential. """ - def __init__(self, credential: AuthCredential): - if credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: - raise ValueError("Credential is not a service account credential.") - self._credential = credential - - def exchange(self) -> AuthCredential: + @override + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: """Exchanges the service account auth credential for an access token. If the AuthCredential contains a service account credential, it will be used to exchange for an access token. Otherwise, if use_default_credential is True, the default application credential will be used for exchanging an access token. + Args: + auth_scheme: The authentication scheme. + auth_credential: The credential to exchange. + Returns: - An AuthCredential in HTTP Bearer format, containing the access token. + An AuthCredential in OAUTH2 format, containing the exchanged credential JSON. Raises: ValueError: If service account credentials are missing or invalid. Exception: If credential exchange or refresh fails. """ + if auth_credential is None: + raise ValueError("Credential cannot be None.") + + if auth_credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: + raise ValueError("Credential is not a service account credential.") + + if auth_credential.service_account is None: + raise ValueError( + "Service account credentials are missing. Please provide them." + ) + if ( - self._credential is None - or self._credential.service_account is None - or ( - self._credential.service_account.service_account_credential is None - and not self._credential.service_account.use_default_credential - ) + auth_credential.service_account.service_account_credential is None + and not auth_credential.service_account.use_default_credential ): raise ValueError( - "Service account credentials are missing. Please provide them, or set" - " `use_default_credential = True` to use application default" - " credential in a hosted service like Google Cloud Run." + "Service account credentials are invalid. Please set the" + " service_account_credential field or set `use_default_credential =" + " True` to use application default credential in a hosted service" + " like Google Cloud Run." ) try: - if self._credential.service_account.use_default_credential: + if auth_credential.service_account.use_default_credential: credentials, _ = google.auth.default() else: - config = self._credential.service_account + config = auth_credential.service_account credentials = service_account.Credentials.from_service_account_info( config.service_account_credential.model_dump(), scopes=config.scopes ) @@ -82,11 +97,8 @@ def exchange(self) -> AuthCredential: credentials.refresh(Request()) return AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", - credentials=HttpCredentials(token=credentials.token), - ), + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json=credentials.to_json(), ) except Exception as e: raise ValueError(f"Failed to exchange service account token: {e}") from e diff --git a/tests/unittests/auth/exchanger/__init__.py b/tests/unittests/auth/exchanger/__init__.py new file mode 100644 index 000000000..5fb8a262b --- /dev/null +++ b/tests/unittests/auth/exchanger/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for credential exchanger.""" diff --git a/tests/unittests/auth/test_service_account_credential_exchanger.py b/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py similarity index 61% rename from tests/unittests/auth/test_service_account_credential_exchanger.py rename to tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py index a5c668436..195e143d3 100644 --- a/tests/unittests/auth/test_service_account_credential_exchanger.py +++ b/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py @@ -17,19 +17,20 @@ from unittest.mock import MagicMock from unittest.mock import patch +from fastapi.openapi.models import HTTPBearer from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import ServiceAccount from google.adk.auth.auth_credential import ServiceAccountCredential -from google.adk.auth.service_account_credential_exchanger import ServiceAccountCredentialExchanger +from google.adk.auth.exchanger.service_account_credential_exchanger import ServiceAccountCredentialExchanger import pytest class TestServiceAccountCredentialExchanger: """Test cases for ServiceAccountCredentialExchanger.""" - def test_init_valid_credential(self): - """Test successful initialization with valid service account credential.""" + def test_exchange_with_valid_credential(self): + """Test successful exchange with valid service account credential.""" credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( @@ -55,26 +56,36 @@ def test_init_valid_credential(self): ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - assert exchanger._credential == credential + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() - def test_init_invalid_credential_type(self): - """Test initialization with invalid credential type raises ValueError.""" + # This should not raise an exception + assert exchanger is not None + + @pytest.mark.asyncio + async def test_exchange_invalid_credential_type(self): + """Test exchange with invalid credential type raises ValueError.""" credential = AuthCredential( auth_type=AuthCredentialTypes.API_KEY, api_key="test-key", ) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + with pytest.raises( ValueError, match="Credential is not a service account credential" ): - ServiceAccountCredentialExchanger(credential) + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) @patch( - "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" ) - @patch("google.adk.auth.service_account_credential_exchanger.Request") - def test_exchange_with_explicit_credentials_success( + async def test_exchange_with_explicit_credentials_success( self, mock_request_class, mock_from_service_account_info ): """Test successful exchange with explicit service account credentials.""" @@ -84,6 +95,9 @@ def test_exchange_with_explicit_credentials_success( mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" + mock_credentials.to_json.return_value = ( + '{"token": "mock_access_token", "type": "authorized_user"}' + ) mock_from_service_account_info.return_value = mock_credentials # Create test credential @@ -113,13 +127,20 @@ def test_exchange_with_explicit_credentials_success( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - result = exchanger.exchange() + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + result = await exchanger.exchange(credential, auth_scheme) # Verify the result - assert result.auth_type == AuthCredentialTypes.HTTP - assert result.http.scheme == "bearer" - assert result.http.credentials.token == "mock_access_token" + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.google_oauth2_json is not None + # Verify that google_oauth2_json contains the token + import json + + exchanged_creds = json.loads(result.google_oauth2_json) + assert exchanged_creds.get( + "token" + ) == "mock_access_token" or "mock_access_token" in str(exchanged_creds) # Verify mocks were called correctly mock_from_service_account_info.assert_called_once_with( @@ -128,11 +149,14 @@ def test_exchange_with_explicit_credentials_success( ) mock_credentials.refresh.assert_called_once_with(mock_request) + @pytest.mark.asyncio @patch( - "google.adk.auth.service_account_credential_exchanger.google.auth.default" + "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" ) - @patch("google.adk.auth.service_account_credential_exchanger.Request") - def test_exchange_with_default_credentials_success( + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" + ) + async def test_exchange_with_default_credentials_success( self, mock_request_class, mock_google_auth_default ): """Test successful exchange with default application credentials.""" @@ -142,6 +166,9 @@ def test_exchange_with_default_credentials_success( mock_credentials = MagicMock() mock_credentials.token = "default_access_token" + mock_credentials.to_json.return_value = ( + '{"token": "default_access_token", "type": "authorized_user"}' + ) mock_google_auth_default.return_value = (mock_credentials, "test-project") # Create test credential with use_default_credential=True @@ -153,33 +180,45 @@ def test_exchange_with_default_credentials_success( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - result = exchanger.exchange() + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + result = await exchanger.exchange(credential, auth_scheme) # Verify the result - assert result.auth_type == AuthCredentialTypes.HTTP - assert result.http.scheme == "bearer" - assert result.http.credentials.token == "default_access_token" + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.google_oauth2_json is not None + # Verify that google_oauth2_json contains the token + import json + + exchanged_creds = json.loads(result.google_oauth2_json) + assert exchanged_creds.get( + "token" + ) == "default_access_token" or "default_access_token" in str( + exchanged_creds + ) # Verify mocks were called correctly mock_google_auth_default.assert_called_once() mock_credentials.refresh.assert_called_once_with(mock_request) - def test_exchange_missing_service_account(self): + @pytest.mark.asyncio + async def test_exchange_missing_service_account(self): """Test exchange fails when service_account is None.""" credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=None, ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Service account credentials are missing" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) - def test_exchange_missing_credentials_and_not_default(self): + @pytest.mark.asyncio + async def test_exchange_missing_credentials_and_not_default(self): """Test exchange fails when credentials are missing and use_default_credential is False.""" credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, @@ -190,17 +229,19 @@ def test_exchange_missing_credentials_and_not_default(self): ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( - ValueError, match="Service account credentials are missing" + ValueError, match="Service account credentials are invalid" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio @patch( - "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" ) - def test_exchange_credential_creation_failure( + async def test_exchange_credential_creation_failure( self, mock_from_service_account_info ): """Test exchange handles credential creation failure gracefully.""" @@ -234,17 +275,21 @@ def test_exchange_credential_creation_failure( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Failed to exchange service account token" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio @patch( - "google.adk.auth.service_account_credential_exchanger.google.auth.default" + "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" ) - def test_exchange_default_credential_failure(self, mock_google_auth_default): + async def test_exchange_default_credential_failure( + self, mock_google_auth_default + ): """Test exchange handles default credential failure gracefully.""" # Setup mock to raise exception mock_google_auth_default.side_effect = Exception( @@ -260,18 +305,22 @@ def test_exchange_default_credential_failure(self, mock_google_auth_default): ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Failed to exchange service account token" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) @patch( - "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" ) - @patch("google.adk.auth.service_account_credential_exchanger.Request") - def test_exchange_refresh_failure( + async def test_exchange_refresh_failure( self, mock_request_class, mock_from_service_account_info ): """Test exchange handles credential refresh failure gracefully.""" @@ -312,30 +361,73 @@ def test_exchange_refresh_failure( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Failed to exchange service account token" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + + @pytest.mark.asyncio + async def test_exchange_none_credential_in_constructor(self): + """Test that passing None credential raises appropriate error during exchange.""" + # This test verifies behavior when credential is None + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + + with pytest.raises(ValueError, match="Credential cannot be None"): + await exchanger.exchange(None, auth_scheme) - def test_exchange_none_credential_in_constructor(self): - """Test that passing None credential raises appropriate error during construction.""" - # This test verifies behavior when _credential is None, though this shouldn't - # happen in normal usage due to constructor validation + @pytest.mark.asyncio + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" + ) + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" + ) + async def test_exchange_with_service_account_no_explicit_credentials( + self, mock_request_class, mock_google_auth_default + ): + """Test exchange with service account that has no explicit credentials uses default.""" + # Setup mocks + mock_request = MagicMock() + mock_request_class.return_value = mock_request + + mock_credentials = MagicMock() + mock_credentials.token = "default_access_token" + mock_credentials.to_json.return_value = ( + '{"token": "default_access_token", "type": "authorized_user"}' + ) + mock_google_auth_default.return_value = (mock_credentials, "test-project") + + # Create test credential with no explicit credentials but use_default_credential=True credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( + service_account_credential=None, use_default_credential=True, scopes=["https://www.googleapis.com/auth/cloud-platform"], ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - # Manually set to None to test the validation logic - exchanger._credential = None + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + result = await exchanger.exchange(credential, auth_scheme) - with pytest.raises( - ValueError, match="Service account credentials are missing" - ): - exchanger.exchange() + # Verify the result + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.google_oauth2_json is not None + # Verify that google_oauth2_json contains the token + import json + + exchanged_creds = json.loads(result.google_oauth2_json) + assert exchanged_creds.get( + "token" + ) == "default_access_token" or "default_access_token" in str( + exchanged_creds + ) + + # Verify mocks were called correctly + mock_google_auth_default.assert_called_once() + mock_credentials.refresh.assert_called_once_with(mock_request) From 55201cb6a1d59674e9aea1d25da37c6edbb7e0c7 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 19:08:45 -0700 Subject: [PATCH 27/79] chore: Add credential exchanger registry (Experimentals) PiperOrigin-RevId: 772713412 --- .../credential_exchanger_registry.py | 58 +++++ .../test_credential_exchanger_registry.py | 242 ++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 src/google/adk/auth/exchanger/credential_exchanger_registry.py create mode 100644 tests/unittests/auth/exchanger/test_credential_exchanger_registry.py diff --git a/src/google/adk/auth/exchanger/credential_exchanger_registry.py b/src/google/adk/auth/exchanger/credential_exchanger_registry.py new file mode 100644 index 000000000..5af7f3c1a --- /dev/null +++ b/src/google/adk/auth/exchanger/credential_exchanger_registry.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential exchanger registry.""" + +from __future__ import annotations + +from typing import Dict +from typing import Optional + +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredentialTypes +from .base_credential_exchanger import BaseCredentialExchanger + + +@experimental +class CredentialExchangerRegistry: + """Registry for credential exchanger instances.""" + + def __init__(self): + self._exchangers: Dict[AuthCredentialTypes, BaseCredentialExchanger] = {} + + def register( + self, + credential_type: AuthCredentialTypes, + exchanger_instance: BaseCredentialExchanger, + ) -> None: + """Register an exchanger instance for a credential type. + + Args: + credential_type: The credential type to register for. + exchanger_instance: The exchanger instance to register. + """ + self._exchangers[credential_type] = exchanger_instance + + def get_exchanger( + self, credential_type: AuthCredentialTypes + ) -> Optional[BaseCredentialExchanger]: + """Get the exchanger instance for a credential type. + + Args: + credential_type: The credential type to get exchanger for. + + Returns: + The exchanger instance if registered, None otherwise. + """ + return self._exchangers.get(credential_type) diff --git a/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py b/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py new file mode 100644 index 000000000..66b858232 --- /dev/null +++ b/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py @@ -0,0 +1,242 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the CredentialExchangerRegistry.""" + +from typing import Optional +from unittest.mock import MagicMock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.exchanger.base_credential_exchanger import BaseCredentialExchanger +from google.adk.auth.exchanger.credential_exchanger_registry import CredentialExchangerRegistry +import pytest + + +class MockCredentialExchanger(BaseCredentialExchanger): + """Mock credential exchanger for testing.""" + + def __init__(self, exchange_result: Optional[AuthCredential] = None): + self.exchange_result = exchange_result or AuthCredential( + auth_type=AuthCredentialTypes.HTTP + ) + + def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Mock exchange method.""" + return self.exchange_result + + +class TestCredentialExchangerRegistry: + """Test cases for CredentialExchangerRegistry.""" + + def test_initialization(self): + """Test that the registry initializes with an empty exchangers dictionary.""" + registry = CredentialExchangerRegistry() + + # Access the private attribute for testing + assert hasattr(registry, '_exchangers') + assert isinstance(registry._exchangers, dict) + assert len(registry._exchangers) == 0 + + def test_register_single_exchanger(self): + """Test registering a single exchanger.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + # Verify the exchanger was registered + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert retrieved_exchanger is mock_exchanger + + def test_register_multiple_exchangers(self): + """Test registering multiple exchangers for different credential types.""" + registry = CredentialExchangerRegistry() + + api_key_exchanger = MockCredentialExchanger() + oauth2_exchanger = MockCredentialExchanger() + service_account_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.API_KEY, api_key_exchanger) + registry.register(AuthCredentialTypes.OAUTH2, oauth2_exchanger) + registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, service_account_exchanger + ) + + # Verify all exchangers were registered correctly + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is api_key_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.OAUTH2) is oauth2_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.SERVICE_ACCOUNT) + is service_account_exchanger + ) + + def test_register_overwrites_existing_exchanger(self): + """Test that registering an exchanger for an existing type overwrites the previous one.""" + registry = CredentialExchangerRegistry() + + first_exchanger = MockCredentialExchanger() + second_exchanger = MockCredentialExchanger() + + # Register first exchanger + registry.register(AuthCredentialTypes.API_KEY, first_exchanger) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is first_exchanger + ) + + # Register second exchanger for the same type + registry.register(AuthCredentialTypes.API_KEY, second_exchanger) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is second_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) + is not first_exchanger + ) + + def test_get_exchanger_returns_correct_instance(self): + """Test that get_exchanger returns the correct exchanger instance.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.HTTP, mock_exchanger) + + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.HTTP) + assert retrieved_exchanger is mock_exchanger + assert isinstance(retrieved_exchanger, BaseCredentialExchanger) + + def test_get_exchanger_nonexistent_type_returns_none(self): + """Test that get_exchanger returns None for non-existent credential types.""" + registry = CredentialExchangerRegistry() + + # Try to get an exchanger that was never registered + result = registry.get_exchanger(AuthCredentialTypes.OAUTH2) + assert result is None + + def test_get_exchanger_after_registration_and_removal(self): + """Test behavior when an exchanger is registered and then the registry is cleared indirectly.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + # Register exchanger + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + assert registry.get_exchanger(AuthCredentialTypes.API_KEY) is mock_exchanger + + # Clear the internal dictionary (simulating some edge case) + registry._exchangers.clear() + assert registry.get_exchanger(AuthCredentialTypes.API_KEY) is None + + def test_register_with_all_credential_types(self): + """Test registering exchangers for all available credential types.""" + registry = CredentialExchangerRegistry() + + exchangers = {} + credential_types = [ + AuthCredentialTypes.API_KEY, + AuthCredentialTypes.HTTP, + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + AuthCredentialTypes.SERVICE_ACCOUNT, + ] + + # Register an exchanger for each credential type + for cred_type in credential_types: + exchanger = MockCredentialExchanger() + exchangers[cred_type] = exchanger + registry.register(cred_type, exchanger) + + # Verify all exchangers can be retrieved + for cred_type in credential_types: + retrieved_exchanger = registry.get_exchanger(cred_type) + assert retrieved_exchanger is exchangers[cred_type] + + def test_register_with_mock_exchanger_using_magicmock(self): + """Test registering with a MagicMock exchanger.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MagicMock(spec=BaseCredentialExchanger) + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert retrieved_exchanger is mock_exchanger + + def test_registry_isolation(self): + """Test that different registry instances are isolated from each other.""" + registry1 = CredentialExchangerRegistry() + registry2 = CredentialExchangerRegistry() + + exchanger1 = MockCredentialExchanger() + exchanger2 = MockCredentialExchanger() + + # Register different exchangers in different registry instances + registry1.register(AuthCredentialTypes.API_KEY, exchanger1) + registry2.register(AuthCredentialTypes.API_KEY, exchanger2) + + # Verify isolation + assert registry1.get_exchanger(AuthCredentialTypes.API_KEY) is exchanger1 + assert registry2.get_exchanger(AuthCredentialTypes.API_KEY) is exchanger2 + assert ( + registry1.get_exchanger(AuthCredentialTypes.API_KEY) is not exchanger2 + ) + assert ( + registry2.get_exchanger(AuthCredentialTypes.API_KEY) is not exchanger1 + ) + + def test_exchanger_functionality_through_registry(self): + """Test that exchangers registered in the registry function correctly.""" + registry = CredentialExchangerRegistry() + + # Create a mock exchanger with specific return value + expected_result = AuthCredential(auth_type=AuthCredentialTypes.HTTP) + mock_exchanger = MockCredentialExchanger(exchange_result=expected_result) + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + # Get the exchanger and test its functionality + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + input_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY) + + result = retrieved_exchanger.exchange(input_credential) + assert result is expected_result + + def test_register_none_exchanger(self): + """Test that registering None as an exchanger works (edge case).""" + registry = CredentialExchangerRegistry() + + # This should work but return None when retrieved + registry.register(AuthCredentialTypes.API_KEY, None) + + result = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert result is None + + def test_internal_dictionary_structure(self): + """Test the internal structure of the registry.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.OAUTH2, mock_exchanger) + + # Verify internal dictionary structure + assert AuthCredentialTypes.OAUTH2 in registry._exchangers + assert registry._exchangers[AuthCredentialTypes.OAUTH2] is mock_exchanger + assert len(registry._exchangers) == 1 From a17ebe6ebd7b58fa86a90cceb7650c2b3187933d Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 21:11:56 -0700 Subject: [PATCH 28/79] chore: Add a credential refresher registry PiperOrigin-RevId: 772747251 --- .../credential_refresher_registry.py | 59 ++++++ .../test_credential_refresher_registry.py | 174 ++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 src/google/adk/auth/refresher/credential_refresher_registry.py create mode 100644 tests/unittests/auth/refresher/test_credential_refresher_registry.py diff --git a/src/google/adk/auth/refresher/credential_refresher_registry.py b/src/google/adk/auth/refresher/credential_refresher_registry.py new file mode 100644 index 000000000..90975d66d --- /dev/null +++ b/src/google/adk/auth/refresher/credential_refresher_registry.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential refresher registry.""" + +from __future__ import annotations + +from typing import Dict +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.utils.feature_decorator import experimental + +from .base_credential_refresher import BaseCredentialRefresher + + +@experimental +class CredentialRefresherRegistry: + """Registry for credential refresher instances.""" + + def __init__(self): + self._refreshers: Dict[AuthCredentialTypes, BaseCredentialRefresher] = {} + + def register( + self, + credential_type: AuthCredentialTypes, + refresher_instance: BaseCredentialRefresher, + ) -> None: + """Register a refresher instance for a credential type. + + Args: + credential_type: The credential type to register for. + refresher_instance: The refresher instance to register. + """ + self._refreshers[credential_type] = refresher_instance + + def get_refresher( + self, credential_type: AuthCredentialTypes + ) -> Optional[BaseCredentialRefresher]: + """Get the refresher instance for a credential type. + + Args: + credential_type: The credential type to get refresher for. + + Returns: + The refresher instance if registered, None otherwise. + """ + return self._refreshers.get(credential_type) diff --git a/tests/unittests/auth/refresher/test_credential_refresher_registry.py b/tests/unittests/auth/refresher/test_credential_refresher_registry.py new file mode 100644 index 000000000..b00cc4da8 --- /dev/null +++ b/tests/unittests/auth/refresher/test_credential_refresher_registry.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CredentialRefresherRegistry.""" + +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.refresher.base_credential_refresher import BaseCredentialRefresher +from google.adk.auth.refresher.credential_refresher_registry import CredentialRefresherRegistry + + +class TestCredentialRefresherRegistry: + """Tests for the CredentialRefresherRegistry class.""" + + def test_init(self): + """Test that registry initializes with empty refreshers dictionary.""" + registry = CredentialRefresherRegistry() + assert registry._refreshers == {} + + def test_register_refresher(self): + """Test registering a refresher instance for a credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher) + + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher + + def test_register_multiple_refreshers(self): + """Test registering multiple refresher instances for different credential types.""" + registry = CredentialRefresherRegistry() + mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher) + mock_openid_refresher = Mock(spec=BaseCredentialRefresher) + mock_service_account_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher) + registry.register( + AuthCredentialTypes.OPEN_ID_CONNECT, mock_openid_refresher + ) + registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, mock_service_account_refresher + ) + + assert ( + registry._refreshers[AuthCredentialTypes.OAUTH2] + == mock_oauth2_refresher + ) + assert ( + registry._refreshers[AuthCredentialTypes.OPEN_ID_CONNECT] + == mock_openid_refresher + ) + assert ( + registry._refreshers[AuthCredentialTypes.SERVICE_ACCOUNT] + == mock_service_account_refresher + ) + + def test_register_overwrite_existing_refresher(self): + """Test that registering a refresher overwrites an existing one for the same credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher_1 = Mock(spec=BaseCredentialRefresher) + mock_refresher_2 = Mock(spec=BaseCredentialRefresher) + + # Register first refresher + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_1) + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_1 + + # Register second refresher for same credential type + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_2) + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_2 + + def test_get_refresher_existing(self): + """Test getting a refresher instance for a registered credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher) + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result == mock_refresher + + def test_get_refresher_non_existing(self): + """Test getting a refresher instance for a non-registered credential type returns None.""" + registry = CredentialRefresherRegistry() + + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result is None + + def test_get_refresher_after_registration(self): + """Test getting refresher instances for multiple credential types.""" + registry = CredentialRefresherRegistry() + mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher) + mock_api_key_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher) + registry.register(AuthCredentialTypes.API_KEY, mock_api_key_refresher) + + # Get registered refreshers + oauth2_result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + api_key_result = registry.get_refresher(AuthCredentialTypes.API_KEY) + + assert oauth2_result == mock_oauth2_refresher + assert api_key_result == mock_api_key_refresher + + # Get non-registered refresher + http_result = registry.get_refresher(AuthCredentialTypes.HTTP) + assert http_result is None + + def test_register_all_credential_types(self): + """Test registering refreshers for all available credential types.""" + registry = CredentialRefresherRegistry() + + refreshers = {} + for credential_type in AuthCredentialTypes: + mock_refresher = Mock(spec=BaseCredentialRefresher) + refreshers[credential_type] = mock_refresher + registry.register(credential_type, mock_refresher) + + # Verify all refreshers are registered correctly + for credential_type in AuthCredentialTypes: + result = registry.get_refresher(credential_type) + assert result == refreshers[credential_type] + + def test_empty_registry_get_refresher(self): + """Test getting refresher from empty registry returns None for any credential type.""" + registry = CredentialRefresherRegistry() + + for credential_type in AuthCredentialTypes: + result = registry.get_refresher(credential_type) + assert result is None + + def test_registry_independence(self): + """Test that multiple registry instances are independent.""" + registry1 = CredentialRefresherRegistry() + registry2 = CredentialRefresherRegistry() + + mock_refresher1 = Mock(spec=BaseCredentialRefresher) + mock_refresher2 = Mock(spec=BaseCredentialRefresher) + + registry1.register(AuthCredentialTypes.OAUTH2, mock_refresher1) + registry2.register(AuthCredentialTypes.OAUTH2, mock_refresher2) + + # Verify registries are independent + assert ( + registry1.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher1 + ) + assert ( + registry2.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher2 + ) + assert registry1.get_refresher( + AuthCredentialTypes.OAUTH2 + ) != registry2.get_refresher(AuthCredentialTypes.OAUTH2) + + def test_register_with_none_refresher(self): + """Test registering None as a refresher instance.""" + registry = CredentialRefresherRegistry() + + # This should technically work as the registry accepts any value + registry.register(AuthCredentialTypes.OAUTH2, None) + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result is None From 9a207cb832e86dd9fd643220139a0384388cdb6c Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 10:45:54 -0700 Subject: [PATCH 29/79] refactor: Refactor oauth2_credential_exchanger to exchanger and refresher separately PiperOrigin-RevId: 772979993 --- src/google/adk/auth/auth_handler.py | 19 +- src/google/adk/auth/auth_preprocessor.py | 6 +- .../exchanger/oauth2_credential_exchanger.py | 104 ++++++ .../adk/auth/oauth2_credential_fetcher.py | 132 -------- .../refresher/oauth2_credential_refresher.py | 154 +++++++++ .../integration_connector_tool.py | 4 +- .../openapi_spec_parser/rest_api_tool.py | 6 +- .../openapi_spec_parser/tool_auth_handler.py | 16 +- .../test_oauth2_credential_exchanger.py | 220 +++++++++++++ tests/unittests/auth/refresher/__init__.py | 13 + .../test_oauth2_credential_refresher.py | 297 ++++++++++++++++++ tests/unittests/auth/test_auth_handler.py | 72 +++-- .../auth/test_oauth2_credential_fetcher.py | 147 --------- .../test_integration_connector_tool.py | 40 +-- .../openapi_spec_parser/test_rest_api_tool.py | 17 +- .../test_tool_auth_handler.py | 47 +-- 16 files changed, 926 insertions(+), 368 deletions(-) create mode 100644 src/google/adk/auth/exchanger/oauth2_credential_exchanger.py delete mode 100644 src/google/adk/auth/oauth2_credential_fetcher.py create mode 100644 src/google/adk/auth/refresher/oauth2_credential_refresher.py create mode 100644 tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py create mode 100644 tests/unittests/auth/refresher/__init__.py create mode 100644 tests/unittests/auth/refresher/test_oauth2_credential_refresher.py delete mode 100644 tests/unittests/auth/test_oauth2_credential_fetcher.py diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 3e13cbac2..473f31413 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -22,7 +22,7 @@ from .auth_schemes import AuthSchemeType from .auth_schemes import OpenIdConnectWithConfig from .auth_tool import AuthConfig -from .oauth2_credential_fetcher import OAuth2CredentialFetcher +from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger if TYPE_CHECKING: from ..sessions.state import State @@ -36,18 +36,23 @@ class AuthHandler: + """A handler that handles the auth flow in Agent Development Kit to help + orchestrate the credential request and response flow (e.g. OAuth flow) + This class should only be used by Agent Development Kit. + """ def __init__(self, auth_config: AuthConfig): self.auth_config = auth_config - def exchange_auth_token( + async def exchange_auth_token( self, ) -> AuthCredential: - return OAuth2CredentialFetcher( - self.auth_config.auth_scheme, self.auth_config.exchanged_auth_credential - ).exchange() + exchanger = OAuth2CredentialExchanger() + return await exchanger.exchange( + self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme + ) - def parse_and_store_auth_response(self, state: State) -> None: + async def parse_and_store_auth_response(self, state: State) -> None: credential_key = "temp:" + self.auth_config.credential_key @@ -60,7 +65,7 @@ def parse_and_store_auth_response(self, state: State) -> None: ): return - state[credential_key] = self.exchange_auth_token() + state[credential_key] = await self.exchange_auth_token() def _validate(self) -> None: if not self.auth_scheme: diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 0c964ed96..b06774973 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -67,9 +67,9 @@ async def run_async( # function call request_euc_function_call_ids.add(function_call_response.id) auth_config = AuthConfig.model_validate(function_call_response.response) - AuthHandler(auth_config=auth_config).parse_and_store_auth_response( - state=invocation_context.session.state - ) + await AuthHandler( + auth_config=auth_config + ).parse_and_store_auth_response(state=invocation_context.session.state) break if not request_euc_function_call_ids: diff --git a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py new file mode 100644 index 000000000..768457e1a --- /dev/null +++ b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 credential exchanger implementation.""" + +from __future__ import annotations + +import logging +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import OAuthGrantType +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +from google.adk.utils.feature_decorator import experimental +from typing_extensions import override + +from .base_credential_exchanger import BaseCredentialExchanger +from .base_credential_exchanger import CredentialExchangError + +try: + from authlib.integrations.requests_client import OAuth2Session + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class OAuth2CredentialExchanger(BaseCredentialExchanger): + """Exchanges OAuth2 credentials from authorization responses.""" + + @override + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Exchange OAuth2 credential from authorization response. + if credential exchange failed, the original credential will be returned. + + Args: + auth_credential: The OAuth2 credential to exchange. + auth_scheme: The OAuth2 authentication scheme. + + Returns: + The exchanged credential with access token. + + Raises: + CredentialExchangError: If auth_scheme is missing. + """ + if not auth_scheme: + raise CredentialExchangError( + "auth_scheme is required for OAuth2 credential exchange" + ) + + if not AUTHLIB_AVIALABLE: + # If authlib is not available, we cannot exchange the credential. + # We return the original credential without exchange. + # The client using this tool can decide to exchange the credential + # themselves using other lib. + logger.warning( + "authlib is not available, skipping OAuth2 credential exchange." + ) + return auth_credential + + if auth_credential.oauth2 and auth_credential.oauth2.access_token: + return auth_credential + + client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential) + if not client: + logger.warning("Could not create OAuth2 session for token exchange") + return auth_credential + + try: + tokens = client.fetch_token( + token_endpoint, + authorization_response=auth_credential.oauth2.auth_response_uri, + code=auth_credential.oauth2.auth_code, + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + ) + update_credential_with_tokens(auth_credential, tokens) + logger.debug("Successfully exchanged OAuth2 tokens") + except Exception as e: + # TODO reconsider whether we should raise errors in this case + logger.error("Failed to exchange OAuth2 tokens: %s", e) + # Return original credential on failure + return auth_credential + + return auth_credential diff --git a/src/google/adk/auth/oauth2_credential_fetcher.py b/src/google/adk/auth/oauth2_credential_fetcher.py deleted file mode 100644 index c9e838b25..000000000 --- a/src/google/adk/auth/oauth2_credential_fetcher.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import logging - -from ..utils.feature_decorator import experimental -from .auth_credential import AuthCredential -from .auth_schemes import AuthScheme -from .auth_schemes import OAuthGrantType -from .oauth2_credential_util import create_oauth2_session -from .oauth2_credential_util import update_credential_with_tokens - -try: - from authlib.oauth2.rfc6749 import OAuth2Token - - AUTHLIB_AVIALABLE = True -except ImportError: - AUTHLIB_AVIALABLE = False - - -logger = logging.getLogger("google_adk." + __name__) - - -@experimental -class OAuth2CredentialFetcher: - """Exchanges and refreshes an OAuth2 access token. (Experimental)""" - - def __init__( - self, - auth_scheme: AuthScheme, - auth_credential: AuthCredential, - ): - self._auth_scheme = auth_scheme - self._auth_credential = auth_credential - - def _update_credential(self, tokens: OAuth2Token) -> None: - self._auth_credential.oauth2.access_token = tokens.get("access_token") - self._auth_credential.oauth2.refresh_token = tokens.get("refresh_token") - self._auth_credential.oauth2.expires_at = ( - int(tokens.get("expires_at")) if tokens.get("expires_at") else None - ) - self._auth_credential.oauth2.expires_in = ( - int(tokens.get("expires_in")) if tokens.get("expires_in") else None - ) - - def exchange(self) -> AuthCredential: - """Exchange an oauth token from the authorization response. - - Returns: - An AuthCredential object containing the access token. - """ - if not AUTHLIB_AVIALABLE: - return self._auth_credential - - if ( - self._auth_credential.oauth2 - and self._auth_credential.oauth2.access_token - ): - return self._auth_credential - - client, token_endpoint = create_oauth2_session( - self._auth_scheme, self._auth_credential - ) - if not client: - logger.warning("Could not create OAuth2 session for token exchange") - return self._auth_credential - - try: - tokens = client.fetch_token( - token_endpoint, - authorization_response=self._auth_credential.oauth2.auth_response_uri, - code=self._auth_credential.oauth2.auth_code, - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - ) - update_credential_with_tokens(self._auth_credential, tokens) - logger.info("Successfully exchanged OAuth2 tokens") - except Exception as e: - logger.error("Failed to exchange OAuth2 tokens: %s", e) - # Return original credential on failure - return self._auth_credential - - return self._auth_credential - - def refresh(self) -> AuthCredential: - """Refresh an oauth token. - - Returns: - An AuthCredential object containing the refreshed access token. - """ - if not AUTHLIB_AVIALABLE: - return self._auth_credential - credential = self._auth_credential - if not credential.oauth2: - return credential - - if OAuth2Token({ - "expires_at": credential.oauth2.expires_at, - "expires_in": credential.oauth2.expires_in, - }).is_expired(): - client, token_endpoint = create_oauth2_session( - self._auth_scheme, self._auth_credential - ) - if not client: - logger.warning("Could not create OAuth2 session for token refresh") - return credential - - try: - tokens = client.refresh_token( - url=token_endpoint, - refresh_token=credential.oauth2.refresh_token, - ) - update_credential_with_tokens(self._auth_credential, tokens) - logger.info("Successfully refreshed OAuth2 tokens") - except Exception as e: - logger.error("Failed to refresh OAuth2 tokens: %s", e) - # Return original credential on failure - return credential - - return self._auth_credential diff --git a/src/google/adk/auth/refresher/oauth2_credential_refresher.py b/src/google/adk/auth/refresher/oauth2_credential_refresher.py new file mode 100644 index 000000000..2d0a8b670 --- /dev/null +++ b/src/google/adk/auth/refresher/oauth2_credential_refresher.py @@ -0,0 +1,154 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 credential refresher implementation.""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +from google.adk.utils.feature_decorator import experimental +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from typing_extensions import override + +from .base_credential_refresher import BaseCredentialRefresher + +try: + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class OAuth2CredentialRefresher(BaseCredentialRefresher): + """Refreshes OAuth2 credentials including Google OAuth2 JSON credentials.""" + + @override + async def is_refresh_needed( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> bool: + """Check if the OAuth2 credential needs to be refreshed. + + Args: + auth_credential: The OAuth2 credential to check. + auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON). + + Returns: + True if the credential needs to be refreshed, False otherwise. + """ + # Handle Google OAuth2 credentials (from service account exchange) + if auth_credential.google_oauth2_json: + try: + google_credential = Credentials.from_authorized_user_info( + json.loads(auth_credential.google_oauth2_json) + ) + return google_credential.expired and bool( + google_credential.refresh_token + ) + except Exception as e: + logger.warning("Failed to parse Google OAuth2 JSON credential: %s", e) + return False + + # Handle regular OAuth2 credentials + elif auth_credential.oauth2 and auth_scheme: + if not AUTHLIB_AVIALABLE: + return False + + if not auth_credential.oauth2: + return False + + return OAuth2Token({ + "expires_at": auth_credential.oauth2.expires_at, + "expires_in": auth_credential.oauth2.expires_in, + }).is_expired() + + return False + + @override + async def refresh( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Refresh the OAuth2 credential. + If refresh failed, return the original credential. + + Args: + auth_credential: The OAuth2 credential to refresh. + auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON). + + Returns: + The refreshed credential. + + """ + # Handle Google OAuth2 credentials (from service account exchange) + if auth_credential.google_oauth2_json: + try: + google_credential = Credentials.from_authorized_user_info( + json.loads(auth_credential.google_oauth2_json) + ) + if google_credential.expired and google_credential.refresh_token: + google_credential.refresh(Request()) + auth_credential.google_oauth2_json = google_credential.to_json() + logger.info("Successfully refreshed Google OAuth2 JSON credential") + except Exception as e: + # TODO reconsider whether we should raise error when refresh failed. + logger.error("Failed to refresh Google OAuth2 JSON credential: %s", e) + + # Handle regular OAuth2 credentials + elif auth_credential.oauth2 and auth_scheme: + if not AUTHLIB_AVIALABLE: + return auth_credential + + if not auth_credential.oauth2: + return auth_credential + + if OAuth2Token({ + "expires_at": auth_credential.oauth2.expires_at, + "expires_in": auth_credential.oauth2.expires_in, + }).is_expired(): + client, token_endpoint = create_oauth2_session( + auth_scheme, auth_credential + ) + if not client: + logger.warning("Could not create OAuth2 session for token refresh") + return auth_credential + + try: + tokens = client.refresh_token( + url=token_endpoint, + refresh_token=auth_credential.oauth2.refresh_token, + ) + update_credential_with_tokens(auth_credential, tokens) + logger.debug("Successfully refreshed OAuth2 tokens") + except Exception as e: + # TODO reconsider whether we should raise error when refresh failed. + logger.error("Failed to refresh OAuth2 tokens: %s", e) + # Return original credential on failure + return auth_credential + + return auth_credential diff --git a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py index 4e5be5959..5a50a7f0c 100644 --- a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py +++ b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py @@ -150,7 +150,7 @@ async def run_async( tool_auth_handler = ToolAuthHandler.from_tool_context( tool_context, self._auth_scheme, self._auth_credential ) - auth_result = tool_auth_handler.prepare_auth_credentials() + auth_result = await tool_auth_handler.prepare_auth_credentials() if auth_result.state == 'pending': return { @@ -178,7 +178,7 @@ async def run_async( args['operation'] = self._operation args['action'] = self._action logger.info('Running tool: %s with args: %s', self.name, args) - return self._rest_api_tool.call(args=args, tool_context=tool_context) + return await self._rest_api_tool.call(args=args, tool_context=tool_context) def __str__(self): return ( diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 1e451fe0f..dee103932 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -345,9 +345,9 @@ def _prepare_request_params( async def run_async( self, *, args: dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: - return self.call(args=args, tool_context=tool_context) + return await self.call(args=args, tool_context=tool_context) - def call( + async def call( self, *, args: dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: """Executes the REST API call. @@ -364,7 +364,7 @@ def call( tool_auth_handler = ToolAuthHandler.from_tool_context( tool_context, self.auth_scheme, self.auth_credential ) - auth_result = tool_auth_handler.prepare_auth_credentials() + auth_result = await tool_auth_handler.prepare_auth_credentials() auth_state, auth_scheme, auth_credential = ( auth_result.state, auth_result.auth_scheme, diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index c36793fdc..08e535d28 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -25,7 +25,7 @@ from ....auth.auth_schemes import AuthScheme from ....auth.auth_schemes import AuthSchemeType from ....auth.auth_tool import AuthConfig -from ....auth.oauth2_credential_fetcher import OAuth2CredentialFetcher +from ....auth.refresher.oauth2_credential_refresher import OAuth2CredentialRefresher from ...tool_context import ToolContext from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError @@ -146,7 +146,7 @@ def from_tool_context( credential_store, ) - def _get_existing_credential( + async def _get_existing_credential( self, ) -> Optional[AuthCredential]: """Checks for and returns an existing, exchanged credential.""" @@ -156,9 +156,11 @@ def _get_existing_credential( ) if existing_credential: if existing_credential.oauth2: - existing_credential = OAuth2CredentialFetcher( - self.auth_scheme, existing_credential - ).refresh() + refresher = OAuth2CredentialRefresher() + if await refresher.is_refresh_needed(existing_credential): + existing_credential = await refresher.refresh( + existing_credential, self.auth_scheme + ) return existing_credential return None @@ -234,7 +236,7 @@ def _external_exchange_required(self, credential) -> bool: and not credential.google_oauth2_json ) - def prepare_auth_credentials( + async def prepare_auth_credentials( self, ) -> AuthPreparationResult: """Prepares authentication credentials, handling exchange and user interaction.""" @@ -244,7 +246,7 @@ def prepare_auth_credentials( return AuthPreparationResult(state="done") # Check for existing credential. - existing_credential = self._get_existing_credential() + existing_credential = await self._get_existing_credential() credential = existing_credential or self.auth_credential # fetch credential from adk framework diff --git a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py new file mode 100644 index 000000000..ef1dbbbee --- /dev/null +++ b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py @@ -0,0 +1,220 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock +from unittest.mock import patch + +from authlib.oauth2.rfc6749 import OAuth2Token +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.exchanger.base_credential_exchanger import CredentialExchangError +from google.adk.auth.exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger +import pytest + + +class TestOAuth2CredentialExchanger: + """Test suite for OAuth2CredentialExchanger.""" + + @pytest.mark.asyncio + async def test_exchange_with_existing_token(self): + """Test exchange method when access token already exists.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return the same credential since access token already exists + assert result == credential + assert result.oauth2.access_token == "existing_token" + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_success(self, mock_oauth2_session): + """Test successful token exchange.""" + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Verify token exchange was successful + assert result.oauth2.access_token == "new_access_token" + assert result.oauth2.refresh_token == "new_refresh_token" + mock_client.fetch_token.assert_called_once() + + @pytest.mark.asyncio + async def test_exchange_missing_auth_scheme(self): + """Test exchange with missing auth_scheme raises ValueError.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + exchanger = OAuth2CredentialExchanger() + try: + await exchanger.exchange(credential, None) + assert False, "Should have raised ValueError" + except CredentialExchangError as e: + assert "auth_scheme is required" in str(e) + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_no_session(self, mock_oauth2_session): + """Test exchange when OAuth2Session cannot be created.""" + # Mock to return None for create_oauth2_session + mock_oauth2_session.return_value = None + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret to trigger session creation failure + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when session creation fails + assert result == credential + assert result.oauth2.access_token is None + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_fetch_token_failure(self, mock_oauth2_session): + """Test exchange when fetch_token fails.""" + # Setup mock to raise exception during fetch_token + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_client.fetch_token.side_effect = Exception("Token fetch failed") + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when fetch_token fails + assert result == credential + assert result.oauth2.access_token is None + mock_client.fetch_token.assert_called_once() + + @pytest.mark.asyncio + async def test_exchange_authlib_not_available(self): + """Test exchange when authlib is not available.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + + # Mock AUTHLIB_AVIALABLE to False + with patch( + "google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVIALABLE", + False, + ): + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when authlib is not available + assert result == credential + assert result.oauth2.access_token is None diff --git a/tests/unittests/auth/refresher/__init__.py b/tests/unittests/auth/refresher/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/refresher/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py new file mode 100644 index 000000000..b22bf2ccd --- /dev/null +++ b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py @@ -0,0 +1,297 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock +from unittest.mock import patch + +from authlib.oauth2.rfc6749 import OAuth2Token +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.refresher.oauth2_credential_refresher import OAuth2CredentialRefresher +import pytest + + +class TestOAuth2CredentialRefresher: + """Test suite for OAuth2CredentialRefresher.""" + + @patch("google.adk.auth.refresher.oauth2_credential_refresher.OAuth2Token") + @pytest.mark.asyncio + async def test_needs_refresh_token_not_expired(self, mock_oauth2_token): + """Test needs_refresh when token is not expired.""" + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = False + mock_oauth2_token.return_value = mock_token_instance + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + expires_at=int(time.time()) + 3600, + ), + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, scheme) + + assert not needs_refresh + + @patch("google.adk.auth.refresher.oauth2_credential_refresher.OAuth2Token") + @pytest.mark.asyncio + async def test_needs_refresh_token_expired(self, mock_oauth2_token): + """Test needs_refresh when token is expired.""" + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + expires_at=int(time.time()) - 3600, # Expired + ), + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, scheme) + + assert needs_refresh + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @patch("google.adk.auth.oauth2_credential_util.OAuth2Token") + @pytest.mark.asyncio + async def test_refresh_token_expired_success( + self, mock_oauth2_token, mock_oauth2_session + ): + """Test successful token refresh when token is expired.""" + # Setup mock token + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + # Setup mock session + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "refreshed_access_token", + "refresh_token": "refreshed_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.refresh_token.return_value = mock_tokens + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="old_token", + refresh_token="old_refresh_token", + expires_at=int(time.time()) - 3600, # Expired + ), + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, scheme) + + # Verify token refresh was successful + assert result.oauth2.access_token == "refreshed_access_token" + assert result.oauth2.refresh_token == "refreshed_refresh_token" + mock_client.refresh_token.assert_called_once() + + @pytest.mark.asyncio + async def test_refresh_no_oauth2_credential(self): + """Test refresh with no OAuth2 credential returns original.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + # No oauth2 field + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, scheme) + + assert result == credential + + @pytest.mark.asyncio + async def test_needs_refresh_google_oauth2_json_expired(self): + """Test needs_refresh with Google OAuth2 JSON credential that is expired.""" + import json + from unittest.mock import patch + + # Mock Google OAuth2 JSON credential data + google_oauth2_json = json.dumps({ + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "refresh_token": "test_refresh_token", + "type": "authorized_user", + }) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json=google_oauth2_json, + ) + + # Mock the Google Credentials class + with patch( + "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" + ) as mock_credentials: + mock_google_credential = Mock() + mock_google_credential.expired = True + mock_google_credential.refresh_token = "test_refresh_token" + mock_credentials.from_authorized_user_info.return_value = ( + mock_google_credential + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, None) + + assert needs_refresh + + @pytest.mark.asyncio + async def test_needs_refresh_google_oauth2_json_not_expired(self): + """Test needs_refresh with Google OAuth2 JSON credential that is not expired.""" + import json + from unittest.mock import patch + + # Mock Google OAuth2 JSON credential data + google_oauth2_json = json.dumps({ + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "refresh_token": "test_refresh_token", + "type": "authorized_user", + }) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json=google_oauth2_json, + ) + + # Mock the Google Credentials class + with patch( + "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" + ) as mock_credentials: + mock_google_credential = Mock() + mock_google_credential.expired = False + mock_google_credential.refresh_token = "test_refresh_token" + mock_credentials.from_authorized_user_info.return_value = ( + mock_google_credential + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, None) + + assert not needs_refresh + + @pytest.mark.asyncio + async def test_refresh_google_oauth2_json_success(self): + """Test successful refresh of Google OAuth2 JSON credential.""" + import json + from unittest.mock import patch + + # Mock Google OAuth2 JSON credential data + google_oauth2_json = json.dumps({ + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "refresh_token": "test_refresh_token", + "type": "authorized_user", + }) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json=google_oauth2_json, + ) + + # Mock the Google Credentials and Request classes + with patch( + "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" + ) as mock_credentials: + with patch( + "google.adk.auth.refresher.oauth2_credential_refresher.Request" + ) as mock_request: + mock_google_credential = Mock() + mock_google_credential.expired = True + mock_google_credential.refresh_token = "test_refresh_token" + mock_google_credential.to_json.return_value = json.dumps({ + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "refresh_token": "new_refresh_token", + "access_token": "new_access_token", + "type": "authorized_user", + }) + mock_credentials.from_authorized_user_info.return_value = ( + mock_google_credential + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, None) + + mock_google_credential.refresh.assert_called_once() + assert ( + result.google_oauth2_json != google_oauth2_json + ) # Should be updated + + @pytest.mark.asyncio + async def test_needs_refresh_no_oauth2_credential(self): + """Test needs_refresh with no OAuth2 credential returns False.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + # No oauth2 field + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, None) + + assert not needs_refresh diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index 2bfc7d4c9..f0d730d02 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -13,8 +13,11 @@ # limitations under the License. import copy +import time +from unittest.mock import Mock from unittest.mock import patch +from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import APIKey from fastapi.openapi.models import APIKeyIn from fastapi.openapi.models import OAuth2 @@ -405,7 +408,8 @@ def test_get_auth_response_not_exists(self, auth_config): class TestParseAndStoreAuthResponse: """Tests for the parse_and_store_auth_response method.""" - def test_non_oauth_scheme(self, auth_config_with_exchanged): + @pytest.mark.asyncio + async def test_non_oauth_scheme(self, auth_config_with_exchanged): """Test with a non-OAuth auth scheme.""" # Modify the auth scheme type to be non-OAuth auth_config = copy.deepcopy(auth_config_with_exchanged) @@ -416,7 +420,7 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged): handler = AuthHandler(auth_config) state = MockState() - handler.parse_and_store_auth_response(state) + await handler.parse_and_store_auth_response(state) credential_key = auth_config.credential_key assert ( @@ -424,7 +428,10 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged): ) @patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token") - def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): + @pytest.mark.asyncio + async def test_oauth_scheme( + self, mock_exchange_token, auth_config_with_exchanged + ): """Test with an OAuth auth scheme.""" mock_exchange_token.return_value = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, @@ -434,7 +441,7 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): handler = AuthHandler(auth_config_with_exchanged) state = MockState() - handler.parse_and_store_auth_response(state) + await handler.parse_and_store_auth_response(state) credential_key = auth_config_with_exchanged.credential_key assert state["temp:" + credential_key] == mock_exchange_token.return_value @@ -444,20 +451,20 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): class TestExchangeAuthToken: """Tests for the exchange_auth_token method.""" - def test_token_exchange_not_supported( + @pytest.mark.asyncio + async def test_token_exchange_not_supported( self, auth_config_with_auth_code, monkeypatch ): """Test when token exchange is not supported.""" - monkeypatch.setattr( - "google.adk.auth.oauth2_credential_fetcher.AUTHLIB_AVIALABLE", False - ) + monkeypatch.setattr("google.adk.auth.auth_handler.AUTHLIB_AVIALABLE", False) handler = AuthHandler(auth_config_with_auth_code) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == auth_config_with_auth_code.exchanged_auth_credential - def test_openid_missing_token_endpoint( + @pytest.mark.asyncio + async def test_openid_missing_token_endpoint( self, openid_auth_scheme, oauth2_credentials_with_auth_code ): """Test OpenID Connect without a token endpoint.""" @@ -472,11 +479,12 @@ def test_openid_missing_token_endpoint( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_auth_code - def test_oauth2_missing_token_url( + @pytest.mark.asyncio + async def test_oauth2_missing_token_url( self, oauth2_auth_scheme, oauth2_credentials_with_auth_code ): """Test OAuth2 without a token URL.""" @@ -491,11 +499,12 @@ def test_oauth2_missing_token_url( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_auth_code - def test_non_oauth_scheme(self, auth_config_with_auth_code): + @pytest.mark.asyncio + async def test_non_oauth_scheme(self, auth_config_with_auth_code): """Test with a non-OAuth auth scheme.""" # Modify the auth scheme type to be non-OAuth auth_config = copy.deepcopy(auth_config_with_auth_code) @@ -504,11 +513,12 @@ def test_non_oauth_scheme(self, auth_config_with_auth_code): ) handler = AuthHandler(auth_config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == auth_config.exchanged_auth_credential - def test_missing_credentials(self, oauth2_auth_scheme): + @pytest.mark.asyncio + async def test_missing_credentials(self, oauth2_auth_scheme): """Test with missing credentials.""" empty_credential = AuthCredential(auth_type=AuthCredentialTypes.OAUTH2) @@ -518,11 +528,12 @@ def test_missing_credentials(self, oauth2_auth_scheme): ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == empty_credential - def test_credentials_with_token( + @pytest.mark.asyncio + async def test_credentials_with_token( self, auth_config, oauth2_credentials_with_token ): """Test when credentials already have a token.""" @@ -533,18 +544,29 @@ def test_credentials_with_token( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_token - @patch( - "google.adk.auth.oauth2_credential_util.OAuth2Session", - MockOAuth2Session, - ) - def test_successful_token_exchange(self, auth_config_with_auth_code): + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_successful_token_exchange( + self, mock_oauth2_session, auth_config_with_auth_code + ): """Test a successful token exchange.""" + # Setup mock OAuth2Session + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "mock_access_token", + "refresh_token": "mock_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + handler = AuthHandler(auth_config_with_auth_code) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result.oauth2.access_token == "mock_access_token" assert result.oauth2.refresh_token == "mock_refresh_token" diff --git a/tests/unittests/auth/test_oauth2_credential_fetcher.py b/tests/unittests/auth/test_oauth2_credential_fetcher.py deleted file mode 100644 index aba6a9923..000000000 --- a/tests/unittests/auth/test_oauth2_credential_fetcher.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from unittest.mock import Mock - -from authlib.oauth2.rfc6749 import OAuth2Token -from fastapi.openapi.models import OAuth2 -from fastapi.openapi.models import OAuthFlowAuthorizationCode -from fastapi.openapi.models import OAuthFlows -from google.adk.auth.auth_credential import AuthCredential -from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.auth.auth_credential import OAuth2Auth -from google.adk.auth.auth_schemes import OpenIdConnectWithConfig -from google.adk.auth.oauth2_credential_util import create_oauth2_session -from google.adk.auth.oauth2_credential_util import update_credential_with_tokens - - -class TestOAuth2CredentialUtil: - """Test suite for OAuth2 credential utility functions.""" - - def test_create_oauth2_session_openid_connect(self): - """Test create_oauth2_session with OpenID Connect scheme.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - state="test_state", - ), - ) - - client, token_endpoint = create_oauth2_session(scheme, credential) - - assert client is not None - assert token_endpoint == "https://example.com/token" - assert client.client_id == "test_client_id" - assert client.client_secret == "test_client_secret" - - def test_create_oauth2_session_oauth2_scheme(self): - """Test create_oauth2_session with OAuth2 scheme.""" - flows = OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://example.com/auth", - tokenUrl="https://example.com/token", - scopes={"read": "Read access", "write": "Write access"}, - ) - ) - scheme = OAuth2(type_="oauth2", flows=flows) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - ), - ) - - client, token_endpoint = create_oauth2_session(scheme, credential) - - assert client is not None - assert token_endpoint == "https://example.com/token" - - def test_create_oauth2_session_invalid_scheme(self): - """Test create_oauth2_session with invalid scheme.""" - scheme = Mock() # Invalid scheme type - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - ), - ) - - client, token_endpoint = create_oauth2_session(scheme, credential) - - assert client is None - assert token_endpoint is None - - def test_create_oauth2_session_missing_credentials(self): - """Test create_oauth2_session with missing credentials.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - # Missing client_secret - ), - ) - - client, token_endpoint = create_oauth2_session(scheme, credential) - - assert client is None - assert token_endpoint is None - - def test_update_credential_with_tokens(self): - """Test update_credential_with_tokens function.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - ), - ) - - tokens = OAuth2Token({ - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - }) - - update_credential_with_tokens(credential, tokens) - - assert credential.oauth2.access_token == "new_access_token" - assert credential.oauth2.refresh_token == "new_refresh_token" - assert credential.oauth2.expires_at == int(time.time()) + 3600 - assert credential.oauth2.expires_in == 3600 diff --git a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py index cd37a105e..c9b542e51 100644 --- a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py +++ b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py @@ -20,6 +20,7 @@ from google.adk.auth.auth_credential import HttpCredentials from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool +from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import AuthPreparationResult from google.genai.types import FunctionDeclaration from google.genai.types import Schema from google.genai.types import Type @@ -50,7 +51,9 @@ def mock_rest_api_tool(): "required": ["user_id", "page_size", "filter", "connection_name"], } mock_tool._operation_parser = mock_parser - mock_tool.call.return_value = {"status": "success", "data": "mock_data"} + mock_tool.call = mock.AsyncMock( + return_value={"status": "success", "data": "mock_data"} + ) return mock_tool @@ -179,9 +182,6 @@ async def test_run_with_auth_async_none_token( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = mock.MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) # Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None mock_auth_credential_without_token = AuthCredential( auth_type=AuthCredentialTypes.HTTP, @@ -190,8 +190,12 @@ async def test_run_with_auth_async_none_token( credentials=HttpCredentials(token=None), # Token is None ), ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = ( - mock_auth_credential_without_token + mock_tool_auth_handler_instance.prepare_auth_credentials = mock.AsyncMock( + return_value=( + AuthPreparationResult( + state="done", auth_credential=mock_auth_credential_without_token + ) + ) ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance @@ -229,18 +233,18 @@ async def test_run_with_auth_async( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = mock.MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", - credentials=HttpCredentials(token="mocked_token"), - ), + + mock_tool_auth_handler_instance.prepare_auth_credentials = mock.AsyncMock( + return_value=AuthPreparationResult( + state="done", + auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="mocked_token"), + ), + ), + ) ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance result = await integration_tool_with_auth.run_async( diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index 303dda69d..c4cbea7b9 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -14,6 +14,7 @@ import json +from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -194,7 +195,8 @@ def test_get_declaration( @patch( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request" ) - def test_call_success( + @pytest.mark.asyncio + async def test_call_success( self, mock_request, mock_tool_context, @@ -217,7 +219,7 @@ def test_call_success( ) # Call the method - result = tool.call(args={}, tool_context=mock_tool_context) + result = await tool.call(args={}, tool_context=mock_tool_context) # Check the result assert result == {"result": "success"} @@ -225,7 +227,8 @@ def test_call_success( @patch( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request" ) - def test_call_auth_pending( + @pytest.mark.asyncio + async def test_call_auth_pending( self, mock_request, sample_endpoint, @@ -246,12 +249,14 @@ def test_call_auth_pending( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "pending" + mock_prepare_result = MagicMock() + mock_prepare_result.state = "pending" + mock_tool_auth_handler_instance.prepare_auth_credentials = AsyncMock( + return_value=mock_prepare_result ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance - response = tool.call(args={}, tool_context=None) + response = await tool.call(args={}, tool_context=None) assert response == { "pending": True, "message": "Needs your authorization to access your data.", diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py index 8db151fc8..e405ce5b8 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py @@ -116,7 +116,8 @@ def openid_connect_credential(): return credential -def test_openid_connect_no_auth_response( +@pytest.mark.asyncio +async def test_openid_connect_no_auth_response( openid_connect_scheme, openid_connect_credential ): # Setup Mock exchanger @@ -132,12 +133,13 @@ def test_openid_connect_no_auth_response( credential_exchanger=mock_exchanger, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'pending' assert result.auth_credential == openid_connect_credential -def test_openid_connect_with_auth_response( +@pytest.mark.asyncio +async def test_openid_connect_with_auth_response( openid_connect_scheme, openid_connect_credential, monkeypatch ): mock_exchanger = MockOpenIdConnectCredentialExchanger( @@ -166,7 +168,7 @@ def test_openid_connect_with_auth_response( credential_exchanger=mock_exchanger, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'done' assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP assert 'test_access_token' in result.auth_credential.http.credentials.token @@ -178,7 +180,8 @@ def test_openid_connect_with_auth_response( mock_auth_handler.get_auth_response.assert_called_once() -def test_openid_connect_existing_token( +@pytest.mark.asyncio +async def test_openid_connect_existing_token( openid_connect_scheme, openid_connect_credential ): _, existing_credential = token_to_scheme_credential( @@ -198,16 +201,17 @@ def test_openid_connect_existing_token( openid_connect_credential, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'done' assert result.auth_credential == existing_credential @patch( - 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialFetcher' + 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialRefresher' ) -def test_openid_connect_existing_oauth2_token_refresh( - mock_oauth2_fetcher, openid_connect_scheme, openid_connect_credential +@pytest.mark.asyncio +async def test_openid_connect_existing_oauth2_token_refresh( + mock_oauth2_refresher, openid_connect_scheme, openid_connect_credential ): """Test that OAuth2 tokens are refreshed when existing credentials are found.""" # Create existing OAuth2 credential @@ -232,10 +236,13 @@ def test_openid_connect_existing_oauth2_token_refresh( ), ) - # Setup mock OAuth2CredentialFetcher - mock_fetcher_instance = MagicMock() - mock_fetcher_instance.refresh.return_value = refreshed_credential - mock_oauth2_fetcher.return_value = mock_fetcher_instance + # Setup mock OAuth2CredentialRefresher + from unittest.mock import AsyncMock + + mock_refresher_instance = MagicMock() + mock_refresher_instance.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher_instance.refresh = AsyncMock(return_value=refreshed_credential) + mock_oauth2_refresher.return_value = mock_refresher_instance tool_context = create_mock_tool_context() credential_store = ToolContextCredentialStore(tool_context=tool_context) @@ -253,13 +260,17 @@ def test_openid_connect_existing_oauth2_token_refresh( credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() + + # Verify OAuth2CredentialRefresher was called for refresh + mock_oauth2_refresher.assert_called_once() - # Verify OAuth2CredentialFetcher was called for refresh - mock_oauth2_fetcher.assert_called_once_with( - openid_connect_scheme, existing_credential + mock_refresher_instance.is_refresh_needed.assert_called_once_with( + existing_credential + ) + mock_refresher_instance.refresh.assert_called_once_with( + existing_credential, openid_connect_scheme ) - mock_fetcher_instance.refresh.assert_called_once() assert result.state == 'done' # The result should contain the refreshed credential after exchange From 2c739ab5812d24686cee61a0a1ce808b63ceb883 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 11:00:41 -0700 Subject: [PATCH 30/79] chore: Add Credential Manager for managing tools credential (Experimental) PiperOrigin-RevId: 772986051 --- src/google/adk/auth/credential_manager.py | 265 +++++++++ .../unittests/auth/test_credential_manager.py | 559 ++++++++++++++++++ 2 files changed, 824 insertions(+) create mode 100644 src/google/adk/auth/credential_manager.py create mode 100644 tests/unittests/auth/test_credential_manager.py diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py new file mode 100644 index 000000000..7471bdffa --- /dev/null +++ b/src/google/adk/auth/credential_manager.py @@ -0,0 +1,265 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from ..tools.tool_context import ToolContext +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_credential import AuthCredentialTypes +from .auth_schemes import AuthSchemeType +from .auth_tool import AuthConfig +from .exchanger.base_credential_exchanger import BaseCredentialExchanger +from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry +from .refresher.base_credential_refresher import BaseCredentialRefresher +from .refresher.credential_refresher_registry import CredentialRefresherRegistry + + +@experimental +class CredentialManager: + """Manages authentication credentials through a structured workflow. + + The CredentialManager orchestrates the complete lifecycle of authentication + credentials, from initial loading to final preparation for use. It provides + a centralized interface for handling various credential types and authentication + schemes while maintaining proper credential hygiene (refresh, exchange, caching). + + This class is only for use by Agent Development Kit. + + Args: + auth_config: Configuration containing authentication scheme and credentials + + Example: + ```python + auth_config = AuthConfig( + auth_scheme=oauth2_scheme, + raw_auth_credential=service_account_credential + ) + manager = CredentialManager(auth_config) + + # Register custom exchanger if needed + manager.register_credential_exchanger( + AuthCredentialTypes.CUSTOM_TYPE, + CustomCredentialExchanger() + ) + + # Register custom refresher if needed + manager.register_credential_refresher( + AuthCredentialTypes.CUSTOM_TYPE, + CustomCredentialRefresher() + ) + + # Load and prepare credential + credential = await manager.load_auth_credential(tool_context) + ``` + """ + + def __init__( + self, + auth_config: AuthConfig, + ): + self._auth_config = auth_config + self._exchanger_registry = CredentialExchangerRegistry() + self._refresher_registry = CredentialRefresherRegistry() + + # Register default exchangers and refreshers + from .exchanger.service_account_credential_exchanger import ServiceAccountCredentialExchanger + + self._exchanger_registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, ServiceAccountCredentialExchanger() + ) + from .refresher.oauth2_credential_refresher import OAuth2CredentialRefresher + + oauth2_refresher = OAuth2CredentialRefresher() + self._refresher_registry.register( + AuthCredentialTypes.OAUTH2, oauth2_refresher + ) + self._refresher_registry.register( + AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher + ) + + def register_credential_exchanger( + self, + credential_type: AuthCredentialTypes, + exchanger_instance: BaseCredentialExchanger, + ) -> None: + """Register a credential exchanger for a credential type. + + Args: + credential_type: The credential type to register for. + exchanger_instance: The exchanger instance to register. + """ + self._exchanger_registry.register(credential_type, exchanger_instance) + + async def request_credential(self, tool_context: ToolContext) -> None: + tool_context.request_credential(self._auth_config) + + async def get_auth_credential( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load and prepare authentication credential through a structured workflow.""" + + # Step 1: Validate credential configuration + await self._validate_credential() + + # Step 2: Check if credential is already ready (no processing needed) + if self._is_credential_ready(): + return self._auth_config.raw_auth_credential + + # Step 3: Try to load existing processed credential + credential = await self._load_existing_credential(tool_context) + + # Step 4: If no existing credential, load from auth response + # TODO instead of load from auth response, we can store auth response in + # credential service. + was_from_auth_response = False + if not credential: + credential = await self._load_from_auth_response(tool_context) + was_from_auth_response = True + + # Step 5: If still no credential available, return None + if not credential: + return None + + # Step 6: Exchange credential if needed (e.g., service account to access token) + credential, was_exchanged = await self._exchange_credential(credential) + + # Step 7: Refresh credential if expired + if not was_exchanged: + credential, was_refreshed = await self._refresh_credential(credential) + + # Step 8: Save credential if it was modified + if was_from_auth_response or was_exchanged or was_refreshed: + await self._save_credential(tool_context, credential) + + return credential + + async def _load_existing_credential( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load existing credential from credential service or cached exchanged credential.""" + + # Try loading from credential service first + credential = await self._load_from_credential_service(tool_context) + if credential: + return credential + + # Check if we have a cached exchanged credential + if self._auth_config.exchanged_auth_credential: + return self._auth_config.exchanged_auth_credential + + return None + + async def _load_from_credential_service( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load credential from credential service if available.""" + credential_service = tool_context._invocation_context.credential_service + if credential_service: + # Note: This should be made async in a future refactor + # For now, assuming synchronous operation + return await credential_service.load_credential( + self._auth_config, tool_context + ) + return None + + async def _load_from_auth_response( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load credential from auth response in tool context.""" + return tool_context.get_auth_response(self._auth_config) + + async def _exchange_credential( + self, credential: AuthCredential + ) -> tuple[AuthCredential, bool]: + """Exchange credential if needed and return the credential and whether it was exchanged.""" + exchanger = self._exchanger_registry.get_exchanger(credential.auth_type) + if not exchanger: + return credential, False + + exchanged_credential = await exchanger.exchange( + credential, self._auth_config.auth_scheme + ) + return exchanged_credential, True + + async def _refresh_credential( + self, credential: AuthCredential + ) -> tuple[AuthCredential, bool]: + """Refresh credential if expired and return the credential and whether it was refreshed.""" + refresher = self._refresher_registry.get_refresher(credential.auth_type) + if not refresher: + return credential, False + + if await refresher.is_refresh_needed( + credential, self._auth_config.auth_scheme + ): + refreshed_credential = await refresher.refresh( + credential, self._auth_config.auth_scheme + ) + return refreshed_credential, True + + return credential, False + + def _is_credential_ready(self) -> bool: + """Check if credential is ready to use without further processing.""" + raw_credential = self._auth_config.raw_auth_credential + if not raw_credential: + return False + + # Simple credentials that don't need exchange or refresh + return raw_credential.auth_type in ( + AuthCredentialTypes.API_KEY, + AuthCredentialTypes.HTTP, + # Add other simple auth types as needed + ) + + async def _validate_credential(self) -> None: + """Validate credential configuration and raise errors if invalid.""" + if not self._auth_config.raw_auth_credential: + if self._auth_config.auth_scheme.type_ in ( + AuthSchemeType.oauth2, + AuthSchemeType.openIdConnect, + ): + raise ValueError( + "raw_auth_credential is required for auth_scheme type " + f"{self._auth_config.auth_scheme.type_}" + ) + + raw_credential = self._auth_config.raw_auth_credential + if raw_credential: + if ( + raw_credential.auth_type + in ( + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + ) + and not raw_credential.oauth2 + ): + raise ValueError( + "auth_config.raw_credential.oauth2 required for credential type " + f"{raw_credential.auth_type}" + ) + # Additional validation can be added here + + async def _save_credential( + self, tool_context: ToolContext, credential: AuthCredential + ) -> None: + """Save credential to credential service if available.""" + credential_service = tool_context._invocation_context.credential_service + if credential_service: + # Update the exchanged credential in config + self._auth_config.exchanged_auth_credential = credential + await credential_service.save_credential(self._auth_config, tool_context) diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py new file mode 100644 index 000000000..283e865a7 --- /dev/null +++ b/tests/unittests/auth/test_credential_manager.py @@ -0,0 +1,559 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from fastapi.openapi.models import HTTPBearer +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_credential import ServiceAccountCredential +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +import pytest + + +class TestCredentialManager: + """Test suite for CredentialManager.""" + + def test_init(self): + """Test CredentialManager initialization.""" + auth_config = Mock(spec=AuthConfig) + manager = CredentialManager(auth_config) + assert manager._auth_config == auth_config + + @pytest.mark.asyncio + async def test_request_credential(self): + """Test request_credential method.""" + auth_config = Mock(spec=AuthConfig) + tool_context = Mock() + tool_context.request_credential = Mock() + + manager = CredentialManager(auth_config) + await manager.request_credential(tool_context) + + tool_context.request_credential.assert_called_once_with(auth_config) + + @pytest.mark.asyncio + async def test_load_auth_credentials_success(self): + """Test load_auth_credential with successful flow.""" + # Create mocks + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.exchanged_auth_credential = None + + # Mock the credential that will be returned + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.API_KEY + + tool_context = Mock() + + manager = CredentialManager(auth_config) + + # Mock the private methods + manager._validate_credential = AsyncMock() + manager._is_credential_ready = Mock(return_value=False) + manager._load_existing_credential = AsyncMock(return_value=None) + manager._load_from_auth_response = AsyncMock(return_value=mock_credential) + manager._exchange_credential = AsyncMock( + return_value=(mock_credential, False) + ) + manager._refresh_credential = AsyncMock( + return_value=(mock_credential, False) + ) + manager._save_credential = AsyncMock() + + result = await manager.get_auth_credential(tool_context) + + # Verify all methods were called + manager._validate_credential.assert_called_once() + manager._is_credential_ready.assert_called_once() + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) + manager._exchange_credential.assert_called_once_with(mock_credential) + manager._refresh_credential.assert_called_once_with(mock_credential) + manager._save_credential.assert_called_once_with( + tool_context, mock_credential + ) + + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_auth_credentials_no_credential(self): + """Test load_auth_credential when no credential is available.""" + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.exchanged_auth_credential = None + + tool_context = Mock() + + manager = CredentialManager(auth_config) + + # Mock the private methods + manager._validate_credential = AsyncMock() + manager._is_credential_ready = Mock(return_value=False) + manager._load_existing_credential = AsyncMock(return_value=None) + manager._load_from_auth_response = AsyncMock(return_value=None) + manager._exchange_credential = AsyncMock() + manager._refresh_credential = AsyncMock() + manager._save_credential = AsyncMock() + + result = await manager.get_auth_credential(tool_context) + + # Verify methods were called but no credential returned + manager._validate_credential.assert_called_once() + manager._is_credential_ready.assert_called_once() + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) + manager._exchange_credential.assert_not_called() + manager._refresh_credential.assert_not_called() + manager._save_credential.assert_not_called() + + assert result is None + + @pytest.mark.asyncio + async def test_load_existing_credential_already_exchanged(self): + """Test _load_existing_credential when credential is already exchanged.""" + auth_config = Mock(spec=AuthConfig) + mock_credential = Mock(spec=AuthCredential) + auth_config.exchanged_auth_credential = mock_credential + + tool_context = Mock() + + manager = CredentialManager(auth_config) + manager._load_from_credential_service = AsyncMock(return_value=None) + + result = await manager._load_existing_credential(tool_context) + + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_existing_credential_with_credential_service(self): + """Test _load_existing_credential with credential service.""" + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = None + + mock_credential = Mock(spec=AuthCredential) + + tool_context = Mock() + + manager = CredentialManager(auth_config) + manager._load_from_credential_service = AsyncMock( + return_value=mock_credential + ) + + result = await manager._load_existing_credential(tool_context) + + manager._load_from_credential_service.assert_called_once_with(tool_context) + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_from_credential_service_with_service(self): + """Test _load_from_credential_service from tool context when credential service is available.""" + auth_config = Mock(spec=AuthConfig) + + mock_credential = Mock(spec=AuthCredential) + + # Mock credential service + credential_service = Mock() + credential_service.load_credential = AsyncMock(return_value=mock_credential) + + # Mock invocation context + invocation_context = Mock() + invocation_context.credential_service = credential_service + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + result = await manager._load_from_credential_service(tool_context) + + credential_service.load_credential.assert_called_once_with( + auth_config, tool_context + ) + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_from_credential_service_no_service(self): + """Test _load_from_credential_service when no credential service is available.""" + auth_config = Mock(spec=AuthConfig) + + # Mock invocation context with no credential service + invocation_context = Mock() + invocation_context.credential_service = None + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + result = await manager._load_from_credential_service(tool_context) + + assert result is None + + @pytest.mark.asyncio + async def test_save_credential_with_service(self): + """Test _save_credential with credential service.""" + auth_config = Mock(spec=AuthConfig) + mock_credential = Mock(spec=AuthCredential) + + # Mock credential service + credential_service = AsyncMock() + + # Mock invocation context + invocation_context = Mock() + invocation_context.credential_service = credential_service + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + await manager._save_credential(tool_context, mock_credential) + + credential_service.save_credential.assert_called_once_with( + auth_config, tool_context + ) + assert auth_config.exchanged_auth_credential == mock_credential + + @pytest.mark.asyncio + async def test_save_credential_no_service(self): + """Test _save_credential when no credential service is available.""" + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = None + mock_credential = Mock(spec=AuthCredential) + + # Mock invocation context with no credential service + invocation_context = Mock() + invocation_context.credential_service = None + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + await manager._save_credential(tool_context, mock_credential) + + # Should not raise an error, and credential should not be set in auth_config + # when there's no credential service (according to implementation) + assert auth_config.exchanged_auth_credential is None + + @pytest.mark.asyncio + async def test_refresh_credential_oauth2(self): + """Test _refresh_credential with OAuth2 credential.""" + mock_oauth2_auth = Mock(spec=OAuth2Auth) + + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.OAUTH2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = Mock() + + # Mock refresher + mock_refresher = Mock() + mock_refresher.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher.refresh = AsyncMock(return_value=mock_credential) + + auth_config.raw_auth_credential = mock_credential + + manager = CredentialManager(auth_config) + + # Mock the refresher registry to return our mock refresher + with patch.object( + manager._refresher_registry, + "get_refresher", + return_value=mock_refresher, + ): + result, was_refreshed = await manager._refresh_credential(mock_credential) + + mock_refresher.is_refresh_needed.assert_called_once_with( + mock_credential, auth_config.auth_scheme + ) + mock_refresher.refresh.assert_called_once_with( + mock_credential, auth_config.auth_scheme + ) + assert result == mock_credential + assert was_refreshed is True + + @pytest.mark.asyncio + async def test_refresh_credential_no_refresher(self): + """Test _refresh_credential with credential that has no refresher.""" + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + + manager = CredentialManager(auth_config) + + # Mock the refresher registry to return None (no refresher available) + with patch.object( + manager._refresher_registry, + "get_refresher", + return_value=None, + ): + result, was_refreshed = await manager._refresh_credential(mock_credential) + + assert result == mock_credential + assert was_refreshed is False + + @pytest.mark.asyncio + async def test_is_credential_ready_api_key(self): + """Test _is_credential_ready with API key credential.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + + manager = CredentialManager(auth_config) + result = manager._is_credential_ready() + + assert result is True + + @pytest.mark.asyncio + async def test_is_credential_ready_oauth2(self): + """Test _is_credential_ready with OAuth2 credential (needs processing).""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + + manager = CredentialManager(auth_config) + result = manager._is_credential_ready() + + assert result is False + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_oauth2(self): + """Test _validate_credential with no raw credential for OAuth2.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises(ValueError, match="raw_auth_credential is required"): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_openid(self): + """Test _validate_credential with no raw credential for OpenID Connect.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.openIdConnect + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises(ValueError, match="raw_auth_credential is required"): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_other_scheme(self): + """Test _validate_credential with no raw credential for other schemes.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.apiKey + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + await manager._validate_credential() + + # Should return without error for non-OAuth2/OpenID schemes + + @pytest.mark.asyncio + async def test_validate_credential_oauth2_missing_oauth2_field(self): + """Test _validate_credential with OAuth2 credential missing oauth2 field.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.oauth2 + + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 + mock_raw_credential.oauth2 = None + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises( + ValueError, match="auth_config.raw_credential.oauth2 required" + ): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_exchange_credentials_service_account(self): + """Test _exchange_credential with service account credential.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.SERVICE_ACCOUNT + + mock_exchanged_credential = Mock(spec=AuthCredential) + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = Mock() + + manager = CredentialManager(auth_config) + + # Mock the exchanger that gets created during registration + with patch.object( + manager._exchanger_registry, "get_exchanger" + ) as mock_get_exchanger: + mock_exchanger = Mock() + mock_exchanger.exchange = AsyncMock( + return_value=mock_exchanged_credential + ) + mock_get_exchanger.return_value = mock_exchanger + + result, was_exchanged = await manager._exchange_credential( + mock_raw_credential + ) + + assert result == mock_exchanged_credential + assert was_exchanged is True + mock_get_exchanger.assert_called_once_with( + AuthCredentialTypes.SERVICE_ACCOUNT + ) + mock_exchanger.exchange.assert_called_once_with( + mock_raw_credential, auth_config.auth_scheme + ) + + @pytest.mark.asyncio + async def test_exchange_credential_no_exchanger(self): + """Test _exchange_credential with credential that has no exchanger.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + + manager = CredentialManager(auth_config) + + # Mock the exchanger registry to return None (no exchanger available) + with patch.object( + manager._exchanger_registry, "get_exchanger", return_value=None + ): + result, was_exchanged = await manager._exchange_credential( + mock_raw_credential + ) + + assert result == mock_raw_credential + assert was_exchanged is False + + +# Test fixtures +@pytest.fixture +def oauth2_auth_scheme(): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + +@pytest.fixture +def openid_auth_scheme(): + """Create an OpenID Connect auth scheme for testing.""" + return OpenIdConnectWithConfig( + type_="openIdConnect", + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + + +@pytest.fixture +def bearer_auth_scheme(): + """Create a Bearer auth scheme for testing.""" + return HTTPBearer(bearerFormat="JWT") + + +@pytest.fixture +def oauth2_credential(): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + +@pytest.fixture +def service_account_credential(): + """Create service account credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=ServiceAccountCredential( + type="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE" + " KEY-----\n" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="123456789", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + ), + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + +@pytest.fixture +def api_key_credential(): + """Create API key credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key="test-api-key", + ) + + +@pytest.fixture +def http_bearer_credential(): + """Create HTTP Bearer credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="bearer-token"), + ), + ) From dcea7767c67c7edfb694304df32dca10b74c9a71 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 11:15:27 -0700 Subject: [PATCH 31/79] feat: Add Authenticated Tool (Experimental) PiperOrigin-RevId: 772992074 --- .../adk/tools/authenticated_function_tool.py | 107 ++++ .../adk/tools/base_authenticated_tool.py | 107 ++++ .../tools/test_authenticated_function_tool.py | 541 ++++++++++++++++++ .../tools/test_base_authenticated_tool.py | 343 +++++++++++ 4 files changed, 1098 insertions(+) create mode 100644 src/google/adk/tools/authenticated_function_tool.py create mode 100644 src/google/adk/tools/base_authenticated_tool.py create mode 100644 tests/unittests/tools/test_authenticated_function_tool.py create mode 100644 tests/unittests/tools/test_base_authenticated_tool.py diff --git a/src/google/adk/tools/authenticated_function_tool.py b/src/google/adk/tools/authenticated_function_tool.py new file mode 100644 index 000000000..67cc5885f --- /dev/null +++ b/src/google/adk/tools/authenticated_function_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import logging +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .function_tool import FunctionTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class AuthenticatedFunctionTool(FunctionTool): + """A FunctionTool that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + func: Callable[..., Any], + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """Initializes the AuthenticatedFunctionTool. + + Args: + func: The function to be called. + auth_config: The authentication configuration. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__(func=func) + self._ignore_params.append("credential") + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + args_to_call = args.copy() + signature = inspect.signature(self.func) + if "credential" in signature.parameters: + args_to_call["credential"] = credential + return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/base_authenticated_tool.py b/src/google/adk/tools/base_authenticated_tool.py new file mode 100644 index 000000000..4858e4953 --- /dev/null +++ b/src/google/adk/tools/base_authenticated_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +import logging +from typing import Any +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .base_tool import BaseTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class BaseAuthenticatedTool(BaseTool): + """A base tool class that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + name, + description, + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """ + Args: + name: The name of the tool. + description: The description of the tool. + auth_config: The auth configuration of the tool. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__( + name=name, + description=description, + ) + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, + tool_context=tool_context, + credential=credential, + ) + + @abstractmethod + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + pass diff --git a/tests/unittests/tools/test_authenticated_function_tool.py b/tests/unittests/tools/test_authenticated_function_tool.py new file mode 100644 index 000000000..88454032a --- /dev/null +++ b/tests/unittests/tools/test_authenticated_function_tool.py @@ -0,0 +1,541 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.authenticated_function_tool import AuthenticatedFunctionTool +from google.adk.tools.tool_context import ToolContext +import pytest + +# Test functions for different scenarios + + +def sync_function_no_credential(arg1: str, arg2: int) -> str: + """Test sync function without credential parameter.""" + return f"sync_result_{arg1}_{arg2}" + + +async def async_function_no_credential(arg1: str, arg2: int) -> str: + """Test async function without credential parameter.""" + return f"async_result_{arg1}_{arg2}" + + +def sync_function_with_credential(arg1: str, credential: AuthCredential) -> str: + """Test sync function with credential parameter.""" + return f"sync_cred_result_{arg1}_{credential.auth_type.value}" + + +async def async_function_with_credential( + arg1: str, credential: AuthCredential +) -> str: + """Test async function with credential parameter.""" + return f"async_cred_result_{arg1}_{credential.auth_type.value}" + + +def sync_function_with_tool_context( + arg1: str, tool_context: ToolContext +) -> str: + """Test sync function with tool_context parameter.""" + return f"sync_context_result_{arg1}" + + +async def async_function_with_both( + arg1: str, tool_context: ToolContext, credential: AuthCredential +) -> str: + """Test async function with both tool_context and credential parameters.""" + return f"async_both_result_{arg1}_{credential.auth_type.value}" + + +def function_with_optional_args( + arg1: str, arg2: str = "default", credential: AuthCredential = None +) -> str: + """Test function with optional arguments.""" + cred_type = credential.auth_type.value if credential else "none" + return f"optional_result_{arg1}_{arg2}_{cred_type}" + + +class MockCallable: + """Test callable class for testing.""" + + def __init__(self): + self.__name__ = "MockCallable" + self.__doc__ = "Test callable documentation" + + def __call__(self, arg1: str, credential: AuthCredential) -> str: + return f"callable_result_{arg1}_{credential.auth_type.value}" + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + # Create a mock auth_type that returns the expected value + mock_auth_type = Mock() + mock_auth_type.value = "oauth2" + credential.auth_type = mock_auth_type + return credential + + +class TestAuthenticatedFunctionTool: + """Test suite for AuthenticatedFunctionTool.""" + + def test_init_with_sync_function(self): + """Test initialization with synchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, + auth_config=auth_config, + response_for_auth_required="Please authenticate", + ) + + assert tool.name == "sync_function_no_credential" + assert ( + tool.description == "Test sync function without credential parameter." + ) + assert tool.func == sync_function_no_credential + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == "Please authenticate" + assert "credential" in tool._ignore_params + + def test_init_with_async_function(self): + """Test initialization with asynchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=async_function_no_credential, auth_config=auth_config + ) + + assert tool.name == "async_function_no_credential" + assert ( + tool.description == "Test async function without credential parameter." + ) + assert tool.func == async_function_no_credential + assert tool._response_for_auth_required is None + + def test_init_with_callable(self): + """Test initialization with callable object.""" + auth_config = _create_mock_auth_config() + test_callable = MockCallable() + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + + assert tool.name == "MockCallable" + assert tool.description == "Test callable documentation" + assert tool.func == test_callable + + def test_init_no_auth_config(self): + """Test initialization without auth_config.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + + assert tool._credentials_manager is None + + @pytest.mark.asyncio + async def test_run_async_sync_function_no_credential_manager(self): + """Test run_async with sync function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_async_function_no_credential_manager(self): + """Test run_async with async function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=async_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "async_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"sync_cred_result_test_{credential.auth_type.value}" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_async_function_with_credential(self): + """Test run_async with async function that expects credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_cred_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, + auth_config=auth_config, + response_for_auth_required="Custom auth required", + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Custom auth required" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_default_message(self): + """Test run_async when no credential is available with default message.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + + @pytest.mark.asyncio + async def test_run_async_function_without_credential_param(self): + """Test run_async with function that doesn't have credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Credential should not be passed to function since it doesn't have the parameter + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_function_with_tool_context(self): + """Test run_async with function that has tool_context parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_tool_context, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_context_result_test" + + @pytest.mark.asyncio + async def test_run_async_function_with_both_params(self): + """Test run_async with function that has both tool_context and credential parameters.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_both, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_both_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_function_with_optional_credential(self): + """Test run_async with function that has optional credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=function_with_optional_args, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert ( + result == f"optional_result_test_default_{credential.auth_type.value}" + ) + + @pytest.mark.asyncio + async def test_run_async_callable_object(self): + """Test run_async with callable object.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + test_callable = MockCallable() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"callable_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_propagates_function_exception(self): + """Test that run_async propagates exceptions from the wrapped function.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + def failing_function(arg1: str, credential: AuthCredential) -> str: + raise ValueError("Function failed") + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=failing_function, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(ValueError, match="Function failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_missing_required_args(self): + """Test run_async with missing required arguments.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} # Missing arg2 + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Should return error dict indicating missing parameters + assert isinstance(result, dict) + assert "error" in result + assert "arg2" in result["error"] + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_credential_in_ignore_params(self): + """Test that 'credential' is added to ignore_params during initialization.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + assert "credential" in tool._ignore_params + + @pytest.mark.asyncio + async def test_run_async_with_none_credential(self): + """Test run_async when credential is None but function expects it.""" + tool = AuthenticatedFunctionTool(func=function_with_optional_args) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "optional_result_test_default_none" + + def test_signature_inspection(self): + """Test that the tool correctly inspects function signatures.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + signature = inspect.signature(tool.func) + assert "credential" in signature.parameters + assert "arg1" in signature.parameters + + @pytest.mark.asyncio + async def test_args_to_call_modification(self): + """Test that args_to_call is properly modified with credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + # Create a spy function to check what arguments are passed + original_args = {} + + def spy_function(arg1: str, credential: AuthCredential) -> str: + nonlocal original_args + original_args = {"arg1": arg1, "credential": credential} + return "spy_result" + + tool = AuthenticatedFunctionTool(func=spy_function, auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "spy_result" + assert original_args is not None + assert original_args["arg1"] == "test" + assert original_args["credential"] == credential diff --git a/tests/unittests/tools/test_base_authenticated_tool.py b/tests/unittests/tools/test_base_authenticated_tool.py new file mode 100644 index 000000000..55454224d --- /dev/null +++ b/tests/unittests/tools/test_base_authenticated_tool.py @@ -0,0 +1,343 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.base_authenticated_tool import BaseAuthenticatedTool +from google.adk.tools.tool_context import ToolContext +import pytest + + +class _TestAuthenticatedTool(BaseAuthenticatedTool): + """Test implementation of BaseAuthenticatedTool for testing purposes.""" + + def __init__( + self, + name="test_auth_tool", + description="Test authenticated tool", + auth_config=None, + unauthenticated_response=None, + ): + super().__init__( + name=name, + description=description, + auth_config=auth_config, + response_for_auth_required=unauthenticated_response, + ) + self.run_impl_called = False + self.run_impl_result = "test_result" + + async def _run_async_impl(self, *, args, tool_context, credential): + """Test implementation of the abstract method.""" + self.run_impl_called = True + self.last_args = args + self.last_tool_context = tool_context + self.last_credential = credential + return self.run_impl_result + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + credential.auth_type = AuthCredentialTypes.OAUTH2 + return credential + + +class TestBaseAuthenticatedTool: + """Test suite for BaseAuthenticatedTool.""" + + def test_init_with_auth_config(self): + """Test initialization with auth_config.""" + auth_config = _create_mock_auth_config() + unauthenticated_response = {"error": "Not authenticated"} + + tool = _TestAuthenticatedTool( + name="test_tool", + description="Test description", + auth_config=auth_config, + unauthenticated_response=unauthenticated_response, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test description" + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == unauthenticated_response + + def test_init_with_no_auth_config(self): + """Test initialization without auth_config.""" + tool = _TestAuthenticatedTool() + + assert tool.name == "test_auth_tool" + assert tool.description == "Test authenticated tool" + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._credentials_manager is None + + def test_init_with_default_unauthenticated_response(self): + """Test initialization with default unauthenticated response.""" + auth_config = _create_mock_auth_config() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._response_for_auth_required is None + + @pytest.mark.asyncio + async def test_run_async_no_credentials_manager(self): + """Test run_async when no credentials manager is configured.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential is None + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential == credential + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_custom_response(self): + """Test run_async when no credential is available with custom response.""" + auth_config = _create_mock_auth_config() + custom_response = { + "status": "authentication_required", + "message": "Please login", + } + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_string_response(self): + """Test run_async when no credential is available with string response.""" + auth_config = _create_mock_auth_config() + custom_response = "Custom authentication required message" + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + + @pytest.mark.asyncio + async def test_run_async_propagates_impl_exception(self): + """Test that run_async propagates exceptions from _run_async_impl.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + # Make the implementation raise an exception + async def failing_impl(*, args, tool_context, credential): + raise ValueError("Implementation failed") + + tool._run_async_impl = failing_impl + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(ValueError, match="Implementation failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_with_different_args_types(self): + """Test run_async with different argument types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + + # Test with empty args + result = await tool.run_async(args={}, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == {} + + # Test with complex args + complex_args = { + "string_param": "test", + "number_param": 42, + "list_param": [1, 2, 3], + "dict_param": {"nested": "value"}, + } + result = await tool.run_async(args=complex_args, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == complex_args + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_abstract_nature(self): + """Test that BaseAuthenticatedTool cannot be instantiated directly.""" + with pytest.raises(TypeError): + # This should fail because _run_async_impl is abstract + BaseAuthenticatedTool(name="test", description="test") + + @pytest.mark.asyncio + async def test_run_async_return_values(self): + """Test run_async with different return value types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {} + + # Test with None return + tool.run_impl_result = None + result = await tool.run_async(args=args, tool_context=tool_context) + assert result is None + + # Test with dict return + tool.run_impl_result = {"key": "value"} + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == {"key": "value"} + + # Test with list return + tool.run_impl_result = [1, 2, 3] + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == [1, 2, 3] From 18a541c8fa5d9cac2769c1875d5d9dc4f782ca75 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 11:33:09 -0700 Subject: [PATCH 32/79] chore: Ignore mcp_tool ut tests for python 3.9 given mcp sdk only supports 3.10+ PiperOrigin-RevId: 772999037 --- .github/workflows/python-unit-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index d4af7b13a..565ee1dca 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -51,6 +51,7 @@ jobs: if [[ "${{ matrix.python-version }}" == "3.9" ]]; then pytest tests/unittests \ --ignore=tests/unittests/a2a \ + --ignore=tests/unittests/tools/mcp_tool \ --ignore=tests/unittests/artifacts/test_artifact_service.py \ --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py else From 157d9be88d92f22320604832e5a334a6eb81e4af Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 11:43:32 -0700 Subject: [PATCH 33/79] feat: Enable MCP Tool Auth (Experimental) PiperOrigin-RevId: 773002759 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 237 ++++++----- src/google/adk/tools/mcp_tool/mcp_tool.py | 88 ++++- src/google/adk/tools/mcp_tool/mcp_toolset.py | 12 +- tests/unittests/tools/mcp_tool/__init__.py | 13 + .../mcp_tool/test_mcp_session_manager.py | 342 ++++++++++++++++ .../unittests/tools/mcp_tool/test_mcp_tool.py | 373 ++++++++++++++++++ .../tools/mcp_tool/test_mcp_toolset.py | 269 +++++++++++++ 7 files changed, 1231 insertions(+), 103 deletions(-) create mode 100644 tests/unittests/tools/mcp_tool/__init__.py create mode 100644 tests/unittests/tools/mcp_tool/test_mcp_session_manager.py create mode 100644 tests/unittests/tools/mcp_tool/test_mcp_tool.py create mode 100644 tests/unittests/tools/mcp_tool/test_mcp_toolset.py diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 5bc06e398..90b39e6cb 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -18,9 +18,12 @@ from contextlib import AsyncExitStack from datetime import timedelta import functools +import hashlib +import json import logging import sys from typing import Any +from typing import Dict from typing import Optional from typing import TextIO from typing import Union @@ -105,74 +108,39 @@ class StreamableHTTPConnectionParams(BaseModel): terminate_on_close: bool = True -def retry_on_closed_resource(session_manager_field_name: str): - """Decorator to automatically reinitialize session and retry action. +def retry_on_closed_resource(func): + """Decorator to automatically retry action when MCP session is closed. - When MCP session was closed, the decorator will automatically recreate the - session and retry the action with the same parameters. - - Note: - 1. session_manager_field_name is the name of the class member field that - contains the MCPSessionManager instance. - 2. The session manager must have a reinitialize_session() async method. - - Usage: - class MCPTool: - def __init__(self): - self._mcp_session_manager = MCPSessionManager(...) - - @retry_on_closed_resource('_mcp_session_manager') - async def use_session(self): - session = await self._mcp_session_manager.create_session() - await session.call_tool() + When MCP session was closed, the decorator will automatically retry the + action once. The create_session method will handle creating a new session + if the old one was disconnected. Args: - session_manager_field_name: The name of the session manager field. + func: The function to decorate. Returns: The decorated function. """ - def decorator(func): - @functools.wraps(func) # Preserves original function metadata - async def wrapper(self, *args, **kwargs): - try: - return await func(self, *args, **kwargs) - except anyio.ClosedResourceError as close_err: - try: - if hasattr(self, session_manager_field_name): - session_manager = getattr(self, session_manager_field_name) - if hasattr(session_manager, 'reinitialize_session') and callable( - getattr(session_manager, 'reinitialize_session') - ): - await session_manager.reinitialize_session() - else: - raise ValueError( - f'Session manager {session_manager_field_name} does not have' - ' reinitialize_session method.' - ) from close_err - else: - raise ValueError( - f'Session manager field {session_manager_field_name} does not' - ' exist in decorated class. Please check the field name in' - ' retry_on_closed_resource decorator.' - ) from close_err - except Exception as reinit_err: - raise RuntimeError( - f'Error reinitializing: {reinit_err}' - ) from reinit_err - return await func(self, *args, **kwargs) - - return wrapper - - return decorator + @functools.wraps(func) # Preserves original function metadata + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except anyio.ClosedResourceError: + # Simply retry the function - create_session will handle + # detecting and replacing disconnected sessions + logger.info('Retrying %s due to closed resource', func.__name__) + return await func(self, *args, **kwargs) + + return wrapper class MCPSessionManager: """Manages MCP client sessions. This class provides methods for creating and initializing MCP client sessions, - handling different connection parameters (Stdio and SSE). + handling different connection parameters (Stdio and SSE) and supporting + session pooling based on authentication headers. """ def __init__( @@ -209,30 +177,125 @@ def __init__( else: self._connection_params = connection_params self._errlog = errlog - # Each session manager maintains its own exit stack for proper cleanup - self._exit_stack: Optional[AsyncExitStack] = None - self._session: Optional[ClientSession] = None + + # Session pool: maps session keys to (session, exit_stack) tuples + self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} + # Lock to prevent race conditions in session creation self._session_lock = asyncio.Lock() - async def create_session(self) -> ClientSession: + def _generate_session_key( + self, merged_headers: Optional[Dict[str, str]] = None + ) -> str: + """Generates a session key based on connection params and merged headers. + + For StdioConnectionParams, returns a constant key since headers are not + supported. For SSE and StreamableHTTP connections, generates a key based + on the provided merged headers. + + Args: + merged_headers: Already merged headers (base + additional). + + Returns: + A unique session key string. + """ + if isinstance(self._connection_params, StdioConnectionParams): + # For stdio connections, headers are not supported, so use constant key + return 'stdio_session' + + # For SSE and StreamableHTTP connections, use merged headers + if merged_headers: + headers_json = json.dumps(merged_headers, sort_keys=True) + headers_hash = hashlib.md5(headers_json.encode()).hexdigest() + return f'session_{headers_hash}' + else: + return 'session_no_headers' + + def _merge_headers( + self, additional_headers: Optional[Dict[str, str]] = None + ) -> Optional[Dict[str, str]]: + """Merges base connection headers with additional headers. + + Args: + additional_headers: Optional headers to merge with connection headers. + + Returns: + Merged headers dictionary, or None if no headers are provided. + """ + if isinstance(self._connection_params, StdioConnectionParams) or isinstance( + self._connection_params, StdioServerParameters + ): + # Stdio connections don't support headers + return None + + base_headers = {} + if ( + hasattr(self._connection_params, 'headers') + and self._connection_params.headers + ): + base_headers = self._connection_params.headers.copy() + + if additional_headers: + base_headers.update(additional_headers) + + return base_headers + + def _is_session_disconnected(self, session: ClientSession) -> bool: + """Checks if a session is disconnected or closed. + + Args: + session: The ClientSession to check. + + Returns: + True if the session is disconnected, False otherwise. + """ + return session._read_stream._closed or session._write_stream._closed + + async def create_session( + self, headers: Optional[Dict[str, str]] = None + ) -> ClientSession: """Creates and initializes an MCP client session. + This method will check if an existing session for the given headers + is still connected. If it's disconnected, it will be cleaned up and + a new session will be created. + + Args: + headers: Optional headers to include in the session. These will be + merged with any existing connection headers. Only applicable + for SSE and StreamableHTTP connections. + Returns: ClientSession: The initialized MCP client session. """ - # Fast path: if session already exists, return it without acquiring lock - if self._session is not None: - return self._session + # Merge headers once at the beginning + merged_headers = self._merge_headers(headers) + + # Generate session key using merged headers + session_key = self._generate_session_key(merged_headers) # Use async lock to prevent race conditions async with self._session_lock: - # Double-check: session might have been created while waiting for lock - if self._session is not None: - return self._session - - # Create a new exit stack for this session - self._exit_stack = AsyncExitStack() + # Check if we have an existing session + if session_key in self._sessions: + session, exit_stack = self._sessions[session_key] + + # Check if the existing session is still connected + if not self._is_session_disconnected(session): + # Session is still good, return it + return session + else: + # Session is disconnected, clean it up + logger.info('Cleaning up disconnected session: %s', session_key) + try: + await exit_stack.aclose() + except Exception as e: + logger.warning('Error during disconnected session cleanup: %s', e) + finally: + del self._sessions[session_key] + + # Create a new session (either first time or replacing disconnected one) + exit_stack = AsyncExitStack() try: if isinstance(self._connection_params, StdioConnectionParams): @@ -243,7 +306,7 @@ async def create_session(self) -> ClientSession: elif isinstance(self._connection_params, SseConnectionParams): client = sse_client( url=self._connection_params.url, - headers=self._connection_params.headers, + headers=merged_headers, timeout=self._connection_params.timeout, sse_read_timeout=self._connection_params.sse_read_timeout, ) @@ -252,7 +315,7 @@ async def create_session(self) -> ClientSession: ): client = streamablehttp_client( url=self._connection_params.url, - headers=self._connection_params.headers, + headers=merged_headers, timeout=timedelta(seconds=self._connection_params.timeout), sse_read_timeout=timedelta( seconds=self._connection_params.sse_read_timeout @@ -266,11 +329,11 @@ async def create_session(self) -> ClientSession: f' {self._connection_params}' ) - transports = await self._exit_stack.enter_async_context(client) + transports = await exit_stack.enter_async_context(client) # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. if isinstance(self._connection_params, StdioConnectionParams): - session = await self._exit_stack.enter_async_context( + session = await exit_stack.enter_async_context( ClientSession( *transports[:2], read_timeout_seconds=timedelta( @@ -279,44 +342,38 @@ async def create_session(self) -> ClientSession: ) ) else: - session = await self._exit_stack.enter_async_context( + session = await exit_stack.enter_async_context( ClientSession(*transports[:2]) ) await session.initialize() - self._session = session + # Store session and exit stack in the pool + self._sessions[session_key] = (session, exit_stack) + logger.debug('Created new session: %s', session_key) return session except Exception: # If session creation fails, clean up the exit stack - if self._exit_stack: - await self._exit_stack.aclose() - self._exit_stack = None + if exit_stack: + await exit_stack.aclose() raise async def close(self): - """Closes the session and cleans up resources.""" - if not self._exit_stack: - return + """Closes all sessions and cleans up resources.""" async with self._session_lock: - if self._exit_stack: + for session_key in list(self._sessions.keys()): + _, exit_stack = self._sessions[session_key] try: - await self._exit_stack.aclose() + await exit_stack.aclose() except Exception as e: # Log the error but don't re-raise to avoid blocking shutdown print( - f'Warning: Error during MCP session cleanup: {e}', + 'Warning: Error during MCP session cleanup for' + f' {session_key}: {e}', file=self._errlog, ) finally: - self._exit_stack = None - self._session = None - - async def reinitialize_session(self): - """Reinitializes the session when connection is lost.""" - # Close the old session and create a new one - await self.close() - await self.create_session() + del self._sessions[session_key] SseServerParams = SseConnectionParams diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 6553bb2c0..24998c925 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -14,10 +14,13 @@ from __future__ import annotations +import base64 +import json import logging from typing import Optional from google.genai.types import FunctionDeclaration +from google.oauth2.credentials import Credentials from typing_extensions import override from .._gemini_schema_util import _to_gemini_schema @@ -42,13 +45,15 @@ from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme -from ..base_tool import BaseTool +from ...auth.auth_tool import AuthConfig +from ..base_authenticated_tool import BaseAuthenticatedTool +# import from ..tool_context import ToolContext logger = logging.getLogger("google_adk." + __name__) -class MCPTool(BaseTool): +class MCPTool(BaseAuthenticatedTool): """Turns an MCP Tool into an ADK Tool. Internally, the tool initializes from a MCP Tool, and uses the MCP Session to @@ -77,19 +82,17 @@ def __init__( Raises: ValueError: If mcp_tool or mcp_session_manager is None. """ - if mcp_tool is None: - raise ValueError("mcp_tool cannot be None") - if mcp_session_manager is None: - raise ValueError("mcp_session_manager cannot be None") super().__init__( name=mcp_tool.name, description=mcp_tool.description if mcp_tool.description else "", + auth_config=AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential + ) + if auth_scheme + else None, ) self._mcp_tool = mcp_tool self._mcp_session_manager = mcp_session_manager - # TODO(cheliu): Support passing auth to MCP Server. - self._auth_scheme = auth_scheme - self._auth_credential = auth_credential @override def _get_declaration(self) -> FunctionDeclaration: @@ -105,8 +108,11 @@ def _get_declaration(self) -> FunctionDeclaration: ) return function_decl - @retry_on_closed_resource("_mcp_session_manager") - async def run_async(self, *, args, tool_context: ToolContext): + @retry_on_closed_resource + @override + async def _run_async_impl( + self, *, args, tool_context: ToolContext, credential: AuthCredential + ): """Runs the tool asynchronously. Args: @@ -116,8 +122,66 @@ async def run_async(self, *, args, tool_context: ToolContext): Returns: Any: The response from the tool. """ + # Extract headers from credential for session pooling + headers = await self._get_headers(tool_context, credential) + # Get the session from the session manager - session = await self._mcp_session_manager.create_session() + session = await self._mcp_session_manager.create_session(headers=headers) response = await session.call_tool(self.name, arguments=args) return response + + async def _get_headers( + self, tool_context: ToolContext, credential: AuthCredential + ) -> Optional[dict[str, str]]: + headers = None + if credential: + if credential.oauth2: + headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} + elif credential.google_oauth2_json: + google_credential = Credentials.from_authorized_user_info( + json.loads(credential.google_oauth2_json) + ) + headers = {"Authorization": f"Bearer {google_credential.token}"} + elif credential.http: + # Handle HTTP authentication schemes + if ( + credential.http.scheme.lower() == "bearer" + and credential.http.credentials.token + ): + headers = { + "Authorization": f"Bearer {credential.http.credentials.token}" + } + elif credential.http.scheme.lower() == "basic": + # Handle basic auth + if ( + credential.http.credentials.username + and credential.http.credentials.password + ): + + credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}" + encoded_credentials = base64.b64encode( + credentials.encode() + ).decode() + headers = {"Authorization": f"Basic {encoded_credentials}"} + elif credential.http.credentials.token: + # Handle other HTTP schemes with token + headers = { + "Authorization": ( + f"{credential.http.scheme} {credential.http.credentials.token}" + ) + } + elif credential.api_key: + # For API keys, we'll add them as headers since MCP typically uses header-based auth + # The specific header name would depend on the API, using a common default + # TODO Allow user to specify the header name for API keys. + headers = {"X-API-Key": credential.api_key} + elif credential.service_account: + # Service accounts should be exchanged for access tokens before reaching this point + # If we reach here, we can try to use google_oauth2_json or log a warning + logger.warning( + "Service account credentials should be exchanged for access" + " tokens before MCP session creation" + ) + + return headers diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index f55693e86..c01b0cec2 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -22,6 +22,8 @@ from typing import Union from ...agents.readonly_context import ReadonlyContext +from ...auth.auth_credential import AuthCredential +from ...auth.auth_schemes import AuthScheme from ..base_tool import BaseTool from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate @@ -94,6 +96,8 @@ def __init__( ], tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, errlog: TextIO = sys.stderr, + auth_scheme: Optional[AuthScheme] = None, + auth_credential: Optional[AuthCredential] = None, ): """Initializes the MCPToolset. @@ -110,6 +114,8 @@ def __init__( list of tool names to include - A ToolPredicate function for custom filtering logic errlog: TextIO stream for error logging. + auth_scheme: The auth scheme of the tool for tool calling + auth_credential: The auth credential of the tool for tool calling """ super().__init__(tool_filter=tool_filter) @@ -124,8 +130,10 @@ def __init__( connection_params=self._connection_params, errlog=self._errlog, ) + self._auth_scheme = auth_scheme + self._auth_credential = auth_credential - @retry_on_closed_resource("_mcp_session_manager") + @retry_on_closed_resource async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None, @@ -151,6 +159,8 @@ async def get_tools( mcp_tool = MCPTool( mcp_tool=tool, mcp_session_manager=self._mcp_session_manager, + auth_scheme=self._auth_scheme, + auth_credential=self._auth_credential, ) if self._is_tool_selected(mcp_tool, readonly_context): diff --git a/tests/unittests/tools/mcp_tool/__init__.py b/tests/unittests/tools/mcp_tool/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/tools/mcp_tool/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py new file mode 100644 index 000000000..448d41260 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -0,0 +1,342 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +from io import StringIO +import json +import sys +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +import pytest + +# Import real MCP classes +try: + from mcp import StdioServerParameters +except ImportError: + # Create a mock if MCP is not available + class StdioServerParameters: + + def __init__(self, command="test_command", args=None): + self.command = command + self.args = args or [] + + +class MockClientSession: + """Mock ClientSession for testing.""" + + def __init__(self): + self._read_stream = Mock() + self._write_stream = Mock() + self._read_stream._closed = False + self._write_stream._closed = False + self.initialize = AsyncMock() + + +class MockAsyncExitStack: + """Mock AsyncExitStack for testing.""" + + def __init__(self): + self.aclose = AsyncMock() + self.enter_async_context = AsyncMock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class TestMCPSessionManager: + """Test suite for MCPSessionManager class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_stdio_connection_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=5.0 + ) + + def test_init_with_stdio_server_parameters(self): + """Test initialization with StdioServerParameters (deprecated).""" + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.logger" + ) as mock_logger: + manager = MCPSessionManager(self.mock_stdio_params) + + # Should log deprecation warning + mock_logger.warning.assert_called_once() + assert "StdioServerParameters is not recommended" in str( + mock_logger.warning.call_args + ) + + # Should convert to StdioConnectionParams + assert isinstance(manager._connection_params, StdioConnectionParams) + assert manager._connection_params.server_params == self.mock_stdio_params + assert manager._connection_params.timeout == 5 + + def test_init_with_stdio_connection_params(self): + """Test initialization with StdioConnectionParams.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + assert manager._connection_params == self.mock_stdio_connection_params + assert manager._errlog == sys.stderr + assert manager._sessions == {} + + def test_init_with_sse_connection_params(self): + """Test initialization with SseConnectionParams.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", + headers={"Authorization": "Bearer token"}, + timeout=10.0, + ) + manager = MCPSessionManager(sse_params) + + assert manager._connection_params == sse_params + + def test_init_with_streamable_http_params(self): + """Test initialization with StreamableHTTPConnectionParams.""" + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", timeout=15.0 + ) + manager = MCPSessionManager(http_params) + + assert manager._connection_params == http_params + + def test_generate_session_key_stdio(self): + """Test session key generation for stdio connections.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # For stdio, headers should be ignored and return constant key + key1 = manager._generate_session_key({"Authorization": "Bearer token"}) + key2 = manager._generate_session_key(None) + + assert key1 == "stdio_session" + assert key2 == "stdio_session" + assert key1 == key2 + + def test_generate_session_key_sse(self): + """Test session key generation for SSE connections.""" + sse_params = SseConnectionParams(url="https://example.com/mcp") + manager = MCPSessionManager(sse_params) + + headers1 = {"Authorization": "Bearer token1"} + headers2 = {"Authorization": "Bearer token2"} + + key1 = manager._generate_session_key(headers1) + key2 = manager._generate_session_key(headers2) + key3 = manager._generate_session_key(headers1) + + # Different headers should generate different keys + assert key1 != key2 + # Same headers should generate same key + assert key1 == key3 + + # Should be deterministic hash + headers_json = json.dumps(headers1, sort_keys=True) + expected_hash = hashlib.md5(headers_json.encode()).hexdigest() + assert key1 == f"session_{expected_hash}" + + def test_merge_headers_stdio(self): + """Test header merging for stdio connections.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Stdio connections don't support headers + headers = manager._merge_headers({"Authorization": "Bearer token"}) + assert headers is None + + def test_merge_headers_sse(self): + """Test header merging for SSE connections.""" + base_headers = {"Content-Type": "application/json"} + sse_params = SseConnectionParams( + url="https://example.com/mcp", headers=base_headers + ) + manager = MCPSessionManager(sse_params) + + # With additional headers + additional = {"Authorization": "Bearer token"} + merged = manager._merge_headers(additional) + + expected = { + "Content-Type": "application/json", + "Authorization": "Bearer token", + } + assert merged == expected + + def test_is_session_disconnected(self): + """Test session disconnection detection.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock session + session = MockClientSession() + + # Not disconnected + assert not manager._is_session_disconnected(session) + + # Disconnected - read stream closed + session._read_stream._closed = True + assert manager._is_session_disconnected(session) + + @pytest.mark.asyncio + async def test_create_session_stdio_new(self): + """Test creating a new stdio session.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + mock_session = MockClientSession() + mock_exit_stack = MockAsyncExitStack() + + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.stdio_client" + ) as mock_stdio: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack" + ) as mock_exit_stack_class: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.ClientSession" + ) as mock_session_class: + + # Setup mocks + mock_exit_stack_class.return_value = mock_exit_stack + mock_stdio.return_value = AsyncMock() + mock_exit_stack.enter_async_context.side_effect = [ + ("read", "write"), # First call returns transports + mock_session, # Second call returns session + ] + mock_session_class.return_value = mock_session + + # Create session + session = await manager.create_session() + + # Verify session creation + assert session == mock_session + assert len(manager._sessions) == 1 + assert "stdio_session" in manager._sessions + + # Verify session was initialized + mock_session.initialize.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_reuse_existing(self): + """Test reusing an existing connected session.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock existing session + existing_session = MockClientSession() + existing_exit_stack = MockAsyncExitStack() + manager._sessions["stdio_session"] = (existing_session, existing_exit_stack) + + # Session is connected + existing_session._read_stream._closed = False + existing_session._write_stream._closed = False + + session = await manager.create_session() + + # Should reuse existing session + assert session == existing_session + assert len(manager._sessions) == 1 + + # Should not create new session + existing_session.initialize.assert_not_called() + + @pytest.mark.asyncio + async def test_close_success(self): + """Test successful cleanup of all sessions.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add mock sessions + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + + manager._sessions["session1"] = (session1, exit_stack1) + manager._sessions["session2"] = (session2, exit_stack2) + + await manager.close() + + # All sessions should be closed + exit_stack1.aclose.assert_called_once() + exit_stack2.aclose.assert_called_once() + assert len(manager._sessions) == 0 + + @pytest.mark.asyncio + async def test_close_with_errors(self): + """Test cleanup when some sessions fail to close.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add mock sessions + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + exit_stack1.aclose.side_effect = Exception("Close error 1") + + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + + manager._sessions["session1"] = (session1, exit_stack1) + manager._sessions["session2"] = (session2, exit_stack2) + + custom_errlog = StringIO() + manager._errlog = custom_errlog + + # Should not raise exception + await manager.close() + + # Good session should still be closed + exit_stack2.aclose.assert_called_once() + assert len(manager._sessions) == 0 + + # Error should be logged + error_output = custom_errlog.getvalue() + assert "Warning: Error during MCP session cleanup" in error_output + assert "Close error 1" in error_output + + +def test_retry_on_closed_resource_decorator(): + """Test the retry_on_closed_resource decorator.""" + + call_count = 0 + + @retry_on_closed_resource + async def mock_function(self): + nonlocal call_count + call_count += 1 + if call_count == 1: + import anyio + + raise anyio.ClosedResourceError("Resource closed") + return "success" + + @pytest.mark.asyncio + async def test_retry(): + nonlocal call_count + call_count = 0 + + mock_self = Mock() + result = await mock_function(mock_self) + + assert result == "success" + assert call_count == 2 # First call fails, second succeeds + + # Run the test + import asyncio + + asyncio.run(test_retry()) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py new file mode 100644 index 000000000..4d9cffb4d --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -0,0 +1,373 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.tool_context import ToolContext +from google.genai.types import FunctionDeclaration +import pytest + + +# Mock MCP Tool from mcp.types +class MockMCPTool: + """Mock MCP Tool for testing.""" + + def __init__(self, name="test_tool", description="Test tool description"): + self.name = name + self.description = description + self.inputSchema = { + "type": "object", + "properties": { + "param1": {"type": "string", "description": "First parameter"}, + "param2": {"type": "integer", "description": "Second parameter"}, + }, + "required": ["param1"], + } + + +class TestMCPTool: + """Test suite for MCPTool class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_mcp_tool = MockMCPTool() + self.mock_session_manager = Mock(spec=MCPSessionManager) + self.mock_session = AsyncMock() + self.mock_session_manager.create_session = AsyncMock( + return_value=self.mock_session + ) + + def test_init_basic(self): + """Test basic initialization without auth.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test tool description" + assert tool._mcp_tool == self.mock_mcp_tool + assert tool._mcp_session_manager == self.mock_session_manager + + def test_init_with_auth(self): + """Test initialization with authentication.""" + # Create real auth scheme instances instead of mocks + from fastapi.openapi.models import OAuth2 + + auth_scheme = OAuth2(flows={}) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + # The auth config is stored in the parent class _credentials_manager + assert tool._credentials_manager is not None + assert tool._credentials_manager._auth_config.auth_scheme == auth_scheme + assert ( + tool._credentials_manager._auth_config.raw_auth_credential + == auth_credential + ) + + def test_init_with_empty_description(self): + """Test initialization with empty description.""" + mock_tool = MockMCPTool(description=None) + tool = MCPTool( + mcp_tool=mock_tool, + mcp_session_manager=self.mock_session_manager, + ) + + assert tool.description == "" + + def test_get_declaration(self): + """Test function declaration generation.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + declaration = tool._get_declaration() + + assert isinstance(declaration, FunctionDeclaration) + assert declaration.name == "test_tool" + assert declaration.description == "Test tool description" + assert declaration.parameters is not None + + @pytest.mark.asyncio + async def test_run_async_impl_no_auth(self): + """Test running tool without authentication.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Mock the session response + expected_response = {"result": "success"} + self.mock_session.call_tool = AsyncMock(return_value=expected_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=None + ) + + assert result == expected_response + self.mock_session_manager.create_session.assert_called_once_with( + headers=None + ) + # Fix: call_tool uses 'arguments' parameter, not positional args + self.mock_session.call_tool.assert_called_once_with( + "test_tool", arguments=args + ) + + @pytest.mark.asyncio + async def test_run_async_impl_with_oauth2(self): + """Test running tool with OAuth2 authentication.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create OAuth2 credential + oauth2_auth = OAuth2Auth(access_token="test_access_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + # Mock the session response + expected_response = {"result": "success"} + self.mock_session.call_tool = AsyncMock(return_value=expected_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + assert result == expected_response + # Check that headers were passed correctly + self.mock_session_manager.create_session.assert_called_once() + call_args = self.mock_session_manager.create_session.call_args + headers = call_args[1]["headers"] + assert headers == {"Authorization": "Bearer test_access_token"} + + @pytest.mark.asyncio + async def test_get_headers_oauth2(self): + """Test header generation for OAuth2 credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + oauth2_auth = OAuth2Auth(access_token="test_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer test_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_bearer(self): + """Test header generation for HTTP Bearer credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="bearer", credentials=HttpCredentials(token="bearer_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer bearer_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_basic(self): + """Test header generation for HTTP Basic credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="basic", + credentials=HttpCredentials(username="user", password="pass"), + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should create Basic auth header with base64 encoded credentials + import base64 + + expected_encoded = base64.b64encode(b"user:pass").decode() + assert headers == {"Authorization": f"Basic {expected_encoded}"} + + @pytest.mark.asyncio + async def test_get_headers_api_key(self): + """Test header generation for API Key credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"X-API-Key": "my_api_key"} + + @pytest.mark.asyncio + @patch("google.adk.tools.mcp_tool.mcp_tool.json") + @patch("google.adk.tools.mcp_tool.mcp_tool.Credentials") + async def test_get_headers_google_oauth2_json( + self, mock_credentials, mock_json + ): + """Test header generation for Google OAuth2 JSON credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Mock the JSON parsing and Credentials creation + mock_json.loads.return_value = {"token": "google_token"} + mock_google_credential = Mock() + mock_google_credential.token = "google_access_token" + mock_credentials.from_authorized_user_info.return_value = ( + mock_google_credential + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json='{"token": "google_token"}', + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer google_access_token"} + mock_json.loads.assert_called_once_with('{"token": "google_token"}') + mock_credentials.from_authorized_user_info.assert_called_once_with( + {"token": "google_token"} + ) + + @pytest.mark.asyncio + async def test_get_headers_no_credential(self): + """Test header generation with no credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, None) + + assert headers is None + + @pytest.mark.asyncio + async def test_get_headers_service_account_no_json(self): + """Test header generation for service account credentials without google_oauth2_json.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create service account credential without google_oauth2_json + service_account = ServiceAccount(scopes=["test"]) + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=service_account, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should return None as no google_oauth2_json is provided + assert headers is None + + @pytest.mark.asyncio + async def test_run_async_impl_retry_decorator(self): + """Test that the retry decorator is applied correctly.""" + # This is more of an integration test to ensure the decorator is present + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Check that the method has the retry decorator + assert hasattr(tool._run_async_impl, "__wrapped__") + + @pytest.mark.asyncio + async def test_get_headers_http_custom_scheme(self): + """Test header generation for custom HTTP scheme.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="custom", credentials=HttpCredentials(token="custom_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "custom custom_token"} + + def test_init_validation(self): + """Test that initialization validates required parameters.""" + # This test ensures that the MCPTool properly handles its dependencies + with pytest.raises(TypeError): + MCPTool() # Missing required parameters + + with pytest.raises(TypeError): + MCPTool(mcp_tool=self.mock_mcp_tool) # Missing session manager diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py new file mode 100644 index 000000000..0ba29b1da --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -0,0 +1,269 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import StringIO +import sys +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +import pytest + +# Import the real MCP classes for proper instantiation +try: + from mcp import StdioServerParameters +except ImportError: + # Create a mock if MCP is not available + class StdioServerParameters: + + def __init__(self, command="test_command", args=None): + self.command = command + self.args = args or [] + + +class MockMCPTool: + """Mock MCP Tool for testing.""" + + def __init__(self, name, description="Test tool description"): + self.name = name + self.description = description + self.inputSchema = { + "type": "object", + "properties": {"param": {"type": "string"}}, + } + + +class MockListToolsResult: + """Mock ListToolsResult for testing.""" + + def __init__(self, tools): + self.tools = tools + + +class TestMCPToolset: + """Test suite for MCPToolset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_session_manager = Mock(spec=MCPSessionManager) + self.mock_session = AsyncMock() + self.mock_session_manager.create_session = AsyncMock( + return_value=self.mock_session + ) + + def test_init_basic(self): + """Test basic initialization with StdioServerParameters.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + + # Note: StdioServerParameters gets converted to StdioConnectionParams internally + assert toolset._errlog == sys.stderr + assert toolset._auth_scheme is None + assert toolset._auth_credential is None + + def test_init_with_stdio_connection_params(self): + """Test initialization with StdioConnectionParams.""" + stdio_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=10.0 + ) + toolset = MCPToolset(connection_params=stdio_params) + + assert toolset._connection_params == stdio_params + + def test_init_with_sse_connection_params(self): + """Test initialization with SseConnectionParams.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", headers={"Authorization": "Bearer token"} + ) + toolset = MCPToolset(connection_params=sse_params) + + assert toolset._connection_params == sse_params + + def test_init_with_streamable_http_params(self): + """Test initialization with StreamableHTTPConnectionParams.""" + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", + headers={"Content-Type": "application/json"}, + ) + toolset = MCPToolset(connection_params=http_params) + + assert toolset._connection_params == http_params + + def test_init_with_tool_filter_list(self): + """Test initialization with tool filter as list.""" + tool_filter = ["tool1", "tool2"] + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=tool_filter + ) + + # The tool filter is stored in the parent BaseToolset class + # We can verify it by checking the filtering behavior in get_tools + assert toolset._is_tool_selected is not None + + def test_init_with_auth(self): + """Test initialization with authentication.""" + # Create real auth scheme instances + from fastapi.openapi.models import OAuth2 + + auth_scheme = OAuth2(flows={}) + from google.adk.auth.auth_credential import OAuth2Auth + + auth_credential = AuthCredential( + auth_type="oauth2", + oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + ) + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + assert toolset._auth_scheme == auth_scheme + assert toolset._auth_credential == auth_credential + + def test_init_missing_connection_params(self): + """Test initialization with missing connection params raises error.""" + with pytest.raises(ValueError, match="Missing connection params"): + MCPToolset(connection_params=None) + + @pytest.mark.asyncio + async def test_get_tools_basic(self): + """Test getting tools without filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("tool1"), + MockMCPTool("tool2"), + MockMCPTool("tool3"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 3 + for tool in tools: + assert isinstance(tool, MCPTool) + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + assert tools[2].name == "tool3" + + @pytest.mark.asyncio + async def test_get_tools_with_list_filter(self): + """Test getting tools with list-based filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("tool1"), + MockMCPTool("tool2"), + MockMCPTool("tool3"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + tool_filter = ["tool1", "tool3"] + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=tool_filter + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool3" + + @pytest.mark.asyncio + async def test_get_tools_with_function_filter(self): + """Test getting tools with function-based filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("read_file"), + MockMCPTool("write_file"), + MockMCPTool("list_directory"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + def file_tools_filter(tool, context): + """Filter for file-related tools only.""" + return "file" in tool.name + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=file_tools_filter + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert tools[0].name == "read_file" + assert tools[1].name == "write_file" + + @pytest.mark.asyncio + async def test_close_success(self): + """Test successful cleanup.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + await toolset.close() + + self.mock_session_manager.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_exception(self): + """Test cleanup when session manager raises exception.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + # Mock close to raise an exception + self.mock_session_manager.close = AsyncMock( + side_effect=Exception("Cleanup error") + ) + + custom_errlog = StringIO() + toolset._errlog = custom_errlog + + # Should not raise exception + await toolset.close() + + # Should log the error + error_output = custom_errlog.getvalue() + assert "Warning: Error during MCPToolset cleanup" in error_output + assert "Cleanup error" in error_output + + @pytest.mark.asyncio + async def test_get_tools_retry_decorator(self): + """Test that get_tools has retry decorator applied.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + + # Check that the method has the retry decorator + assert hasattr(toolset.get_tools, "__wrapped__") From 58e07cae83048d5213d822be5197a96be9ce2950 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Wed, 18 Jun 2025 16:33:56 -0700 Subject: [PATCH 34/79] fix: Fix tracing for live the original code passed in wrong args. now fixed. tested locally. PiperOrigin-RevId: 773108589 --- src/google/adk/flows/llm_flows/functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 2541ac664..2772550c2 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -288,8 +288,7 @@ async def handle_function_calls_live( trace_tool_call( tool=tool, args=function_args, - response_event_id=function_response_event.id, - function_response=function_response, + function_response_event=function_response_event, ) function_response_events.append(function_response_event) From 913d771d6dda4f0b4a5f9c82ab914f3495a92092 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 17:40:12 -0700 Subject: [PATCH 35/79] chore: Raise meaningful errors when importing a2a modules for python 3.9 PiperOrigin-RevId: 773128206 --- src/google/adk/a2a/converters/part_converter.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index 1c51fd7c1..2d94abd7c 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -20,9 +20,20 @@ import json import logging +import sys from typing import Optional -from a2a import types as a2a_types +try: + from a2a import types as a2a_types +except ImportError as e: + if sys.version_info < (3, 10): + raise ImportError( + 'A2A Tool requires Python 3.10 or above. Please upgrade your Python' + ' version.' + ) from e + else: + raise e + from google.genai import types as genai_types from ...utils.feature_decorator import working_in_progress From 9a1115c504427ed8285b5c2053946c11c5d7c0a6 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 18:18:17 -0700 Subject: [PATCH 36/79] chore: Remove service account support given it was not correctly supported. PiperOrigin-RevId: 773137317 --- src/google/adk/auth/auth_credential.py | 1 - src/google/adk/auth/credential_manager.py | 6 +- src/google/adk/auth/exchanger/__init__.py | 2 - .../service_account_credential_exchanger.py | 104 ----- .../refresher/oauth2_credential_refresher.py | 32 +- src/google/adk/tools/mcp_tool/mcp_tool.py | 10 +- .../openapi_spec_parser/tool_auth_handler.py | 1 - ...st_service_account_credential_exchanger.py | 433 ------------------ .../test_oauth2_credential_refresher.py | 118 ----- .../unittests/auth/test_credential_manager.py | 26 +- .../unittests/tools/mcp_tool/test_mcp_tool.py | 42 +- 11 files changed, 15 insertions(+), 760 deletions(-) delete mode 100644 src/google/adk/auth/exchanger/service_account_credential_exchanger.py delete mode 100644 tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 1009a50dd..34d04dde9 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -230,4 +230,3 @@ class AuthCredential(BaseModelWithConfig): http: Optional[HttpAuth] = None service_account: Optional[ServiceAccount] = None oauth2: Optional[OAuth2Auth] = None - google_oauth2_json: Optional[str] = None diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index 7471bdffa..0dbf006ab 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -76,11 +76,7 @@ def __init__( self._refresher_registry = CredentialRefresherRegistry() # Register default exchangers and refreshers - from .exchanger.service_account_credential_exchanger import ServiceAccountCredentialExchanger - - self._exchanger_registry.register( - AuthCredentialTypes.SERVICE_ACCOUNT, ServiceAccountCredentialExchanger() - ) + # TODO: support service account credential exchanger from .refresher.oauth2_credential_refresher import OAuth2CredentialRefresher oauth2_refresher = OAuth2CredentialRefresher() diff --git a/src/google/adk/auth/exchanger/__init__.py b/src/google/adk/auth/exchanger/__init__.py index 4226ae715..3b0fbb246 100644 --- a/src/google/adk/auth/exchanger/__init__.py +++ b/src/google/adk/auth/exchanger/__init__.py @@ -15,9 +15,7 @@ """Credential exchanger module.""" from .base_credential_exchanger import BaseCredentialExchanger -from .service_account_credential_exchanger import ServiceAccountCredentialExchanger __all__ = [ "BaseCredentialExchanger", - "ServiceAccountCredentialExchanger", ] diff --git a/src/google/adk/auth/exchanger/service_account_credential_exchanger.py b/src/google/adk/auth/exchanger/service_account_credential_exchanger.py deleted file mode 100644 index 415081ca5..000000000 --- a/src/google/adk/auth/exchanger/service_account_credential_exchanger.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Credential fetcher for Google Service Account.""" - -from __future__ import annotations - -from typing import Optional - -import google.auth -from google.auth.transport.requests import Request -from google.oauth2 import service_account -from typing_extensions import override - -from ...utils.feature_decorator import experimental -from ..auth_credential import AuthCredential -from ..auth_credential import AuthCredentialTypes -from ..auth_schemes import AuthScheme -from .base_credential_exchanger import BaseCredentialExchanger - - -@experimental -class ServiceAccountCredentialExchanger(BaseCredentialExchanger): - """Exchanges Google Service Account credentials for an access token. - - Uses the default service credential if `use_default_credential = True`. - Otherwise, uses the service account credential provided in the auth - credential. - """ - - @override - async def exchange( - self, - auth_credential: AuthCredential, - auth_scheme: Optional[AuthScheme] = None, - ) -> AuthCredential: - """Exchanges the service account auth credential for an access token. - - If the AuthCredential contains a service account credential, it will be used - to exchange for an access token. Otherwise, if use_default_credential is True, - the default application credential will be used for exchanging an access token. - - Args: - auth_scheme: The authentication scheme. - auth_credential: The credential to exchange. - - Returns: - An AuthCredential in OAUTH2 format, containing the exchanged credential JSON. - - Raises: - ValueError: If service account credentials are missing or invalid. - Exception: If credential exchange or refresh fails. - """ - if auth_credential is None: - raise ValueError("Credential cannot be None.") - - if auth_credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: - raise ValueError("Credential is not a service account credential.") - - if auth_credential.service_account is None: - raise ValueError( - "Service account credentials are missing. Please provide them." - ) - - if ( - auth_credential.service_account.service_account_credential is None - and not auth_credential.service_account.use_default_credential - ): - raise ValueError( - "Service account credentials are invalid. Please set the" - " service_account_credential field or set `use_default_credential =" - " True` to use application default credential in a hosted service" - " like Google Cloud Run." - ) - - try: - if auth_credential.service_account.use_default_credential: - credentials, _ = google.auth.default() - else: - config = auth_credential.service_account - credentials = service_account.Credentials.from_service_account_info( - config.service_account_credential.model_dump(), scopes=config.scopes - ) - - # Refresh credentials to ensure we have a valid access token - credentials.refresh(Request()) - - return AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json=credentials.to_json(), - ) - except Exception as e: - raise ValueError(f"Failed to exchange service account token: {e}") from e diff --git a/src/google/adk/auth/refresher/oauth2_credential_refresher.py b/src/google/adk/auth/refresher/oauth2_credential_refresher.py index 2d0a8b670..4c19520ce 100644 --- a/src/google/adk/auth/refresher/oauth2_credential_refresher.py +++ b/src/google/adk/auth/refresher/oauth2_credential_refresher.py @@ -60,27 +60,12 @@ async def is_refresh_needed( Returns: True if the credential needs to be refreshed, False otherwise. """ - # Handle Google OAuth2 credentials (from service account exchange) - if auth_credential.google_oauth2_json: - try: - google_credential = Credentials.from_authorized_user_info( - json.loads(auth_credential.google_oauth2_json) - ) - return google_credential.expired and bool( - google_credential.refresh_token - ) - except Exception as e: - logger.warning("Failed to parse Google OAuth2 JSON credential: %s", e) - return False # Handle regular OAuth2 credentials - elif auth_credential.oauth2 and auth_scheme: + if auth_credential.oauth2: if not AUTHLIB_AVIALABLE: return False - if not auth_credential.oauth2: - return False - return OAuth2Token({ "expires_at": auth_credential.oauth2.expires_at, "expires_in": auth_credential.oauth2.expires_in, @@ -105,22 +90,9 @@ async def refresh( The refreshed credential. """ - # Handle Google OAuth2 credentials (from service account exchange) - if auth_credential.google_oauth2_json: - try: - google_credential = Credentials.from_authorized_user_info( - json.loads(auth_credential.google_oauth2_json) - ) - if google_credential.expired and google_credential.refresh_token: - google_credential.refresh(Request()) - auth_credential.google_oauth2_json = google_credential.to_json() - logger.info("Successfully refreshed Google OAuth2 JSON credential") - except Exception as e: - # TODO reconsider whether we should raise error when refresh failed. - logger.error("Failed to refresh Google OAuth2 JSON credential: %s", e) # Handle regular OAuth2 credentials - elif auth_credential.oauth2 and auth_scheme: + if auth_credential.oauth2 and auth_scheme: if not AUTHLIB_AVIALABLE: return auth_credential diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 24998c925..310fc48f1 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -138,11 +138,6 @@ async def _get_headers( if credential: if credential.oauth2: headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} - elif credential.google_oauth2_json: - google_credential = Credentials.from_authorized_user_info( - json.loads(credential.google_oauth2_json) - ) - headers = {"Authorization": f"Bearer {google_credential.token}"} elif credential.http: # Handle HTTP authentication schemes if ( @@ -178,10 +173,9 @@ async def _get_headers( headers = {"X-API-Key": credential.api_key} elif credential.service_account: # Service accounts should be exchanged for access tokens before reaching this point - # If we reach here, we can try to use google_oauth2_json or log a warning logger.warning( - "Service account credentials should be exchanged for access" - " tokens before MCP session creation" + "Service account credentials should be exchanged before MCP" + " session creation" ) return headers diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index 08e535d28..74166b00e 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -233,7 +233,6 @@ def _external_exchange_required(self, credential) -> bool: AuthCredentialTypes.OPEN_ID_CONNECT, ) and not credential.oauth2.access_token - and not credential.google_oauth2_json ) async def prepare_auth_credentials( diff --git a/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py b/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py deleted file mode 100644 index 195e143d3..000000000 --- a/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py +++ /dev/null @@ -1,433 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for the ServiceAccountCredentialExchanger.""" - -from unittest.mock import MagicMock -from unittest.mock import patch - -from fastapi.openapi.models import HTTPBearer -from google.adk.auth.auth_credential import AuthCredential -from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.auth.auth_credential import ServiceAccount -from google.adk.auth.auth_credential import ServiceAccountCredential -from google.adk.auth.exchanger.service_account_credential_exchanger import ServiceAccountCredentialExchanger -import pytest - - -class TestServiceAccountCredentialExchanger: - """Test cases for ServiceAccountCredentialExchanger.""" - - def test_exchange_with_valid_credential(self): - """Test successful exchange with valid service account credential.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="test-project", - private_key_id="key-id", - private_key=( - "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE" - " KEY-----" - ), - client_email="test@test-project.iam.gserviceaccount.com", - client_id="12345", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - # This should not raise an exception - assert exchanger is not None - - @pytest.mark.asyncio - async def test_exchange_invalid_credential_type(self): - """Test exchange with invalid credential type raises ValueError.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, - api_key="test-key", - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Credential is not a service account credential" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" - ) - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.Request" - ) - async def test_exchange_with_explicit_credentials_success( - self, mock_request_class, mock_from_service_account_info - ): - """Test successful exchange with explicit service account credentials.""" - # Setup mocks - mock_request = MagicMock() - mock_request_class.return_value = mock_request - - mock_credentials = MagicMock() - mock_credentials.token = "mock_access_token" - mock_credentials.to_json.return_value = ( - '{"token": "mock_access_token", "type": "authorized_user"}' - ) - mock_from_service_account_info.return_value = mock_credentials - - # Create test credential - service_account_cred = ServiceAccountCredential( - type_="service_account", - project_id="test-project", - private_key_id="key-id", - private_key=( - "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE KEY-----" - ), - client_email="test@test-project.iam.gserviceaccount.com", - client_id="12345", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", - universe_domain="googleapis.com", - ) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=service_account_cred, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - result = await exchanger.exchange(credential, auth_scheme) - - # Verify the result - assert result.auth_type == AuthCredentialTypes.OAUTH2 - assert result.google_oauth2_json is not None - # Verify that google_oauth2_json contains the token - import json - - exchanged_creds = json.loads(result.google_oauth2_json) - assert exchanged_creds.get( - "token" - ) == "mock_access_token" or "mock_access_token" in str(exchanged_creds) - - # Verify mocks were called correctly - mock_from_service_account_info.assert_called_once_with( - service_account_cred.model_dump(), - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - mock_credentials.refresh.assert_called_once_with(mock_request) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" - ) - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.Request" - ) - async def test_exchange_with_default_credentials_success( - self, mock_request_class, mock_google_auth_default - ): - """Test successful exchange with default application credentials.""" - # Setup mocks - mock_request = MagicMock() - mock_request_class.return_value = mock_request - - mock_credentials = MagicMock() - mock_credentials.token = "default_access_token" - mock_credentials.to_json.return_value = ( - '{"token": "default_access_token", "type": "authorized_user"}' - ) - mock_google_auth_default.return_value = (mock_credentials, "test-project") - - # Create test credential with use_default_credential=True - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - result = await exchanger.exchange(credential, auth_scheme) - - # Verify the result - assert result.auth_type == AuthCredentialTypes.OAUTH2 - assert result.google_oauth2_json is not None - # Verify that google_oauth2_json contains the token - import json - - exchanged_creds = json.loads(result.google_oauth2_json) - assert exchanged_creds.get( - "token" - ) == "default_access_token" or "default_access_token" in str( - exchanged_creds - ) - - # Verify mocks were called correctly - mock_google_auth_default.assert_called_once() - mock_credentials.refresh.assert_called_once_with(mock_request) - - @pytest.mark.asyncio - async def test_exchange_missing_service_account(self): - """Test exchange fails when service_account is None.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=None, - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Service account credentials are missing" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - async def test_exchange_missing_credentials_and_not_default(self): - """Test exchange fails when credentials are missing and use_default_credential is False.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=None, - use_default_credential=False, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Service account credentials are invalid" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" - ) - async def test_exchange_credential_creation_failure( - self, mock_from_service_account_info - ): - """Test exchange handles credential creation failure gracefully.""" - # Setup mock to raise exception - mock_from_service_account_info.side_effect = Exception( - "Invalid private key" - ) - - # Create test credential - service_account_cred = ServiceAccountCredential( - type_="service_account", - project_id="test-project", - private_key_id="key-id", - private_key="invalid-key", - client_email="test@test-project.iam.gserviceaccount.com", - client_id="12345", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", - universe_domain="googleapis.com", - ) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=service_account_cred, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Failed to exchange service account token" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" - ) - async def test_exchange_default_credential_failure( - self, mock_google_auth_default - ): - """Test exchange handles default credential failure gracefully.""" - # Setup mock to raise exception - mock_google_auth_default.side_effect = Exception( - "No default credentials found" - ) - - # Create test credential with use_default_credential=True - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Failed to exchange service account token" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" - ) - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.Request" - ) - async def test_exchange_refresh_failure( - self, mock_request_class, mock_from_service_account_info - ): - """Test exchange handles credential refresh failure gracefully.""" - # Setup mocks - mock_request = MagicMock() - mock_request_class.return_value = mock_request - - mock_credentials = MagicMock() - mock_credentials.refresh.side_effect = Exception( - "Network error during refresh" - ) - mock_from_service_account_info.return_value = mock_credentials - - # Create test credential - service_account_cred = ServiceAccountCredential( - type_="service_account", - project_id="test-project", - private_key_id="key-id", - private_key=( - "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE KEY-----" - ), - client_email="test@test-project.iam.gserviceaccount.com", - client_id="12345", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", - universe_domain="googleapis.com", - ) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=service_account_cred, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Failed to exchange service account token" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - async def test_exchange_none_credential_in_constructor(self): - """Test that passing None credential raises appropriate error during exchange.""" - # This test verifies behavior when credential is None - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises(ValueError, match="Credential cannot be None"): - await exchanger.exchange(None, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" - ) - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.Request" - ) - async def test_exchange_with_service_account_no_explicit_credentials( - self, mock_request_class, mock_google_auth_default - ): - """Test exchange with service account that has no explicit credentials uses default.""" - # Setup mocks - mock_request = MagicMock() - mock_request_class.return_value = mock_request - - mock_credentials = MagicMock() - mock_credentials.token = "default_access_token" - mock_credentials.to_json.return_value = ( - '{"token": "default_access_token", "type": "authorized_user"}' - ) - mock_google_auth_default.return_value = (mock_credentials, "test-project") - - # Create test credential with no explicit credentials but use_default_credential=True - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=None, - use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - result = await exchanger.exchange(credential, auth_scheme) - - # Verify the result - assert result.auth_type == AuthCredentialTypes.OAUTH2 - assert result.google_oauth2_json is not None - # Verify that google_oauth2_json contains the token - import json - - exchanged_creds = json.loads(result.google_oauth2_json) - assert exchanged_creds.get( - "token" - ) == "default_access_token" or "default_access_token" in str( - exchanged_creds - ) - - # Verify mocks were called correctly - mock_google_auth_default.assert_called_once() - mock_credentials.refresh.assert_called_once_with(mock_request) diff --git a/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py index b22bf2ccd..3342fcb05 100644 --- a/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py +++ b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py @@ -165,124 +165,6 @@ async def test_refresh_no_oauth2_credential(self): assert result == credential - @pytest.mark.asyncio - async def test_needs_refresh_google_oauth2_json_expired(self): - """Test needs_refresh with Google OAuth2 JSON credential that is expired.""" - import json - from unittest.mock import patch - - # Mock Google OAuth2 JSON credential data - google_oauth2_json = json.dumps({ - "client_id": "test_client_id", - "client_secret": "test_client_secret", - "refresh_token": "test_refresh_token", - "type": "authorized_user", - }) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json=google_oauth2_json, - ) - - # Mock the Google Credentials class - with patch( - "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" - ) as mock_credentials: - mock_google_credential = Mock() - mock_google_credential.expired = True - mock_google_credential.refresh_token = "test_refresh_token" - mock_credentials.from_authorized_user_info.return_value = ( - mock_google_credential - ) - - refresher = OAuth2CredentialRefresher() - needs_refresh = await refresher.is_refresh_needed(credential, None) - - assert needs_refresh - - @pytest.mark.asyncio - async def test_needs_refresh_google_oauth2_json_not_expired(self): - """Test needs_refresh with Google OAuth2 JSON credential that is not expired.""" - import json - from unittest.mock import patch - - # Mock Google OAuth2 JSON credential data - google_oauth2_json = json.dumps({ - "client_id": "test_client_id", - "client_secret": "test_client_secret", - "refresh_token": "test_refresh_token", - "type": "authorized_user", - }) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json=google_oauth2_json, - ) - - # Mock the Google Credentials class - with patch( - "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" - ) as mock_credentials: - mock_google_credential = Mock() - mock_google_credential.expired = False - mock_google_credential.refresh_token = "test_refresh_token" - mock_credentials.from_authorized_user_info.return_value = ( - mock_google_credential - ) - - refresher = OAuth2CredentialRefresher() - needs_refresh = await refresher.is_refresh_needed(credential, None) - - assert not needs_refresh - - @pytest.mark.asyncio - async def test_refresh_google_oauth2_json_success(self): - """Test successful refresh of Google OAuth2 JSON credential.""" - import json - from unittest.mock import patch - - # Mock Google OAuth2 JSON credential data - google_oauth2_json = json.dumps({ - "client_id": "test_client_id", - "client_secret": "test_client_secret", - "refresh_token": "test_refresh_token", - "type": "authorized_user", - }) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json=google_oauth2_json, - ) - - # Mock the Google Credentials and Request classes - with patch( - "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" - ) as mock_credentials: - with patch( - "google.adk.auth.refresher.oauth2_credential_refresher.Request" - ) as mock_request: - mock_google_credential = Mock() - mock_google_credential.expired = True - mock_google_credential.refresh_token = "test_refresh_token" - mock_google_credential.to_json.return_value = json.dumps({ - "client_id": "test_client_id", - "client_secret": "test_client_secret", - "refresh_token": "new_refresh_token", - "access_token": "new_access_token", - "type": "authorized_user", - }) - mock_credentials.from_authorized_user_info.return_value = ( - mock_google_credential - ) - - refresher = OAuth2CredentialRefresher() - result = await refresher.refresh(credential, None) - - mock_google_credential.refresh.assert_called_once() - assert ( - result.google_oauth2_json != google_oauth2_json - ) # Should be updated - @pytest.mark.asyncio async def test_needs_refresh_no_oauth2_credential(self): """Test needs_refresh with no OAuth2 credential returns False.""" diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py index 283e865a7..8e3638dd6 100644 --- a/tests/unittests/auth/test_credential_manager.py +++ b/tests/unittests/auth/test_credential_manager.py @@ -410,39 +410,25 @@ async def test_validate_credential_oauth2_missing_oauth2_field(self): @pytest.mark.asyncio async def test_exchange_credentials_service_account(self): - """Test _exchange_credential with service account credential.""" + """Test _exchange_credential with service account credential (no exchanger available).""" mock_raw_credential = Mock(spec=AuthCredential) mock_raw_credential.auth_type = AuthCredentialTypes.SERVICE_ACCOUNT - mock_exchanged_credential = Mock(spec=AuthCredential) - auth_config = Mock(spec=AuthConfig) auth_config.auth_scheme = Mock() manager = CredentialManager(auth_config) - # Mock the exchanger that gets created during registration + # Mock the exchanger registry to return None (no exchanger available) with patch.object( - manager._exchanger_registry, "get_exchanger" - ) as mock_get_exchanger: - mock_exchanger = Mock() - mock_exchanger.exchange = AsyncMock( - return_value=mock_exchanged_credential - ) - mock_get_exchanger.return_value = mock_exchanger - + manager._exchanger_registry, "get_exchanger", return_value=None + ): result, was_exchanged = await manager._exchange_credential( mock_raw_credential ) - assert result == mock_exchanged_credential - assert was_exchanged is True - mock_get_exchanger.assert_called_once_with( - AuthCredentialTypes.SERVICE_ACCOUNT - ) - mock_exchanger.exchange.assert_called_once_with( - mock_raw_credential, auth_config.auth_scheme - ) + assert result == mock_raw_credential + assert was_exchanged is False @pytest.mark.asyncio async def test_exchange_credential_no_exchanger(self): diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 4d9cffb4d..d25a84eac 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -263,40 +263,6 @@ async def test_get_headers_api_key(self): assert headers == {"X-API-Key": "my_api_key"} - @pytest.mark.asyncio - @patch("google.adk.tools.mcp_tool.mcp_tool.json") - @patch("google.adk.tools.mcp_tool.mcp_tool.Credentials") - async def test_get_headers_google_oauth2_json( - self, mock_credentials, mock_json - ): - """Test header generation for Google OAuth2 JSON credentials.""" - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - # Mock the JSON parsing and Credentials creation - mock_json.loads.return_value = {"token": "google_token"} - mock_google_credential = Mock() - mock_google_credential.token = "google_access_token" - mock_credentials.from_authorized_user_info.return_value = ( - mock_google_credential - ) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json='{"token": "google_token"}', - ) - - tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, credential) - - assert headers == {"Authorization": "Bearer google_access_token"} - mock_json.loads.assert_called_once_with('{"token": "google_token"}') - mock_credentials.from_authorized_user_info.assert_called_once_with( - {"token": "google_token"} - ) - @pytest.mark.asyncio async def test_get_headers_no_credential(self): """Test header generation with no credentials.""" @@ -311,14 +277,14 @@ async def test_get_headers_no_credential(self): assert headers is None @pytest.mark.asyncio - async def test_get_headers_service_account_no_json(self): - """Test header generation for service account credentials without google_oauth2_json.""" + async def test_get_headers_service_account(self): + """Test header generation for service account credentials.""" tool = MCPTool( mcp_tool=self.mock_mcp_tool, mcp_session_manager=self.mock_session_manager, ) - # Create service account credential without google_oauth2_json + # Create service account credential service_account = ServiceAccount(scopes=["test"]) credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, @@ -328,7 +294,7 @@ async def test_get_headers_service_account_no_json(self): tool_context = Mock(spec=ToolContext) headers = await tool._get_headers(tool_context, credential) - # Should return None as no google_oauth2_json is provided + # Should return None as service account credentials are not supported for direct header generation assert headers is None @pytest.mark.asyncio From 7f8dc8927aaa401fe7079e430c941597c14237a3 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 18:31:24 -0700 Subject: [PATCH 37/79] chore: fix the mcp_sse_agent PiperOrigin-RevId: 773140021 --- contributing/samples/mcp_sse_agent/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/contributing/samples/mcp_sse_agent/agent.py b/contributing/samples/mcp_sse_agent/agent.py index 888a88b24..5423bfc6b 100755 --- a/contributing/samples/mcp_sse_agent/agent.py +++ b/contributing/samples/mcp_sse_agent/agent.py @@ -16,8 +16,8 @@ import os from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset -from google.adk.tools.mcp_tool.mcp_toolset import SseServerParams _allowed_path = os.path.dirname(os.path.abspath(__file__)) @@ -31,7 +31,7 @@ """, tools=[ MCPToolset( - connection_params=SseServerParams( + connection_params=SseConnectionParams( url='http://localhost:3000/sse', headers={'Accept': 'text/event-stream'}, ), From 17beb32880235712025992f4553634088d32a56c Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 18 Jun 2025 19:08:03 -0700 Subject: [PATCH 38/79] chore: Bump version number and update changelog for 1.4.1 release PiperOrigin-RevId: 773148349 --- CHANGELOG.md | 35 +++++++++++++++++++++++++++++++++++ src/google/adk/version.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04740bb7a..a9873184b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,40 @@ # Changelog +## [1.4.1](https://github.com/google/adk-python/compare/v1.3.0...v1.4.1) (2025-06-18) + + +### Features + +* Add Authenticated Tool (Experimental) ([dcea776](https://github.com/google/adk-python/commit/dcea7767c67c7edfb694304df32dca10b74c9a71)) +* Add enable_affective_dialog and proactivity to run_config and llm_request ([fe1d5aa](https://github.com/google/adk-python/commit/fe1d5aa439cc56b89d248a52556c0a9b4cbd15e4)) +* Add import session API in the fast API ([233fd20](https://github.com/google/adk-python/commit/233fd2024346abd7f89a16c444de0cf26da5c1a1)) +* Add integration tests for litellm with and without turning on add_function_to_prompt ([8e28587](https://github.com/google/adk-python/commit/8e285874da7f5188ea228eb4d7262dbb33b1ae6f)) +* Allow data_store_specs pass into ADK VAIS built-in tool ([675faef](https://github.com/google/adk-python/commit/675faefc670b5cd41991939fe0fc604df331111a)) +* Enable MCP Tool Auth (Experimental) ([157d9be](https://github.com/google/adk-python/commit/157d9be88d92f22320604832e5a334a6eb81e4af)) +* Implement GcsEvalSetResultsManager to handle storage of eval sets on GCS, and refactor eval set results manager ([0a5cf45](https://github.com/google/adk-python/commit/0a5cf45a75aca7b0322136b65ca5504a0c3c7362)) +* Re-factor some eval sets manager logic, and implement GcsEvalSetsManager to handle storage of eval sets on GCS ([1551bd4](https://github.com/google/adk-python/commit/1551bd4f4d7042fffb497d9308b05f92d45d818f)) +* Support real time input config ([d22920b](https://github.com/google/adk-python/commit/d22920bd7f827461afd649601326b0c58aea6716)) +* Support refresh access token automatically for rest_api_tool ([1779801](https://github.com/google/adk-python/commit/177980106b2f7be9a8c0a02f395ff0f85faa0c5a)) + +### Bug Fixes + +* Fix Agent generate config err ([#1305](https://github.com/google/adk-python/issues/1305)) ([badbcbd](https://github.com/google/adk-python/commit/badbcbd7a464e6b323cf3164d2bcd4e27cbc057f)) +* Fix Agent generate config error ([#1450](https://github.com/google/adk-python/issues/1450)) ([694b712](https://github.com/google/adk-python/commit/694b71256c631d44bb4c4488279ea91d82f43e26)) +* Fix liteLLM test failures ([fef8778](https://github.com/google/adk-python/commit/fef87784297b806914de307f48c51d83f977298f)) +* Fix tracing for live ([58e07ca](https://github.com/google/adk-python/commit/58e07cae83048d5213d822be5197a96be9ce2950)) +* Merge custom http options with adk specific http options in model api request ([4ccda99](https://github.com/google/adk-python/commit/4ccda99e8ec7aa715399b4b83c3f101c299a95e8)) +* Remove unnecessary double quote on Claude docstring ([bbceb4f](https://github.com/google/adk-python/commit/bbceb4f2e89f720533b99cf356c532024a120dc4)) +* Set explicit project in the BigQuery client ([6d174eb](https://github.com/google/adk-python/commit/6d174eba305a51fcf2122c0fd481378752d690ef)) +* Support streaming in litellm + adk and add corresponding integration tests ([aafa80b](https://github.com/google/adk-python/commit/aafa80bd85a49fb1c1a255ac797587cffd3fa567)) +* Support project-based gemini model path to use google_search_tool ([b2fc774](https://github.com/google/adk-python/commit/b2fc7740b363a4e33ec99c7377f396f5cee40b5a)) +* Update conversion between Celsius and Fahrenheit ([1ae176a](https://github.com/google/adk-python/commit/1ae176ad2fa2b691714ac979aec21f1cf7d35e45)) + +### Chores + +* Set `agent_engine_id` in the VertexAiSessionService constructor, also use the `agent_engine_id` field instead of overriding `app_name` in FastAPI endpoint ([fc65873](https://github.com/google/adk-python/commit/fc65873d7c31be607f6cd6690f142a031631582a)) + + + ## [1.3.0](https://github.com/google/adk-python/compare/v1.2.1...v1.3.0) (2025-06-11) diff --git a/src/google/adk/version.py b/src/google/adk/version.py index c0b08cc60..e39c67455 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.3.0" +__version__ = "1.4.1" From 2f716ada7fbcf8e03ff5ae16ce26a80ca6fd7bf6 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 18 Jun 2025 22:01:07 -0700 Subject: [PATCH 39/79] fix: Allow more credentials types for BigQuery tools This change accepts the `google.auth.credentials.Credentials` type for `BigQueryCredentialsConfig`, so any subclass of that, including `google.oauth2.credentials.Credentials` would work to integrate with BigQuery service. This opens up a whole range of possibilities, such as using service account credentials to deploy an agent using these tools. PiperOrigin-RevId: 773190440 --- contributing/samples/bigquery/README.md | 23 +++++-- contributing/samples/bigquery/agent.py | 27 +++++--- .../tools/bigquery/bigquery_credentials.py | 39 ++++++++---- .../adk/tools/bigquery/bigquery_tool.py | 2 +- src/google/adk/tools/bigquery/client.py | 2 +- .../adk/tools/bigquery/metadata_tool.py | 2 +- src/google/adk/tools/bigquery/query_tool.py | 2 +- .../bigquery/test_bigquery_credentials.py | 47 ++++++++++---- .../test_bigquery_credentials_manager.py | 62 +++++++++++++++---- 9 files changed, 155 insertions(+), 51 deletions(-) diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index cd4583c72..050ce1332 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -40,13 +40,28 @@ would set: ### With Application Default Credentials This mode is useful for quick development when the agent builder is the only -user interacting with the agent. The tools are initialized with the default -credentials present on the machine running the agent. +user interacting with the agent. The tools are run with these credentials. 1. Create application default credentials on the machine where the agent would be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc. -1. Set `RUN_WITH_ADC=True` in `agent.py` and run the agent +1. Set `CREDENTIALS_TYPE=None` in `agent.py` + +1. Run the agent + +### With Service Account Keys + +This mode is useful for quick development when the agent builder wants to run +the agent with service account credentials. The tools are run with these +credentials. + +1. Create service account key by following https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.SERVICE_ACCOUNT` in `agent.py` + +1. Download the key file and replace `"service_account_key.json"` with the path + +1. Run the agent ### With Interactive OAuth @@ -72,7 +87,7 @@ type. Note: don't create a separate .env, instead put it to the same .env file that stores your Vertex AI or Dev ML credentials -1. Set `RUN_WITH_ADC=False` in `agent.py` and run the agent +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.OAUTH2` in `agent.py` and run the agent ## Sample prompts diff --git a/contributing/samples/bigquery/agent.py b/contributing/samples/bigquery/agent.py index 0999ca12a..3cd1eb997 100644 --- a/contributing/samples/bigquery/agent.py +++ b/contributing/samples/bigquery/agent.py @@ -15,24 +15,21 @@ import os from google.adk.agents import llm_agent +from google.adk.auth import AuthCredentialTypes from google.adk.tools.bigquery import BigQueryCredentialsConfig from google.adk.tools.bigquery import BigQueryToolset from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode import google.auth -RUN_WITH_ADC = False +# Define an appropriate credential type +CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2 +# Define BigQuery tool config tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) -if RUN_WITH_ADC: - # Initialize the tools to use the application default credentials. - application_default_credentials, _ = google.auth.default() - credentials_config = BigQueryCredentialsConfig( - credentials=application_default_credentials - ) -else: +if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: # Initiaze the tools to do interactive OAuth # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET # must be set @@ -40,6 +37,20 @@ client_id=os.getenv("OAUTH_CLIENT_ID"), client_secret=os.getenv("OAUTH_CLIENT_SECRET"), ) +elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT: + # Initialize the tools to use the credentials in the service account key. + # If this flow is enabled, make sure to replace the file path with your own + # service account key file + # https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys + creds, _ = google.auth.load_credentials_from_file("service_account_key.json") + credentials_config = BigQueryCredentialsConfig(credentials=creds) +else: + # Initialize the tools to use the application default credentials. + # https://cloud.google.com/docs/authentication/provide-credentials-adc + application_default_credentials, _ = google.auth.default() + credentials_config = BigQueryCredentialsConfig( + credentials=application_default_credentials + ) bigquery_toolset = BigQueryToolset( credentials_config=credentials_config, bigquery_tool_config=tool_config diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index 0a99136c4..d0f3abe0e 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -21,9 +21,10 @@ from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import OAuthFlowAuthorizationCode from fastapi.openapi.models import OAuthFlows +import google.auth.credentials from google.auth.exceptions import RefreshError from google.auth.transport.requests import Request -from google.oauth2.credentials import Credentials +import google.oauth2.credentials from pydantic import BaseModel from pydantic import model_validator @@ -40,26 +41,35 @@ @experimental class BigQueryCredentialsConfig(BaseModel): - """Configuration for Google API tools. (Experimental)""" + """Configuration for Google API tools (Experimental). + + Please do not use this in production, as it may be deprecated later. + """ # Configure the model to allow arbitrary types like Credentials model_config = {"arbitrary_types_allowed": True} - credentials: Optional[Credentials] = None - """the existing oauth credentials to use. If set,this credential will be used + credentials: Optional[google.auth.credentials.Credentials] = None + """The existing auth credentials to use. If set, this credential will be used for every end user, end users don't need to be involved in the oauthflow. This field is mutually exclusive with client_id, client_secret and scopes. Don't set this field unless you are sure this credential has the permission to access every end user's data. - Example usage: when the agent is deployed in Google Cloud environment and + Example usage 1: When the agent is deployed in Google Cloud environment and the service account (used as application default credentials) has access to all the required BigQuery resource. Setting this credential to allow user to access the BigQuery resource without end users going through oauth flow. - To get application default credential: `google.auth.default(...)`. See more + To get application default credential, use: `google.auth.default(...)`. See more details in https://cloud.google.com/docs/authentication/application-default-credentials. + Example usage 2: When the agent wants to access the user's BigQuery resources + using the service account key credentials. + + To load service account key credentials, use: `google.auth.load_credentials_from_file(...)`. + See more details in https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + When the deployed environment cannot provide a pre-existing credential, consider setting below client_id, client_secret and scope for end users to go through oauth flow, so that agent can access the user data. @@ -86,7 +96,9 @@ def __post_init__(self) -> BigQueryCredentialsConfig: " client_id/client_secret/scopes." ) - if self.credentials: + if self.credentials and isinstance( + self.credentials, google.oauth2.credentials.Credentials + ): self.client_id = self.credentials.client_id self.client_secret = self.credentials.client_secret self.scopes = self.credentials.scopes @@ -115,7 +127,7 @@ def __init__(self, credentials_config: BigQueryCredentialsConfig): async def get_valid_credentials( self, tool_context: ToolContext - ) -> Optional[Credentials]: + ) -> Optional[google.auth.credentials.Credentials]: """Get valid credentials, handling refresh and OAuth flow as needed. Args: @@ -127,7 +139,7 @@ async def get_valid_credentials( # First, try to get credentials from the tool context creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None) creds = ( - Credentials.from_authorized_user_info( + google.oauth2.credentials.Credentials.from_authorized_user_info( json.loads(creds_json), self.credentials_config.scopes ) if creds_json @@ -138,6 +150,11 @@ async def get_valid_credentials( if not creds: creds = self.credentials_config.credentials + # If non-oauth credentials are provided then use them as is. This helps + # in flows such as service account keys + if creds and not isinstance(creds, google.oauth2.credentials.Credentials): + return creds + # Check if we have valid credentials if creds and creds.valid: return creds @@ -159,7 +176,7 @@ async def get_valid_credentials( async def _perform_oauth_flow( self, tool_context: ToolContext - ) -> Optional[Credentials]: + ) -> Optional[google.oauth2.credentials.Credentials]: """Perform OAuth flow to get new credentials. Args: @@ -199,7 +216,7 @@ async def _perform_oauth_flow( if auth_response: # OAuth flow completed, create credentials - creds = Credentials( + creds = google.oauth2.credentials.Credentials( token=auth_response.oauth2.access_token, refresh_token=auth_response.oauth2.refresh_token, token_uri=auth_scheme.flows.authorizationCode.tokenUrl, diff --git a/src/google/adk/tools/bigquery/bigquery_tool.py b/src/google/adk/tools/bigquery/bigquery_tool.py index 182734188..50d49ff77 100644 --- a/src/google/adk/tools/bigquery/bigquery_tool.py +++ b/src/google/adk/tools/bigquery/bigquery_tool.py @@ -19,7 +19,7 @@ from typing import Callable from typing import Optional -from google.oauth2.credentials import Credentials +from google.auth.credentials import Credentials from typing_extensions import override from ...utils.feature_decorator import experimental diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index 23f1befc5..8b2816ebe 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -15,8 +15,8 @@ from __future__ import annotations import google.api_core.client_info +from google.auth.credentials import Credentials from google.cloud import bigquery -from google.oauth2.credentials import Credentials from ... import version diff --git a/src/google/adk/tools/bigquery/metadata_tool.py b/src/google/adk/tools/bigquery/metadata_tool.py index 4f5400611..64f23d07b 100644 --- a/src/google/adk/tools/bigquery/metadata_tool.py +++ b/src/google/adk/tools/bigquery/metadata_tool.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.auth.credentials import Credentials from google.cloud import bigquery -from google.oauth2.credentials import Credentials from . import client diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index d3a94fda7..147d0b4db 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -16,8 +16,8 @@ import types from typing import Callable +from google.auth.credentials import Credentials from google.cloud import bigquery -from google.oauth2.credentials import Credentials from . import client from .config import BigQueryToolConfig diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index 9fa152fc2..05af3aaf3 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest import mock from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig # Mock the Google OAuth and API dependencies -from google.oauth2.credentials import Credentials +import google.auth.credentials +import google.oauth2.credentials import pytest @@ -27,22 +28,46 @@ class TestBigQueryCredentials: either existing credentials or client ID/secret pairs are provided. """ - def test_valid_credentials_object(self): - """Test that providing valid Credentials object works correctly. + def test_valid_credentials_object_auth_credentials(self): + """Test that providing valid Credentials object works correctly with + google.auth.credentials.Credentials. When a user already has valid OAuth credentials, they should be able to pass them directly without needing to provide client ID/secret. """ - # Create a mock credentials object with the expected attributes - mock_creds = Mock(spec=Credentials) - mock_creds.client_id = "test_client_id" - mock_creds.client_secret = "test_client_secret" - mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"] + # Create a mock auth credentials object + # auth_creds = google.auth.credentials.Credentials() + auth_creds = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + config = BigQueryCredentialsConfig(credentials=auth_creds) + + # Verify that the credentials are properly stored and attributes are extracted + assert config.credentials == auth_creds + assert config.client_id is None + assert config.client_secret is None + assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + + def test_valid_credentials_object_oauth2_credentials(self): + """Test that providing valid Credentials object works correctly with + google.oauth2.credentials.Credentials. + + When a user already has valid OAuth credentials, they should be able + to pass them directly without needing to provide client ID/secret. + """ + # Create a mock oauth2 credentials object + oauth2_creds = google.oauth2.credentials.Credentials( + "test_token", + client_id="test_client_id", + client_secret="test_client_secret", + scopes=["https://www.googleapis.com/auth/calendar"], + ) - config = BigQueryCredentialsConfig(credentials=mock_creds) + config = BigQueryCredentialsConfig(credentials=oauth2_creds) # Verify that the credentials are properly stored and attributes are extracted - assert config.credentials == mock_creds + assert config.credentials == oauth2_creds assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" assert config.scopes == ["https://www.googleapis.com/auth/calendar"] diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py index 95d8b00d6..47d955906 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py @@ -22,9 +22,10 @@ from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager +from google.auth.credentials import Credentials as AuthCredentials from google.auth.exceptions import RefreshError # Mock the Google OAuth and API dependencies -from google.oauth2.credentials import Credentials +from google.oauth2.credentials import Credentials as OAuthCredentials import pytest @@ -64,9 +65,16 @@ def manager(self, credentials_config): """Create a credentials manager instance for testing.""" return BigQueryCredentialsManager(credentials_config) + @pytest.mark.parametrize( + ("credentials_class",), + [ + pytest.param(OAuthCredentials, id="oauth"), + pytest.param(AuthCredentials, id="auth"), + ], + ) @pytest.mark.asyncio async def test_get_valid_credentials_with_valid_existing_creds( - self, manager, mock_tool_context + self, manager, mock_tool_context, credentials_class ): """Test that valid existing credentials are returned immediately. @@ -74,7 +82,7 @@ async def test_get_valid_credentials_with_valid_existing_creds( should be needed. This is the optimal happy path scenario. """ # Create mock credentials that are already valid - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=credentials_class) mock_creds.valid = True manager.credentials_config.credentials = mock_creds @@ -85,6 +93,34 @@ async def test_get_valid_credentials_with_valid_existing_creds( mock_tool_context.get_auth_response.assert_not_called() mock_tool_context.request_credential.assert_not_called() + @pytest.mark.parametrize( + ("valid",), + [ + pytest.param(False, id="invalid"), + pytest.param(True, id="valid"), + ], + ) + @pytest.mark.asyncio + async def test_get_valid_credentials_with_existing_non_oauth_creds( + self, manager, mock_tool_context, valid + ): + """Test that existing non-oauth credentials are returned immediately. + + When credentials are of non-oauth type, no refresh or OAuth flow + is triggered irrespective of whether it is valid or not. + """ + # Create mock credentials that are already valid + mock_creds = Mock(spec=AuthCredentials) + mock_creds.valid = valid + manager.credentials_config.credentials = mock_creds + + result = await manager.get_valid_credentials(mock_tool_context) + + assert result == mock_creds + # Verify no OAuth flow was triggered + mock_tool_context.get_auth_response.assert_not_called() + mock_tool_context.request_credential.assert_not_called() + @pytest.mark.asyncio async def test_get_credentials_from_cache_when_none_in_manager( self, manager, mock_tool_context @@ -113,7 +149,7 @@ async def test_get_credentials_from_cache_when_none_in_manager( with patch( "google.oauth2.credentials.Credentials.from_authorized_user_info" ) as mock_from_json: - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) mock_creds.valid = True mock_from_json.return_value = mock_creds @@ -179,7 +215,7 @@ async def test_refresh_cached_credentials_success( mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json # Create expired cached credentials with refresh token - mock_cached_creds = Mock(spec=Credentials) + mock_cached_creds = Mock(spec=OAuthCredentials) mock_cached_creds.valid = False mock_cached_creds.expired = True mock_cached_creds.refresh_token = "valid_refresh_token" @@ -227,7 +263,7 @@ async def test_get_valid_credentials_with_refresh_success( users from having to re-authenticate for every expired token. """ # Create expired credentials with refresh token - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) mock_creds.valid = False mock_creds.expired = True mock_creds.refresh_token = "refresh_token" @@ -257,7 +293,7 @@ async def test_get_valid_credentials_with_refresh_failure( gracefully fall back to requesting a new OAuth flow. """ # Create expired credentials that fail to refresh - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) mock_creds.valid = False mock_creds.expired = True mock_creds.refresh_token = "expired_refresh_token" @@ -287,7 +323,7 @@ async def test_oauth_flow_completion_with_caching( mock_tool_context.get_auth_response.return_value = mock_auth_response # Create a mock credentials instance that will represent our created credentials - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) # Make the JSON match what a real Credentials object would produce mock_creds_json = ( '{"token": "new_access_token", "refresh_token": "new_refresh_token",' @@ -300,7 +336,7 @@ async def test_oauth_flow_completion_with_caching( # Use the full module path as it appears in the project structure with patch( - "google.adk.tools.bigquery.bigquery_credentials.Credentials", + "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: result = await manager.get_valid_credentials(mock_tool_context) @@ -361,7 +397,7 @@ async def test_cache_persistence_across_manager_instances( mock_tool_context.get_auth_response.return_value = mock_auth_response # Create the mock credentials instance that will be returned by the constructor - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) # Make sure our mock JSON matches the structure that real Credentials objects produce mock_creds_json = ( '{"token": "cached_access_token", "refresh_token":' @@ -376,7 +412,7 @@ async def test_cache_persistence_across_manager_instances( # Use the correct module path - without the 'src.' prefix with patch( - "google.adk.tools.bigquery.bigquery_credentials.Credentials", + "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: # Complete OAuth flow with first manager @@ -396,9 +432,9 @@ async def test_cache_persistence_across_manager_instances( # Mock the from_authorized_user_info method for the second manager with patch( - "google.adk.tools.bigquery.bigquery_credentials.Credentials.from_authorized_user_info" + "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials.from_authorized_user_info" ) as mock_from_json: - mock_cached_creds = Mock(spec=Credentials) + mock_cached_creds = Mock(spec=OAuthCredentials) mock_cached_creds.valid = True mock_from_json.return_value = mock_cached_creds From ffcba70686f4e06d40aac37445a11daa16a418cf Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 20 Jun 2025 11:45:56 -0700 Subject: [PATCH 40/79] chore: skip mcp and a2a tests for python 3.9 PiperOrigin-RevId: 773785385 --- .../a2a/converters/test_part_converter.py | 39 +++++++++++--- .../mcp_tool/test_mcp_session_manager.py | 32 ++++++++++-- .../unittests/tools/mcp_tool/test_mcp_tool.py | 35 ++++++++++--- .../tools/mcp_tool/test_mcp_toolset.py | 51 ++++++++++++------- 4 files changed, 121 insertions(+), 36 deletions(-) diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 5ad6cd62d..d9c8e86d4 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -13,18 +13,43 @@ # limitations under the License. import json +import sys from unittest.mock import Mock from unittest.mock import patch -from a2a import types as a2a_types -from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL -from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE -from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY -from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part -from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part -from google.genai import types as genai_types import pytest +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a import types as a2a_types + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY + from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part + from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part + from google.genai import types as genai_types +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + a2a_types = DummyTypes() + genai_types = DummyTypes() + A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = "function_call" + A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = "function_response" + A2A_DATA_PART_METADATA_TYPE_KEY = "type" + convert_a2a_part_to_genai_part = lambda x: None + convert_genai_part_to_a2a_part = lambda x: None + else: + raise e + class TestConvertA2aPartToGenaiPart: """Test cases for convert_a2a_part_to_genai_part function.""" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 448d41260..559e51719 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -20,13 +20,35 @@ from unittest.mock import Mock from unittest.mock import patch -from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager -from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource -from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams -from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams -from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams import pytest +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager + from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource + from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams + from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams + from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyClass: + pass + + MCPSessionManager = DummyClass + retry_on_closed_resource = lambda x: x + SseConnectionParams = DummyClass + StdioConnectionParams = DummyClass + StreamableHTTPConnectionParams = DummyClass + else: + raise e + # Import real MCP classes try: from mcp import StdioServerParameters diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index d25a84eac..82e3f2234 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +import sys +from typing import Any +from typing import Dict from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch @@ -23,14 +25,33 @@ from google.adk.auth.auth_credential import HttpCredentials from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_credential import ServiceAccount -from google.adk.auth.auth_schemes import AuthScheme -from google.adk.auth.auth_schemes import AuthSchemeType -from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager -from google.adk.tools.mcp_tool.mcp_tool import MCPTool -from google.adk.tools.tool_context import ToolContext -from google.genai.types import FunctionDeclaration import pytest +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager + from google.adk.tools.mcp_tool.mcp_tool import MCPTool + from google.adk.tools.tool_context import ToolContext + from google.genai.types import FunctionDeclaration +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyClass: + pass + + MCPSessionManager = DummyClass + MCPTool = DummyClass + ToolContext = DummyClass + FunctionDeclaration = DummyClass + else: + raise e + # Mock MCP Tool from mcp.types class MockMCPTool: diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 0ba29b1da..d5e6ae243 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -14,32 +14,49 @@ from io import StringIO import sys +import unittest from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch -from google.adk.agents.readonly_context import ReadonlyContext from google.adk.auth.auth_credential import AuthCredential -from google.adk.auth.auth_schemes import AuthScheme -from google.adk.auth.auth_schemes import AuthSchemeType -from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager -from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams -from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams -from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams -from google.adk.tools.mcp_tool.mcp_tool import MCPTool -from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset import pytest -# Import the real MCP classes for proper instantiation +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" +) + +# Import dependencies with version checking try: + from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager + from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams + from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams + from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams + from google.adk.tools.mcp_tool.mcp_tool import MCPTool + from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from mcp import StdioServerParameters -except ImportError: - # Create a mock if MCP is not available - class StdioServerParameters: - - def __init__(self, command="test_command", args=None): - self.command = command - self.args = args or [] +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyClass: + pass + + class StdioServerParameters: + + def __init__(self, command="test_command", args=None): + self.command = command + self.args = args or [] + + MCPSessionManager = DummyClass + SseConnectionParams = DummyClass + StdioConnectionParams = DummyClass + StreamableHTTPConnectionParams = DummyClass + MCPTool = DummyClass + MCPToolset = DummyClass + else: + raise e class MockMCPTool: From 742478fdb78d2178bce5e9f77dde519975294cc4 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 20 Jun 2025 12:13:31 -0700 Subject: [PATCH 41/79] chore: Add event converters to convert adk event to a2a event (WIP) PiperOrigin-RevId: 773795427 --- .../adk/a2a/converters/event_converter.py | 382 ++++++++++++ .../adk/a2a/converters/part_converter.py | 6 +- src/google/adk/a2a/converters/utils.py | 34 + .../a2a/converters/test_event_converter.py | 589 ++++++++++++++++++ .../a2a/converters/test_part_converter.py | 6 +- 5 files changed, 1013 insertions(+), 4 deletions(-) create mode 100644 src/google/adk/a2a/converters/event_converter.py create mode 100644 src/google/adk/a2a/converters/utils.py create mode 100644 tests/unittests/a2a/converters/test_event_converter.py diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py new file mode 100644 index 000000000..5594c0e63 --- /dev/null +++ b/src/google/adk/a2a/converters/event_converter.py @@ -0,0 +1,382 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import datetime +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +import uuid + +from a2a.server.events import Event as A2AEvent +from a2a.types import Artifact +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event +from ...utils.feature_decorator import working_in_progress +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import convert_genai_part_to_a2a_part +from .utils import _get_adk_metadata_key + +# Constants + +ARTIFACT_ID_SEPARATOR = "-" +DEFAULT_ERROR_MESSAGE = "An error occurred during processing" + +# Logger +logger = logging.getLogger("google_adk." + __name__) + + +def _serialize_metadata_value(value: Any) -> str: + """Safely serializes metadata values to string format. + + Args: + value: The value to serialize. + + Returns: + String representation of the value. + """ + if hasattr(value, "model_dump"): + try: + return value.model_dump(exclude_none=True, by_alias=True) + except Exception as e: + logger.warning("Failed to serialize metadata value: %s", e) + return str(value) + return str(value) + + +def _get_context_metadata( + event: Event, invocation_context: InvocationContext +) -> Dict[str, str]: + """Gets the context metadata for the event. + + Args: + event: The ADK event to extract metadata from. + invocation_context: The invocation context containing session information. + + Returns: + A dictionary containing the context metadata. + + Raises: + ValueError: If required fields are missing from event or context. + """ + if not event: + raise ValueError("Event cannot be None") + if not invocation_context: + raise ValueError("Invocation context cannot be None") + + try: + metadata = { + _get_adk_metadata_key("app_name"): invocation_context.app_name, + _get_adk_metadata_key("user_id"): invocation_context.user_id, + _get_adk_metadata_key("session_id"): invocation_context.session.id, + _get_adk_metadata_key("invocation_id"): event.invocation_id, + _get_adk_metadata_key("author"): event.author, + } + + # Add optional metadata fields if present + optional_fields = [ + ("branch", event.branch), + ("grounding_metadata", event.grounding_metadata), + ("custom_metadata", event.custom_metadata), + ("usage_metadata", event.usage_metadata), + ("error_code", event.error_code), + ] + + for field_name, field_value in optional_fields: + if field_value is not None: + metadata[_get_adk_metadata_key(field_name)] = _serialize_metadata_value( + field_value + ) + + return metadata + + except Exception as e: + logger.error("Failed to create context metadata: %s", e) + raise + + +def _create_artifact_id( + app_name: str, user_id: str, session_id: str, filename: str, version: int +) -> str: + """Creates a unique artifact ID. + + Args: + app_name: The application name. + user_id: The user ID. + session_id: The session ID. + filename: The artifact filename. + version: The artifact version. + + Returns: + A unique artifact ID string. + """ + components = [app_name, user_id, session_id, filename, str(version)] + return ARTIFACT_ID_SEPARATOR.join(components) + + +def _convert_artifact_to_a2a_events( + event: Event, + invocation_context: InvocationContext, + filename: str, + version: int, +) -> TaskArtifactUpdateEvent: + """Converts a new artifact version to an A2A TaskArtifactUpdateEvent. + + Args: + event: The ADK event containing the artifact information. + invocation_context: The invocation context. + filename: The name of the artifact file. + version: The version number of the artifact. + + Returns: + A TaskArtifactUpdateEvent representing the artifact update. + + Raises: + ValueError: If required parameters are invalid. + RuntimeError: If artifact loading fails. + """ + if not filename: + raise ValueError("Filename cannot be empty") + if version < 0: + raise ValueError("Version must be non-negative") + + try: + artifact_part = invocation_context.artifact_service.load_artifact( + app_name=invocation_context.app_name, + user_id=invocation_context.user_id, + session_id=invocation_context.session.id, + filename=filename, + version=version, + ) + + converted_part = convert_genai_part_to_a2a_part(part=artifact_part) + if not converted_part: + raise RuntimeError(f"Failed to convert artifact part for {filename}") + + artifact_id = _create_artifact_id( + invocation_context.app_name, + invocation_context.user_id, + invocation_context.session.id, + filename, + version, + ) + + return TaskArtifactUpdateEvent( + taskId=str(uuid.uuid4()), + append=False, + contextId=invocation_context.session.id, + lastChunk=True, + artifact=Artifact( + artifactId=artifact_id, + name=filename, + metadata={ + "filename": filename, + "version": version, + }, + parts=[converted_part], + ), + ) + except Exception as e: + logger.error( + "Failed to convert artifact for %s, version %s: %s", + filename, + version, + e, + ) + raise RuntimeError(f"Artifact conversion failed: {e}") from e + + +def _process_long_running_tool(a2a_part, event: Event) -> None: + """Processes long-running tool metadata for an A2A part. + + Args: + a2a_part: The A2A part to potentially mark as long-running. + event: The ADK event containing long-running tool information. + """ + if ( + isinstance(a2a_part.root, DataPart) + and event.long_running_tool_ids + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and a2a_part.root.metadata.get("id") in event.long_running_tool_ids + ): + a2a_part.root.metadata[_get_adk_metadata_key("is_long_running")] = True + + +@working_in_progress +def convert_event_to_a2a_status_message( + event: Event, invocation_context: InvocationContext +) -> Optional[Message]: + """Converts an ADK event to an A2A message. + + Args: + event: The ADK event to convert. + invocation_context: The invocation context. + + Returns: + An A2A Message if the event has content, None otherwise. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + if not invocation_context: + raise ValueError("Invocation context cannot be None") + + if not event.content or not event.content.parts: + return None + + try: + a2a_parts = [] + for part in event.content.parts: + a2a_part = convert_genai_part_to_a2a_part(part) + if a2a_part: + a2a_parts.append(a2a_part) + _process_long_running_tool(a2a_part, event) + + if a2a_parts: + return Message( + messageId=str(uuid.uuid4()), role=Role.agent, parts=a2a_parts + ) + + except Exception as e: + logger.error("Failed to convert event to status message: %s", e) + raise + + return None + + +def _create_error_status_event( + event: Event, invocation_context: InvocationContext +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent for error scenarios. + + Args: + event: The ADK event containing error information. + invocation_context: The invocation context. + + Returns: + A TaskStatusUpdateEvent with FAILED state. + """ + error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE + + return TaskStatusUpdateEvent( + taskId=str(uuid.uuid4()), + contextId=invocation_context.session.id, + final=False, + metadata=_get_context_metadata(event, invocation_context), + status=TaskStatus( + state=TaskState.failed, + message=Message( + messageId=str(uuid.uuid4()), + role=Role.agent, + parts=[TextPart(text=error_message)], + ), + timestamp=datetime.datetime.now().isoformat(), + ), + ) + + +def _create_running_status_event( + message: Message, invocation_context: InvocationContext, event: Event +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent for running scenarios. + + Args: + message: The A2A message to include. + invocation_context: The invocation context. + event: The ADK event. + + Returns: + A TaskStatusUpdateEvent with RUNNING state. + """ + return TaskStatusUpdateEvent( + taskId=str(uuid.uuid4()), + contextId=invocation_context.session.id, + final=False, + status=TaskStatus( + state=TaskState.working, + message=message, + timestamp=datetime.datetime.now().isoformat(), + ), + metadata=_get_context_metadata(event, invocation_context), + ) + + +@working_in_progress +def convert_event_to_a2a_events( + event: Event, invocation_context: InvocationContext +) -> List[A2AEvent]: + """Converts a GenAI event to a list of A2A events. + + Args: + event: The ADK event to convert. + invocation_context: The invocation context. + + Returns: + A list of A2A events representing the converted ADK event. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + if not invocation_context: + raise ValueError("Invocation context cannot be None") + + a2a_events = [] + + try: + # Handle artifact deltas + if event.actions.artifact_delta: + for filename, version in event.actions.artifact_delta.items(): + artifact_event = _convert_artifact_to_a2a_events( + event, invocation_context, filename, version + ) + a2a_events.append(artifact_event) + + # Handle error scenarios + if event.error_code: + error_event = _create_error_status_event(event, invocation_context) + a2a_events.append(error_event) + + # Handle regular message content + message = convert_event_to_a2a_status_message(event, invocation_context) + if message: + running_event = _create_running_status_event( + message, invocation_context, event + ) + a2a_events.append(running_event) + + except Exception as e: + logger.error("Failed to convert event to A2A events: %s", e) + raise + + return a2a_events diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index 2d94abd7c..c47ac7276 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -23,6 +23,8 @@ import sys from typing import Optional +from .utils import _get_adk_metadata_key + try: from a2a import types as a2a_types except ImportError as e: @@ -84,7 +86,7 @@ def convert_a2a_part_to_genai_part( # logic accordinlgy if part.metadata and A2A_DATA_PART_METADATA_TYPE_KEY in part.metadata: if ( - part.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL ): return genai_types.Part( @@ -93,7 +95,7 @@ def convert_a2a_part_to_genai_part( ) ) if ( - part.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE ): return genai_types.Part( diff --git a/src/google/adk/a2a/converters/utils.py b/src/google/adk/a2a/converters/utils.py new file mode 100644 index 000000000..fe5f2e927 --- /dev/null +++ b/src/google/adk/a2a/converters/utils.py @@ -0,0 +1,34 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +ADK_METADATA_KEY_PREFIX = "adk_" + + +def _get_adk_metadata_key(key: str) -> str: + """Gets the A2A event metadata key for the given key. + + Args: + key: The metadata key to prefix. + + Returns: + The prefixed metadata key. + + Raises: + ValueError: If key is empty or None. + """ + if not key: + raise ValueError("Metadata key cannot be empty or None") + return f"{ADK_METADATA_KEY_PREFIX}{key}" diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py new file mode 100644 index 000000000..311ffc954 --- /dev/null +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -0,0 +1,589 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a.types import DataPart + from a2a.types import Message + from a2a.types import Role + from a2a.types import TaskArtifactUpdateEvent + from a2a.types import TaskState + from a2a.types import TaskStatusUpdateEvent + from google.adk.a2a.converters.event_converter import _convert_artifact_to_a2a_events + from google.adk.a2a.converters.event_converter import _create_artifact_id + from google.adk.a2a.converters.event_converter import _create_error_status_event + from google.adk.a2a.converters.event_converter import _create_running_status_event + from google.adk.a2a.converters.event_converter import _get_adk_metadata_key + from google.adk.a2a.converters.event_converter import _get_context_metadata + from google.adk.a2a.converters.event_converter import _process_long_running_tool + from google.adk.a2a.converters.event_converter import _serialize_metadata_value + from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR + from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events + from google.adk.a2a.converters.event_converter import convert_event_to_a2a_status_message + from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE + from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX + from google.adk.agents.invocation_context import InvocationContext + from google.adk.events.event import Event + from google.adk.events.event_actions import EventActions +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + DataPart = DummyTypes() + Message = DummyTypes() + Role = DummyTypes() + TaskArtifactUpdateEvent = DummyTypes() + TaskState = DummyTypes() + TaskStatusUpdateEvent = DummyTypes() + _convert_artifact_to_a2a_events = lambda *args: None + _create_artifact_id = lambda *args: None + _create_error_status_event = lambda *args: None + _create_running_status_event = lambda *args: None + _get_adk_metadata_key = lambda *args: None + _get_context_metadata = lambda *args: None + _process_long_running_tool = lambda *args: None + _serialize_metadata_value = lambda *args: None + ADK_METADATA_KEY_PREFIX = "adk_" + ARTIFACT_ID_SEPARATOR = "_" + convert_event_to_a2a_events = lambda *args: None + convert_event_to_a2a_status_message = lambda *args: None + DEFAULT_ERROR_MESSAGE = "error" + InvocationContext = DummyTypes() + Event = DummyTypes() + EventActions = DummyTypes() + types = DummyTypes() + else: + raise e + + +class TestEventConverter: + """Test suite for event_converter module.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_session = Mock() + self.mock_session.id = "test-session-id" + + self.mock_artifact_service = Mock() + self.mock_invocation_context = Mock(spec=InvocationContext) + self.mock_invocation_context.app_name = "test-app" + self.mock_invocation_context.user_id = "test-user" + self.mock_invocation_context.session = self.mock_session + self.mock_invocation_context.artifact_service = self.mock_artifact_service + + self.mock_event = Mock(spec=Event) + self.mock_event.invocation_id = "test-invocation-id" + self.mock_event.author = "test-author" + self.mock_event.branch = None + self.mock_event.grounding_metadata = None + self.mock_event.custom_metadata = None + self.mock_event.usage_metadata = None + self.mock_event.error_code = None + self.mock_event.error_message = None + self.mock_event.content = None + self.mock_event.long_running_tool_ids = None + self.mock_event.actions = Mock(spec=EventActions) + self.mock_event.actions.artifact_delta = None + + def test_get_adk_event_metadata_key_success(self): + """Test successful metadata key generation.""" + key = "test_key" + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_get_adk_event_metadata_key_empty_string(self): + """Test metadata key generation with empty string.""" + with pytest.raises(ValueError) as exc_info: + _get_adk_metadata_key("") + assert "cannot be empty or None" in str(exc_info.value) + + def test_get_adk_event_metadata_key_none(self): + """Test metadata key generation with None.""" + with pytest.raises(ValueError) as exc_info: + _get_adk_metadata_key(None) + assert "cannot be empty or None" in str(exc_info.value) + + def test_serialize_metadata_value_with_model_dump(self): + """Test serialization of value with model_dump method.""" + mock_value = Mock() + mock_value.model_dump.return_value = {"key": "value"} + + result = _serialize_metadata_value(mock_value) + + assert result == {"key": "value"} + mock_value.model_dump.assert_called_once_with( + exclude_none=True, by_alias=True + ) + + def test_serialize_metadata_value_with_model_dump_exception(self): + """Test serialization when model_dump raises exception.""" + mock_value = Mock() + mock_value.model_dump.side_effect = Exception("Serialization failed") + + with patch( + "google.adk.a2a.converters.event_converter.logger" + ) as mock_logger: + result = _serialize_metadata_value(mock_value) + + assert result == str(mock_value) + mock_logger.warning.assert_called_once() + + def test_serialize_metadata_value_without_model_dump(self): + """Test serialization of value without model_dump method.""" + value = "simple_string" + result = _serialize_metadata_value(value) + assert result == "simple_string" + + def test_get_context_metadata_success(self): + """Test successful context metadata creation.""" + result = _get_context_metadata( + self.mock_event, self.mock_invocation_context + ) + + assert result is not None + expected_keys = [ + f"{ADK_METADATA_KEY_PREFIX}app_name", + f"{ADK_METADATA_KEY_PREFIX}user_id", + f"{ADK_METADATA_KEY_PREFIX}session_id", + f"{ADK_METADATA_KEY_PREFIX}invocation_id", + f"{ADK_METADATA_KEY_PREFIX}author", + ] + + for key in expected_keys: + assert key in result + + def test_get_context_metadata_with_optional_fields(self): + """Test context metadata creation with optional fields.""" + self.mock_event.branch = "test-branch" + self.mock_event.error_code = "ERROR_001" + + mock_metadata = Mock() + mock_metadata.model_dump.return_value = {"test": "value"} + self.mock_event.grounding_metadata = mock_metadata + + result = _get_context_metadata( + self.mock_event, self.mock_invocation_context + ) + + assert result is not None + assert f"{ADK_METADATA_KEY_PREFIX}branch" in result + assert f"{ADK_METADATA_KEY_PREFIX}grounding_metadata" in result + assert result[f"{ADK_METADATA_KEY_PREFIX}branch"] == "test-branch" + + # Check if error_code is in the result - it should be there since we set it + if f"{ADK_METADATA_KEY_PREFIX}error_code" in result: + assert result[f"{ADK_METADATA_KEY_PREFIX}error_code"] == "ERROR_001" + + def test_get_context_metadata_none_event(self): + """Test context metadata creation with None event.""" + with pytest.raises(ValueError) as exc_info: + _get_context_metadata(None, self.mock_invocation_context) + assert "Event cannot be None" in str(exc_info.value) + + def test_get_context_metadata_none_context(self): + """Test context metadata creation with None context.""" + with pytest.raises(ValueError) as exc_info: + _get_context_metadata(self.mock_event, None) + assert "Invocation context cannot be None" in str(exc_info.value) + + def test_create_artifact_id(self): + """Test artifact ID creation.""" + app_name = "test-app" + user_id = "user123" + session_id = "session456" + filename = "test.txt" + version = 1 + + result = _create_artifact_id( + app_name, user_id, session_id, filename, version + ) + expected = f"{app_name}{ARTIFACT_ID_SEPARATOR}{user_id}{ARTIFACT_ID_SEPARATOR}{session_id}{ARTIFACT_ID_SEPARATOR}{filename}{ARTIFACT_ID_SEPARATOR}{version}" + + assert result == expected + + @patch( + "google.adk.a2a.converters.event_converter.convert_genai_part_to_a2a_part" + ) + def test_convert_artifact_to_a2a_events_success(self, mock_convert_part): + """Test successful artifact delta conversion.""" + filename = "test.txt" + version = 1 + + mock_artifact_part = Mock() + # Create a proper Part that Pydantic will accept + from a2a.types import Part + from a2a.types import TextPart + + text_part = TextPart(text="test content") + mock_converted_part = Part(root=text_part) + + self.mock_artifact_service.load_artifact.return_value = mock_artifact_part + mock_convert_part.return_value = mock_converted_part + + result = _convert_artifact_to_a2a_events( + self.mock_event, self.mock_invocation_context, filename, version + ) + + assert isinstance(result, TaskArtifactUpdateEvent) + assert result.contextId == self.mock_invocation_context.session.id + assert result.append is False + assert result.lastChunk is True + + # Check artifact properties + assert result.artifact.name == filename + assert result.artifact.metadata["filename"] == filename + assert result.artifact.metadata["version"] == version + assert len(result.artifact.parts) == 1 + assert result.artifact.parts[0].root.text == "test content" + + def test_convert_artifact_to_a2a_events_empty_filename(self): + """Test artifact delta conversion with empty filename.""" + with pytest.raises(ValueError) as exc_info: + _convert_artifact_to_a2a_events( + self.mock_event, self.mock_invocation_context, "", 1 + ) + assert "Filename cannot be empty" in str(exc_info.value) + + def test_convert_artifact_to_a2a_events_negative_version(self): + """Test artifact delta conversion with negative version.""" + with pytest.raises(ValueError) as exc_info: + _convert_artifact_to_a2a_events( + self.mock_event, self.mock_invocation_context, "test.txt", -1 + ) + assert "Version must be non-negative" in str(exc_info.value) + + @patch( + "google.adk.a2a.converters.event_converter.convert_genai_part_to_a2a_part" + ) + def test_convert_artifact_to_a2a_events_conversion_failure( + self, mock_convert_part + ): + """Test artifact delta conversion when part conversion fails.""" + filename = "test.txt" + version = 1 + + mock_artifact_part = Mock() + self.mock_artifact_service.load_artifact.return_value = mock_artifact_part + mock_convert_part.return_value = None # Simulate conversion failure + + with pytest.raises(RuntimeError) as exc_info: + _convert_artifact_to_a2a_events( + self.mock_event, self.mock_invocation_context, filename, version + ) + assert "Failed to convert artifact part" in str(exc_info.value) + + def test_process_long_running_tool_marks_tool(self): + """Test processing of long-running tool metadata.""" + mock_a2a_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-123"} + mock_a2a_part.root = mock_data_part + + self.mock_event.long_running_tool_ids = {"tool-123"} + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + ): + + _process_long_running_tool(mock_a2a_part, self.mock_event) + + expected_key = f"{ADK_METADATA_KEY_PREFIX}is_long_running" + assert mock_data_part.metadata[expected_key] is True + + def test_process_long_running_tool_no_marking(self): + """Test processing when tool should not be marked as long-running.""" + mock_a2a_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-456"} + mock_a2a_part.root = mock_data_part + + self.mock_event.long_running_tool_ids = {"tool-123"} # Different ID + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + ): + + _process_long_running_tool(mock_a2a_part, self.mock_event) + + expected_key = f"{ADK_METADATA_KEY_PREFIX}is_long_running" + assert expected_key not in mock_data_part.metadata + + @patch( + "google.adk.a2a.converters.event_converter.convert_genai_part_to_a2a_part" + ) + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + def test_convert_event_to_message_success(self, mock_uuid, mock_convert_part): + """Test successful event to message conversion.""" + mock_uuid.return_value = "test-uuid" + + mock_part = Mock() + # Create a proper Part that Pydantic will accept + from a2a.types import Part + from a2a.types import TextPart + + text_part = TextPart(text="test message") + mock_a2a_part = Part(root=text_part) + mock_convert_part.return_value = mock_a2a_part + + mock_content = Mock() + mock_content.parts = [mock_part] + self.mock_event.content = mock_content + + result = convert_event_to_a2a_status_message( + self.mock_event, self.mock_invocation_context + ) + + assert isinstance(result, Message) + assert result.messageId == "test-uuid" + assert result.role == Role.agent + assert len(result.parts) == 1 + assert result.parts[0].root.text == "test message" + + def test_convert_event_to_message_no_content(self): + """Test event to message conversion with no content.""" + self.mock_event.content = None + + result = convert_event_to_a2a_status_message( + self.mock_event, self.mock_invocation_context + ) + + assert result is None + + def test_convert_event_to_message_empty_parts(self): + """Test event to message conversion with empty parts.""" + mock_content = Mock() + mock_content.parts = [] + self.mock_event.content = mock_content + + result = convert_event_to_a2a_status_message( + self.mock_event, self.mock_invocation_context + ) + + assert result is None + + def test_convert_event_to_message_none_event(self): + """Test event to message conversion with None event.""" + with pytest.raises(ValueError) as exc_info: + convert_event_to_a2a_status_message(None, self.mock_invocation_context) + assert "Event cannot be None" in str(exc_info.value) + + def test_convert_event_to_message_none_context(self): + """Test event to message conversion with None context.""" + with pytest.raises(ValueError) as exc_info: + convert_event_to_a2a_status_message(self.mock_event, None) + assert "Invocation context cannot be None" in str(exc_info.value) + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + def test_create_error_status_event(self, mock_datetime, mock_uuid): + """Test creation of error status event.""" + mock_uuid.return_value = "test-uuid" + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + self.mock_event.error_message = "Test error message" + + result = _create_error_status_event( + self.mock_event, self.mock_invocation_context + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.contextId == self.mock_invocation_context.session.id + assert result.status.state == TaskState.failed + assert result.status.message.parts[0].root.text == "Test error message" + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid): + """Test creation of error status event without error message.""" + mock_uuid.return_value = "test-uuid" + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + result = _create_error_status_event( + self.mock_event, self.mock_invocation_context + ) + + assert result.status.message.parts[0].root.text == DEFAULT_ERROR_MESSAGE + + @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + def test_create_running_status_event(self, mock_datetime): + """Test creation of running status event.""" + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + mock_message = Mock(spec=Message) + + result = _create_running_status_event( + mock_message, self.mock_invocation_context, self.mock_event + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.contextId == self.mock_invocation_context.session.id + assert result.status.state == TaskState.working + assert result.status.message == mock_message + + @patch( + "google.adk.a2a.converters.event_converter._convert_artifact_to_a2a_events" + ) + @patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_status_message" + ) + @patch("google.adk.a2a.converters.event_converter._create_error_status_event") + @patch( + "google.adk.a2a.converters.event_converter._create_running_status_event" + ) + def test_convert_event_to_a2a_events_full_scenario( + self, + mock_create_running, + mock_create_error, + mock_convert_message, + mock_convert_artifact, + ): + """Test full event to A2A events conversion scenario.""" + # Setup artifact delta + self.mock_event.actions.artifact_delta = {"file1.txt": 1, "file2.txt": 2} + + # Setup error + self.mock_event.error_code = "ERROR_001" + + # Setup message + mock_message = Mock(spec=Message) + mock_convert_message.return_value = mock_message + + # Setup mock returns + mock_artifact_event1 = Mock() + mock_artifact_event2 = Mock() + mock_convert_artifact.side_effect = [ + mock_artifact_event1, + mock_artifact_event2, + ] + + mock_error_event = Mock() + mock_create_error.return_value = mock_error_event + + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context + ) + + # Verify artifact delta events + assert mock_convert_artifact.call_count == 2 + + # Verify error event + mock_create_error.assert_called_once_with( + self.mock_event, self.mock_invocation_context + ) + + # Verify running event + mock_create_running.assert_called_once_with( + mock_message, self.mock_invocation_context, self.mock_event + ) + + # Verify result contains all events + assert len(result) == 4 # 2 artifact + 1 error + 1 running + assert mock_artifact_event1 in result + assert mock_artifact_event2 in result + assert mock_error_event in result + assert mock_running_event in result + + def test_convert_event_to_a2a_events_empty_scenario(self): + """Test event to A2A events conversion with empty event.""" + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context + ) + + assert result == [] + + def test_convert_event_to_a2a_events_none_event(self): + """Test event to A2A events conversion with None event.""" + with pytest.raises(ValueError) as exc_info: + convert_event_to_a2a_events(None, self.mock_invocation_context) + assert "Event cannot be None" in str(exc_info.value) + + def test_convert_event_to_a2a_events_none_context(self): + """Test event to A2A events conversion with None context.""" + with pytest.raises(ValueError) as exc_info: + convert_event_to_a2a_events(self.mock_event, None) + assert "Invocation context cannot be None" in str(exc_info.value) + + @patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_status_message" + ) + def test_convert_event_to_a2a_events_message_only(self, mock_convert_message): + """Test event to A2A events conversion with message only.""" + mock_message = Mock(spec=Message) + mock_convert_message.return_value = mock_message + + with patch( + "google.adk.a2a.converters.event_converter._create_running_status_event" + ) as mock_create_running: + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context + ) + + assert len(result) == 1 + assert result[0] == mock_running_event + + @patch("google.adk.a2a.converters.event_converter.logger") + def test_convert_event_to_a2a_events_exception_handling(self, mock_logger): + """Test exception handling in event to A2A events conversion.""" + # Make convert_event_to_a2a_status_message raise an exception + with patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_status_message" + ) as mock_convert: + mock_convert.side_effect = Exception("Conversion failed") + + with pytest.raises(Exception): + convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context + ) + + mock_logger.error.assert_called_once() diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index d9c8e86d4..4b9bd47cf 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -125,7 +125,8 @@ def test_convert_data_part_function_call(self): metadata={ A2A_DATA_PART_METADATA_TYPE_KEY: ( A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - ) + ), + "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, }, ) ) @@ -153,7 +154,8 @@ def test_convert_data_part_function_response(self): metadata={ A2A_DATA_PART_METADATA_TYPE_KEY: ( A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - ) + ), + "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, }, ) ) From 4d72d31b13f352245baa72b78502206dcbe25406 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Fri, 20 Jun 2025 14:12:24 -0700 Subject: [PATCH 42/79] fix: Add type checking to handle different response type of genai API client Fixes https://github.com/google/adk-python/issues/1514 PiperOrigin-RevId: 773838035 --- .../adk/sessions/vertex_ai_session_service.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 258dcd933..bd1345162 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -14,6 +14,7 @@ from __future__ import annotations import asyncio +import json import logging import re from typing import Any @@ -87,6 +88,7 @@ async def create_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions', request_dict=session_json_dict, ) + api_response = _convert_api_response(api_response) logger.info(f'Create Session response {api_response}') session_id = api_response['name'].split('/')[-3] @@ -100,6 +102,7 @@ async def create_session( path=f'operations/{operation_id}', request_dict={}, ) + lro_response = _convert_api_response(lro_response) if lro_response.get('done', None): break @@ -118,6 +121,7 @@ async def create_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) + get_session_api_response = _convert_api_response(get_session_api_response) update_timestamp = isoparse( get_session_api_response['updateTime'] @@ -149,6 +153,7 @@ async def get_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) + get_session_api_response = _convert_api_response(get_session_api_response) session_id = get_session_api_response['name'].split('/')[-1] update_timestamp = isoparse( @@ -167,9 +172,12 @@ async def get_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', request_dict={}, ) + list_events_api_response = _convert_api_response(list_events_api_response) # Handles empty response case - if list_events_api_response.get('httpHeaders', None): + if not list_events_api_response or list_events_api_response.get( + 'httpHeaders', None + ): return session session.events += [ @@ -226,9 +234,10 @@ async def list_sessions( path=path, request_dict={}, ) + api_response = _convert_api_response(api_response) # Handles empty response case - if api_response.get('httpHeaders', None): + if not api_response or api_response.get('httpHeaders', None): return ListSessionsResponse() sessions = [] @@ -303,6 +312,13 @@ def _get_api_client(self): return client._api_client +def _convert_api_response(api_response): + """Converts the API response to a JSON object based on the type.""" + if hasattr(api_response, 'body'): + return json.loads(api_response.body) + return api_response + + def _convert_event_to_json(event: Event) -> Dict[str, Any]: metadata_json = { 'partial': event.partial, From 8677d5c8dce985a4070a03917da37ff2f6e5391e Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Fri, 20 Jun 2025 15:41:10 -0700 Subject: [PATCH 43/79] chore: bump version number to 1.4.2 PiperOrigin-RevId: 773867075 --- CHANGELOG.md | 9 +++++++++ src/google/adk/version.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a9873184b..ce36dcdcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## [1.4.2](https://github.com/google/adk-python/compare/v1.4.1...v1.4.2) (2025-06-20) + + +### Bug Fixes + +* Add type checking to handle different response type of genai API client ([4d72d31](https://github.com/google/adk-python/commit/4d72d31b13f352245baa72b78502206dcbe25406)) + * This fixes the broken VertexAiSessionService +* Allow more credentials types for BigQuery tools ([2f716ad](https://github.com/google/adk-python/commit/2f716ada7fbcf8e03ff5ae16ce26a80ca6fd7bf6)) + ## [1.4.1](https://github.com/google/adk-python/compare/v1.3.0...v1.4.1) (2025-06-18) diff --git a/src/google/adk/version.py b/src/google/adk/version.py index e39c67455..9accc1025 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.4.1" +__version__ = "1.4.2" From 2fd8feb65d6ae59732fb3ec0652d5650f47132cc Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Fri, 20 Jun 2025 16:53:34 -0700 Subject: [PATCH 44/79] chore: Support `allow_origins` in cloud_run deployment Also reorganize the fast_api_common_options. This resolves https://github.com/google/adk-python/issues/1444. PiperOrigin-RevId: 773890111 --- src/google/adk/cli/cli_deploy.py | 11 ++- src/google/adk/cli/cli_tools_click.py | 93 ++++++++++---------- tests/unittests/cli/utils/test_cli_deploy.py | 2 + 3 files changed, 59 insertions(+), 47 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 99c7e9bb1..44d4a900d 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -55,7 +55,7 @@ EXPOSE {port} -CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} "/app/agents" +CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} {allow_origins_option} "/app/agents" """ _AGENT_ENGINE_APP_TEMPLATE = """ @@ -121,8 +121,10 @@ def to_cloud_run( port: int, trace_to_cloud: bool, with_ui: bool, + log_level: str, verbosity: str, adk_version: str, + allow_origins: Optional[list[str]] = None, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, @@ -150,6 +152,7 @@ def to_cloud_run( app_name: The name of the app, by default, it's basename of `agent_folder`. temp_folder: The temp folder for the generated Cloud Run source files. port: The port of the ADK api server. + allow_origins: The list of allowed origins for the ADK api server. trace_to_cloud: Whether to enable Cloud Trace. with_ui: Whether to deploy with UI. verbosity: The verbosity level of the CLI. @@ -183,6 +186,9 @@ def to_cloud_run( # create Dockerfile click.echo('Creating Dockerfile...') host_option = '--host=0.0.0.0' if adk_version > '0.5.0' else '' + allow_origins_option = ( + f'--allow_origins={",".join(allow_origins)}' if allow_origins else '' + ) dockerfile_content = _DOCKERFILE_TEMPLATE.format( gcp_project_id=project, gcp_region=region, @@ -197,6 +203,7 @@ def to_cloud_run( memory_service_uri, ), trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '', + allow_origins_option=allow_origins_option, adk_version=adk_version, host_option=host_option, ) @@ -226,7 +233,7 @@ def to_cloud_run( '--port', str(port), '--verbosity', - verbosity, + log_level.lower() if log_level else verbosity, '--labels', 'created-by=adk', ], diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 8f45db96d..49ecee482 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -39,6 +39,11 @@ from .utils import envs from .utils import logs +LOG_LEVELS = click.Choice( + ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + case_sensitive=False, +) + class HelpfulCommand(click.Command): """Command that shows full help on error instead of just the error message. @@ -498,13 +503,6 @@ def fast_api_common_options(): """Decorator to add common fast api options to click commands.""" def decorator(func): - @click.option( - "--host", - type=str, - help="Optional. The binding host of the server", - default="127.0.0.1", - show_default=True, - ) @click.option( "--port", type=int, @@ -518,10 +516,7 @@ def decorator(func): ) @click.option( "--log_level", - type=click.Choice( - ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - case_sensitive=False, - ), + type=LOG_LEVELS, default="INFO", help="Optional. Set the logging level", ) @@ -535,7 +530,10 @@ def decorator(func): @click.option( "--reload/--no-reload", default=True, - help="Optional. Whether to enable auto reload for server.", + help=( + "Optional. Whether to enable auto reload for server. Not supported" + " for Cloud Run." + ), ) @functools.wraps(func) def wrapper(*args, **kwargs): @@ -547,6 +545,13 @@ def wrapper(*args, **kwargs): @main.command("web") +@click.option( + "--host", + type=str, + help="Optional. The binding host of the server", + default="127.0.0.1", + show_default=True, +) @fast_api_common_options() @adk_services_options() @deprecated_adk_services_options() @@ -578,7 +583,7 @@ def cli_web( Example: - adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir + adk web --port=[port] path/to/agents_dir """ logs.setup_adk_logger(getattr(logging, log_level.upper())) @@ -628,6 +633,16 @@ async def _lifespan(app: FastAPI): @main.command("api_server") +@click.option( + "--host", + type=str, + help="Optional. The binding host of the server", + default="127.0.0.1", + show_default=True, +) +@fast_api_common_options() +@adk_services_options() +@deprecated_adk_services_options() # The directory of agents, where each sub-directory is a single agent. # By default, it is the current working directory @click.argument( @@ -637,9 +652,6 @@ async def _lifespan(app: FastAPI): ), default=os.getcwd(), ) -@fast_api_common_options() -@adk_services_options() -@deprecated_adk_services_options() def cli_api_server( agents_dir: str, log_level: str = "INFO", @@ -661,7 +673,7 @@ def cli_api_server( Example: - adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir + adk api_server --port=[port] path/to/agents_dir """ logs.setup_adk_logger(getattr(logging, log_level.upper())) @@ -720,19 +732,7 @@ def cli_api_server( " of the AGENT source code)." ), ) -@click.option( - "--port", - type=int, - default=8000, - help="Optional. The port of the ADK API server (default: 8000).", -) -@click.option( - "--trace_to_cloud", - is_flag=True, - show_default=True, - default=False, - help="Optional. Whether to enable Cloud Trace for cloud run.", -) +@fast_api_common_options() @click.option( "--with_ui", is_flag=True, @@ -743,6 +743,11 @@ def cli_api_server( " only)" ), ) +@click.option( + "--verbosity", + type=LOG_LEVELS, + help="Deprecated. Use --log_level instead.", +) @click.option( "--temp_folder", type=str, @@ -756,20 +761,6 @@ def cli_api_server( " (default: a timestamped folder in the system temp directory)." ), ) -@click.option( - "--verbosity", - type=click.Choice( - ["debug", "info", "warning", "error", "critical"], case_sensitive=False - ), - default="WARNING", - help="Optional. Override the default verbosity level.", -) -@click.argument( - "agent", - type=click.Path( - exists=True, dir_okay=True, file_okay=False, resolve_path=True - ), -) @click.option( "--adk_version", type=str, @@ -782,6 +773,12 @@ def cli_api_server( ) @adk_services_options() @deprecated_adk_services_options() +@click.argument( + "agent", + type=click.Path( + exists=True, dir_okay=True, file_okay=False, resolve_path=True + ), +) def cli_deploy_cloud_run( agent: str, project: Optional[str], @@ -792,8 +789,11 @@ def cli_deploy_cloud_run( port: int, trace_to_cloud: bool, with_ui: bool, - verbosity: str, adk_version: str, + log_level: Optional[str] = None, + verbosity: str = "WARNING", + reload: bool = True, + allow_origins: Optional[list[str]] = None, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, @@ -808,6 +808,7 @@ def cli_deploy_cloud_run( adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent """ + log_level = log_level or verbosity session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri try: @@ -820,7 +821,9 @@ def cli_deploy_cloud_run( temp_folder=temp_folder, port=port, trace_to_cloud=trace_to_cloud, + allow_origins=allow_origins, with_ui=with_ui, + log_level=log_level, verbosity=verbosity, adk_version=adk_version, session_service_uri=session_service_uri, diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index 312844db8..d3b2a538c 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -162,6 +162,7 @@ def _recording_copytree(*args: Any, **kwargs: Any): trace_to_cloud=True, with_ui=True, verbosity="info", + log_level="info", session_service_uri="sqlite://", artifact_service_uri="gs://bucket", memory_service_uri="rag://", @@ -206,6 +207,7 @@ def _fake_rmtree(path: str | Path, *a: Any, **k: Any) -> None: trace_to_cloud=False, with_ui=False, verbosity="info", + log_level="info", adk_version="1.0.0", session_service_uri=None, artifact_service_uri=None, From fb13963deda0ff0650ac27771711ea0411474bf5 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 20 Jun 2025 17:08:09 -0700 Subject: [PATCH 45/79] chore: Add request converter to convert a2a request to ADK request PiperOrigin-RevId: 773894462 --- .../adk/a2a/converters/request_converter.py | 90 ++++ src/google/adk/a2a/converters/utils.py | 37 ++ .../a2a/converters/test_request_converter.py | 497 ++++++++++++++++++ 3 files changed, 624 insertions(+) create mode 100644 src/google/adk/a2a/converters/request_converter.py create mode 100644 tests/unittests/a2a/converters/test_request_converter.py diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py new file mode 100644 index 000000000..293df46e6 --- /dev/null +++ b/src/google/adk/a2a/converters/request_converter.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from typing import Any + +try: + from a2a.server.agent_execution import RequestContext +except ImportError as e: + if sys.version_info < (3, 10): + raise ImportError( + 'A2A Tool requires Python 3.10 or above. Please upgrade your Python' + ' version.' + ) from e + else: + raise e + +from google.genai import types as genai_types + +from ...runners import RunConfig +from ...utils.feature_decorator import working_in_progress +from .part_converter import convert_a2a_part_to_genai_part +from .utils import _from_a2a_context_id +from .utils import _get_adk_metadata_key + + +def _get_user_id(request: RequestContext, user_id_from_context: str) -> str: + # Get user from call context if available (auth is enabled on a2a server) + if request.call_context and request.call_context.user: + return request.call_context.user.user_name + + # Get user from context id if available + if user_id_from_context: + return user_id_from_context + + # Get user from message metadata if available (client is an ADK agent) + if request.message.metadata: + user_id = request.message.metadata.get(_get_adk_metadata_key('user_id')) + if user_id: + return f'ADK_USER_{user_id}' + + # Get user from task if available (client is a an ADK agent) + if request.current_task: + user_id = request.current_task.metadata.get( + _get_adk_metadata_key('user_id') + ) + if user_id: + return f'ADK_USER_{user_id}' + return ( + f'temp_user_{request.task_id}' + if request.task_id + else f'TEMP_USER_{request.message.messageId}' + ) + + +@working_in_progress +def convert_a2a_request_to_adk_run_args( + request: RequestContext, +) -> dict[str, Any]: + + if not request.message: + raise ValueError('Request message cannot be None') + + _, user_id, session_id = _from_a2a_context_id(request.context_id) + + return { + 'user_id': _get_user_id(request, user_id), + 'session_id': session_id, + 'new_message': genai_types.Content( + role='user', + parts=[ + convert_a2a_part_to_genai_part(part) + for part in request.message.parts + ], + ), + 'run_config': RunConfig(), + } diff --git a/src/google/adk/a2a/converters/utils.py b/src/google/adk/a2a/converters/utils.py index fe5f2e927..ecbff1e10 100644 --- a/src/google/adk/a2a/converters/utils.py +++ b/src/google/adk/a2a/converters/utils.py @@ -15,6 +15,7 @@ from __future__ import annotations ADK_METADATA_KEY_PREFIX = "adk_" +ADK_CONTEXT_ID_PREFIX = "ADK" def _get_adk_metadata_key(key: str) -> str: @@ -32,3 +33,39 @@ def _get_adk_metadata_key(key: str) -> str: if not key: raise ValueError("Metadata key cannot be empty or None") return f"{ADK_METADATA_KEY_PREFIX}{key}" + + +def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str: + """Converts app name, user id and session id to an A2A context id. + + Args: + app_name: The app name. + user_id: The user id. + session_id: The session id. + + Returns: + The A2A context id. + """ + return [ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id].join("$") + + +def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: + """Converts an A2A context id to app name, user id and session id. + if context_id is None, return None, None, None + if context_id is not None, but not in the format of + ADK$app_name$user_id$session_id, return None, None, None + + Args: + context_id: The A2A context id. + + Returns: + The app name, user id and session id. + """ + if not context_id: + return None, None, None + + prefix, app_name, user_id, session_id = context_id.split("$") + if prefix == "ADK" and app_name and user_id and session_id: + return app_name, user_id, session_id + + return None, None, None diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py new file mode 100644 index 000000000..02c6400fc --- /dev/null +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -0,0 +1,497 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a.server.agent_execution import RequestContext + from google.adk.a2a.converters.request_converter import _get_user_id + from google.adk.a2a.converters.request_converter import convert_a2a_request_to_adk_run_args + from google.adk.runners import RunConfig + from google.genai import types as genai_types +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + a2a_types = DummyTypes() + genai_types = DummyTypes() + RequestContext = DummyTypes() + RunConfig = DummyTypes() + _get_user_id = lambda x, y: None + convert_a2a_request_to_adk_run_args = lambda x: None + else: + raise e + + +class TestGetUserId: + """Test cases for _get_user_id function.""" + + def test_get_user_id_from_call_context(self): + """Test getting user ID from call context when auth is enabled.""" + # Arrange + mock_user = Mock() + mock_user.user_name = "authenticated_user" + + mock_call_context = Mock() + mock_call_context.user = mock_user + + request = Mock(spec=RequestContext) + request.call_context = mock_call_context + request.message = Mock() + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "context_user") + + # Assert + assert result == "authenticated_user" + + def test_get_user_id_from_context_when_no_call_context(self): + """Test getting user ID from context when call context is not available.""" + # Arrange + request = Mock(spec=RequestContext) + request.call_context = None + request.message = Mock() + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "context_user") + + # Assert + assert result == "context_user" + + def test_get_user_id_from_context_when_call_context_has_no_user(self): + """Test getting user ID from context when call context has no user.""" + # Arrange + mock_call_context = Mock() + mock_call_context.user = None + + request = Mock(spec=RequestContext) + request.call_context = mock_call_context + request.message = Mock() + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "context_user") + + # Assert + assert result == "context_user" + + def test_get_user_id_from_message_metadata(self): + """Test getting user ID from message metadata when context user is not available.""" + # Arrange + mock_message = Mock() + mock_message.metadata = {"adk_user_id": "message_user"} + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "ADK_USER_message_user" + + def test_get_user_id_from_task_metadata(self): + """Test getting user ID from task metadata when message metadata is not available.""" + # Arrange + mock_message = Mock() + mock_message.metadata = None + + mock_task = Mock() + mock_task.metadata = {"adk_user_id": "task_user"} + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = mock_task + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "ADK_USER_task_user" + + def test_get_user_id_fallback_to_task_id(self): + """Test fallback to task ID when no other user ID is available.""" + # Arrange + mock_message = Mock() + mock_message.metadata = None + mock_message.messageId = "msg456" + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "temp_user_task123" + + def test_get_user_id_fallback_to_message_id(self): + """Test fallback to message ID when no task ID is available.""" + # Arrange + mock_message = Mock() + mock_message.metadata = None + mock_message.messageId = "msg456" + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = None + request.task_id = None + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "TEMP_USER_msg456" + + def test_get_user_id_message_metadata_empty(self): + """Test getting user ID when message metadata exists but doesn't contain user_id.""" + # Arrange + mock_message = Mock() + mock_message.metadata = {"other_key": "other_value"} + mock_message.messageId = "msg456" + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "temp_user_task123" + + def test_get_user_id_task_metadata_empty(self): + """Test getting user ID when task metadata exists but doesn't contain user_id.""" + # Arrange + mock_message = Mock() + mock_message.metadata = None + mock_message.messageId = "msg456" + + mock_task = Mock() + mock_task.metadata = {"other_key": "other_value"} + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = mock_task + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "temp_user_task123" + + +class TestConvertA2aRequestToAdkRunArgs: + """Test cases for convert_a2a_request_to_adk_run_args function.""" + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + @patch("google.adk.a2a.converters.request_converter._get_user_id") + def test_convert_a2a_request_basic( + self, mock_get_user_id, mock_from_context_id, mock_convert_part + ): + """Test basic conversion of A2A request to ADK run args.""" + # Arrange + mock_part1 = Mock() + mock_part2 = Mock() + + mock_message = Mock() + mock_message.parts = [mock_part1, mock_part2] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = "ADK$app$user$session" + + mock_from_context_id.return_value = ( + "app_name", + "user_from_context", + "session123", + ) + mock_get_user_id.return_value = "final_user" + + # Create proper genai_types.Part objects instead of mocks + mock_genai_part1 = genai_types.Part(text="test part 1") + mock_genai_part2 = genai_types.Part(text="test part 2") + mock_convert_part.side_effect = [mock_genai_part1, mock_genai_part2] + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert result["user_id"] == "final_user" + assert result["session_id"] == "session123" + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part1, mock_genai_part2] + assert isinstance(result["run_config"], RunConfig) + + # Verify calls + mock_from_context_id.assert_called_once_with("ADK$app$user$session") + mock_get_user_id.assert_called_once_with(request, "user_from_context") + assert mock_convert_part.call_count == 2 + mock_convert_part.assert_any_call(mock_part1) + mock_convert_part.assert_any_call(mock_part2) + + def test_convert_a2a_request_no_message_raises_error(self): + """Test that conversion raises ValueError when message is None.""" + # Arrange + request = Mock(spec=RequestContext) + request.message = None + + # Act & Assert + with pytest.raises(ValueError, match="Request message cannot be None"): + convert_a2a_request_to_adk_run_args(request) + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + @patch("google.adk.a2a.converters.request_converter._get_user_id") + def test_convert_a2a_request_empty_parts( + self, mock_get_user_id, mock_from_context_id, mock_convert_part + ): + """Test conversion with empty parts list.""" + # Arrange + mock_message = Mock() + mock_message.parts = [] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = "ADK$app$user$session" + + mock_from_context_id.return_value = ( + "app_name", + "user_from_context", + "session123", + ) + mock_get_user_id.return_value = "final_user" + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert result["user_id"] == "final_user" + assert result["session_id"] == "session123" + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [] + assert isinstance(result["run_config"], RunConfig) + + # Verify convert_part wasn't called + mock_convert_part.assert_not_called() + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + @patch("google.adk.a2a.converters.request_converter._get_user_id") + def test_convert_a2a_request_none_context_id( + self, mock_get_user_id, mock_from_context_id, mock_convert_part + ): + """Test conversion when context_id is None.""" + # Arrange + mock_part = Mock() + mock_message = Mock() + mock_message.parts = [mock_part] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = None + + mock_from_context_id.return_value = (None, None, None) + mock_get_user_id.return_value = "fallback_user" + + # Create proper genai_types.Part object instead of mock + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part.return_value = mock_genai_part + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert result["user_id"] == "fallback_user" + assert result["session_id"] is None + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part] + assert isinstance(result["run_config"], RunConfig) + + # Verify calls + mock_from_context_id.assert_called_once_with(None) + mock_get_user_id.assert_called_once_with(request, None) + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + @patch("google.adk.a2a.converters.request_converter._get_user_id") + def test_convert_a2a_request_invalid_context_id( + self, mock_get_user_id, mock_from_context_id, mock_convert_part + ): + """Test conversion when context_id is invalid format.""" + # Arrange + mock_part = Mock() + mock_message = Mock() + mock_message.parts = [mock_part] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = "invalid_format" + + mock_from_context_id.return_value = (None, None, None) + mock_get_user_id.return_value = "fallback_user" + + # Create proper genai_types.Part object instead of mock + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part.return_value = mock_genai_part + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert result["user_id"] == "fallback_user" + assert result["session_id"] is None + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part] + assert isinstance(result["run_config"], RunConfig) + + # Verify calls + mock_from_context_id.assert_called_once_with("invalid_format") + mock_get_user_id.assert_called_once_with(request, None) + + +class TestIntegration: + """Integration test cases combining both functions.""" + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + def test_end_to_end_conversion_with_auth_user(self, mock_convert_part): + """Test end-to-end conversion with authenticated user.""" + # Arrange + mock_user = Mock() + mock_user.user_name = "auth_user" + + mock_call_context = Mock() + mock_call_context.user = mock_user + + mock_part = Mock() + mock_message = Mock() + mock_message.parts = [mock_part] + + request = Mock(spec=RequestContext) + request.call_context = mock_call_context + request.message = mock_message + request.context_id = "ADK$myapp$context_user$mysession" + request.current_task = None + request.task_id = "task123" + + # Create proper genai_types.Part object instead of mock + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part.return_value = mock_genai_part + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert ( + result["user_id"] == "auth_user" + ) # Should use authenticated user, not context user + assert result["session_id"] == "mysession" + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part] + assert isinstance(result["run_config"], RunConfig) + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + def test_end_to_end_conversion_with_fallback_user( + self, mock_from_context_id, mock_convert_part + ): + """Test end-to-end conversion with fallback user ID.""" + # Arrange + mock_part = Mock() + mock_message = Mock() + mock_message.parts = [mock_part] + mock_message.messageId = "msg789" + mock_message.metadata = None + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.context_id = "invalid_format" + request.current_task = None + request.task_id = None + + # Mock the utils function to return None values for invalid context + mock_from_context_id.return_value = (None, None, None) + + # Create proper genai_types.Part object instead of mock + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part.return_value = mock_genai_part + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert ( + result["user_id"] == "TEMP_USER_msg789" + ) # Should fallback to message ID + assert result["session_id"] is None + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part] + assert isinstance(result["run_config"], RunConfig) From 7c670f638bc17374ceb08740bdd057e55c9c2e12 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 20 Jun 2025 17:14:47 -0700 Subject: [PATCH 46/79] chore: Send user message to the agent that returned a corresponding function call if user message is a function response PiperOrigin-RevId: 773895971 --- src/google/adk/runners.py | 43 +++ tests/unittests/test_runners.py | 481 ++++++++++++++++++++++++++++++++ 2 files changed, 524 insertions(+) create mode 100644 tests/unittests/test_runners.py diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 01412a2b3..936bc5205 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -337,6 +337,8 @@ def _find_agent_to_run( """Finds the agent to run to continue the session. A qualified agent must be either of: + - The agent that returned a function call and the last user message is a + function response to this function call. - The root agent; - An LlmAgent who replied last and is capable to transfer to any other agent in the agent hierarchy. @@ -348,6 +350,15 @@ def _find_agent_to_run( Returns: The agent of the last message in the session or the root agent. """ + # If the last event is a function response, should send this response to + # the agent that returned the corressponding function call regardless the + # type of the agent. e.g. a remote a2a agent may surface a credential + # request as a special long running function tool call. + event = _find_function_call_event_if_last_event_is_function_response( + session + ) + if event and event.author: + return root_agent.find_agent(event.author) for event in filter(lambda e: e.author != 'user', reversed(session.events)): if event.author == root_agent.name: # Found root agent. @@ -527,3 +538,35 @@ def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'): session_service=self._in_memory_session_service, memory_service=InMemoryMemoryService(), ) + + +def _find_function_call_event_if_last_event_is_function_response( + session: Session, +) -> Optional[Event]: + events = session.events + if not events: + return None + + last_event = events[-1] + if ( + last_event.content + and last_event.content.parts + and any(part.function_response for part in last_event.content.parts) + ): + + function_call_id = next( + part.function_response.id + for part in last_event.content.parts + if part.function_response + ) + for i in range(len(events) - 2, -1, -1): + event = events[i] + # looking for the system long running request euc function call + function_calls = event.get_function_calls() + if not function_calls: + continue + + for function_call in function_calls: + if function_call.id == function_call_id: + return event + return None diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py new file mode 100644 index 000000000..56d7667ab --- /dev/null +++ b/tests/unittests/test_runners.py @@ -0,0 +1,481 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.llm_agent import LlmAgent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.events.event import Event +from google.adk.runners import _find_function_call_event_if_last_event_is_function_response +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types + + +class MockAgent(BaseAgent): + """Mock agent for unit testing.""" + + def __init__( + self, + name: str, + parent_agent: Optional[BaseAgent] = None, + ): + super().__init__(name=name, sub_agents=[]) + # BaseAgent doesn't have disallow_transfer_to_parent field + # This is intentional as we want to test non-LLM agents + if parent_agent: + self.parent_agent = parent_agent + + async def _run_async_impl(self, invocation_context): + yield Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ), + ) + + +class MockLlmAgent(LlmAgent): + """Mock LLM agent for unit testing.""" + + def __init__( + self, + name: str, + disallow_transfer_to_parent: bool = False, + parent_agent: Optional[BaseAgent] = None, + ): + # Use a string model instead of mock + super().__init__(name=name, model="gemini-1.5-pro", sub_agents=[]) + self.disallow_transfer_to_parent = disallow_transfer_to_parent + self.parent_agent = parent_agent + + async def _run_async_impl(self, invocation_context): + yield Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + content=types.Content( + role="model", parts=[types.Part(text="Test LLM response")] + ), + ) + + +class TestFindFunctionCallEventIfLastEventIsFunctionResponse: + """Tests for _find_function_call_event_if_last_event_is_function_response function.""" + + def test_no_function_response_in_last_event(self): + """Test when last event has no function response.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="user", + content=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + ) + ], + ) + + result = _find_function_call_event_if_last_event_is_function_response( + session + ) + assert result is None + + def test_empty_session_events(self): + """Test when session has no events.""" + session = Session( + id="test_session", user_id="test_user", app_name="test_app", events=[] + ) + + result = _find_function_call_event_if_last_event_is_function_response( + session + ) + assert result is None + + def test_last_event_has_function_response_but_no_matching_call(self): + """Test when last event has function response but no matching call found.""" + # Create a function response + function_response = types.FunctionResponse( + id="func_123", name="test_func", response={} + ) + + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="agent1", + content=types.Content( + role="model", + parts=[types.Part(text="Some other response")], + ), + ), + Event( + invocation_id="inv2", + author="user", + content=types.Content( + role="user", + parts=[types.Part(function_response=function_response)], + ), + ), + ], + ) + + result = _find_function_call_event_if_last_event_is_function_response( + session + ) + assert result is None + + def test_last_event_has_function_response_with_matching_call(self): + """Test when last event has function response with matching function call.""" + # Create a function call + function_call = types.FunctionCall(id="func_123", name="test_func", args={}) + + # Create a function response with matching ID + function_response = types.FunctionResponse( + id="func_123", name="test_func", response={} + ) + + call_event = Event( + invocation_id="inv1", + author="agent1", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + response_event = Event( + invocation_id="inv2", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[call_event, response_event], + ) + + result = _find_function_call_event_if_last_event_is_function_response( + session + ) + assert result == call_event + + def test_last_event_has_multiple_function_responses(self): + """Test when last event has multiple function responses.""" + # Create function calls + function_call1 = types.FunctionCall( + id="func_123", name="test_func1", args={} + ) + function_call2 = types.FunctionCall( + id="func_456", name="test_func2", args={} + ) + + # Create function responses + function_response1 = types.FunctionResponse( + id="func_123", name="test_func1", response={} + ) + function_response2 = types.FunctionResponse( + id="func_456", name="test_func2", response={} + ) + + call_event1 = Event( + invocation_id="inv1", + author="agent1", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call1)] + ), + ) + + call_event2 = Event( + invocation_id="inv2", + author="agent2", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call2)] + ), + ) + + response_event = Event( + invocation_id="inv3", + author="user", + content=types.Content( + role="user", + parts=[ + types.Part(function_response=function_response1), + types.Part(function_response=function_response2), + ], + ), + ) + + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[call_event1, call_event2, response_event], + ) + + # Should return the first matching function call event found + result = _find_function_call_event_if_last_event_is_function_response( + session + ) + assert result == call_event1 # First match (func_123) + + +class TestRunnerFindAgentToRun: + """Tests for Runner._find_agent_to_run method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.session_service = InMemorySessionService() + self.artifact_service = InMemoryArtifactService() + + # Create test agents + self.root_agent = MockLlmAgent("root_agent") + self.sub_agent1 = MockLlmAgent("sub_agent1", parent_agent=self.root_agent) + self.sub_agent2 = MockLlmAgent("sub_agent2", parent_agent=self.root_agent) + self.non_transferable_agent = MockLlmAgent( + "non_transferable", + disallow_transfer_to_parent=True, + parent_agent=self.root_agent, + ) + + self.root_agent.sub_agents = [ + self.sub_agent1, + self.sub_agent2, + self.non_transferable_agent, + ] + + self.runner = Runner( + app_name="test_app", + agent=self.root_agent, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + def test_find_agent_to_run_with_function_response_scenario(self): + """Test finding agent when last event is function response.""" + # Create a function call from sub_agent1 + function_call = types.FunctionCall(id="func_123", name="test_func", args={}) + function_response = types.FunctionResponse( + id="func_123", name="test_func", response={} + ) + + call_event = Event( + invocation_id="inv1", + author="sub_agent1", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + response_event = Event( + invocation_id="inv2", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[call_event, response_event], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.sub_agent1 + + def test_find_agent_to_run_returns_root_agent_when_no_events(self): + """Test that root agent is returned when session has no non-user events.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="user", + content=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + ) + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.root_agent + + def test_find_agent_to_run_returns_root_agent_when_found_in_events(self): + """Test that root agent is returned when it's found in session events.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="root_agent", + content=types.Content( + role="model", parts=[types.Part(text="Root response")] + ), + ) + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.root_agent + + def test_find_agent_to_run_returns_transferable_sub_agent(self): + """Test that transferable sub agent is returned when found.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="sub_agent1", + content=types.Content( + role="model", parts=[types.Part(text="Sub agent response")] + ), + ) + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.sub_agent1 + + def test_find_agent_to_run_skips_non_transferable_agent(self): + """Test that non-transferable agent is skipped and root agent is returned.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="non_transferable", + content=types.Content( + role="model", + parts=[types.Part(text="Non-transferable response")], + ), + ) + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.root_agent + + def test_find_agent_to_run_skips_unknown_agent(self): + """Test that unknown agent is skipped and root agent is returned.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="unknown_agent", + content=types.Content( + role="model", + parts=[types.Part(text="Unknown agent response")], + ), + ), + Event( + invocation_id="inv2", + author="root_agent", + content=types.Content( + role="model", parts=[types.Part(text="Root response")] + ), + ), + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.root_agent + + def test_find_agent_to_run_function_response_takes_precedence(self): + """Test that function response scenario takes precedence over other logic.""" + # Create a function call from sub_agent2 + function_call = types.FunctionCall(id="func_456", name="test_func", args={}) + function_response = types.FunctionResponse( + id="func_456", name="test_func", response={} + ) + + call_event = Event( + invocation_id="inv1", + author="sub_agent2", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + # Add another event from root_agent + root_event = Event( + invocation_id="inv2", + author="root_agent", + content=types.Content( + role="model", parts=[types.Part(text="Root response")] + ), + ) + + response_event = Event( + invocation_id="inv3", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[call_event, root_event, response_event], + ) + + # Should return sub_agent2 due to function response, not root_agent + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.sub_agent2 + + def test_is_transferable_across_agent_tree_with_llm_agent(self): + """Test _is_transferable_across_agent_tree with LLM agent.""" + result = self.runner._is_transferable_across_agent_tree(self.sub_agent1) + assert result is True + + def test_is_transferable_across_agent_tree_with_non_transferable_agent(self): + """Test _is_transferable_across_agent_tree with non-transferable agent.""" + result = self.runner._is_transferable_across_agent_tree( + self.non_transferable_agent + ) + assert result is False + + def test_is_transferable_across_agent_tree_with_non_llm_agent(self): + """Test _is_transferable_across_agent_tree with non-LLM agent.""" + non_llm_agent = MockAgent("non_llm_agent") + # MockAgent inherits from BaseAgent, not LlmAgent, so it should return False + result = self.runner._is_transferable_across_agent_tree(non_llm_agent) + assert result is False From 3b1d9a8a3e631ca2d86d30f09640497f1728986c Mon Sep 17 00:00:00 2001 From: bck-ob-gh Date: Mon, 23 Jun 2025 09:24:00 -0700 Subject: [PATCH 47/79] fix: Use starred tuple unpacking on GCS artifact blob names Merges https://github.com/google/adk-python/pull/1471 Fixes google#1436 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1471 from bck-ob-gh:main 4c4f2b66ab1e6fde8b1a9d2b914dcb24040db144 PiperOrigin-RevId: 774809270 --- src/google/adk/artifacts/gcs_artifact_service.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index e4af21e15..35aa88622 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -13,6 +13,7 @@ # limitations under the License. """An artifact service implementation using Google Cloud Storage (GCS).""" +from __future__ import annotations import logging from typing import Optional @@ -151,7 +152,7 @@ async def list_artifact_keys( self.bucket, prefix=session_prefix ) for blob in session_blobs: - _, _, _, filename, _ = blob.name.split("/") + *_, filename, _ = blob.name.split("/") filenames.add(filename) user_namespace_prefix = f"{app_name}/{user_id}/user/" @@ -159,7 +160,7 @@ async def list_artifact_keys( self.bucket, prefix=user_namespace_prefix ) for blob in user_namespace_blobs: - _, _, _, filename, _ = blob.name.split("/") + *_, filename, _ = blob.name.split("/") filenames.add(filename) return sorted(list(filenames)) From f033e405c10ff8d86550d1419a9d63c0099182f9 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Mon, 23 Jun 2025 10:11:47 -0700 Subject: [PATCH 48/79] chore: Clarify the behavior of Event.invocation_id PiperOrigin-RevId: 774827874 --- src/google/adk/events/event.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/google/adk/events/event.py b/src/google/adk/events/event.py index c3b8b8699..6dd617fff 100644 --- a/src/google/adk/events/event.py +++ b/src/google/adk/events/event.py @@ -34,9 +34,10 @@ class Event(LlmResponse): taken by the agents like function calls, etc. Attributes: - invocation_id: The invocation ID of the event. - author: "user" or the name of the agent, indicating who appended the event - to the session. + invocation_id: Required. The invocation ID of the event. Should be non-empty + before appending to a session. + author: Required. "user" or the name of the agent, indicating who appended + the event to the session. actions: The actions taken by the agent. long_running_tool_ids: The ids of the long running function calls. branch: The branch of the event. @@ -55,9 +56,8 @@ class Event(LlmResponse): ) """The pydantic model config.""" - # TODO: revert to be required after spark migration invocation_id: str = '' - """The invocation ID of the event.""" + """The invocation ID of the event. Should be non-empty before appending to a session.""" author: str """'user' or the name of the agent, indicating who appended the event to the session.""" From ea69c9093a16489afdf72657136c96f61c69cafd Mon Sep 17 00:00:00 2001 From: Keisuke Oohashi Date: Mon, 23 Jun 2025 10:27:41 -0700 Subject: [PATCH 49/79] feat: add usage span attributes to telemetry (#356) Merge https://github.com/google/adk-python/pull/1079 Fixes part of #356 Add usage attributes to span. Note: Since the handling of GenAI event bodies in OpenTelemetry has not yet been determined, I have temporarily added only attributes related to usage. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1079 from soundTricker:feature/356-support-more-opentelemetry-semantics 99a9d0352b4bca165baa645440e39ce7199f072b PiperOrigin-RevId: 774834279 --- src/google/adk/telemetry.py | 10 ++++++++++ tests/unittests/test_telemetry.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/google/adk/telemetry.py b/src/google/adk/telemetry.py index badaec46d..a09c2f55b 100644 --- a/src/google/adk/telemetry.py +++ b/src/google/adk/telemetry.py @@ -195,6 +195,16 @@ def trace_call_llm( llm_response_json, ) + if llm_response.usage_metadata is not None: + span.set_attribute( + 'gen_ai.usage.input_tokens', + llm_response.usage_metadata.prompt_token_count, + ) + span.set_attribute( + 'gen_ai.usage.output_tokens', + llm_response.usage_metadata.total_token_count, + ) + def trace_send_data( invocation_context: InvocationContext, diff --git a/tests/unittests/test_telemetry.py b/tests/unittests/test_telemetry.py index 1b8ee1b16..debdc802e 100644 --- a/tests/unittests/test_telemetry.py +++ b/tests/unittests/test_telemetry.py @@ -141,6 +141,36 @@ async def test_trace_call_llm_function_response_includes_part_from_bytes( assert llm_request_json_str.count('') == 2 +@pytest.mark.asyncio +async def test_trace_call_llm_usage_metadata(monkeypatch, mock_span_fixture): + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + agent = LlmAgent(name='test_agent') + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest( + config=types.GenerateContentConfig(system_instruction=''), + ) + llm_response = LlmResponse( + turn_complete=True, + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=100, prompt_token_count=50 + ), + ) + trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response) + + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + mock.call('gen_ai.usage.input_tokens', 50), + mock.call('gen_ai.usage.output_tokens', 100), + ] + assert mock_span_fixture.set_attribute.call_count == 9 + mock_span_fixture.set_attribute.assert_has_calls( + expected_calls, any_order=True + ) + + def test_trace_tool_call_with_scalar_response( monkeypatch, mock_span_fixture, mock_tool_fixture, mock_event_fixture ): From bd67e8480f6e8b4b0f8c22b94f15a8cda1336339 Mon Sep 17 00:00:00 2001 From: avidelatm Date: Mon, 23 Jun 2025 10:29:43 -0700 Subject: [PATCH 50/79] fix: make LiteLLM streaming truly asynchronous Merge https://github.com/google/adk-python/pull/1451 ## Description Fixes https://github.com/google/adk-python/issues/1306 by using `async for` with `await self.llm_client.acompletion()` instead of synchronous `for` loop. ## Changes - Updated test mocks to properly handle async streaming by creating an async generator - Ensured proper parameter handling to avoid duplicate stream parameter ## Testing Plan - All unit tests now pass with the async streaming implementation - Verified with `pytest tests/unittests/models/test_litellm.py` that all streaming tests pass - Manually tested with a sample agent using LiteLLM to confirm streaming works properly # Test Evidence: https://youtu.be/hSp3otI79DM Let me know if you need anything else from me for this PR COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1451 from avidelatm:fix/litellm-async-streaming d35b9dc90b2fd6fad44c3869de0fda2514e50055 PiperOrigin-RevId: 774835130 --- src/google/adk/models/lite_llm.py | 2 +- tests/unittests/models/test_litellm.py | 23 ++++++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index dce5ed7c4..acc88ed19 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -679,7 +679,7 @@ async def generate_content_async( aggregated_llm_response_with_tool_call = None usage_metadata = None fallback_index = 0 - for part in self.llm_client.completion(**completion_args): + async for part in await self.llm_client.acompletion(**completion_args): for chunk, finish_reason in _model_response_to_chunk(part): if isinstance(chunk, FunctionChunk): index = chunk.index or fallback_index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 8b43cc48b..d058aa44d 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -416,9 +416,26 @@ def __init__(self, acompletion_mock, completion_mock): self.completion_mock = completion_mock async def acompletion(self, model, messages, tools, **kwargs): - return await self.acompletion_mock( - model=model, messages=messages, tools=tools, **kwargs - ) + if kwargs.get("stream", False): + kwargs_copy = dict(kwargs) + kwargs_copy.pop("stream", None) + + async def stream_generator(): + stream_data = self.completion_mock( + model=model, + messages=messages, + tools=tools, + stream=True, + **kwargs_copy, + ) + for item in stream_data: + yield item + + return stream_generator() + else: + return await self.acompletion_mock( + model=model, messages=messages, tools=tools, **kwargs + ) def completion(self, model, messages, tools, stream, **kwargs): return self.completion_mock( From 29cd183aa1b47dc4f5d8afe22f410f8546634abc Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 23 Jun 2025 12:15:26 -0700 Subject: [PATCH 51/79] chore: Add credential service backed by session state PiperOrigin-RevId: 774878336 --- .../session_state_credential_service.py | 83 ++++ tests/unittests/auth/__init__.py | 13 + .../auth/credential_service/__init__.py | 13 + .../test_session_state_credential_service.py | 355 ++++++++++++++++++ 4 files changed, 464 insertions(+) create mode 100644 src/google/adk/auth/credential_service/session_state_credential_service.py create mode 100644 tests/unittests/auth/__init__.py create mode 100644 tests/unittests/auth/credential_service/__init__.py create mode 100644 tests/unittests/auth/credential_service/test_session_state_credential_service.py diff --git a/src/google/adk/auth/credential_service/session_state_credential_service.py b/src/google/adk/auth/credential_service/session_state_credential_service.py new file mode 100644 index 000000000..e2ff7e07d --- /dev/null +++ b/src/google/adk/auth/credential_service/session_state_credential_service.py @@ -0,0 +1,83 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from typing_extensions import override + +from ...tools.tool_context import ToolContext +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_tool import AuthConfig +from .base_credential_service import BaseCredentialService + + +@experimental +class SessionStateCredentialService(BaseCredentialService): + """Class for implementation of credential service using session state as the + store. + Note: store credential in session may not be secure, use at your own risk. + """ + + @override + async def load_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> Optional[AuthCredential]: + """ + Loads the credential by auth config and current tool context from the + backend credential store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to load the credential. + + tool_context: The context of the current invocation when the tool is + trying to load the credential. + + Returns: + Optional[AuthCredential]: the credential saved in the store. + + """ + return tool_context.state.get(auth_config.credential_key) + + @override + async def save_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> None: + """ + Saves the exchanged_auth_credential in auth config to the backend credential + store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to save the credential. + + tool_context: The context of the current invocation when the tool is + trying to save the credential. + + Returns: + None + """ + + tool_context.state[auth_config.credential_key] = ( + auth_config.exchanged_auth_credential + ) diff --git a/tests/unittests/auth/__init__.py b/tests/unittests/auth/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/credential_service/__init__.py b/tests/unittests/auth/credential_service/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/credential_service/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/credential_service/test_session_state_credential_service.py b/tests/unittests/auth/credential_service/test_session_state_credential_service.py new file mode 100644 index 000000000..610a9d3d1 --- /dev/null +++ b/tests/unittests/auth/credential_service/test_session_state_credential_service.py @@ -0,0 +1,355 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_service.session_state_credential_service import SessionStateCredentialService +from google.adk.tools.tool_context import ToolContext +import pytest + + +class TestSessionStateCredentialService: + """Tests for the SessionStateCredentialService class.""" + + @pytest.fixture + def credential_service(self): + """Create a SessionStateCredentialService instance for testing.""" + return SessionStateCredentialService() + + @pytest.fixture + def oauth2_auth_scheme(self): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + @pytest.fixture + def oauth2_credentials(self): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + @pytest.fixture + def auth_config(self, oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + exchanged_credential = oauth2_credentials.model_copy(deep=True) + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged_credential, + ) + + @pytest.fixture + def tool_context(self): + """Create a mock ToolContext for testing.""" + mock_context = Mock(spec=ToolContext) + # Create a state dictionary that behaves like session state + mock_context.state = {} + return mock_context + + @pytest.fixture + def another_tool_context(self): + """Create another mock ToolContext with different state for testing isolation.""" + mock_context = Mock(spec=ToolContext) + # Create a separate state dictionary to simulate different session + mock_context.state = {} + return mock_context + + @pytest.mark.asyncio + async def test_load_credential_not_found( + self, credential_service, auth_config, tool_context + ): + """Test loading a credential that doesn't exist returns None.""" + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_save_and_load_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving and then loading a credential.""" + # Save the credential + await credential_service.save_credential(auth_config, tool_context) + + # Load the credential + result = await credential_service.load_credential(auth_config, tool_context) + + # Verify the credential was saved and loaded correctly + assert result is not None + assert result == auth_config.exchanged_auth_credential + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.client_id == "mock_client_id" + + @pytest.mark.asyncio + async def test_save_credential_updates_existing( + self, credential_service, auth_config, tool_context, oauth2_credentials + ): + """Test that saving a credential updates an existing one.""" + # Save initial credential + await credential_service.save_credential(auth_config, tool_context) + + # Create a new credential and update the auth_config + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="updated_client_id", + client_secret="updated_client_secret", + redirect_uri="https://updated.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + + # Save the updated credential + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify the credential was updated + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "updated_client_id" + assert result.oauth2.client_secret == "updated_client_secret" + + @pytest.mark.asyncio + async def test_credentials_isolated_by_context( + self, credential_service, auth_config, tool_context, another_tool_context + ): + """Test that credentials are isolated between different tool contexts.""" + # Save credential in first context + await credential_service.save_credential(auth_config, tool_context) + + # Try to load from another context (should not find it) + result = await credential_service.load_credential( + auth_config, another_tool_context + ) + assert result is None + + # Verify original context still has the credential + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + + @pytest.mark.asyncio + async def test_multiple_credentials_same_context( + self, credential_service, tool_context, oauth2_auth_scheme + ): + """Test storing multiple credentials in the same context with different keys.""" + # Create two different auth configs with different credential keys + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client1", + client_secret="secret1", + redirect_uri="https://example1.com/callback", + ), + ) + + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client2", + client_secret="secret2", + redirect_uri="https://example2.com/callback", + ), + ) + + auth_config1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred1, + exchanged_auth_credential=cred1, + credential_key="key1", + ) + + auth_config2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred2, + exchanged_auth_credential=cred2, + credential_key="key2", + ) + + # Save both credentials + await credential_service.save_credential(auth_config1, tool_context) + await credential_service.save_credential(auth_config2, tool_context) + + # Load and verify both credentials + result1 = await credential_service.load_credential( + auth_config1, tool_context + ) + result2 = await credential_service.load_credential( + auth_config2, tool_context + ) + + assert result1 is not None + assert result2 is not None + assert result1.oauth2.client_id == "client1" + assert result2.oauth2.client_id == "client2" + + @pytest.mark.asyncio + async def test_save_credential_with_none_exchanged_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving when exchanged_auth_credential is None.""" + # Set exchanged credential to None + auth_config.exchanged_auth_credential = None + + # Save the credential (should save None) + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify None was saved + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_load_credential_with_empty_credential_key( + self, credential_service, auth_config, tool_context + ): + """Test loading credential with empty credential key.""" + # Set credential key to empty string + auth_config.credential_key = "" + + # Save first to have something to load + await credential_service.save_credential(auth_config, tool_context) + + # Load should work with empty key + result = await credential_service.load_credential(auth_config, tool_context) + assert result == auth_config.exchanged_auth_credential + + @pytest.mark.asyncio + async def test_state_persistence_across_operations( + self, credential_service, auth_config, tool_context + ): + """Test that state persists correctly across multiple operations.""" + # Initially, no credential should exist + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + # Save a credential + await credential_service.save_credential(auth_config, tool_context) + + # Verify it was saved + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result == auth_config.exchanged_auth_credential + + # Update and save again + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="new_client_id", + client_secret="new_client_secret", + redirect_uri="https://new.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + await credential_service.save_credential(auth_config, tool_context) + + # Verify the update persisted + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "new_client_id" + + @pytest.mark.asyncio + async def test_credential_key_uniqueness( + self, credential_service, oauth2_auth_scheme, tool_context + ): + """Test that different credential keys create separate storage slots.""" + # Create credentials with same content but different keys + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="same_client", + client_secret="same_secret", + redirect_uri="https://same.com/callback", + ), + ) + + config_key1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=credential, + exchanged_auth_credential=credential, + credential_key="unique_key_1", + ) + + config_key2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=credential, + exchanged_auth_credential=credential, + credential_key="unique_key_2", + ) + + # Save credential with first key + await credential_service.save_credential(config_key1, tool_context) + + # Verify it's stored under first key + result1 = await credential_service.load_credential( + config_key1, tool_context + ) + assert result1 is not None + + # Verify it's not accessible under second key + result2 = await credential_service.load_credential( + config_key2, tool_context + ) + assert result2 is None + + # Save under second key + await credential_service.save_credential(config_key2, tool_context) + + # Now both should be accessible + result1 = await credential_service.load_credential( + config_key1, tool_context + ) + result2 = await credential_service.load_credential( + config_key2, tool_context + ) + assert result1 is not None + assert result2 is not None + assert result1 == result2 # Same credential content + + def test_direct_state_access( + self, credential_service, auth_config, tool_context + ): + """Test that the service correctly uses tool_context.state for storage.""" + # Verify that the state starts empty + assert len(tool_context.state) == 0 + + # Save a credential (this is async but we're testing the state directly) + credential_key = auth_config.credential_key + test_credential = auth_config.exchanged_auth_credential + + # Directly set the state to simulate save_credential behavior + tool_context.state[credential_key] = test_credential + + # Verify the credential is in the state + assert credential_key in tool_context.state + assert tool_context.state[credential_key] == test_credential + + # Verify we can retrieve it using the get method (simulating load_credential) + retrieved = tool_context.state.get(credential_key) + assert retrieved == test_credential From 120cbabeb23c16d9ce4be511e768885f19a8c2d2 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 23 Jun 2025 12:22:53 -0700 Subject: [PATCH 52/79] refactor: Rename long util function name in runner.py and move it to functions.py PiperOrigin-RevId: 774880990 --- src/google/adk/flows/llm_flows/functions.py | 32 ++++ src/google/adk/runners.py | 37 +--- .../flows/llm_flows/test_functions_simple.py | 136 ++++++++++++++ tests/unittests/test_runners.py | 171 ------------------ 4 files changed, 170 insertions(+), 206 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 2772550c2..5c690f1fd 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -519,3 +519,35 @@ def merge_parallel_function_response_events( # Use the base_event as the timestamp merged_event.timestamp = base_event.timestamp return merged_event + + +def find_matching_function_call( + events: list[Event], +) -> Optional[Event]: + """Finds the function call event that matches the function response id of the last event.""" + if not events: + return None + + last_event = events[-1] + if ( + last_event.content + and last_event.content.parts + and any(part.function_response for part in last_event.content.parts) + ): + + function_call_id = next( + part.function_response.id + for part in last_event.content.parts + if part.function_response + ) + for i in range(len(events) - 2, -1, -1): + event = events[i] + # looking for the system long running request euc function call + function_calls = event.get_function_calls() + if not function_calls: + continue + + for function_call in function_calls: + if function_call.id == function_call_id: + return event + return None diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 936bc5205..017997bb3 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -36,6 +36,7 @@ from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .events.event import Event +from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService from .memory.in_memory_memory_service import InMemoryMemoryService from .platform.thread import create_thread @@ -354,9 +355,7 @@ def _find_agent_to_run( # the agent that returned the corressponding function call regardless the # type of the agent. e.g. a remote a2a agent may surface a credential # request as a special long running function tool call. - event = _find_function_call_event_if_last_event_is_function_response( - session - ) + event = find_matching_function_call(session.events) if event and event.author: return root_agent.find_agent(event.author) for event in filter(lambda e: e.author != 'user', reversed(session.events)): @@ -538,35 +537,3 @@ def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'): session_service=self._in_memory_session_service, memory_service=InMemoryMemoryService(), ) - - -def _find_function_call_event_if_last_event_is_function_response( - session: Session, -) -> Optional[Event]: - events = session.events - if not events: - return None - - last_event = events[-1] - if ( - last_event.content - and last_event.content.parts - and any(part.function_response for part in last_event.content.parts) - ): - - function_call_id = next( - part.function_response.id - for part in last_event.content.parts - if part.function_response - ) - for i in range(len(events) - 2, -1, -1): - event = events[i] - # looking for the system long running request euc function call - function_calls = event.get_function_calls() - if not function_calls: - continue - - for function_call in function_calls: - if function_call.id == function_call_id: - return event - return None diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index 2c5ef9bce..720af516d 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -17,6 +17,9 @@ from typing import Callable from google.adk.agents import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.functions import find_matching_function_call +from google.adk.sessions.session import Session from google.adk.tools import ToolContext from google.adk.tools.function_tool import FunctionTool from google.genai import types @@ -256,3 +259,136 @@ def increase_by_one(x: int) -> int: assert part.function_response.id is None assert events[0].content.parts[0].function_call.id.startswith('adk-') assert events[1].content.parts[0].function_response.id.startswith('adk-') + + +def test_find_function_call_event_no_function_response_in_last_event(): + """Test when last event has no function response.""" + events = [ + Event( + invocation_id='inv1', + author='user', + content=types.Content(role='user', parts=[types.Part(text='Hello')]), + ) + ] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_empty_session_events(): + """Test when session has no events.""" + events = [] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_function_response_but_no_matching_call(): + """Test when last event has function response but no matching call found.""" + # Create a function response + function_response = types.FunctionResponse( + id='func_123', name='test_func', response={} + ) + + events = [ + Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', + parts=[types.Part(text='Some other response')], + ), + ), + Event( + invocation_id='inv2', + author='user', + content=types.Content( + role='user', + parts=[types.Part(function_response=function_response)], + ), + ), + ] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_function_response_with_matching_call(): + """Test when last event has function response with matching function call.""" + # Create a function call + function_call = types.FunctionCall(id='func_123', name='test_func', args={}) + + # Create a function response with matching ID + function_response = types.FunctionResponse( + id='func_123', name='test_func', response={} + ) + + call_event = Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call)] + ), + ) + + response_event = Event( + invocation_id='inv2', + author='user', + content=types.Content( + role='user', parts=[types.Part(function_response=function_response)] + ), + ) + + events = [call_event, response_event] + + result = find_matching_function_call(events) + assert result == call_event + + +def test_find_function_call_event_multiple_function_responses(): + """Test when last event has multiple function responses.""" + # Create function calls + function_call1 = types.FunctionCall(id='func_123', name='test_func1', args={}) + function_call2 = types.FunctionCall(id='func_456', name='test_func2', args={}) + + # Create function responses + function_response1 = types.FunctionResponse( + id='func_123', name='test_func1', response={} + ) + function_response2 = types.FunctionResponse( + id='func_456', name='test_func2', response={} + ) + + call_event1 = Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call1)] + ), + ) + + call_event2 = Event( + invocation_id='inv2', + author='agent2', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call2)] + ), + ) + + response_event = Event( + invocation_id='inv3', + author='user', + content=types.Content( + role='user', + parts=[ + types.Part(function_response=function_response1), + types.Part(function_response=function_response2), + ], + ), + ) + + events = [call_event1, call_event2, response_event] + + # Should return the first matching function call event found + result = find_matching_function_call(events) + assert result == call_event1 # First match (func_123) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 56d7667ab..8d5bd2418 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -18,7 +18,6 @@ from google.adk.agents.llm_agent import LlmAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.events.event import Event -from google.adk.runners import _find_function_call_event_if_last_event_is_function_response from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.session import Session @@ -73,176 +72,6 @@ async def _run_async_impl(self, invocation_context): ) -class TestFindFunctionCallEventIfLastEventIsFunctionResponse: - """Tests for _find_function_call_event_if_last_event_is_function_response function.""" - - def test_no_function_response_in_last_event(self): - """Test when last event has no function response.""" - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[ - Event( - invocation_id="inv1", - author="user", - content=types.Content( - role="user", parts=[types.Part(text="Hello")] - ), - ) - ], - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result is None - - def test_empty_session_events(self): - """Test when session has no events.""" - session = Session( - id="test_session", user_id="test_user", app_name="test_app", events=[] - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result is None - - def test_last_event_has_function_response_but_no_matching_call(self): - """Test when last event has function response but no matching call found.""" - # Create a function response - function_response = types.FunctionResponse( - id="func_123", name="test_func", response={} - ) - - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[ - Event( - invocation_id="inv1", - author="agent1", - content=types.Content( - role="model", - parts=[types.Part(text="Some other response")], - ), - ), - Event( - invocation_id="inv2", - author="user", - content=types.Content( - role="user", - parts=[types.Part(function_response=function_response)], - ), - ), - ], - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result is None - - def test_last_event_has_function_response_with_matching_call(self): - """Test when last event has function response with matching function call.""" - # Create a function call - function_call = types.FunctionCall(id="func_123", name="test_func", args={}) - - # Create a function response with matching ID - function_response = types.FunctionResponse( - id="func_123", name="test_func", response={} - ) - - call_event = Event( - invocation_id="inv1", - author="agent1", - content=types.Content( - role="model", parts=[types.Part(function_call=function_call)] - ), - ) - - response_event = Event( - invocation_id="inv2", - author="user", - content=types.Content( - role="user", parts=[types.Part(function_response=function_response)] - ), - ) - - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[call_event, response_event], - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result == call_event - - def test_last_event_has_multiple_function_responses(self): - """Test when last event has multiple function responses.""" - # Create function calls - function_call1 = types.FunctionCall( - id="func_123", name="test_func1", args={} - ) - function_call2 = types.FunctionCall( - id="func_456", name="test_func2", args={} - ) - - # Create function responses - function_response1 = types.FunctionResponse( - id="func_123", name="test_func1", response={} - ) - function_response2 = types.FunctionResponse( - id="func_456", name="test_func2", response={} - ) - - call_event1 = Event( - invocation_id="inv1", - author="agent1", - content=types.Content( - role="model", parts=[types.Part(function_call=function_call1)] - ), - ) - - call_event2 = Event( - invocation_id="inv2", - author="agent2", - content=types.Content( - role="model", parts=[types.Part(function_call=function_call2)] - ), - ) - - response_event = Event( - invocation_id="inv3", - author="user", - content=types.Content( - role="user", - parts=[ - types.Part(function_response=function_response1), - types.Part(function_response=function_response2), - ], - ), - ) - - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[call_event1, call_event2, response_event], - ) - - # Should return the first matching function call event found - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result == call_event1 # First match (func_123) - - class TestRunnerFindAgentToRun: """Tests for Runner._find_agent_to_run method.""" From fa025d755978e1506fa0da1fecc49775bebc1045 Mon Sep 17 00:00:00 2001 From: Joseph Pagadora Date: Mon, 23 Jun 2025 15:24:15 -0700 Subject: [PATCH 53/79] feat: Add a new option `eval_storage_uri` in adk web & adk eval to specify GCS bucket to store eval data PiperOrigin-RevId: 774947795 --- src/google/adk/cli/cli_tools_click.py | 66 +++++++++++++++++-- src/google/adk/cli/fast_api.py | 17 ++++- src/google/adk/cli/utils/evals.py | 53 +++++++++++++++ .../adk/evaluation/gcs_eval_sets_manager.py | 13 ++-- tests/unittests/cli/test_fast_api.py | 5 +- 5 files changed, 139 insertions(+), 15 deletions(-) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 49ecee482..9923b46c2 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -31,12 +31,15 @@ from . import cli_create from . import cli_deploy from .. import version +from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..sessions.in_memory_session_service import InMemorySessionService from .cli import run_cli from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE from .fast_api import get_fast_api_app from .utils import envs +from .utils import evals from .utils import logs LOG_LEVELS = click.Choice( @@ -282,11 +285,21 @@ def cli_run( default=False, help="Optional. Whether to print detailed results on console or not.", ) +@click.option( + "--eval_storage_uri", + type=str, + help=( + "Optional. The evals storage URI to store agent evals," + " supported URIs: gs://." + ), + default=None, +) def cli_eval( agent_module_file_path: str, - eval_set_file_path: tuple[str], + eval_set_file_path: list[str], config_file_path: str, print_detailed_results: bool, + eval_storage_uri: Optional[str] = None, ): """Evaluates an agent given the eval sets. @@ -338,12 +351,33 @@ def cli_eval( root_agent = get_root_agent(agent_module_file_path) reset_func = try_get_reset_func(agent_module_file_path) + gcs_eval_sets_manager = None + eval_set_results_manager = None + if eval_storage_uri: + gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( + eval_storage_uri + ) + gcs_eval_sets_manager = gcs_eval_managers.eval_sets_manager + eval_set_results_manager = gcs_eval_managers.eval_set_results_manager + else: + eval_set_results_manager = LocalEvalSetResultsManager( + agents_dir=os.path.dirname(agent_module_file_path) + ) eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path) eval_set_id_to_eval_cases = {} # Read the eval_set files and get the cases. for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items(): - eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) + if gcs_eval_sets_manager: + eval_set = gcs_eval_sets_manager._load_eval_set_from_blob( + eval_set_file_path + ) + if not eval_set: + raise click.ClickException( + f"Eval set {eval_set_file_path} not found in GCS." + ) + else: + eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) eval_cases = eval_set.eval_cases if eval_case_ids: @@ -378,16 +412,13 @@ async def _collect_eval_results() -> list[EvalCaseResult]: raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) # Write eval set results. - local_eval_set_results_manager = LocalEvalSetResultsManager( - agents_dir=os.path.dirname(agent_module_file_path) - ) eval_set_id_to_eval_results = collections.defaultdict(list) for eval_case_result in eval_results: eval_set_id = eval_case_result.eval_set_id eval_set_id_to_eval_results[eval_set_id].append(eval_case_result) for eval_set_id, eval_case_results in eval_set_id_to_eval_results.items(): - local_eval_set_results_manager.save_eval_set_result( + eval_set_results_manager.save_eval_set_result( app_name=os.path.basename(agent_module_file_path), eval_set_id=eval_set_id, eval_case_results=eval_case_results, @@ -444,6 +475,15 @@ def decorator(func): ), default=None, ) + @click.option( + "--eval_storage_uri", + type=str, + help=( + "Optional. The evals storage URI to store agent evals," + " supported URIs: gs://." + ), + default=None, + ) @click.option( "--memory_service_uri", type=str, @@ -564,6 +604,7 @@ def wrapper(*args, **kwargs): ) def cli_web( agents_dir: str, + eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, host: str = "127.0.0.1", @@ -616,6 +657,7 @@ async def _lifespan(app: FastAPI): session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, memory_service_uri=memory_service_uri, + eval_storage_uri=eval_storage_uri, allow_origins=allow_origins, web=True, trace_to_cloud=trace_to_cloud, @@ -654,6 +696,7 @@ async def _lifespan(app: FastAPI): ) def cli_api_server( agents_dir: str, + eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, host: str = "127.0.0.1", @@ -685,6 +728,7 @@ def cli_api_server( session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, memory_service_uri=memory_service_uri, + eval_storage_uri=eval_storage_uri, allow_origins=allow_origins, web=False, trace_to_cloud=trace_to_cloud, @@ -771,6 +815,15 @@ def cli_api_server( " version in the dev environment)" ), ) +@click.option( + "--eval_storage_uri", + type=str, + help=( + "Optional. The evals storage URI to store agent evals," + " supported URIs: gs://." + ), + default=None, +) @adk_services_options() @deprecated_adk_services_options() @click.argument( @@ -797,6 +850,7 @@ def cli_deploy_cloud_run( session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, + eval_storage_uri: Optional[str] = None, session_db_url: Optional[str] = None, # Deprecated artifact_storage_uri: Optional[str] = None, # Deprecated ): diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 46e008655..4b2ed6c2e 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -65,6 +65,8 @@ from ..evaluation.eval_metrics import EvalMetricResult from ..evaluation.eval_metrics import EvalMetricResultPerInvocation from ..evaluation.eval_result import EvalSetResult +from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..events.event import Event @@ -198,6 +200,7 @@ def get_fast_api_app( session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, + eval_storage_uri: Optional[str] = None, allow_origins: Optional[list[str]] = None, web: bool, trace_to_cloud: bool = False, @@ -256,8 +259,18 @@ async def internal_lifespan(app: FastAPI): runner_dict = {} - eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) - eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) + # Set up eval managers. + eval_sets_manager = None + eval_set_results_manager = None + if eval_storage_uri: + gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( + eval_storage_uri + ) + eval_sets_manager = gcs_eval_managers.eval_sets_manager + eval_set_results_manager = gcs_eval_managers.eval_set_results_manager + else: + eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) + eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) # Build the Memory service if memory_service_uri: diff --git a/src/google/adk/cli/utils/evals.py b/src/google/adk/cli/utils/evals.py index c8d1a3296..305d47544 100644 --- a/src/google/adk/cli/utils/evals.py +++ b/src/google/adk/cli/utils/evals.py @@ -14,17 +14,36 @@ from __future__ import annotations +import dataclasses +import os from typing import Any from typing import Tuple from google.genai import types as genai_types +from pydantic import alias_generators +from pydantic import BaseModel +from pydantic import ConfigDict from typing_extensions import deprecated from ...evaluation.eval_case import IntermediateData from ...evaluation.eval_case import Invocation +from ...evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from ...evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ...sessions.session import Session +class GcsEvalManagers(BaseModel): + model_config = ConfigDict( + alias_generator=alias_generators.to_camel, + populate_by_name=True, + arbitrary_types_allowed=True, + ) + + eval_sets_manager: GcsEvalSetsManager + + eval_set_results_manager: GcsEvalSetResultsManager + + @deprecated('Use convert_session_to_eval_invocations instead.') def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]: """Converts a session data into eval format. @@ -176,3 +195,37 @@ def convert_session_to_eval_invocations(session: Session) -> list[Invocation]: ) return invocations + + +def create_gcs_eval_managers_from_uri( + eval_storage_uri: str, +) -> GcsEvalManagers: + """Creates GcsEvalManagers from eval_storage_uri. + + Args: + eval_storage_uri: The evals storage URI to use. Supported URIs: + gs://. If a path is provided, the bucket will be extracted. + + Returns: + GcsEvalManagers: The GcsEvalManagers object. + + Raises: + ValueError: If the eval_storage_uri is not supported. + """ + if eval_storage_uri.startswith('gs://'): + gcs_bucket = eval_storage_uri.split('://')[1] + eval_sets_manager = GcsEvalSetsManager( + bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT'] + ) + eval_set_results_manager = GcsEvalSetResultsManager( + bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT'] + ) + return GcsEvalManagers( + eval_sets_manager=eval_sets_manager, + eval_set_results_manager=eval_set_results_manager, + ) + else: + raise ValueError( + f'Unsupported evals storage URI: {eval_storage_uri}. Supported URIs:' + ' gs://' + ) diff --git a/src/google/adk/evaluation/gcs_eval_sets_manager.py b/src/google/adk/evaluation/gcs_eval_sets_manager.py index fe5d8c9b5..c253e4cd5 100644 --- a/src/google/adk/evaluation/gcs_eval_sets_manager.py +++ b/src/google/adk/evaluation/gcs_eval_sets_manager.py @@ -72,6 +72,13 @@ def _validate_id(self, id_name: str, id_value: str): f"Invalid {id_name}. {id_name} should have the `{pattern}` format", ) + def _load_eval_set_from_blob(self, blob_name: str) -> Optional[EvalSet]: + blob = self.bucket.blob(blob_name) + if not blob.exists(): + return None + eval_set_data = blob.download_as_text() + return EvalSet.model_validate_json(eval_set_data) + def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet): """Writes an EvalSet to GCS.""" blob = self.bucket.blob(blob_name) @@ -88,11 +95,7 @@ def _save_eval_set(self, app_name: str, eval_set_id: str, eval_set: EvalSet): def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]: """Returns an EvalSet identified by an app_name and eval_set_id.""" eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id) - blob = self.bucket.blob(eval_set_blob_name) - if not blob.exists(): - return None - eval_set_data = blob.download_as_text() - return EvalSet.model_validate_json(eval_set_data) + return self._load_eval_set_from_blob(eval_set_blob_name) @override def create_eval_set(self, app_name: str, eval_set_id: str): diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 65c1eee3b..aec7a020b 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -40,7 +40,7 @@ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) -logger = logging.getLogger(__name__) +logger = logging.getLogger("google_adk." + __name__) # Here we create a dummy agent module that get_fast_api_app expects @@ -138,6 +138,7 @@ async def mock_run_evals_for_fast_api(*args, **kwargs): final_eval_status=1, # Matches expected (assuming 1 is PASSED) user_id="test_user", # Placeholder, adapt if needed session_id="test_session_for_eval_case", # Placeholder + eval_set_file="test_eval_set_file", # Placeholder overall_eval_metric_results=[{ # Matches expected "metricName": "tool_trajectory_avg_score", "threshold": 0.5, @@ -372,7 +373,7 @@ def add_eval_case(self, app_name, eval_set_id, eval_case): @pytest.fixture def mock_eval_set_results_manager(): - """Create a mock eval set results manager.""" + """Create a mock local eval set results manager.""" # Storage for eval set results. eval_set_results = {} From 9597a446fdec63ad9e4c2692d6966b14f80ff8e2 Mon Sep 17 00:00:00 2001 From: Joseph Pagadora Date: Mon, 23 Jun 2025 15:30:16 -0700 Subject: [PATCH 54/79] feat: Add rouge_score library to ADK eval dependencies, and implement RougeEvaluator that is computes ROUGE-1 for "response_match_score" metric PiperOrigin-RevId: 774949712 --- pyproject.toml | 1 + .../adk/evaluation/final_response_match_v1.py | 110 ++++++++++++++ .../adk/evaluation/response_evaluator.py | 13 +- .../test_final_response_match_v1.py | 140 ++++++++++++++++++ .../evaluation/test_response_evaluator.py | 39 ++++- 5 files changed, 301 insertions(+), 2 deletions(-) create mode 100644 src/google/adk/evaluation/final_response_match_v1.py create mode 100644 tests/unittests/evaluation/test_final_response_match_v1.py diff --git a/pyproject.toml b/pyproject.toml index 8ece4db81..23dbcb537 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ eval = [ "google-cloud-aiplatform[evaluation]>=1.87.0", "pandas>=2.2.3", "tabulate>=0.9.0", + "rouge-score>=0.1.2", # go/keep-sorted end ] diff --git a/src/google/adk/evaluation/final_response_match_v1.py b/src/google/adk/evaluation/final_response_match_v1.py new file mode 100644 index 000000000..a034b470f --- /dev/null +++ b/src/google/adk/evaluation/final_response_match_v1.py @@ -0,0 +1,110 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from google.genai import types as genai_types +from rouge_score import rouge_scorer +from typing_extensions import override + +from .eval_case import Invocation +from .eval_metrics import EvalMetric +from .evaluator import EvalStatus +from .evaluator import EvaluationResult +from .evaluator import Evaluator +from .evaluator import PerInvocationResult + + +class RougeEvaluator(Evaluator): + """Calculates the ROUGE-1 metric to compare responses.""" + + def __init__(self, eval_metric: EvalMetric): + self._eval_metric = eval_metric + + @override + def evaluate_invocations( + self, + actual_invocations: list[Invocation], + expected_invocations: list[Invocation], + ) -> EvaluationResult: + total_score = 0.0 + num_invocations = 0 + per_invocation_results = [] + for actual, expected in zip(actual_invocations, expected_invocations): + reference = _get_text_from_content(expected.final_response) + response = _get_text_from_content(actual.final_response) + rouge_1_scores = _calculate_rouge_1_scores(response, reference) + score = rouge_1_scores.fmeasure + per_invocation_results.append( + PerInvocationResult( + actual_invocation=actual, + expected_invocation=expected, + score=score, + eval_status=_get_eval_status(score, self._eval_metric.threshold), + ) + ) + total_score += score + num_invocations += 1 + + if per_invocation_results: + overall_score = total_score / num_invocations + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=_get_eval_status( + overall_score, self._eval_metric.threshold + ), + per_invocation_results=per_invocation_results, + ) + + return EvaluationResult() + + +def _get_text_from_content(content: Optional[genai_types.Content]) -> str: + if content and content.parts: + return "\n".join([part.text for part in content.parts if part.text]) + + return "" + + +def _get_eval_status(score: float, threshold: float): + return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED + + +def _calculate_rouge_1_scores(candidate: str, reference: str): + """Calculates the ROUGE-1 score between a candidate and reference text. + + ROUGE-1 measures the overlap of unigrams (single words) between the + candidate and reference texts. The score is broken down into: + - Precision: The proportion of unigrams in the candidate that are also in the + reference. + - Recall: The proportion of unigrams in the reference that are also in the + candidate. + - F-measure: The harmonic mean of precision and recall. + + Args: + candidate: The generated text to be evaluated. + reference: The ground-truth text to compare against. + + Returns: + A dictionary containing the ROUGE-1 precision, recall, and f-measure. + """ + scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True) + + # The score method returns a dictionary where keys are the ROUGE types + # and values are Score objects (tuples) with precision, recall, and fmeasure. + scores = scorer.score(reference, candidate) + + return scores["rouge1"] diff --git a/src/google/adk/evaluation/response_evaluator.py b/src/google/adk/evaluation/response_evaluator.py index 52ab50c74..0826f8796 100644 --- a/src/google/adk/evaluation/response_evaluator.py +++ b/src/google/adk/evaluation/response_evaluator.py @@ -27,10 +27,12 @@ from .eval_case import IntermediateData from .eval_case import Invocation +from .eval_metrics import EvalMetric from .evaluator import EvalStatus from .evaluator import EvaluationResult from .evaluator import Evaluator from .evaluator import PerInvocationResult +from .final_response_match_v1 import RougeEvaluator class ResponseEvaluator(Evaluator): @@ -40,7 +42,7 @@ def __init__(self, threshold: float, metric_name: str): if "response_evaluation_score" == metric_name: self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE elif "response_match_score" == metric_name: - self._metric_name = "rouge_1" + self._metric_name = "response_match_score" else: raise ValueError(f"`{metric_name}` is not supported.") @@ -52,6 +54,15 @@ def evaluate_invocations( actual_invocations: list[Invocation], expected_invocations: list[Invocation], ) -> EvaluationResult: + # If the metric is response_match_score, just use the RougeEvaluator. + if self._metric_name == "response_match_score": + rouge_evaluator = RougeEvaluator( + EvalMetric(metric_name=self._metric_name, threshold=self._threshold) + ) + return rouge_evaluator.evaluate_invocations( + actual_invocations, expected_invocations + ) + total_score = 0.0 num_invocations = 0 per_invocation_results = [] diff --git a/tests/unittests/evaluation/test_final_response_match_v1.py b/tests/unittests/evaluation/test_final_response_match_v1.py new file mode 100644 index 000000000..d5544a5a1 --- /dev/null +++ b/tests/unittests/evaluation/test_final_response_match_v1.py @@ -0,0 +1,140 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_metrics import EvalMetric +from google.adk.evaluation.evaluator import EvalStatus +from google.adk.evaluation.final_response_match_v1 import _calculate_rouge_1_scores +from google.adk.evaluation.final_response_match_v1 import RougeEvaluator +from google.genai import types as genai_types +import pytest + + +def _create_test_rouge_evaluator(threshold: float) -> RougeEvaluator: + return RougeEvaluator( + EvalMetric(metric_name="response_match_score", threshold=threshold) + ) + + +def _create_test_invocations( + candidate: str, reference: str +) -> tuple[Invocation, Invocation]: + """Returns tuple of (actual_invocation, expected_invocation).""" + return Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text=candidate)] + ), + ), Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text=reference)] + ), + ) + + +def test_calculate_rouge_1_scores_empty_candidate_and_reference(): + candidate = "" + reference = "" + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == 0 + assert rouge_1_score.recall == 0 + assert rouge_1_score.fmeasure == 0 + + +def test_calculate_rouge_1_scores_empty_candidate(): + candidate = "" + reference = "This is a test reference." + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == 0 + assert rouge_1_score.recall == 0 + assert rouge_1_score.fmeasure == 0 + + +def test_calculate_rouge_1_scores_empty_reference(): + candidate = "This is a test candidate response." + reference = "" + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == 0 + assert rouge_1_score.recall == 0 + assert rouge_1_score.fmeasure == 0 + + +def test_calculate_rouge_1_scores(): + candidate = "This is a test candidate response." + reference = "This is a test reference." + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == pytest.approx(2 / 3) + assert rouge_1_score.recall == pytest.approx(4 / 5) + assert rouge_1_score.fmeasure == pytest.approx(8 / 11) + + +@pytest.mark.parametrize( + "candidates, references, expected_score, expected_status", + [ + ( + ["The quick brown fox jumps.", "hello world"], + ["The quick brown fox jumps over the lazy dog.", "hello"], + 0.69048, # (5/7 + 2/3) / 2 + EvalStatus.FAILED, + ), + ( + ["This is a test.", "Another test case."], + ["This is a test.", "This is a different test."], + 0.625, # (1 + 1/4) / 2 + EvalStatus.FAILED, + ), + ( + ["No matching words here.", "Second candidate."], + ["Completely different text.", "Another reference."], + 0.0, # (0 + 1/2) / 2 + EvalStatus.FAILED, + ), + ( + ["Same words", "Same words"], + ["Same words", "Same words"], + 1.0, + EvalStatus.PASSED, + ), + ], +) +def test_rouge_evaluator_multiple_invocations( + candidates: list[str], + references: list[str], + expected_score: float, + expected_status: EvalStatus, +): + rouge_evaluator = _create_test_rouge_evaluator(threshold=0.8) + actual_invocations = [] + expected_invocations = [] + for candidate, reference in zip(candidates, references): + actual_invocation, expected_invocation = _create_test_invocations( + candidate, reference + ) + actual_invocations.append(actual_invocation) + expected_invocations.append(expected_invocation) + + evaluation_result = rouge_evaluator.evaluate_invocations( + actual_invocations, expected_invocations + ) + assert evaluation_result.overall_score == pytest.approx( + expected_score, rel=1e-3 + ) + assert evaluation_result.overall_eval_status == expected_status diff --git a/tests/unittests/evaluation/test_response_evaluator.py b/tests/unittests/evaluation/test_response_evaluator.py index bbaa694f2..839b7188a 100644 --- a/tests/unittests/evaluation/test_response_evaluator.py +++ b/tests/unittests/evaluation/test_response_evaluator.py @@ -16,7 +16,10 @@ from unittest.mock import MagicMock from unittest.mock import patch +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.evaluator import EvalStatus from google.adk.evaluation.response_evaluator import ResponseEvaluator +from google.genai import types as genai_types import pandas as pd import pytest from vertexai.preview.evaluation import MetricPromptTemplateExamples @@ -63,7 +66,7 @@ "google.adk.evaluation.response_evaluator.ResponseEvaluator._perform_eval" ) class TestResponseEvaluator: - """A class to help organize "patch" that are applicabple to all tests.""" + """A class to help organize "patch" that are applicable to all tests.""" def test_evaluate_none_dataset_raises_value_error(self, mock_perform_eval): """Test evaluate function raises ValueError for an empty list.""" @@ -77,6 +80,40 @@ def test_evaluate_empty_dataset_raises_value_error(self, mock_perform_eval): ResponseEvaluator.evaluate([], ["response_evaluation_score"]) mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called + def test_evaluate_invocations_rouge_metric(self, mock_perform_eval): + """Test evaluate_invocations function for Rouge metric.""" + actual_invocations = [ + Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[ + genai_types.Part(text="This is a test candidate response.") + ] + ), + ) + ] + expected_invocations = [ + Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text="This is a test reference.")] + ), + ) + ] + evaluator = ResponseEvaluator( + threshold=0.8, metric_name="response_match_score" + ) + evaluation_result = evaluator.evaluate_invocations( + actual_invocations, expected_invocations + ) + assert evaluation_result.overall_score == pytest.approx(8 / 11) + # ROUGE-1 F1 is approx. 0.73 < 0.8 threshold, so eval status is FAILED. + assert evaluation_result.overall_eval_status == EvalStatus.FAILED + def test_evaluate_determines_metrics_correctly_for_perform_eval( self, mock_perform_eval ): From 00cc8cd6433fc45ecfc2dbaa04dbbc1a81213b4d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Jun 2025 11:33:05 -0700 Subject: [PATCH 55/79] feat: Add Vertex Express mode compatibility for VertexAiSessionService PiperOrigin-RevId: 775317848 --- .../adk/sessions/vertex_ai_session_service.py | 71 ++++++++++++++----- 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index bd1345162..06a904c89 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -16,6 +16,7 @@ import asyncio import json import logging +import os import re from typing import Any from typing import Dict @@ -23,6 +24,7 @@ import urllib.parse from dateutil import parser +from google.genai.errors import ClientError from typing_extensions import override from google import genai @@ -95,25 +97,46 @@ async def create_session( operation_id = api_response['name'].split('/')[-1] max_retry_attempt = 5 - lro_response = None - while max_retry_attempt >= 0: - lro_response = await api_client.async_request( - http_method='GET', - path=f'operations/{operation_id}', - request_dict={}, - ) - lro_response = _convert_api_response(lro_response) - if lro_response.get('done', None): - break - - await asyncio.sleep(1) - max_retry_attempt -= 1 - - if lro_response is None or not lro_response.get('done', None): - raise TimeoutError( - f'Timeout waiting for operation {operation_id} to complete.' - ) + if _is_vertex_express_mode(self._project, self._location): + # Express mode doesn't support LRO, so we need to poll + # the session resource. + # TODO: remove this once LRO polling is supported in Express mode. + for i in range(max_retry_attempt): + try: + await api_client.async_request( + http_method='GET', + path=( + f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' + ), + request_dict={}, + ) + break + except ClientError as e: + logger.info('Polling for session %s: %s', session_id, e) + # Add slight exponential backoff to avoid excessive polling. + await asyncio.sleep(1 + 0.5 * i) + else: + raise TimeoutError('Session creation failed.') + else: + lro_response = None + for _ in range(max_retry_attempt): + lro_response = await api_client.async_request( + http_method='GET', + path=f'operations/{operation_id}', + request_dict={}, + ) + lro_response = _convert_api_response(lro_response) + + if lro_response.get('done', None): + break + + await asyncio.sleep(1) + + if lro_response is None or not lro_response.get('done', None): + raise TimeoutError( + f'Timeout waiting for operation {operation_id} to complete.' + ) # Get session resource get_session_api_response = await api_client.async_request( @@ -312,6 +335,18 @@ def _get_api_client(self): return client._api_client +def _is_vertex_express_mode( + project: Optional[str], location: Optional[str] +) -> bool: + """Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode.""" + return ( + os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in ['true', '1'] + and os.environ.get('GOOGLE_API_KEY', None) is not None + and project is None + and location is None + ) + + def _convert_api_response(api_response): """Converts the API response to a JSON object based on the type.""" if hasattr(api_response, 'body'): From abc89d2c811ba00805f81b27a3a07d56bdf55a0b Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 24 Jun 2025 11:56:28 -0700 Subject: [PATCH 56/79] feat: Add implementation of VertexAiMemoryBankService and support in FastAPI endpoint PiperOrigin-RevId: 775327151 --- src/google/adk/cli/cli_tools_click.py | 3 +- src/google/adk/cli/fast_api.py | 11 ++ src/google/adk/memory/__init__.py | 4 +- .../memory/vertex_ai_memory_bank_service.py | 147 ++++++++++++++++ .../test_vertex_ai_memory_bank_service.py | 158 ++++++++++++++++++ 5 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 src/google/adk/memory/vertex_ai_memory_bank_service.py create mode 100644 tests/unittests/memory/test_vertex_ai_memory_bank_service.py diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 9923b46c2..c0935cceb 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -489,7 +489,8 @@ def decorator(func): type=str, help=( """Optional. The URI of the memory service. - - Use 'rag://' to connect to Vertex AI Rag Memory Service.""" + - Use 'rag://' to connect to Vertex AI Rag Memory Service. + - Use 'agentengine://' to connect to Vertex AI Memory Bank Service. e.g. agentengine://12345""" ), default=None, ) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4b2ed6c2e..abe1961e7 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -71,6 +71,7 @@ from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..events.event import Event from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService from ..runners import Runner from ..sessions.database_session_service import DatabaseSessionService @@ -282,6 +283,16 @@ async def internal_lifespan(app: FastAPI): memory_service = VertexAiRagMemoryService( rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}' ) + elif memory_service_uri.startswith("agentengine://"): + agent_engine_id = memory_service_uri.split("://")[1] + if not agent_engine_id: + raise click.ClickException("Agent engine id can not be empty.") + envs.load_dotenv_for_agent("", agents_dir) + memory_service = VertexAiMemoryBankService( + project=os.environ["GOOGLE_CLOUD_PROJECT"], + location=os.environ["GOOGLE_CLOUD_LOCATION"], + agent_engine_id=agent_engine_id, + ) else: raise click.ClickException( "Unsupported memory service URI: %s" % memory_service_uri diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index f2ac4f9b5..915d7e517 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -15,12 +15,14 @@ from .base_memory_service import BaseMemoryService from .in_memory_memory_service import InMemoryMemoryService +from .vertex_ai_memory_bank_service import VertexAiMemoryBankService logger = logging.getLogger('google_adk.' + __name__) __all__ = [ 'BaseMemoryService', 'InMemoryMemoryService', + 'VertexAiMemoryBankService', ] try: @@ -29,7 +31,7 @@ __all__.append('VertexAiRagMemoryService') except ImportError: logger.debug( - 'The Vertex sdk is not installed. If you want to use the' + 'The Vertex SDK is not installed. If you want to use the' ' VertexAiRagMemoryService please install it. If not, you can ignore this' ' warning.' ) diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py new file mode 100644 index 000000000..b5b70ab1c --- /dev/null +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import logging +from typing import Optional +from typing import TYPE_CHECKING + +from typing_extensions import override + +from google import genai + +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..sessions.session import Session + +logger = logging.getLogger('google_adk.' + __name__) + + +class VertexAiMemoryBankService(BaseMemoryService): + """Implementation of the BaseMemoryService using Vertex AI Memory Bank.""" + + def __init__( + self, + project: Optional[str] = None, + location: Optional[str] = None, + agent_engine_id: Optional[str] = None, + ): + """Initializes a VertexAiMemoryBankService. + + Args: + project: The project ID of the Memory Bank to use. + location: The location of the Memory Bank to use. + agent_engine_id: The ID of the agent engine to use for the Memory Bank. + e.g. '456' in + 'projects/my-project/locations/us-central1/reasoningEngines/456'. + """ + self._project = project + self._location = location + self._agent_engine_id = agent_engine_id + + @override + async def add_session_to_memory(self, session: Session): + api_client = self._get_api_client() + + if not self._agent_engine_id: + raise ValueError('Agent Engine ID is required for Memory Bank.') + + events = [] + for event in session.events: + if event.content and event.content.parts: + events.append({ + 'content': event.content.model_dump(exclude_none=True, mode='json') + }) + request_dict = { + 'direct_contents_source': { + 'events': events, + }, + 'scope': { + 'app_name': session.app_name, + 'user_id': session.user_id, + }, + } + + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:generate', + request_dict=request_dict, + ) + logger.info(f'Generate memory response: {api_response}') + + @override + async def search_memory(self, *, app_name: str, user_id: str, query: str): + api_client = self._get_api_client() + + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve', + request_dict={ + 'scope': { + 'app_name': app_name, + 'user_id': user_id, + }, + 'similarity_search_params': { + 'search_query': query, + }, + }, + ) + api_response = _convert_api_response(api_response) + logger.info(f'Search memory response: {api_response}') + + if not api_response or not api_response.get('retrievedMemories', None): + return SearchMemoryResponse() + + memory_events = [] + for memory in api_response.get('retrievedMemories', []): + # TODO: add more complex error handling + memory_events.append( + MemoryEntry( + author='user', + content=genai.types.Content( + parts=[ + genai.types.Part(text=memory.get('memory').get('fact')) + ], + role='user', + ), + timestamp=memory.get('updateTime'), + ) + ) + return SearchMemoryResponse(memories=memory_events) + + def _get_api_client(self): + """Instantiates an API client for the given project and location. + + It needs to be instantiated inside each request so that the event loop + management can be properly propagated. + + Returns: + An API client for the given project and location. + """ + client = genai.Client( + vertexai=True, project=self._project, location=self._location + ) + return client._api_client + + +def _convert_api_response(api_response): + """Converts the API response to a JSON object based on the type.""" + if hasattr(api_response, 'body'): + return json.loads(api_response.body) + return api_response diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py new file mode 100644 index 000000000..27e2bbdd5 --- /dev/null +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -0,0 +1,158 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any +from unittest import mock + +from google.adk.events import Event +from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService +from google.adk.sessions import Session +from google.genai import types +import pytest + +MOCK_APP_NAME = 'test-app' +MOCK_USER_ID = 'test-user' + +MOCK_SESSION = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='333', + last_update_time=22333, + events=[ + Event( + id='444', + invocation_id='123', + author='user', + timestamp=12345, + content=types.Content(parts=[types.Part(text='test_content')]), + ), + # Empty event, should be ignored + Event( + id='555', + invocation_id='456', + author='user', + timestamp=12345, + ), + ], +) + + +RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$' +GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$' + + +class MockApiClient: + """Mocks the API Client.""" + + def __init__(self) -> None: + """Initializes MockClient.""" + self.async_request = mock.AsyncMock() + self.async_request.side_effect = self._mock_async_request + + async def _mock_async_request( + self, http_method: str, path: str, request_dict: dict[str, Any] + ): + """Mocks the API Client request method.""" + if http_method == 'POST': + if re.match(GENERATE_MEMORIES_REGEX, path): + return {} + elif re.match(RETRIEVE_MEMORIES_REGEX, path): + if ( + request_dict.get('scope', None) + and request_dict['scope'].get('app_name', None) == MOCK_APP_NAME + ): + return { + 'retrievedMemories': [ + { + 'memory': { + 'fact': 'test_content', + }, + 'updateTime': '2024-12-12T12:12:12.123456Z', + }, + ], + } + else: + return {'retrievedMemories': []} + else: + raise ValueError(f'Unsupported path: {path}') + else: + raise ValueError(f'Unsupported http method: {http_method}') + + +def mock_vertex_ai_memory_bank_service(): + """Creates a mock Vertex AI Memory Bank service for testing.""" + return VertexAiMemoryBankService( + project='test-project', + location='test-location', + agent_engine_id='123', + ) + + +@pytest.fixture +def mock_get_api_client(): + api_client = MockApiClient() + with mock.patch( + 'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client', + return_value=api_client, + ): + yield api_client + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_add_session_to_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_session_to_memory(MOCK_SESSION) + + mock_get_api_client.async_request.assert_awaited_once_with( + http_method='POST', + path='reasoningEngines/123/memories:generate', + request_dict={ + 'direct_contents_source': { + 'events': [ + { + 'content': { + 'parts': [ + {'text': 'test_content'}, + ], + }, + }, + ], + }, + 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + }, + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_search_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + mock_get_api_client.async_request.assert_awaited_once_with( + http_method='POST', + path='reasoningEngines/123/memories:retrieve', + request_dict={ + 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + 'similarity_search_params': {'search_query': 'query'}, + }, + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'test_content' From f33e0903b21b752168db3006dd034d7d43f7e84d Mon Sep 17 00:00:00 2001 From: Genquan Duan Date: Tue, 24 Jun 2025 13:07:57 -0700 Subject: [PATCH 57/79] feat: Add ADK examples for litellm with add_function_to_prompt Add examples for for https://github.com/google/adk-python/issues/1273 PiperOrigin-RevId: 775352677 --- .../__init__.py | 16 ++++ .../agent.py | 78 ++++++++++++++++++ .../main.py | 81 +++++++++++++++++++ src/google/adk/models/lite_llm.py | 8 ++ .../models/test_litellm_with_function.py | 3 - 5 files changed, 183 insertions(+), 3 deletions(-) create mode 100644 contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py create mode 100644 contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py create mode 100644 contributing/samples/hello_world_litellm_add_function_to_prompt/main.py diff --git a/contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py b/contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py new file mode 100644 index 000000000..7d5bb0b1c --- /dev/null +++ b/contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from . import agent diff --git a/contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py b/contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py new file mode 100644 index 000000000..0f10621ae --- /dev/null +++ b/contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random + +from google.adk import Agent +from google.adk.models.lite_llm import LiteLlm +from langchain_core.utils.function_calling import convert_to_openai_function + + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +def check_prime(number: int) -> str: + """Check if a given number is prime. + + Args: + number: The input number to check. + + Returns: + A str indicating the number is prime or not. + """ + if number <= 1: + return f"{number} is not prime." + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + return f"{number} is prime." + else: + return f"{number} is not prime." + + +root_agent = Agent( + model=LiteLlm( + model="vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas", + # If the model is not trained with functions and you would like to + # enable function calling, you can add functions to the models, and the + # functions will be added to the prompts during inferences. + functions=[ + convert_to_openai_function(roll_die), + convert_to_openai_function(check_prime), + ], + ), + name="data_processing_agent", + description="""You are a helpful assistant.""", + instruction=""" + You are a helpful assistant, and call tools optionally. + If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs. + """, + tools=[ + roll_die, + check_prime, + ], +) diff --git a/contributing/samples/hello_world_litellm_add_function_to_prompt/main.py b/contributing/samples/hello_world_litellm_add_function_to_prompt/main.py new file mode 100644 index 000000000..123ba1368 --- /dev/null +++ b/contributing/samples/hello_world_litellm_add_function_to_prompt/main.py @@ -0,0 +1,81 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import time + +import agent +from dotenv import load_dotenv +from google.adk import Runner +from google.adk.artifacts import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.sessions import InMemorySessionService +from google.adk.sessions import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder() + + +async def main(): + app_name = 'my_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_11 = await session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts: + part = event.content.parts[0] + if part.text: + print(f'** {event.author}: {part.text}') + if part.function_call: + print(f'** {event.author} calls tool: {part.function_call}') + if part.function_response: + print( + f'** {event.author} gets tool response: {part.function_response}' + ) + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi, introduce yourself.') + await run_prompt(session_11, 'Roll a die with 100 sides.') + await run_prompt(session_11, 'Check if it is prime.') + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index acc88ed19..624b7adfc 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -29,6 +29,7 @@ from typing import Union from google.genai import types +import litellm from litellm import acompletion from litellm import ChatCompletionAssistantMessage from litellm import ChatCompletionAssistantToolCall @@ -53,6 +54,9 @@ from .llm_request import LlmRequest from .llm_response import LlmResponse +# This will add functions to prompts if functions are provided. +litellm.add_function_to_prompt = True + logger = logging.getLogger("google_adk." + __name__) _NEW_LINE = "\n" @@ -662,6 +666,10 @@ async def generate_content_async( messages, tools, response_format = _get_completion_inputs(llm_request) + if "functions" in self._additional_args: + # LiteLLM does not support both tools and functions together. + tools = None + completion_args = { "model": self.model, "messages": messages, diff --git a/tests/integration/models/test_litellm_with_function.py b/tests/integration/models/test_litellm_with_function.py index 799c55e5c..e0d2bc991 100644 --- a/tests/integration/models/test_litellm_with_function.py +++ b/tests/integration/models/test_litellm_with_function.py @@ -17,11 +17,8 @@ from google.genai import types from google.genai.types import Content from google.genai.types import Part -import litellm import pytest -litellm.add_function_to_prompt = True - _TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """ From a1e14411159fd9f3e114e15b39b4949d0fd6ecb1 Mon Sep 17 00:00:00 2001 From: Liang Wu <18244712+wuliang229@users.noreply.github.com> Date: Tue, 24 Jun 2025 14:26:45 -0700 Subject: [PATCH 58/79] fix: update contributing links Merge https://github.com/google/adk-python/pull/1528 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1528 from google:doc ec8325e126aba7257de73ab26d8d3a30064859b4 PiperOrigin-RevId: 775383121 --- .gitignore | 1 + CONTRIBUTING.md | 2 +- README.md | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 6fb068d48..6f398cbf9 100644 --- a/.gitignore +++ b/.gitignore @@ -82,6 +82,7 @@ log/ .env.development.local .env.test.local .env.production.local +uv.lock # Google Cloud specific .gcloudignore diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c0f3d0069..0d7b2d67d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -200,7 +200,7 @@ For any changes that impact user-facing documentation (guides, API reference, tu ## Contributing Resources -[Contributing folder](https://github.com/google/adk-python/tree/main/contributing/samples) has resources that is helpful for contributors. +[Contributing folder](https://github.com/google/adk-python/tree/main/contributing) has resources that is helpful for contributors. ## Code reviews diff --git a/README.md b/README.md index 7bd5e7401..874658d07 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ adk eval \ ## 🤝 Contributing We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our -- [General contribution guideline and flow](https://google.github.io/adk-docs/contributing-guide/#questions). +- [General contribution guideline and flow](https://google.github.io/adk-docs/contributing-guide/). - Then if you want to contribute code, please read [Code Contributing Guidelines](./CONTRIBUTING.md) to get started. ## 📄 License From ed7a21e1890466fcdf04f7025775305dc71f603d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Jun 2025 14:57:11 -0700 Subject: [PATCH 59/79] chore: Update google-genai package and related deps to latest PiperOrigin-RevId: 775394737 --- pyproject.toml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 23dbcb537..6cf78ab40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start + "PyYAML>=6.0.2", # For APIHubToolset. "anyio>=4.9.0;python_version>='3.10'", # For MCP Session Manager "authlib>=1.5.1", # For RestAPI Tool "click>=8.1.8", # For CLI tools @@ -34,7 +35,7 @@ dependencies = [ "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool "google-cloud-speech>=2.30.0", # For Audio Transcription "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service - "google-genai>=1.17.0", # Google GenAI SDK + "google-genai>=1.21.1", # Google GenAI SDK "graphviz>=0.20.2", # Graphviz for graph rendering "mcp>=1.8.0;python_version>='3.10'", # For MCP Toolset "opentelemetry-api>=1.31.0", # OpenTelemetry @@ -43,7 +44,6 @@ dependencies = [ "pydantic>=2.0, <3.0.0", # For data validation/models "python-dateutil>=2.9.0.post0", # For Vertext AI Session Service "python-dotenv>=1.0.0", # To manage environment variables - "PyYAML>=6.0.2", # For APIHubToolset. "requests>=2.32.4", "sqlalchemy>=2.0", # SQL database ORM "starlette>=0.46.2", # For FastAPI CLI @@ -70,9 +70,9 @@ dev = [ # go/keep-sorted start "flit>=3.10.0", "isort>=6.0.0", + "mypy>=1.15.0", "pyink>=24.10.0", "pylint>=2.6.0", - "mypy>=1.15.0", # go/keep-sorted end ] @@ -98,7 +98,6 @@ test = [ "langgraph>=0.2.60", # For LangGraphAgent "litellm>=1.71.2", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests - "pytest-asyncio>=0.25.0", "pytest-mock>=3.14.0", "pytest-xdist>=3.6.1", From acbdca0d8400e292ba5525931175e0d6feab15f1 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 24 Jun 2025 15:03:23 -0700 Subject: [PATCH 60/79] fix: Make raw_auth_credential and exchanged_auth_credential optional given their default value is None PiperOrigin-RevId: 775397286 --- src/google/adk/auth/auth_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py index 53c571d42..0316e5258 100644 --- a/src/google/adk/auth/auth_tool.py +++ b/src/google/adk/auth/auth_tool.py @@ -31,12 +31,12 @@ class AuthConfig(BaseModelWithConfig): auth_scheme: AuthScheme """The auth scheme used to collect credentials""" - raw_auth_credential: AuthCredential = None + raw_auth_credential: Optional[AuthCredential] = None """The raw auth credential used to collect credentials. The raw auth credentials are used in some auth scheme that needs to exchange auth credentials. e.g. OAuth2 and OIDC. For other auth scheme, it could be None. """ - exchanged_auth_credential: AuthCredential = None + exchanged_auth_credential: Optional[AuthCredential] = None """The exchanged auth credential used to collect credentials. adk and client will work together to fill it. For those auth scheme that doesn't need to exchange auth credentials, e.g. API key, service account etc. It's filled by From 9e473e0abdded24e710fd857782356c15d04b515 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Jun 2025 15:10:49 -0700 Subject: [PATCH 61/79] fix: Include current turn context when include_contents='none' The intended behavior for include_contents='none' is to: - Exclude conversation history from previous turns - Still include current turn context (user input, tool calls/responses within current turn) https://google.github.io/adk-docs/agents/llm-agents/#managing-context-include_contents This resolves https://github.com/google/adk-python/issues/1124 PiperOrigin-RevId: 775400036 --- src/google/adk/agents/llm_agent.py | 8 +- src/google/adk/flows/llm_flows/contents.py | 53 +++- .../agents/test_llm_agent_include_contents.py | 242 ++++++++++++++++++ 3 files changed, 296 insertions(+), 7 deletions(-) create mode 100644 tests/unittests/agents/test_llm_agent_include_contents.py diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index fe145a60e..64c3628df 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -161,10 +161,12 @@ class LlmAgent(BaseAgent): # LLM-based agent transfer configs - End include_contents: Literal['default', 'none'] = 'default' - """Whether to include contents in the model request. + """Controls content inclusion in model requests. - When set to 'none', the model request will not include any contents, such as - user messages, tool results, etc. + Options: + default: Model receives relevant conversation history + none: Model receives no prior history, operates solely on current + instruction and input """ # Controlled input/output configurations - Start diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index ea418888f..039eaf8c5 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -43,12 +43,20 @@ async def run_async( if not isinstance(agent, LlmAgent): return - if agent.include_contents != 'none': + if agent.include_contents == 'default': + # Include full conversation history llm_request.contents = _get_contents( invocation_context.branch, invocation_context.session.events, agent.name, ) + else: + # Include current turn context only (no conversation history) + llm_request.contents = _get_current_turn_contents( + invocation_context.branch, + invocation_context.session.events, + agent.name, + ) # Maintain async generator behavior if False: # Ensures it behaves as a generator @@ -190,13 +198,15 @@ def _get_contents( ) -> list[types.Content]: """Get the contents for the LLM request. + Applies filtering, rearrangement, and content processing to events. + Args: current_branch: The current branch of the agent. - events: A list of events. + events: Events to process. agent_name: The name of the agent. Returns: - A list of contents. + A list of processed contents. """ filtered_events = [] # Parse the events, leaving the contents and the function calls and @@ -211,12 +221,13 @@ def _get_contents( # Skip events without content, or generated neither by user nor by model # or has empty text. # E.g. events purely for mutating session states. + continue if not _is_event_belongs_to_branch(current_branch, event): # Skip events not belong to current branch. continue if _is_auth_event(event): - # skip auth event + # Skip auth events. continue filtered_events.append( _convert_foreign_event(event) @@ -224,12 +235,15 @@ def _get_contents( else event ) + # Rearrange events for proper function call/response pairing result_events = _rearrange_events_for_latest_function_response( filtered_events ) result_events = _rearrange_events_for_async_function_responses_in_history( result_events ) + + # Convert events to contents contents = [] for event in result_events: content = copy.deepcopy(event.content) @@ -238,6 +252,37 @@ def _get_contents( return contents +def _get_current_turn_contents( + current_branch: Optional[str], events: list[Event], agent_name: str = '' +) -> list[types.Content]: + """Get contents for the current turn only (no conversation history). + + When include_contents='none', we want to include: + - The current user input + - Tool calls and responses from the current turn + But exclude conversation history from previous turns. + + In multi-agent scenarios, the "current turn" for an agent starts from an + actual user or from another agent. + + Args: + current_branch: The current branch of the agent. + events: A list of all session events. + agent_name: The name of the agent. + + Returns: + A list of contents for the current turn only, preserving context needed + for proper tool execution while excluding conversation history. + """ + # Find the latest event that starts the current turn and process from there + for i in range(len(events) - 1, -1, -1): + event = events[i] + if event.author == 'user' or _is_other_agent_reply(agent_name, event): + return _get_contents(current_branch, events[i:], agent_name) + + return [] + + def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool: """Whether the event is a reply from another agent.""" return bool( diff --git a/tests/unittests/agents/test_llm_agent_include_contents.py b/tests/unittests/agents/test_llm_agent_include_contents.py new file mode 100644 index 000000000..d4d76cf4e --- /dev/null +++ b/tests/unittests/agents/test_llm_agent_include_contents.py @@ -0,0 +1,242 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for LlmAgent include_contents field behavior.""" + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.genai import types +import pytest + +from .. import testing_utils + + +@pytest.mark.asyncio +async def test_include_contents_default_behavior(): + """Test that include_contents='default' preserves conversation history including tool interactions.""" + + def simple_tool(message: str) -> dict: + return {"result": f"Tool processed: {message}"} + + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + "First response", + types.Part.from_function_call( + name="simple_tool", args={"message": "second"} + ), + "Second response", + ] + ) + + agent = LlmAgent( + name="test_agent", + model=mock_model, + include_contents="default", + instruction="You are a helpful assistant", + tools=[simple_tool], + ) + + runner = testing_utils.InMemoryRunner(agent) + runner.run("First message") + runner.run("Second message") + + # First turn requests + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "First message") + ] + + assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ] + + # Second turn should include full conversation history + assert testing_utils.simplify_contents(mock_model.requests[2].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ("model", "First response"), + ("user", "Second message"), + ] + + # Second turn with tool should include full history + current tool interaction + assert testing_utils.simplify_contents(mock_model.requests[3].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ("model", "First response"), + ("user", "Second message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "second"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: second"} + ), + ), + ] + + +@pytest.mark.asyncio +async def test_include_contents_none_behavior(): + """Test that include_contents='none' excludes conversation history but includes current input.""" + + def simple_tool(message: str) -> dict: + return {"result": f"Tool processed: {message}"} + + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + "First response", + "Second response", + ] + ) + + agent = LlmAgent( + name="test_agent", + model=mock_model, + include_contents="none", + instruction="You are a helpful assistant", + tools=[simple_tool], + ) + + runner = testing_utils.InMemoryRunner(agent) + runner.run("First message") + runner.run("Second message") + + # First turn behavior + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "First message") + ] + + assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ] + + # Second turn should only have current input, no history + assert testing_utils.simplify_contents(mock_model.requests[2].contents) == [ + ("user", "Second message") + ] + + # System instruction and tools should be preserved + assert ( + "You are a helpful assistant" + in mock_model.requests[0].config.system_instruction + ) + assert len(mock_model.requests[0].config.tools) > 0 + + +@pytest.mark.asyncio +async def test_include_contents_none_sequential_agents(): + """Test include_contents='none' with sequential agents.""" + + agent1_model = testing_utils.MockModel.create( + responses=["Agent1 response: XYZ"] + ) + agent1 = LlmAgent( + name="agent1", + model=agent1_model, + instruction="You are Agent1", + ) + + agent2_model = testing_utils.MockModel.create( + responses=["Agent2 final response"] + ) + agent2 = LlmAgent( + name="agent2", + model=agent2_model, + include_contents="none", + instruction="You are Agent2", + ) + + sequential_agent = SequentialAgent( + name="sequential_test_agent", sub_agents=[agent1, agent2] + ) + + runner = testing_utils.InMemoryRunner(sequential_agent) + events = runner.run("Original user request") + + assert len(events) == 2 + assert events[0].author == "agent1" + assert events[1].author == "agent2" + + # Agent1 sees original user request + agent1_contents = testing_utils.simplify_contents( + agent1_model.requests[0].contents + ) + assert ("user", "Original user request") in agent1_contents + + # Agent2 with include_contents='none' should not see original request + agent2_contents = testing_utils.simplify_contents( + agent2_model.requests[0].contents + ) + + assert not any( + "Original user request" in str(content) for _, content in agent2_contents + ) + assert any( + "Agent1 response" in str(content) for _, content in agent2_contents + ) From 09f1269bf7fa46ab4b9324e7f92b4f70ffc923e5 Mon Sep 17 00:00:00 2001 From: Dave Bunten Date: Tue, 24 Jun 2025 15:17:39 -0700 Subject: [PATCH 62/79] ci(tests): leverage official uv action for install Merge https://github.com/google/adk-python/pull/1547 This PR replaces the `curl`-based installation of `uv` to instead use the [official GitHub Action from Astral](https://github.com/astral-sh/setup-uv). Closes #1545 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1547 from d33bs:use-uv-action 05ab7a138cbb5babee30ea81e83f26064e041529 PiperOrigin-RevId: 775402484 --- .github/workflows/python-unit-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 565ee1dca..52e61b8a3 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -36,8 +36,8 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Install uv - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v6 - name: Install dependencies run: | From 88a4402d142672171d0a8ceae74671f47fa14289 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 24 Jun 2025 16:14:52 -0700 Subject: [PATCH 63/79] chore: Do not send api request when session does not have events PiperOrigin-RevId: 775423356 --- .../adk/memory/vertex_ai_memory_bank_service.py | 15 +++++++++------ .../memory/test_vertex_ai_memory_bank_service.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index b5b70ab1c..083b48e8d 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -78,12 +78,15 @@ async def add_session_to_memory(self, session: Session): }, } - api_response = await api_client.async_request( - http_method='POST', - path=f'reasoningEngines/{self._agent_engine_id}/memories:generate', - request_dict=request_dict, - ) - logger.info(f'Generate memory response: {api_response}') + if events: + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:generate', + request_dict=request_dict, + ) + logger.info(f'Generate memory response: {api_response}') + else: + logger.info('No events to add to memory.') @override async def search_memory(self, *, app_name: str, user_id: str, query: str): diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index 27e2bbdd5..2fbf3291c 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -48,6 +48,13 @@ ], ) +MOCK_SESSION_WITH_EMPTY_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='444', + last_update_time=22333, +) + RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$' GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$' @@ -136,6 +143,15 @@ async def test_add_session_to_memory(mock_get_api_client): ) +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_add_empty_session_to_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS) + + mock_get_api_client.async_request.assert_not_called() + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_search_memory(mock_get_api_client): From ef3c745d655538ebd1ed735671be615f842341a8 Mon Sep 17 00:00:00 2001 From: Aditya Mulik Date: Tue, 24 Jun 2025 16:44:00 -0700 Subject: [PATCH 64/79] fix: typo fix in sample agent instruction Merge https://github.com/google/adk-python/pull/1623 fix: minor typo fix in the agent instruction COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1623 from adityamulik:minor_typo_fix 12ea09ae397b5c5e2a9ada48017cd1ca14add72e PiperOrigin-RevId: 775433411 --- contributing/samples/artifact_save_text/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contributing/samples/artifact_save_text/agent.py b/contributing/samples/artifact_save_text/agent.py index 53a7f300d..3ce43bcd1 100755 --- a/contributing/samples/artifact_save_text/agent.py +++ b/contributing/samples/artifact_save_text/agent.py @@ -31,7 +31,7 @@ async def log_query(tool_context: ToolContext, query: str): model='gemini-2.0-flash', name='log_agent', description='Log user query.', - instruction="""Always log the user query and reploy "kk, I've logged." + instruction="""Always log the user query and reply "kk, I've logged." """, tools=[log_query], generate_content_config=types.GenerateContentConfig( From 917a8a19f794ba33fef08898937a73f0ceb809a2 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 24 Jun 2025 16:45:45 -0700 Subject: [PATCH 65/79] chore: Adapt oauth calendar agent to use authenticated tool PiperOrigin-RevId: 775433950 --- .../samples/oauth_calendar_agent/agent.py | 116 ++++++------------ 1 file changed, 35 insertions(+), 81 deletions(-) diff --git a/contributing/samples/oauth_calendar_agent/agent.py b/contributing/samples/oauth_calendar_agent/agent.py index a1b1dea87..3f966b787 100644 --- a/contributing/samples/oauth_calendar_agent/agent.py +++ b/contributing/samples/oauth_calendar_agent/agent.py @@ -13,7 +13,6 @@ # limitations under the License. from datetime import datetime -import json import os from dotenv import load_dotenv @@ -27,8 +26,8 @@ from google.adk.auth import AuthCredentialTypes from google.adk.auth import OAuth2Auth from google.adk.tools import ToolContext +from google.adk.tools.authenticated_function_tool import AuthenticatedFunctionTool from google.adk.tools.google_api_tool import CalendarToolset -from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from googleapiclient.discovery import build @@ -56,6 +55,7 @@ def list_calendar_events( end_time: str, limit: int, tool_context: ToolContext, + credential: AuthCredential, ) -> list[dict]: """Search for calendar events. @@ -80,84 +80,11 @@ def list_calendar_events( Returns: list[dict]: A list of events that match the search criteria. """ - creds = None - - # Check if the tokes were already in the session state, which means the user - # has already gone through the OAuth flow and successfully authenticated and - # authorized the tool to access their calendar. - if "calendar_tool_tokens" in tool_context.state: - creds = Credentials.from_authorized_user_info( - tool_context.state["calendar_tool_tokens"], SCOPES - ) - if not creds or not creds.valid: - # If the access token is expired, refresh it with the refresh token. - if creds and creds.expired and creds.refresh_token: - creds.refresh(Request()) - else: - auth_scheme = OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://accounts.google.com/o/oauth2/auth", - tokenUrl="https://oauth2.googleapis.com/token", - scopes={ - "https://www.googleapis.com/auth/calendar": ( - "See, edit, share, and permanently delete all the" - " calendars you can access using Google Calendar" - ) - }, - ) - ) - ) - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id=oauth_client_id, client_secret=oauth_client_secret - ), - ) - # If the user has not gone through the OAuth flow before, or the refresh - # token also expired, we need to ask users to go through the OAuth flow. - # First we check whether the user has just gone through the OAuth flow and - # Oauth response is just passed back. - auth_response = tool_context.get_auth_response( - AuthConfig( - auth_scheme=auth_scheme, raw_auth_credential=auth_credential - ) - ) - if auth_response: - # ADK exchanged the access token already for us - access_token = auth_response.oauth2.access_token - refresh_token = auth_response.oauth2.refresh_token - - creds = Credentials( - token=access_token, - refresh_token=refresh_token, - token_uri=auth_scheme.flows.authorizationCode.tokenUrl, - client_id=oauth_client_id, - client_secret=oauth_client_secret, - scopes=list(auth_scheme.flows.authorizationCode.scopes.keys()), - ) - else: - # If there are no auth response which means the user has not gone - # through the OAuth flow yet, we need to ask users to go through the - # OAuth flow. - tool_context.request_credential( - AuthConfig( - auth_scheme=auth_scheme, - raw_auth_credential=auth_credential, - ) - ) - # The return value is optional and could be any dict object. It will be - # wrapped in a dict with key as 'result' and value as the return value - # if the object returned is not a dict. This response will be passed - # to LLM to generate a user friendly message. e.g. LLM will tell user: - # "I need your authorization to access your calendar. Please authorize - # me so I can check your meetings for today." - return "Need User Authorization to access their calendar." - # We store the access token and refresh token in the session state for the - # next runs. This is just an example. On production, a tool should store - # those credentials in some secure store or properly encrypt it before store - # it in the session state. - tool_context.state["calendar_tool_tokens"] = json.loads(creds.to_json()) + + creds = Credentials( + token=credential.oauth2.access_token, + refresh_token=credential.oauth2.refresh_token, + ) service = build("calendar", "v3", credentials=creds) events_result = ( @@ -208,6 +135,33 @@ def update_time(callback_context: CallbackContext): Currnet time: {_time} """, - tools=[list_calendar_events, calendar_toolset], + tools=[ + AuthenticatedFunctionTool( + func=list_calendar_events, + auth_config=AuthConfig( + auth_scheme=OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl=( + "https://accounts.google.com/o/oauth2/auth" + ), + tokenUrl="https://oauth2.googleapis.com/token", + scopes={ + "https://www.googleapis.com/auth/calendar": "", + }, + ) + ) + ), + raw_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=oauth_client_id, + client_secret=oauth_client_secret, + ), + ), + ), + ), + calendar_toolset, + ], before_agent_callback=update_time, ) From 6729edd08e427e3be78d8e4665443f5bbabfd635 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 25 Jun 2025 06:04:49 -0700 Subject: [PATCH 66/79] refactor: Rename the Google API based bigquery sample agent This change renames the sample agent based on the Google API based tools to reflect the larger purpose and avoid confusion with the built-in BigQuery tools. In addition, it also renames the root agent in the BigQuery sample agent to "bigquery_agent" PiperOrigin-RevId: 775655226 --- contributing/samples/bigquery/agent.py | 2 +- .../{bigquery_agent => google_api}/README.md | 33 ++++++++----------- .../__init__.py | 0 .../{bigquery_agent => google_api}/agent.py | 2 +- 4 files changed, 16 insertions(+), 21 deletions(-) rename contributing/samples/{bigquery_agent => google_api}/README.md (50%) rename contributing/samples/{bigquery_agent => google_api}/__init__.py (100%) rename contributing/samples/{bigquery_agent => google_api}/agent.py (98%) diff --git a/contributing/samples/bigquery/agent.py b/contributing/samples/bigquery/agent.py index 3cd1eb997..c1b265c00 100644 --- a/contributing/samples/bigquery/agent.py +++ b/contributing/samples/bigquery/agent.py @@ -60,7 +60,7 @@ # debug CLI root_agent = llm_agent.Agent( model="gemini-2.0-flash", - name="hello_agent", + name="bigquery_agent", description=( "Agent to answer questions about BigQuery data and models and execute" " SQL queries." diff --git a/contributing/samples/bigquery_agent/README.md b/contributing/samples/google_api/README.md similarity index 50% rename from contributing/samples/bigquery_agent/README.md rename to contributing/samples/google_api/README.md index c7dc7fd8b..c1e6e8d4c 100644 --- a/contributing/samples/bigquery_agent/README.md +++ b/contributing/samples/google_api/README.md @@ -1,45 +1,40 @@ -# BigQuery Sample +# Google API Tools Sample ## Introduction -This sample tests and demos the BigQuery support in ADK via two tools: +This sample tests and demos Google API tools available in the +`google.adk.tools.google_api_tool` module. We pick the following BigQuery API +tools for this sample agent: -* 1. bigquery_datasets_list: +1. `bigquery_datasets_list`: List user's datasets. - List user's datasets. +2. `bigquery_datasets_get`: Get a dataset's details. -* 2. bigquery_datasets_get: - Get a dataset's details. +3. `bigquery_datasets_insert`: Create a new dataset. -* 3. bigquery_datasets_insert: - Create a new dataset. +4. `bigquery_tables_list`: List all tables in a dataset. -* 4. bigquery_tables_list: - List all tables in a dataset. +5. `bigquery_tables_get`: Get a table's details. -* 5. bigquery_tables_get: - Get a table's details. - -* 6. bigquery_tables_insert: - Insert a new table into a dataset. +6. `bigquery_tables_insert`: Insert a new table into a dataset. ## How to use -* 1. Follow https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name. to get your client id and client secret. +1. Follow https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name. to get your client id and client secret. Be sure to choose "web" as your client type. -* 2. Configure your `.env` file to add two variables: +2. Configure your `.env` file to add two variables: * OAUTH_CLIENT_ID={your client id} * OAUTH_CLIENT_SECRET={your client secret} Note: don't create a separate `.env` file , instead put it to the same `.env` file that stores your Vertex AI or Dev ML credentials -* 3. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs". +3. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs". Note: localhost here is just a hostname that you use to access the dev ui, replace it with the actual hostname you use to access the dev ui. -* 4. For 1st run, allow popup for localhost in Chrome. +4. For 1st run, allow popup for localhost in Chrome. ## Sample prompt diff --git a/contributing/samples/bigquery_agent/__init__.py b/contributing/samples/google_api/__init__.py similarity index 100% rename from contributing/samples/bigquery_agent/__init__.py rename to contributing/samples/google_api/__init__.py diff --git a/contributing/samples/bigquery_agent/agent.py b/contributing/samples/google_api/agent.py similarity index 98% rename from contributing/samples/bigquery_agent/agent.py rename to contributing/samples/google_api/agent.py index 976cea170..1cdbab9c6 100644 --- a/contributing/samples/bigquery_agent/agent.py +++ b/contributing/samples/google_api/agent.py @@ -40,7 +40,7 @@ root_agent = Agent( model="gemini-2.0-flash", - name="bigquery_agent", + name="google_api_bigquery_agent", instruction=""" You are a helpful Google BigQuery agent that help to manage users' data on Google BigQuery. Use the provided tools to conduct various operations on users' data in Google BigQuery. From f54b9b6ad10220ddb2a69e2b951c0bc57a50a8b6 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 09:05:23 -0700 Subject: [PATCH 67/79] chore: Add unit tests for contents.py PiperOrigin-RevId: 775713101 --- .../flows/llm_flows/test_contents.py | 361 ++++++++++++++++++ 1 file changed, 361 insertions(+) create mode 100644 tests/unittests/flows/llm_flows/test_contents.py diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py new file mode 100644 index 000000000..a330852a1 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -0,0 +1,361 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows import contents +from google.adk.flows.llm_flows.contents import _convert_foreign_event +from google.adk.flows.llm_flows.contents import _get_contents +from google.adk.flows.llm_flows.contents import _merge_function_response_events +from google.adk.flows.llm_flows.contents import _rearrange_events_for_async_function_responses_in_history +from google.adk.flows.llm_flows.contents import _rearrange_events_for_latest_function_response +from google.adk.models import LlmRequest +from google.genai import types +import pytest + +from ... import testing_utils + + +@pytest.mark.asyncio +async def test_content_processor_no_contents(): + """Test ContentLlmRequestProcessor when include_contents is 'none'.""" + agent = Agent(model="gemini-1.5-flash", name="agent", include_contents="none") + llm_request = LlmRequest(model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + # Collect events from async generator + events = [] + async for event in contents.request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + # Should not yield any events + assert len(events) == 0 + # Contents should not be set when include_contents is 'none' + assert llm_request.contents == [] + + +@pytest.mark.asyncio +async def test_content_processor_with_contents(): + """Test ContentLlmRequestProcessor when include_contents is not 'none'.""" + agent = Agent(model="gemini-1.5-flash", name="agent") + llm_request = LlmRequest(model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + # Add some test events to the session + test_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + invocation_context.session.events = [test_event] + + # Collect events from async generator + events = [] + async for event in contents.request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + # Should not yield any events (processor doesn't emit events, just modifies request) + assert len(events) == 0 + # Contents should be set + assert llm_request.contents is not None + assert len(llm_request.contents) == 1 + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts[0].text == "Hello" + + +@pytest.mark.asyncio +async def test_content_processor_non_llm_agent(): + """Test ContentLlmRequestProcessor with non-LLM agent.""" + from google.adk.agents.base_agent import BaseAgent + + # Create a base agent (not LLM agent) + agent = BaseAgent(name="base_agent") + llm_request = LlmRequest(model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + # Collect events from async generator + events = [] + async for event in contents.request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + # Should not yield any events and not modify request + assert len(events) == 0 + assert llm_request.contents == [] + + +def test_get_contents_empty_events(): + """Test _get_contents with empty events list.""" + contents_result = _get_contents(None, [], "test_agent") + assert contents_result == [] + + +def test_get_contents_with_events(): + """Test _get_contents with valid events.""" + test_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + + contents_result = _get_contents(None, [test_event], "test_agent") + assert len(contents_result) == 1 + assert contents_result[0].role == "user" + assert contents_result[0].parts[0].text == "Hello" + + +def test_get_contents_filters_empty_events(): + """Test _get_contents filters out events with empty content.""" + # Event with empty text + empty_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content(role="user", parts=[types.Part.from_text(text="")]), + ) + + # Event without content + no_content_event = Event( + invocation_id="test_inv", + author="user", + ) + + # Valid event + valid_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + + contents_result = _get_contents( + None, [empty_event, no_content_event, valid_event], "test_agent" + ) + assert len(contents_result) == 1 + assert contents_result[0].role == "user" + assert contents_result[0].parts[0].text == "Hello" + + +def test_convert_foreign_event(): + """Test _convert_foreign_event function.""" + agent_event = Event( + invocation_id="test_inv", + author="agent1", + content=types.Content( + role="model", parts=[types.Part.from_text(text="Agent response")] + ), + ) + + converted_event = _convert_foreign_event(agent_event) + + assert converted_event.author == "user" + assert converted_event.content.role == "user" + assert len(converted_event.content.parts) == 2 + assert converted_event.content.parts[0].text == "For context:" + assert ( + "[agent1] said: Agent response" in converted_event.content.parts[1].text + ) + + +def test_convert_event_with_function_call(): + """Test _convert_foreign_event with function call.""" + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + agent_event = Event( + invocation_id="test_inv", + author="agent1", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + converted_event = _convert_foreign_event(agent_event) + + assert converted_event.author == "user" + assert converted_event.content.role == "user" + assert len(converted_event.content.parts) == 2 + assert converted_event.content.parts[0].text == "For context:" + assert ( + "[agent1] called tool `test_function`" + in converted_event.content.parts[1].text + ) + assert "{'param': 'value'}" in converted_event.content.parts[1].text + + +def test_convert_event_with_function_response(): + """Test _convert_foreign_event with function response.""" + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + agent_event = Event( + invocation_id="test_inv", + author="agent1", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + converted_event = _convert_foreign_event(agent_event) + + assert converted_event.author == "user" + assert converted_event.content.role == "user" + assert len(converted_event.content.parts) == 2 + assert converted_event.content.parts[0].text == "For context:" + assert ( + "[agent1] `test_function` tool returned result:" + in converted_event.content.parts[1].text + ) + assert "{'result': 'success'}" in converted_event.content.parts[1].text + + +def test_merge_function_response_events(): + """Test _merge_function_response_events function.""" + # Create initial function response event + function_response1 = types.FunctionResponse( + id="func_123", name="test_function", response={"status": "pending"} + ) + + initial_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response1)] + ), + ) + + # Create final function response event + function_response2 = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + final_event = Event( + invocation_id="test_inv2", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response2)] + ), + ) + + merged_event = _merge_function_response_events([initial_event, final_event]) + + assert ( + merged_event.invocation_id == "test_inv" + ) # Should keep initial event ID + assert len(merged_event.content.parts) == 1 + # The first part should be replaced with the final response + assert merged_event.content.parts[0].function_response.response == { + "result": "success" + } + + +def test_rearrange_events_for_async_function_responses(): + """Test _rearrange_events_for_async_function_responses_in_history function.""" + # Create function call event + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + call_event = Event( + invocation_id="test_inv1", + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + # Create function response event + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + response_event = Event( + invocation_id="test_inv2", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + # Test rearrangement + events = [call_event, response_event] + rearranged = _rearrange_events_for_async_function_responses_in_history(events) + + # Should have both events in correct order + assert len(rearranged) == 2 + assert rearranged[0] == call_event + assert rearranged[1] == response_event + + +def test_rearrange_events_for_latest_function_response(): + """Test _rearrange_events_for_latest_function_response function.""" + # Create function call event + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + call_event = Event( + invocation_id="test_inv1", + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + # Create intermediate event + intermediate_event = Event( + invocation_id="test_inv2", + author="agent", + content=types.Content( + role="model", parts=[types.Part.from_text(text="Processing...")] + ), + ) + + # Create function response event + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + response_event = Event( + invocation_id="test_inv3", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + # Test with matching function call and response + events = [call_event, intermediate_event, response_event] + rearranged = _rearrange_events_for_latest_function_response(events) + + # Should remove intermediate events and merge responses + assert len(rearranged) == 2 + assert rearranged[0] == call_event From a623467299e768be93f516a9afb533c32172fd74 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 09:18:34 -0700 Subject: [PATCH 68/79] chore: Enhance a2a context id parsing and construction logic PiperOrigin-RevId: 775718282 --- src/google/adk/a2a/converters/utils.py | 26 ++- .../a2a/converters/test_request_converter.py | 8 +- tests/unittests/a2a/converters/test_utils.py | 213 ++++++++++++++++++ 3 files changed, 239 insertions(+), 8 deletions(-) create mode 100644 tests/unittests/a2a/converters/test_utils.py diff --git a/src/google/adk/a2a/converters/utils.py b/src/google/adk/a2a/converters/utils.py index ecbff1e10..acb2581d4 100644 --- a/src/google/adk/a2a/converters/utils.py +++ b/src/google/adk/a2a/converters/utils.py @@ -16,6 +16,7 @@ ADK_METADATA_KEY_PREFIX = "adk_" ADK_CONTEXT_ID_PREFIX = "ADK" +ADK_CONTEXT_ID_SEPARATOR = "/" def _get_adk_metadata_key(key: str) -> str: @@ -45,8 +46,17 @@ def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str: Returns: The A2A context id. + + Raises: + ValueError: If any of the input parameters are empty or None. """ - return [ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id].join("$") + if not all([app_name, user_id, session_id]): + raise ValueError( + "All parameters (app_name, user_id, session_id) must be non-empty" + ) + return ADK_CONTEXT_ID_SEPARATOR.join( + [ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id] + ) def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: @@ -64,8 +74,16 @@ def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: if not context_id: return None, None, None - prefix, app_name, user_id, session_id = context_id.split("$") - if prefix == "ADK" and app_name and user_id and session_id: - return app_name, user_id, session_id + try: + parts = context_id.split(ADK_CONTEXT_ID_SEPARATOR) + if len(parts) != 4: + return None, None, None + + prefix, app_name, user_id, session_id = parts + if prefix == ADK_CONTEXT_ID_PREFIX and app_name and user_id and session_id: + return app_name, user_id, session_id + except ValueError: + # Handle any split errors gracefully + pass return None, None, None diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py index 02c6400fc..08266751e 100644 --- a/tests/unittests/a2a/converters/test_request_converter.py +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -244,7 +244,7 @@ def test_convert_a2a_request_basic( request = Mock(spec=RequestContext) request.message = mock_message - request.context_id = "ADK$app$user$session" + request.context_id = "ADK/app/user/session" mock_from_context_id.return_value = ( "app_name", @@ -271,7 +271,7 @@ def test_convert_a2a_request_basic( assert isinstance(result["run_config"], RunConfig) # Verify calls - mock_from_context_id.assert_called_once_with("ADK$app$user$session") + mock_from_context_id.assert_called_once_with("ADK/app/user/session") mock_get_user_id.assert_called_once_with(request, "user_from_context") assert mock_convert_part.call_count == 2 mock_convert_part.assert_any_call(mock_part1) @@ -302,7 +302,7 @@ def test_convert_a2a_request_empty_parts( request = Mock(spec=RequestContext) request.message = mock_message - request.context_id = "ADK$app$user$session" + request.context_id = "ADK/app/user/session" mock_from_context_id.return_value = ( "app_name", @@ -431,7 +431,7 @@ def test_end_to_end_conversion_with_auth_user(self, mock_convert_part): request = Mock(spec=RequestContext) request.call_context = mock_call_context request.message = mock_message - request.context_id = "ADK$myapp$context_user$mysession" + request.context_id = "ADK/myapp/context_user/mysession" request.current_task = None request.task_id = "task123" diff --git a/tests/unittests/a2a/converters/test_utils.py b/tests/unittests/a2a/converters/test_utils.py new file mode 100644 index 000000000..f919cbd00 --- /dev/null +++ b/tests/unittests/a2a/converters/test_utils.py @@ -0,0 +1,213 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) + +from google.adk.a2a.converters.utils import _from_a2a_context_id +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.converters.utils import _to_a2a_context_id +from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX +import pytest + + +class TestUtilsFunctions: + """Test suite for utils module functions.""" + + def test_get_adk_metadata_key_success(self): + """Test successful metadata key generation.""" + key = "test_key" + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_get_adk_metadata_key_empty_string(self): + """Test metadata key generation with empty string.""" + with pytest.raises( + ValueError, match="Metadata key cannot be empty or None" + ): + _get_adk_metadata_key("") + + def test_get_adk_metadata_key_none(self): + """Test metadata key generation with None.""" + with pytest.raises( + ValueError, match="Metadata key cannot be empty or None" + ): + _get_adk_metadata_key(None) + + def test_get_adk_metadata_key_whitespace(self): + """Test metadata key generation with whitespace string.""" + key = " " + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_to_a2a_context_id_success(self): + """Test successful context ID generation.""" + app_name = "test-app" + user_id = "test-user" + session_id = "test-session" + + result = _to_a2a_context_id(app_name, user_id, session_id) + + expected = f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user/test-session" + assert result == expected + + def test_to_a2a_context_id_empty_app_name(self): + """Test context ID generation with empty app name.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("", "user", "session") + + def test_to_a2a_context_id_empty_user_id(self): + """Test context ID generation with empty user ID.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("app", "", "session") + + def test_to_a2a_context_id_empty_session_id(self): + """Test context ID generation with empty session ID.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("app", "user", "") + + def test_to_a2a_context_id_none_values(self): + """Test context ID generation with None values.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id(None, "user", "session") + + def test_to_a2a_context_id_special_characters(self): + """Test context ID generation with special characters.""" + app_name = "test-app@2024" + user_id = "user_123" + session_id = "session-456" + + result = _to_a2a_context_id(app_name, user_id, session_id) + + expected = f"{ADK_CONTEXT_ID_PREFIX}/test-app@2024/user_123/session-456" + assert result == expected + + def test_from_a2a_context_id_success(self): + """Test successful context ID parsing.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user/test-session" + + app_name, user_id, session_id = _from_a2a_context_id(context_id) + + assert app_name == "test-app" + assert user_id == "test-user" + assert session_id == "test-session" + + def test_from_a2a_context_id_none_input(self): + """Test context ID parsing with None input.""" + result = _from_a2a_context_id(None) + assert result == (None, None, None) + + def test_from_a2a_context_id_empty_string(self): + """Test context ID parsing with empty string.""" + result = _from_a2a_context_id("") + assert result == (None, None, None) + + def test_from_a2a_context_id_invalid_prefix(self): + """Test context ID parsing with invalid prefix.""" + context_id = "INVALID/test-app/test-user/test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_too_few_parts(self): + """Test context ID parsing with too few parts.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_too_many_parts(self): + """Test context ID parsing with too many parts.""" + context_id = ( + f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user/test-session/extra" + ) + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_empty_components(self): + """Test context ID parsing with empty components.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}//test-user/test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_no_dollar_separator(self): + """Test context ID parsing without dollar separators.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}-test-app-test-user-test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_roundtrip_context_id(self): + """Test roundtrip conversion: to -> from.""" + app_name = "test-app" + user_id = "test-user" + session_id = "test-session" + + # Convert to context ID + context_id = _to_a2a_context_id(app_name, user_id, session_id) + + # Convert back + parsed_app, parsed_user, parsed_session = _from_a2a_context_id(context_id) + + assert parsed_app == app_name + assert parsed_user == user_id + assert parsed_session == session_id + + def test_from_a2a_context_id_special_characters(self): + """Test context ID parsing with special characters.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}/test-app@2024/user_123/session-456" + + app_name, user_id, session_id = _from_a2a_context_id(context_id) + + assert app_name == "test-app@2024" + assert user_id == "user_123" + assert session_id == "session-456" From 5306ddad4dde29748fe9c75e01511fc59e28a8d1 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Wed, 25 Jun 2025 10:18:30 -0700 Subject: [PATCH 69/79] chore: Release 1.5.0 PiperOrigin-RevId: 775742049 --- CHANGELOG.md | 35 +++++++++++++++++++++++++++++++++++ src/google/adk/version.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce36dcdcf..b6bba2692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,40 @@ # Changelog +## [1.5.0](https://github.com/google/adk-python/compare/v1.4.2...v1.5.0) (2025-06-25) + + +### Features + +* Add a new option `eval_storage_uri` in adk web & adk eval to specify GCS bucket to store eval data ([fa025d7](https://github.com/google/adk-python/commit/fa025d755978e1506fa0da1fecc49775bebc1045)) +* Add ADK examples for litellm with add_function_to_prompt ([f33e090](https://github.com/google/adk-python/commit/f33e0903b21b752168db3006dd034d7d43f7e84d)) +* Add implementation of VertexAiMemoryBankService and support in FastAPI endpoint ([abc89d2](https://github.com/google/adk-python/commit/abc89d2c811ba00805f81b27a3a07d56bdf55a0b)) +* Add rouge_score library to ADK eval dependencies, and implement RougeEvaluator that is computes ROUGE-1 for "response_match_score" metric ([9597a44](https://github.com/google/adk-python/commit/9597a446fdec63ad9e4c2692d6966b14f80ff8e2)) +* Add usage span attributes to telemetry ([#356](https://github.com/google/adk-python/issues/356)) ([ea69c90](https://github.com/google/adk-python/commit/ea69c9093a16489afdf72657136c96f61c69cafd)) +* Add Vertex Express mode compatibility for VertexAiSessionService ([00cc8cd](https://github.com/google/adk-python/commit/00cc8cd6433fc45ecfc2dbaa04dbbc1a81213b4d)) + + +### Bug Fixes + +* Include current turn context when include_contents='none' ([9e473e0](https://github.com/google/adk-python/commit/9e473e0abdded24e710fd857782356c15d04b515)) +* Make LiteLLM streaming truly asynchronous ([bd67e84](https://github.com/google/adk-python/commit/bd67e8480f6e8b4b0f8c22b94f15a8cda1336339)) +* Make raw_auth_credential and exchanged_auth_credential optional given their default value is None ([acbdca0](https://github.com/google/adk-python/commit/acbdca0d8400e292ba5525931175e0d6feab15f1)) +* Minor typo fix in the agent instruction ([ef3c745](https://github.com/google/adk-python/commit/ef3c745d655538ebd1ed735671be615f842341a8)) +* Typo fix in sample agent instruction ([ef3c745](https://github.com/google/adk-python/commit/ef3c745d655538ebd1ed735671be615f842341a8)) +* Update contributing links ([a1e1441](https://github.com/google/adk-python/commit/a1e14411159fd9f3e114e15b39b4949d0fd6ecb1)) +* Use starred tuple unpacking on GCS artifact blob names ([3b1d9a8](https://github.com/google/adk-python/commit/3b1d9a8a3e631ca2d86d30f09640497f1728986c)) + + +### Chore + +* Do not send api request when session does not have events ([88a4402](https://github.com/google/adk-python/commit/88a4402d142672171d0a8ceae74671f47fa14289)) +* Leverage official uv action for install([09f1269](https://github.com/google/adk-python/commit/09f1269bf7fa46ab4b9324e7f92b4f70ffc923e5)) +* Update google-genai package and related deps to latest([ed7a21e](https://github.com/google/adk-python/commit/ed7a21e1890466fcdf04f7025775305dc71f603d)) +* Add credential service backed by session state([29cd183](https://github.com/google/adk-python/commit/29cd183aa1b47dc4f5d8afe22f410f8546634abc)) +* Clarify the behavior of Event.invocation_id([f033e40](https://github.com/google/adk-python/commit/f033e405c10ff8d86550d1419a9d63c0099182f9)) +* Send user message to the agent that returned a corresponding function call if user message is a function response([7c670f6](https://github.com/google/adk-python/commit/7c670f638bc17374ceb08740bdd057e55c9c2e12)) +* Add request converter to convert a2a request to ADK request([fb13963](https://github.com/google/adk-python/commit/fb13963deda0ff0650ac27771711ea0411474bf5)) +* Support allow_origins in cloud_run deployment ([2fd8feb](https://github.com/google/adk-python/commit/2fd8feb65d6ae59732fb3ec0652d5650f47132cc)) + ## [1.4.2](https://github.com/google/adk-python/compare/v1.4.1...v1.4.2) (2025-06-20) diff --git a/src/google/adk/version.py b/src/google/adk/version.py index 9accc1025..1c061dd03 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.4.2" +__version__ = "1.5.0" From 738d1a8b84f388dfc761f5818da9467aedc160cf Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Wed, 25 Jun 2025 10:19:06 -0700 Subject: [PATCH 70/79] chore: create an agent to check issue format and content for bugs and feature requests This agent will pose a comment to ask for more information according to the template if necessary. PiperOrigin-RevId: 775742256 --- .github/ISSUE_TEMPLATE/bug_report.md | 3 + .../adk_issue_formatting_agent/__init__.py | 15 ++ .../adk_issue_formatting_agent/agent.py | 241 ++++++++++++++++++ .../adk_issue_formatting_agent/settings.py | 33 +++ .../adk_issue_formatting_agent/utils.py | 53 ++++ 5 files changed, 345 insertions(+) create mode 100644 contributing/samples/adk_issue_formatting_agent/__init__.py create mode 100644 contributing/samples/adk_issue_formatting_agent/agent.py create mode 100644 contributing/samples/adk_issue_formatting_agent/settings.py create mode 100644 contributing/samples/adk_issue_formatting_agent/utils.py diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 7c2ffdd95..f04f3f039 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -31,5 +31,8 @@ If applicable, add screenshots to help explain your problem. - Python version(python -V): - ADK version(pip show google-adk): + **Model Information:** + For example, which model is being used. + **Additional context** Add any other context about the problem here. diff --git a/contributing/samples/adk_issue_formatting_agent/__init__.py b/contributing/samples/adk_issue_formatting_agent/__init__.py new file mode 100644 index 000000000..c48963cdc --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/adk_issue_formatting_agent/agent.py b/contributing/samples/adk_issue_formatting_agent/agent.py new file mode 100644 index 000000000..78add9b83 --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/agent.py @@ -0,0 +1,241 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any + +from adk_issue_formatting_agent.settings import GITHUB_BASE_URL +from adk_issue_formatting_agent.settings import IS_INTERACTIVE +from adk_issue_formatting_agent.settings import OWNER +from adk_issue_formatting_agent.settings import REPO +from adk_issue_formatting_agent.utils import error_response +from adk_issue_formatting_agent.utils import get_request +from adk_issue_formatting_agent.utils import post_request +from adk_issue_formatting_agent.utils import read_file +from google.adk import Agent +import requests + +BUG_REPORT_TEMPLATE = read_file( + Path(__file__).parent / "../../../../.github/ISSUE_TEMPLATE/bug_report.md" +) +FREATURE_REQUEST_TEMPLATE = read_file( + Path(__file__).parent + / "../../../../.github/ISSUE_TEMPLATE/feature_request.md" +) + +APPROVAL_INSTRUCTION = ( + "**Do not** wait or ask for user approval or confirmation for adding the" + " comment." +) +if IS_INTERACTIVE: + APPROVAL_INSTRUCTION = ( + "Ask for user approval or confirmation for adding the comment." + ) + + +def list_open_issues(issue_count: int) -> dict[str, Any]: + """List most recent `issue_count` numer of open issues in the repo. + + Args: + issue_count: number of issues to return + + Returns: + The status of this request, with a list of issues when successful. + """ + url = f"{GITHUB_BASE_URL}/search/issues" + query = f"repo:{OWNER}/{REPO} is:open is:issue" + params = { + "q": query, + "sort": "created", + "order": "desc", + "per_page": issue_count, + "page": 1, + } + + try: + response = get_request(url, params) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + issues = response.get("items", None) + return {"status": "success", "issues": issues} + + +def get_issue(issue_number: int) -> dict[str, Any]: + """Get the details of the specified issue number. + + Args: + issue_number: issue number of the Github issue. + + Returns: + The status of this request, with the issue details when successful. + """ + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}" + try: + response = get_request(url) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + return {"status": "success", "issue": response} + + +def add_comment_to_issue(issue_number: int, comment: str) -> dict[str, any]: + """Add the specified comment to the given issue number. + + Args: + issue_number: issue number of the Github issue + comment: comment to add + + Returns: + The the status of this request, with the applied comment when successful. + """ + print(f"Attempting to add comment '{comment}' to issue #{issue_number}") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}/comments" + payload = {"body": comment} + + try: + response = post_request(url, payload) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + return { + "status": "success", + "added_comment": response, + } + + +def list_comments_on_issue(issue_number: int) -> dict[str, any]: + """List all comments on the given issue number. + + Args: + issue_number: issue number of the Github issue + + Returns: + The the status of this request, with the list of comments when successful. + """ + print(f"Attempting to list comments on issue #{issue_number}") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}/comments" + + try: + response = get_request(url) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + return {"status": "success", "comments": response} + + +root_agent = Agent( + model="gemini-2.5-pro", + name="adk_issue_formatting_assistant", + description="Check ADK issue format and content.", + instruction=f""" + # 1. IDENTITY + You are an AI assistant designed to help maintain the quality and consistency of issues in our GitHub repository. + Your primary role is to act as a "GitHub Issue Format Validator." You will analyze new and existing **open** issues + to ensure they contain all the necessary information as required by our templates. You are helpful, polite, + and precise in your feedback. + + # 2. CONTEXT & RESOURCES + * **Repository:** You are operating on the GitHub repository `{OWNER}/{REPO}`. + * **Bug Report Template:** (`{BUG_REPORT_TEMPLATE}`) + * **Feature Request Template:** (`{FREATURE_REQUEST_TEMPLATE}`) + + # 3. CORE MISSION + Your goal is to check if a GitHub issue, identified as either a "bug" or a "feature request," + contains all the information required by the corresponding template. If it does not, your job is + to post a single, helpful comment asking the original author to provide the missing information. + {APPROVAL_INSTRUCTION} + + **IMPORTANT NOTE:** + * You add one comment at most each time you are invoked. + * Don't proceed to other issues which are not the target issues. + * Don't take any action on closed issues. + + # 4. BEHAVIORAL RULES & LOGIC + + ## Step 1: Identify Issue Type & Applicability + + Your first task is to determine if the issue is a valid target for validation. + + 1. **Assess Content Intent:** You must perform a quick semantic check of the issue's title, body, and comments. + If you determine the issue's content is fundamentally *not* a bug report or a feature request + (for example, it is a general question, a request for help, or a discussion prompt), then you must ignore it. + 2. **Exit Condition:** If the issue does not clearly fall into the categories of "bug" or "feature request" + based on both its labels and its content, **take no action**. + + ## Step 2: Analyze the Issue Content + + If you have determined the issue is a valid bug or feature request, your analysis depends on whether it has comments. + + **Scenario A: Issue has NO comments** + 1. Read the main body of the issue. + 2. Compare the content of the issue body against the required headings/sections in the relevant template (Bug or Feature). + 3. Check for the presence of content under each heading. A heading with no content below it is considered incomplete. + 4. If one or more sections are missing or empty, proceed to Step 3. + 5. If all sections are filled out, your task is complete. Do nothing. + + **Scenario B: Issue HAS one or more comments** + 1. First, analyze the main issue body to see which sections of the template are filled out. + 2. Next, read through **all** the comments in chronological order. + 3. As you read the comments, check if the information provided in them satisfies any of the template sections that were missing from the original issue body. + 4. After analyzing the body and all comments, determine if any required sections from the template *still* remain unaddressed. + 5. If one or more sections are still missing information, proceed to Step 3. + 6. If the issue body and comments *collectively* provide all the required information, your task is complete. Do nothing. + + ## Step 3: Formulate and Post a Comment (If Necessary) + + If you determined in Step 2 that information is missing, you must post a **single comment** on the issue. + + Please include a bolded note in your comment that this comment was added by an ADK agent. + + **Comment Guidelines:** + * **Be Polite and Helpful:** Start with a friendly tone. + * **Be Specific:** Clearly list only the sections from the template that are still missing. Do not list sections that have already been filled out. + * **Address the Author:** Mention the issue author by their username (e.g., `@username`). + * **Provide Context:** Explain *why* the information is needed (e.g., "to help us reproduce the bug" or "to better understand your request"). + * **Do not be repetitive:** If you have already commented on an issue asking for information, do not comment again unless new information has been added and it's still incomplete. + + **Example Comment for a Bug Report:** + > **Response from ADK Agent** + > + > Hello @[issue-author-username], thank you for submitting this issue! + > + > To help us investigate and resolve this bug effectively, could you please provide the missing details for the following sections of our bug report template: + > + > * **To Reproduce:** (Please provide the specific steps required to reproduce the behavior) + > * **Desktop (please complete the following information):** (Please provide OS, Python version, and ADK version) + > + > This information will give us the context we need to move forward. Thanks! + + **Example Comment for a Feature Request:** + > **Response from ADK Agent** + > + > Hi @[issue-author-username], thanks for this great suggestion! + > + > To help our team better understand and evaluate your feature request, could you please provide a bit more information on the following section: + > + > * **Is your feature request related to a problem? Please describe.** + > + > We look forward to hearing more about your idea! + + # 5. FINAL INSTRUCTION + + Execute this process for the given GitHub issue. Your final output should either be **[NO ACTION]** + if the issue is complete or invalid, or **[POST COMMENT]** followed by the exact text of the comment you will post. + + Please include your justification for your decision in your output. + """, + tools={ + list_open_issues, + get_issue, + add_comment_to_issue, + list_comments_on_issue, + }, +) diff --git a/contributing/samples/adk_issue_formatting_agent/settings.py b/contributing/samples/adk_issue_formatting_agent/settings.py new file mode 100644 index 000000000..d29bda9b7 --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/settings.py @@ -0,0 +1,33 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dotenv import load_dotenv + +load_dotenv(override=True) + +GITHUB_BASE_URL = "https://api.github.com" + +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +if not GITHUB_TOKEN: + raise ValueError("GITHUB_TOKEN environment variable not set") + +OWNER = os.getenv("OWNER", "google") +REPO = os.getenv("REPO", "adk-python") +EVENT_NAME = os.getenv("EVENT_NAME") +ISSUE_NUMBER = os.getenv("ISSUE_NUMBER") +ISSUE_COUNT_TO_PROCESS = os.getenv("ISSUE_COUNT_TO_PROCESS") + +IS_INTERACTIVE = os.environ.get("INTERACTIVE", "1").lower() in ["true", "1"] diff --git a/contributing/samples/adk_issue_formatting_agent/utils.py b/contributing/samples/adk_issue_formatting_agent/utils.py new file mode 100644 index 000000000..2ee735d3d --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/utils.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from adk_issue_formatting_agent.settings import GITHUB_TOKEN +import requests + +headers = { + "Authorization": f"token {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +} + + +def get_request( + url: str, params: dict[str, Any] | None = None +) -> dict[str, Any]: + if params is None: + params = {} + response = requests.get(url, headers=headers, params=params, timeout=60) + response.raise_for_status() + return response.json() + + +def post_request(url: str, payload: Any) -> dict[str, Any]: + response = requests.post(url, headers=headers, json=payload, timeout=60) + response.raise_for_status() + return response.json() + + +def error_response(error_message: str) -> dict[str, Any]: + return {"status": "error", "message": error_message} + + +def read_file(file_path: str) -> str: + """Read the content of the given file.""" + try: + with open(file_path, "r") as f: + return f.read() + except FileNotFoundError: + print(f"Error: File not found: {file_path}.") + return "" From 832a6333512eb96fccc3d394939fd947a38fe409 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 13:58:14 -0700 Subject: [PATCH 71/79] chore: Enhance a2a part converters a. fix binary data conversion b. support thoughts, code execution result, executable codes conversion PiperOrigin-RevId: 775827259 --- .../adk/a2a/converters/part_converter.py | 106 +++++- .../a2a/converters/test_part_converter.py | 342 ++++++++++++++++-- 2 files changed, 403 insertions(+), 45 deletions(-) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index c47ac7276..8dab1097d 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -18,6 +18,7 @@ from __future__ import annotations +import base64 import json import logging import sys @@ -43,8 +44,11 @@ logger = logging.getLogger('google_adk.' + __name__) A2A_DATA_PART_METADATA_TYPE_KEY = 'type' +A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY = 'is_long_running' A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = 'function_call' A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response' +A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = 'code_execution_result' +A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = 'executable_code' @working_in_progress @@ -67,7 +71,8 @@ def convert_a2a_part_to_genai_part( elif isinstance(part.file, a2a_types.FileWithBytes): return genai_types.Part( inline_data=genai_types.Blob( - data=part.file.bytes.encode('utf-8'), mime_type=part.file.mimeType + data=base64.b64decode(part.file.bytes), + mime_type=part.file.mimeType, ) ) else: @@ -84,7 +89,11 @@ def convert_a2a_part_to_genai_part( # response. # TODO once A2A defined how to suervice such information, migrate below # logic accordinlgy - if part.metadata and A2A_DATA_PART_METADATA_TYPE_KEY in part.metadata: + if ( + part.metadata + and _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + in part.metadata + ): if ( part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL @@ -103,6 +112,24 @@ def convert_a2a_part_to_genai_part( part.data, by_alias=True ) ) + if ( + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + ): + return genai_types.Part( + code_execution_result=genai_types.CodeExecutionResult.model_validate( + part.data, by_alias=True + ) + ) + if ( + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + ): + return genai_types.Part( + executable_code=genai_types.ExecutableCode.model_validate( + part.data, by_alias=True + ) + ) return genai_types.Part(text=json.dumps(part.data)) logger.warning( @@ -118,27 +145,40 @@ def convert_genai_part_to_a2a_part( part: genai_types.Part, ) -> Optional[a2a_types.Part]: """Convert a Google GenAI Part to an A2A Part.""" + if part.text: - return a2a_types.TextPart(text=part.text) + a2a_part = a2a_types.TextPart(text=part.text) + if part.thought is not None: + a2a_part.metadata = {_get_adk_metadata_key('thought'): part.thought} + return a2a_types.Part(root=a2a_part) if part.file_data: - return a2a_types.FilePart( - file=a2a_types.FileWithUri( - uri=part.file_data.file_uri, - mimeType=part.file_data.mime_type, + return a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=part.file_data.file_uri, + mimeType=part.file_data.mime_type, + ) ) ) if part.inline_data: - return a2a_types.Part( - root=a2a_types.FilePart( - file=a2a_types.FileWithBytes( - bytes=part.inline_data.data, - mimeType=part.inline_data.mime_type, - ) + a2a_part = a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=base64.b64encode(part.inline_data.data).decode('utf-8'), + mimeType=part.inline_data.mime_type, ) ) + if part.video_metadata: + a2a_part.metadata = { + _get_adk_metadata_key( + 'video_metadata' + ): part.video_metadata.model_dump(by_alias=True, exclude_none=True) + } + + return a2a_types.Part(root=a2a_part) + # Conver the funcall and function reponse to A2A DataPart. # This is mainly for converting human in the loop and auth request and # response. @@ -151,9 +191,9 @@ def convert_genai_part_to_a2a_part( by_alias=True, exclude_none=True ), metadata={ - A2A_DATA_PART_METADATA_TYPE_KEY: ( - A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - ) + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL }, ) ) @@ -165,9 +205,37 @@ def convert_genai_part_to_a2a_part( by_alias=True, exclude_none=True ), metadata={ - A2A_DATA_PART_METADATA_TYPE_KEY: ( - A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - ) + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + }, + ) + ) + + if part.code_execution_result: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.code_execution_result.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + }, + ) + ) + + if part.executable_code: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.executable_code.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE }, ) ) diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 4b9bd47cf..1e8f0d4a3 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -21,17 +21,20 @@ # Skip all tests in this module if Python version is less than 3.10 pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" ) # Import dependencies with version checking try: from a2a import types as a2a_types + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part + from google.adk.a2a.converters.utils import _get_adk_metadata_key from google.genai import types as genai_types except ImportError as e: if sys.version_info < (3, 10): @@ -44,9 +47,12 @@ class DummyTypes: genai_types = DummyTypes() A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = "function_call" A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = "function_response" + A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = "code_execution_result" + A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = "executable_code" A2A_DATA_PART_METADATA_TYPE_KEY = "type" convert_a2a_part_to_genai_part = lambda x: None convert_genai_part_to_a2a_part = lambda x: None + _get_adk_metadata_key = lambda x: f"adk_{x}" else: raise e @@ -92,11 +98,14 @@ def test_convert_file_part_with_bytes(self): """Test conversion of A2A FilePart with bytes to GenAI Part.""" # Arrange test_bytes = b"test file content" - # Note: A2A FileWithBytes converts bytes to string automatically + # A2A FileWithBytes expects base64-encoded string + import base64 + + base64_encoded = base64.b64encode(test_bytes).decode("utf-8") a2a_part = a2a_types.Part( root=a2a_types.FilePart( file=a2a_types.FileWithBytes( - bytes=test_bytes, mimeType="text/plain" + bytes=base64_encoded, mimeType="text/plain" ) ) ) @@ -108,7 +117,7 @@ def test_convert_file_part_with_bytes(self): assert result is not None assert isinstance(result, genai_types.Part) assert result.inline_data is not None - # Source code now properly converts A2A string back to bytes for GenAI Blob + # The converter decodes base64 back to original bytes assert result.inline_data.data == test_bytes assert result.inline_data.mime_type == "text/plain" @@ -123,9 +132,9 @@ def test_convert_data_part_function_call(self): root=a2a_types.DataPart( data=function_call_data, metadata={ - A2A_DATA_PART_METADATA_TYPE_KEY: ( - A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - ), + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, }, ) @@ -152,9 +161,9 @@ def test_convert_data_part_function_response(self): root=a2a_types.DataPart( data=function_response_data, metadata={ - A2A_DATA_PART_METADATA_TYPE_KEY: ( - A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - ), + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, }, ) @@ -260,8 +269,25 @@ def test_convert_text_part(self): # Assert assert result is not None - assert isinstance(result, a2a_types.TextPart) - assert result.text == "Hello, world!" + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.TextPart) + assert result.root.text == "Hello, world!" + + def test_convert_text_part_with_thought(self): + """Test conversion of GenAI text Part with thought to A2A Part.""" + # Arrange - thought is a boolean field in genai_types.Part + genai_part = genai_types.Part(text="Hello, world!", thought=True) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.TextPart) + assert result.root.text == "Hello, world!" + assert result.root.metadata is not None + assert result.root.metadata[_get_adk_metadata_key("thought")] == True def test_convert_file_data_part(self): """Test conversion of GenAI file_data Part to A2A Part.""" @@ -277,10 +303,11 @@ def test_convert_file_data_part(self): # Assert assert result is not None - assert isinstance(result, a2a_types.FilePart) - assert isinstance(result.file, a2a_types.FileWithUri) - assert result.file.uri == "gs://bucket/file.txt" - assert result.file.mimeType == "text/plain" + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithUri) + assert result.root.file.uri == "gs://bucket/file.txt" + assert result.root.file.mimeType == "text/plain" def test_convert_inline_data_part(self): """Test conversion of GenAI inline_data Part to A2A Part.""" @@ -298,10 +325,34 @@ def test_convert_inline_data_part(self): assert isinstance(result, a2a_types.Part) assert isinstance(result.root, a2a_types.FilePart) assert isinstance(result.root.file, a2a_types.FileWithBytes) - # A2A FileWithBytes stores bytes as strings - assert result.root.file.bytes == test_bytes.decode("utf-8") + # A2A FileWithBytes now stores base64-encoded bytes to ensure round-trip compatibility + import base64 + + expected_base64 = base64.b64encode(test_bytes).decode("utf-8") + assert result.root.file.bytes == expected_base64 assert result.root.file.mimeType == "text/plain" + def test_convert_inline_data_part_with_video_metadata(self): + """Test conversion of GenAI inline_data Part with video metadata to A2A Part.""" + # Arrange + test_bytes = b"test video content" + video_metadata = genai_types.VideoMetadata(fps=30.0) + genai_part = genai_types.Part( + inline_data=genai_types.Blob(data=test_bytes, mime_type="video/mp4"), + video_metadata=video_metadata, + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithBytes) + assert result.root.metadata is not None + assert _get_adk_metadata_key("video_metadata") in result.root.metadata + def test_convert_function_call_part(self): """Test conversion of GenAI function_call Part to A2A Part.""" # Arrange @@ -320,7 +371,9 @@ def test_convert_function_call_part(self): expected_data = function_call.model_dump(by_alias=True, exclude_none=True) assert result.root.data == expected_data assert ( - result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL ) @@ -344,10 +397,62 @@ def test_convert_function_response_part(self): ) assert result.root.data == expected_data assert ( - result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE ) + def test_convert_code_execution_result_part(self): + """Test conversion of GenAI code_execution_result Part to A2A Part.""" + # Arrange + code_execution_result = genai_types.CodeExecutionResult( + outcome=genai_types.Outcome.OUTCOME_OK, output="Hello, World!" + ) + genai_part = genai_types.Part(code_execution_result=code_execution_result) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = code_execution_result.model_dump( + by_alias=True, exclude_none=True + ) + assert result.root.data == expected_data + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + ) + + def test_convert_executable_code_part(self): + """Test conversion of GenAI executable_code Part to A2A Part.""" + # Arrange + executable_code = genai_types.ExecutableCode( + language=genai_types.Language.PYTHON, code="print('Hello, World!')" + ) + genai_part = genai_types.Part(executable_code=executable_code) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = executable_code.model_dump(by_alias=True, exclude_none=True) + assert result.root.data == expected_data + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + ) + def test_convert_unsupported_part(self): """Test handling of unsupported GenAI Part types.""" # Arrange - Create a GenAI Part with no recognized fields @@ -379,8 +484,9 @@ def test_text_part_round_trip(self): # Assert assert result_a2a_part is not None - assert isinstance(result_a2a_part, a2a_types.TextPart) - assert result_a2a_part.text == original_text + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.TextPart) + assert result_a2a_part.root.text == original_text def test_file_uri_round_trip(self): """Test round-trip conversion for file parts with URI.""" @@ -401,10 +507,122 @@ def test_file_uri_round_trip(self): # Assert assert result_a2a_part is not None - assert isinstance(result_a2a_part, a2a_types.FilePart) - assert isinstance(result_a2a_part.file, a2a_types.FileWithUri) - assert result_a2a_part.file.uri == original_uri - assert result_a2a_part.file.mimeType == original_mime_type + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.FilePart) + assert isinstance(result_a2a_part.root.file, a2a_types.FileWithUri) + assert result_a2a_part.root.file.uri == original_uri + assert result_a2a_part.root.file.mimeType == original_mime_type + + def test_file_bytes_round_trip(self): + """Test round-trip conversion for file parts with bytes.""" + # Arrange + original_bytes = b"test file content for round trip" + original_mime_type = "application/octet-stream" + + # Start with GenAI part (the more common starting point) + genai_part = genai_types.Part( + inline_data=genai_types.Blob( + data=original_bytes, mime_type=original_mime_type + ) + ) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.inline_data is not None + assert result_genai_part.inline_data.data == original_bytes + assert result_genai_part.inline_data.mime_type == original_mime_type + + def test_function_call_round_trip(self): + """Test round-trip conversion for function call parts.""" + # Arrange + function_call = genai_types.FunctionCall( + name="test_function", args={"param1": "value1", "param2": 42} + ) + genai_part = genai_types.Part(function_call=function_call) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.function_call is not None + assert result_genai_part.function_call.name == function_call.name + assert result_genai_part.function_call.args == function_call.args + + def test_function_response_round_trip(self): + """Test round-trip conversion for function response parts.""" + # Arrange + function_response = genai_types.FunctionResponse( + name="test_function", response={"result": "success", "data": [1, 2, 3]} + ) + genai_part = genai_types.Part(function_response=function_response) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.function_response is not None + assert result_genai_part.function_response.name == function_response.name + assert ( + result_genai_part.function_response.response + == function_response.response + ) + + def test_code_execution_result_round_trip(self): + """Test round-trip conversion for code execution result parts.""" + # Arrange + code_execution_result = genai_types.CodeExecutionResult( + outcome=genai_types.Outcome.OUTCOME_OK, output="Hello, World!" + ) + genai_part = genai_types.Part(code_execution_result=code_execution_result) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.code_execution_result is not None + assert ( + result_genai_part.code_execution_result.outcome + == code_execution_result.outcome + ) + assert ( + result_genai_part.code_execution_result.output + == code_execution_result.output + ) + + def test_executable_code_round_trip(self): + """Test round-trip conversion for executable code parts.""" + # Arrange + executable_code = genai_types.ExecutableCode( + language=genai_types.Language.PYTHON, code="print('Hello, World!')" + ) + genai_part = genai_types.Part(executable_code=executable_code) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.executable_code is not None + assert ( + result_genai_part.executable_code.language == executable_code.language + ) + assert result_genai_part.executable_code.code == executable_code.code class TestEdgeCases: @@ -468,3 +686,75 @@ def test_data_part_with_empty_metadata(self): # Assert assert result is not None assert result.text == json.dumps(data) + + +class TestNewConstants: + """Test cases for new constants and functionality.""" + + def test_new_constants_exist(self): + """Test that new constants are defined.""" + assert ( + A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + == "code_execution_result" + ) + assert A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE == "executable_code" + + def test_convert_a2a_data_part_with_code_execution_result_metadata(self): + """Test conversion of A2A DataPart with code execution result metadata.""" + # Arrange + code_execution_result_data = { + "outcome": "OUTCOME_OK", + "output": "Hello, World!", + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=code_execution_result_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + # Now it should convert back to a proper CodeExecutionResult + assert result.code_execution_result is not None + assert ( + result.code_execution_result.outcome == genai_types.Outcome.OUTCOME_OK + ) + assert result.code_execution_result.output == "Hello, World!" + + def test_convert_a2a_data_part_with_executable_code_metadata(self): + """Test conversion of A2A DataPart with executable code metadata.""" + # Arrange + executable_code_data = { + "language": "PYTHON", + "code": "print('Hello, World!')", + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=executable_code_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + # Now it should convert back to a proper ExecutableCode + assert result.executable_code is not None + assert result.executable_code.language == genai_types.Language.PYTHON + assert result.executable_code.code == "print('Hello, World!')" From a71dbdf9e24a44f18b23be082ef9cb14d9b03692 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 15:31:42 -0700 Subject: [PATCH 72/79] chore: Enhance a2a event converter a. fix function call long running id matching logic b. fix error code conversion logic c. add input required and auth required status conversion logic d. add a2a Task/Message to ADK event converter f. set task id and context id from input argument PiperOrigin-RevId: 775860563 --- .../adk/a2a/converters/event_converter.py | 288 ++++++- .../a2a/converters/test_event_converter.py | 701 +++++++++++++++++- 2 files changed, 918 insertions(+), 71 deletions(-) diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py index 5594c0e63..25183f6be 100644 --- a/src/google/adk/a2a/converters/event_converter.py +++ b/src/google/adk/a2a/converters/event_converter.py @@ -14,7 +14,8 @@ from __future__ import annotations -import datetime +from datetime import datetime +from datetime import timezone import logging from typing import Any from typing import Dict @@ -26,18 +27,24 @@ from a2a.types import Artifact from a2a.types import DataPart from a2a.types import Message +from a2a.types import Part as A2APart from a2a.types import Role +from a2a.types import Task from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart +from google.genai import types as genai_types from ...agents.invocation_context import InvocationContext from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME from ...utils.feature_decorator import working_in_progress +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import convert_a2a_part_to_genai_part from .part_converter import convert_genai_part_to_a2a_part from .utils import _get_adk_metadata_key @@ -143,6 +150,8 @@ def _convert_artifact_to_a2a_events( invocation_context: InvocationContext, filename: str, version: int, + task_id: Optional[str] = None, + context_id: Optional[str] = None, ) -> TaskArtifactUpdateEvent: """Converts a new artifact version to an A2A TaskArtifactUpdateEvent. @@ -151,6 +160,7 @@ def _convert_artifact_to_a2a_events( invocation_context: The invocation context. filename: The name of the artifact file. version: The version number of the artifact. + task_id: Optional task ID to use for generated events. If not provided, new UUIDs will be generated. Returns: A TaskArtifactUpdateEvent representing the artifact update. @@ -186,9 +196,9 @@ def _convert_artifact_to_a2a_events( ) return TaskArtifactUpdateEvent( - taskId=str(uuid.uuid4()), + taskId=task_id, append=False, - contextId=invocation_context.session.id, + contextId=context_id, lastChunk=True, artifact=Artifact( artifactId=artifact_id, @@ -210,7 +220,7 @@ def _convert_artifact_to_a2a_events( raise RuntimeError(f"Artifact conversion failed: {e}") from e -def _process_long_running_tool(a2a_part, event: Event) -> None: +def _process_long_running_tool(a2a_part: A2APart, event: Event) -> None: """Processes long-running tool metadata for an A2A part. Args: @@ -220,18 +230,173 @@ def _process_long_running_tool(a2a_part, event: Event) -> None: if ( isinstance(a2a_part.root, DataPart) and event.long_running_tool_ids + and a2a_part.root.metadata and a2a_part.root.metadata.get( _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) ) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - and a2a_part.root.metadata.get("id") in event.long_running_tool_ids + and a2a_part.root.data.get("id") in event.long_running_tool_ids ): - a2a_part.root.metadata[_get_adk_metadata_key("is_long_running")] = True + a2a_part.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ] = True @working_in_progress -def convert_event_to_a2a_status_message( - event: Event, invocation_context: InvocationContext +def convert_a2a_task_to_event( + a2a_task: Task, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, +) -> Event: + """Converts an A2A task to an ADK event. + + Args: + a2a_task: The A2A task to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + + Returns: + An ADK Event object representing the converted task. + + Raises: + ValueError: If a2a_task is None. + RuntimeError: If conversion of the underlying message fails. + """ + if a2a_task is None: + raise ValueError("A2A task cannot be None") + + try: + # Extract message from task status or history + message = None + if a2a_task.status and a2a_task.status.message: + message = a2a_task.status.message + elif a2a_task.history: + message = a2a_task.history[-1] + + # Convert message if available + if message: + try: + return convert_a2a_message_to_event(message, author, invocation_context) + except Exception as e: + logger.error("Failed to convert A2A task message to event: %s", e) + raise RuntimeError(f"Failed to convert task message: {e}") from e + + # Create minimal event if no message is available + return Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + ) + + except Exception as e: + logger.error("Failed to convert A2A task to event: %s", e) + raise + + +@working_in_progress +def convert_a2a_message_to_event( + a2a_message: Message, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, +) -> Event: + """Converts an A2A message to an ADK event. + + Args: + a2a_message: The A2A message to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + + Returns: + An ADK Event object with converted content and long-running tool metadata. + + Raises: + ValueError: If a2a_message is None. + RuntimeError: If conversion of message parts fails. + """ + if a2a_message is None: + raise ValueError("A2A message cannot be None") + + if not a2a_message.parts: + logger.warning( + "A2A message has no parts, creating event with empty content" + ) + return Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + content=genai_types.Content(role="model", parts=[]), + ) + + try: + parts = [] + long_running_tool_ids = set() + + for a2a_part in a2a_message.parts: + try: + part = convert_a2a_part_to_genai_part(a2a_part) + if part is None: + logger.warning("Failed to convert A2A part, skipping: %s", a2a_part) + continue + + # Check for long-running tools + if ( + a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY + ) + ) + is True + ): + long_running_tool_ids.add(part.function_call.id) + + parts.append(part) + + except Exception as e: + logger.error("Failed to convert A2A part: %s, error: %s", a2a_part, e) + # Continue processing other parts instead of failing completely + continue + + if not parts: + logger.warning( + "No parts could be converted from A2A message %s", a2a_message + ) + + return Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + long_running_tool_ids=long_running_tool_ids + if long_running_tool_ids + else None, + content=genai_types.Content( + role="model", + parts=parts, + ), + ) + + except Exception as e: + logger.error("Failed to convert A2A message to event: %s", e) + raise RuntimeError(f"Failed to convert message: {e}") from e + + +@working_in_progress +def convert_event_to_a2a_message( + event: Event, invocation_context: InvocationContext, role: Role = Role.agent ) -> Optional[Message]: """Converts an ADK event to an A2A message. @@ -262,9 +427,7 @@ def convert_event_to_a2a_status_message( _process_long_running_tool(a2a_part, event) if a2a_parts: - return Message( - messageId=str(uuid.uuid4()), role=Role.agent, parts=a2a_parts - ) + return Message(messageId=str(uuid.uuid4()), role=role, parts=a2a_parts) except Exception as e: logger.error("Failed to convert event to status message: %s", e) @@ -274,38 +437,57 @@ def convert_event_to_a2a_status_message( def _create_error_status_event( - event: Event, invocation_context: InvocationContext + event: Event, + invocation_context: InvocationContext, + task_id: Optional[str] = None, + context_id: Optional[str] = None, ) -> TaskStatusUpdateEvent: """Creates a TaskStatusUpdateEvent for error scenarios. Args: event: The ADK event containing error information. invocation_context: The invocation context. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. Returns: A TaskStatusUpdateEvent with FAILED state. """ error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE + # Get context metadata and add error code + event_metadata = _get_context_metadata(event, invocation_context) + if event.error_code: + event_metadata[_get_adk_metadata_key("error_code")] = str(event.error_code) + return TaskStatusUpdateEvent( - taskId=str(uuid.uuid4()), - contextId=invocation_context.session.id, - final=False, - metadata=_get_context_metadata(event, invocation_context), + taskId=task_id, + contextId=context_id, + metadata=event_metadata, status=TaskStatus( state=TaskState.failed, message=Message( messageId=str(uuid.uuid4()), role=Role.agent, parts=[TextPart(text=error_message)], + metadata={ + _get_adk_metadata_key("error_code"): str(event.error_code) + } + if event.error_code + else {}, ), - timestamp=datetime.datetime.now().isoformat(), + timestamp=datetime.now(timezone.utc).isoformat(), ), + final=False, ) -def _create_running_status_event( - message: Message, invocation_context: InvocationContext, event: Event +def _create_status_update_event( + message: Message, + invocation_context: InvocationContext, + event: Event, + task_id: Optional[str] = None, + context_id: Optional[str] = None, ) -> TaskStatusUpdateEvent: """Creates a TaskStatusUpdateEvent for running scenarios. @@ -313,32 +495,70 @@ def _create_running_status_event( message: The A2A message to include. invocation_context: The invocation context. event: The ADK event. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + Returns: A TaskStatusUpdateEvent with RUNNING state. """ + status = TaskStatus( + state=TaskState.working, + message=message, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + if any( + part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ) + is True + and part.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME + for part in message.parts + if part.root.metadata + ): + status.state = TaskState.auth_required + elif any( + part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ) + is True + for part in message.parts + if part.root.metadata + ): + status.state = TaskState.input_required + return TaskStatusUpdateEvent( - taskId=str(uuid.uuid4()), - contextId=invocation_context.session.id, - final=False, - status=TaskStatus( - state=TaskState.working, - message=message, - timestamp=datetime.datetime.now().isoformat(), - ), + taskId=task_id, + contextId=context_id, + status=status, metadata=_get_context_metadata(event, invocation_context), + final=False, ) @working_in_progress def convert_event_to_a2a_events( - event: Event, invocation_context: InvocationContext + event: Event, + invocation_context: InvocationContext, + task_id: Optional[str] = None, + context_id: Optional[str] = None, ) -> List[A2AEvent]: """Converts a GenAI event to a list of A2A events. Args: event: The ADK event to convert. invocation_context: The invocation context. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. Returns: A list of A2A events representing the converted ADK event. @@ -358,20 +578,22 @@ def convert_event_to_a2a_events( if event.actions.artifact_delta: for filename, version in event.actions.artifact_delta.items(): artifact_event = _convert_artifact_to_a2a_events( - event, invocation_context, filename, version + event, invocation_context, filename, version, task_id, context_id ) a2a_events.append(artifact_event) # Handle error scenarios if event.error_code: - error_event = _create_error_status_event(event, invocation_context) + error_event = _create_error_status_event( + event, invocation_context, task_id, context_id + ) a2a_events.append(error_event) # Handle regular message content - message = convert_event_to_a2a_status_message(event, invocation_context) + message = convert_event_to_a2a_message(event, invocation_context) if message: - running_event = _create_running_status_event( - message, invocation_context, event + running_event = _create_status_update_event( + message, invocation_context, event, task_id, context_id ) a2a_events.append(running_event) diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py index 311ffc954..2ba8e26b4 100644 --- a/tests/unittests/a2a/converters/test_event_converter.py +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -20,7 +20,7 @@ # Skip all tests in this module if Python version is less than 3.10 pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" ) # Import dependencies with version checking @@ -28,20 +28,21 @@ from a2a.types import DataPart from a2a.types import Message from a2a.types import Role + from a2a.types import Task from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatusUpdateEvent from google.adk.a2a.converters.event_converter import _convert_artifact_to_a2a_events from google.adk.a2a.converters.event_converter import _create_artifact_id from google.adk.a2a.converters.event_converter import _create_error_status_event - from google.adk.a2a.converters.event_converter import _create_running_status_event + from google.adk.a2a.converters.event_converter import _create_status_update_event from google.adk.a2a.converters.event_converter import _get_adk_metadata_key from google.adk.a2a.converters.event_converter import _get_context_metadata from google.adk.a2a.converters.event_converter import _process_long_running_tool from google.adk.a2a.converters.event_converter import _serialize_metadata_value from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events - from google.adk.a2a.converters.event_converter import convert_event_to_a2a_status_message + from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX from google.adk.agents.invocation_context import InvocationContext @@ -57,13 +58,14 @@ class DummyTypes: DataPart = DummyTypes() Message = DummyTypes() Role = DummyTypes() + Task = DummyTypes() TaskArtifactUpdateEvent = DummyTypes() TaskState = DummyTypes() TaskStatusUpdateEvent = DummyTypes() _convert_artifact_to_a2a_events = lambda *args: None _create_artifact_id = lambda *args: None _create_error_status_event = lambda *args: None - _create_running_status_event = lambda *args: None + _create_status_update_event = lambda *args: None _get_adk_metadata_key = lambda *args: None _get_context_metadata = lambda *args: None _process_long_running_tool = lambda *args: None @@ -71,7 +73,7 @@ class DummyTypes: ADK_METADATA_KEY_PREFIX = "adk_" ARTIFACT_ID_SEPARATOR = "_" convert_event_to_a2a_events = lambda *args: None - convert_event_to_a2a_status_message = lambda *args: None + convert_event_to_a2a_message = lambda *args: None DEFAULT_ERROR_MESSAGE = "error" InvocationContext = DummyTypes() Event = DummyTypes() @@ -233,6 +235,8 @@ def test_convert_artifact_to_a2a_events_success(self, mock_convert_part): """Test successful artifact delta conversion.""" filename = "test.txt" version = 1 + task_id = "test-task-id" + context_id = "test-context-id" mock_artifact_part = Mock() # Create a proper Part that Pydantic will accept @@ -246,11 +250,17 @@ def test_convert_artifact_to_a2a_events_success(self, mock_convert_part): mock_convert_part.return_value = mock_converted_part result = _convert_artifact_to_a2a_events( - self.mock_event, self.mock_invocation_context, filename, version + self.mock_event, + self.mock_invocation_context, + filename, + version, + task_id, + context_id, ) assert isinstance(result, TaskArtifactUpdateEvent) - assert result.contextId == self.mock_invocation_context.session.id + assert result.taskId == task_id + assert result.contextId == context_id assert result.append is False assert result.lastChunk is True @@ -265,7 +275,7 @@ def test_convert_artifact_to_a2a_events_empty_filename(self): """Test artifact delta conversion with empty filename.""" with pytest.raises(ValueError) as exc_info: _convert_artifact_to_a2a_events( - self.mock_event, self.mock_invocation_context, "", 1 + self.mock_event, self.mock_invocation_context, "", 1, "", "" ) assert "Filename cannot be empty" in str(exc_info.value) @@ -273,7 +283,7 @@ def test_convert_artifact_to_a2a_events_negative_version(self): """Test artifact delta conversion with negative version.""" with pytest.raises(ValueError) as exc_info: _convert_artifact_to_a2a_events( - self.mock_event, self.mock_invocation_context, "test.txt", -1 + self.mock_event, self.mock_invocation_context, "test.txt", -1, "", "" ) assert "Version must be non-negative" in str(exc_info.value) @@ -293,7 +303,12 @@ def test_convert_artifact_to_a2a_events_conversion_failure( with pytest.raises(RuntimeError) as exc_info: _convert_artifact_to_a2a_events( - self.mock_event, self.mock_invocation_context, filename, version + self.mock_event, + self.mock_invocation_context, + filename, + version, + "", + "", ) assert "Failed to convert artifact part" in str(exc_info.value) @@ -302,6 +317,8 @@ def test_process_long_running_tool_marks_tool(self): mock_a2a_part = Mock() mock_data_part = Mock(spec=DataPart) mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-123"} + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="tool-123") mock_a2a_part.root = mock_data_part self.mock_event.long_running_tool_ids = {"tool-123"} @@ -315,7 +332,11 @@ def test_process_long_running_tool_marks_tool(self): "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", "function_call", ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, ): + mock_get_key.side_effect = lambda key: f"adk_{key}" _process_long_running_tool(mock_a2a_part, self.mock_event) @@ -327,6 +348,8 @@ def test_process_long_running_tool_no_marking(self): mock_a2a_part = Mock() mock_data_part = Mock(spec=DataPart) mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-456"} + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="tool-456") mock_a2a_part.root = mock_data_part self.mock_event.long_running_tool_ids = {"tool-123"} # Different ID @@ -340,7 +363,11 @@ def test_process_long_running_tool_no_marking(self): "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", "function_call", ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, ): + mock_get_key.side_effect = lambda key: f"adk_{key}" _process_long_running_tool(mock_a2a_part, self.mock_event) @@ -368,7 +395,7 @@ def test_convert_event_to_message_success(self, mock_uuid, mock_convert_part): mock_content.parts = [mock_part] self.mock_event.content = mock_content - result = convert_event_to_a2a_status_message( + result = convert_event_to_a2a_message( self.mock_event, self.mock_invocation_context ) @@ -382,7 +409,7 @@ def test_convert_event_to_message_no_content(self): """Test event to message conversion with no content.""" self.mock_event.content = None - result = convert_event_to_a2a_status_message( + result = convert_event_to_a2a_message( self.mock_event, self.mock_invocation_context ) @@ -394,7 +421,7 @@ def test_convert_event_to_message_empty_parts(self): mock_content.parts = [] self.mock_event.content = mock_content - result = convert_event_to_a2a_status_message( + result = convert_event_to_a2a_message( self.mock_event, self.mock_invocation_context ) @@ -403,17 +430,50 @@ def test_convert_event_to_message_empty_parts(self): def test_convert_event_to_message_none_event(self): """Test event to message conversion with None event.""" with pytest.raises(ValueError) as exc_info: - convert_event_to_a2a_status_message(None, self.mock_invocation_context) + convert_event_to_a2a_message(None, self.mock_invocation_context) assert "Event cannot be None" in str(exc_info.value) def test_convert_event_to_message_none_context(self): """Test event to message conversion with None context.""" with pytest.raises(ValueError) as exc_info: - convert_event_to_a2a_status_message(self.mock_event, None) + convert_event_to_a2a_message(self.mock_event, None) assert "Invocation context cannot be None" in str(exc_info.value) + @patch( + "google.adk.a2a.converters.event_converter.convert_genai_part_to_a2a_part" + ) @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") - @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + def test_convert_event_to_message_with_custom_role( + self, mock_uuid, mock_convert_part + ): + """Test event to message conversion with custom role.""" + mock_uuid.return_value = "test-uuid" + + mock_part = Mock() + # Create a proper Part that Pydantic will accept + from a2a.types import Part + from a2a.types import TextPart + + text_part = TextPart(text="test message") + mock_a2a_part = Part(root=text_part) + mock_convert_part.return_value = mock_a2a_part + + mock_content = Mock() + mock_content.parts = [mock_part] + self.mock_event.content = mock_content + + result = convert_event_to_a2a_message( + self.mock_event, self.mock_invocation_context, role=Role.user + ) + + assert isinstance(result, Message) + assert result.messageId == "test-uuid" + assert result.role == Role.user + assert len(result.parts) == 1 + assert result.parts[0].root.text == "test message" + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + @patch("google.adk.a2a.converters.event_converter.datetime") def test_create_error_status_event(self, mock_datetime, mock_uuid): """Test creation of error status event.""" mock_uuid.return_value = "test-uuid" @@ -422,18 +482,21 @@ def test_create_error_status_event(self, mock_datetime, mock_uuid): ) self.mock_event.error_message = "Test error message" + task_id = "test-task-id" + context_id = "test-context-id" result = _create_error_status_event( - self.mock_event, self.mock_invocation_context + self.mock_event, self.mock_invocation_context, task_id, context_id ) assert isinstance(result, TaskStatusUpdateEvent) - assert result.contextId == self.mock_invocation_context.session.id + assert result.taskId == task_id + assert result.contextId == context_id assert result.status.state == TaskState.failed assert result.status.message.parts[0].root.text == "Test error message" @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") - @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + @patch("google.adk.a2a.converters.event_converter.datetime") def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid): """Test creation of error status event without error message.""" mock_uuid.return_value = "test-uuid" @@ -441,13 +504,16 @@ def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid): "2023-01-01T00:00:00" ) + task_id = "test-task-id" + context_id = "test-context-id" + result = _create_error_status_event( - self.mock_event, self.mock_invocation_context + self.mock_event, self.mock_invocation_context, task_id, context_id ) assert result.status.message.parts[0].root.text == DEFAULT_ERROR_MESSAGE - @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + @patch("google.adk.a2a.converters.event_converter.datetime") def test_create_running_status_event(self, mock_datetime): """Test creation of running status event.""" mock_datetime.now.return_value.isoformat.return_value = ( @@ -455,13 +521,21 @@ def test_create_running_status_event(self, mock_datetime): ) mock_message = Mock(spec=Message) - - result = _create_running_status_event( - mock_message, self.mock_invocation_context, self.mock_event + mock_message.parts = [] + task_id = "test-task-id" + context_id = "test-context-id" + + result = _create_status_update_event( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, ) assert isinstance(result, TaskStatusUpdateEvent) - assert result.contextId == self.mock_invocation_context.session.id + assert result.taskId == task_id + assert result.contextId == context_id assert result.status.state == TaskState.working assert result.status.message == mock_message @@ -469,11 +543,11 @@ def test_create_running_status_event(self, mock_datetime): "google.adk.a2a.converters.event_converter._convert_artifact_to_a2a_events" ) @patch( - "google.adk.a2a.converters.event_converter.convert_event_to_a2a_status_message" + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" ) @patch("google.adk.a2a.converters.event_converter._create_error_status_event") @patch( - "google.adk.a2a.converters.event_converter._create_running_status_event" + "google.adk.a2a.converters.event_converter._create_status_update_event" ) def test_convert_event_to_a2a_events_full_scenario( self, @@ -514,14 +588,14 @@ def test_convert_event_to_a2a_events_full_scenario( # Verify artifact delta events assert mock_convert_artifact.call_count == 2 - # Verify error event + # Verify error event - now called with task_id and context_id parameters mock_create_error.assert_called_once_with( - self.mock_event, self.mock_invocation_context + self.mock_event, self.mock_invocation_context, None, None ) - # Verify running event + # Verify running event - now called with task_id and context_id parameters mock_create_running.assert_called_once_with( - mock_message, self.mock_invocation_context, self.mock_event + mock_message, self.mock_invocation_context, self.mock_event, None, None ) # Verify result contains all events @@ -552,7 +626,7 @@ def test_convert_event_to_a2a_events_none_context(self): assert "Invocation context cannot be None" in str(exc_info.value) @patch( - "google.adk.a2a.converters.event_converter.convert_event_to_a2a_status_message" + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" ) def test_convert_event_to_a2a_events_message_only(self, mock_convert_message): """Test event to A2A events conversion with message only.""" @@ -560,7 +634,7 @@ def test_convert_event_to_a2a_events_message_only(self, mock_convert_message): mock_convert_message.return_value = mock_message with patch( - "google.adk.a2a.converters.event_converter._create_running_status_event" + "google.adk.a2a.converters.event_converter._create_status_update_event" ) as mock_create_running: mock_running_event = Mock() mock_create_running.return_value = mock_running_event @@ -571,15 +645,23 @@ def test_convert_event_to_a2a_events_message_only(self, mock_convert_message): assert len(result) == 1 assert result[0] == mock_running_event + # Verify the function is called with task_id and context_id parameters + mock_create_running.assert_called_once_with( + mock_message, + self.mock_invocation_context, + self.mock_event, + None, + None, + ) @patch("google.adk.a2a.converters.event_converter.logger") def test_convert_event_to_a2a_events_exception_handling(self, mock_logger): - """Test exception handling in event to A2A events conversion.""" - # Make convert_event_to_a2a_status_message raise an exception + """Test exception handling in convert_event_to_a2a_events.""" + # Make convert_event_to_a2a_message raise an exception with patch( - "google.adk.a2a.converters.event_converter.convert_event_to_a2a_status_message" - ) as mock_convert: - mock_convert.side_effect = Exception("Conversion failed") + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.side_effect = Exception("Test exception") with pytest.raises(Exception): convert_event_to_a2a_events( @@ -587,3 +669,546 @@ def test_convert_event_to_a2a_events_exception_handling(self, mock_logger): ) mock_logger.error.assert_called_once() + + def test_convert_event_to_a2a_events_with_task_id_and_context_id(self): + """Test event to A2A events conversion with specific task_id and context_id.""" + # Setup message + mock_message = Mock(spec=Message) + mock_message.parts = [] + + with patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.return_value = mock_message + + with patch( + "google.adk.a2a.converters.event_converter._create_status_update_event" + ) as mock_create_running: + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + task_id = "custom-task-id" + context_id = "custom-context-id" + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context, task_id, context_id + ) + + assert len(result) == 1 + assert result[0] == mock_running_event + + # Verify the function is called with the specific task_id and context_id + mock_create_running.assert_called_once_with( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + def test_convert_event_to_a2a_events_with_artifacts_and_custom_ids(self): + """Test event to A2A events conversion with artifacts and custom IDs.""" + # Setup artifact delta + self.mock_event.actions.artifact_delta = {"file1.txt": 1} + + # Setup message + mock_message = Mock(spec=Message) + mock_message.parts = [] + + with patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.return_value = mock_message + + with patch( + "google.adk.a2a.converters.event_converter._convert_artifact_to_a2a_events" + ) as mock_convert_artifact: + mock_artifact_event = Mock() + mock_convert_artifact.return_value = mock_artifact_event + + with patch( + "google.adk.a2a.converters.event_converter._create_status_update_event" + ) as mock_create_running: + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + task_id = "custom-task-id" + context_id = "custom-context-id" + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context, task_id, context_id + ) + + assert len(result) == 2 # 1 artifact + 1 status + assert mock_artifact_event in result + assert mock_running_event in result + + # Verify artifact conversion is called with custom IDs + mock_convert_artifact.assert_called_once_with( + self.mock_event, + self.mock_invocation_context, + "file1.txt", + 1, + task_id, + context_id, + ) + + # Verify status update is called with custom IDs + mock_create_running.assert_called_once_with( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + def test_create_status_update_event_with_auth_required_state(self): + """Test creation of status update event with auth_required state.""" + from a2a.types import DataPart + from a2a.types import Part + + # Create a mock message with a part that triggers auth_required state + mock_message = Mock(spec=Message) + mock_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = { + "adk_type": "function_call", + "adk_is_long_running": True, + } + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="request_euc") + mock_part.root = mock_data_part + mock_message.parts = [mock_part] + + task_id = "test-task-id" + context_id = "test-context-id" + + with patch( + "google.adk.a2a.converters.event_converter.datetime" + ) as mock_datetime: + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY", + "is_long_running", + ), + patch( + "google.adk.a2a.converters.event_converter.REQUEST_EUC_FUNCTION_CALL_NAME", + "request_euc", + ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, + ): + mock_get_key.side_effect = lambda key: f"adk_{key}" + + result = _create_status_update_event( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.taskId == task_id + assert result.contextId == context_id + assert result.status.state == TaskState.auth_required + + def test_create_status_update_event_with_input_required_state(self): + """Test creation of status update event with input_required state.""" + from a2a.types import DataPart + from a2a.types import Part + + # Create a mock message with a part that triggers input_required state + mock_message = Mock(spec=Message) + mock_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = { + "adk_type": "function_call", + "adk_is_long_running": True, + } + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="some_other_function") + mock_part.root = mock_data_part + mock_message.parts = [mock_part] + + task_id = "test-task-id" + context_id = "test-context-id" + + with patch( + "google.adk.a2a.converters.event_converter.datetime" + ) as mock_datetime: + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY", + "is_long_running", + ), + patch( + "google.adk.a2a.converters.event_converter.REQUEST_EUC_FUNCTION_CALL_NAME", + "request_euc", + ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, + ): + mock_get_key.side_effect = lambda key: f"adk_{key}" + + result = _create_status_update_event( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.taskId == task_id + assert result.contextId == context_id + assert result.status.state == TaskState.input_required + + +class TestA2AToEventConverters: + """Test suite for A2A to Event conversion functions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_invocation_context = Mock(spec=InvocationContext) + self.mock_invocation_context.branch = "test-branch" + self.mock_invocation_context.invocation_id = "test-invocation-id" + + def test_convert_a2a_task_to_event_with_status_message(self): + """Test converting A2A task with status message.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock message and task + mock_message = Mock(spec=Message) + mock_status = Mock() + mock_status.message = mock_message + mock_task = Mock(spec=Task) + mock_task.status = mock_status + mock_task.history = [] + + # Mock the convert_a2a_message_to_event function + with patch( + "google.adk.a2a.converters.event_converter.convert_a2a_message_to_event" + ) as mock_convert_message: + mock_event = Mock(spec=Event) + mock_event.invocation_id = "test-invocation-id" + mock_convert_message.return_value = mock_event + + result = convert_a2a_task_to_event( + mock_task, "test-author", self.mock_invocation_context + ) + + # Verify the message converter was called with correct parameters + mock_convert_message.assert_called_once_with( + mock_message, "test-author", self.mock_invocation_context + ) + assert result == mock_event + assert result.invocation_id == "test-invocation-id" + + def test_convert_a2a_task_to_event_with_history_message(self): + """Test converting A2A task with history message when no status message.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock message and task + mock_message = Mock(spec=Message) + mock_task = Mock(spec=Task) + mock_task.status = None + mock_task.history = [mock_message] + + # Mock the convert_a2a_message_to_event function + with patch( + "google.adk.a2a.converters.event_converter.convert_a2a_message_to_event" + ) as mock_convert_message: + mock_event = Mock(spec=Event) + mock_event.invocation_id = "test-invocation-id" + mock_convert_message.return_value = mock_event + + result = convert_a2a_task_to_event(mock_task, "test-author") + + # Verify the message converter was called with correct parameters + mock_convert_message.assert_called_once_with( + mock_message, "test-author", None + ) + assert result == mock_event + + def test_convert_a2a_task_to_event_no_message(self): + """Test converting A2A task with no message.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock task with no message + mock_task = Mock(spec=Task) + mock_task.status = None + mock_task.history = [] + + result = convert_a2a_task_to_event( + mock_task, "test-author", self.mock_invocation_context + ) + + # Verify minimal event was created with correct invocation_id + assert result.author == "test-author" + assert result.branch == "test-branch" + assert result.invocation_id == "test-invocation-id" + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + def test_convert_a2a_task_to_event_default_author(self, mock_uuid): + """Test converting A2A task with default author and no invocation context.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock task with no message + mock_task = Mock(spec=Task) + mock_task.status = None + mock_task.history = [] + + # Mock UUID generation + mock_uuid.return_value = "generated-uuid" + + result = convert_a2a_task_to_event(mock_task) + + # Verify default author was used and UUID was generated for invocation_id + assert result.author == "a2a agent" + assert result.branch is None + assert result.invocation_id == "generated-uuid" + + def test_convert_a2a_task_to_event_none_task(self): + """Test converting None task raises ValueError.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + with pytest.raises(ValueError, match="A2A task cannot be None"): + convert_a2a_task_to_event(None) + + def test_convert_a2a_task_to_event_message_conversion_error(self): + """Test error handling when message conversion fails.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock message and task + mock_message = Mock(spec=Message) + mock_status = Mock() + mock_status.message = mock_message + mock_task = Mock(spec=Task) + mock_task.status = mock_status + mock_task.history = [] + + # Mock the convert_a2a_message_to_event function to raise an exception + with patch( + "google.adk.a2a.converters.event_converter.convert_a2a_message_to_event" + ) as mock_convert_message: + mock_convert_message.side_effect = Exception("Conversion failed") + + with pytest.raises(RuntimeError, match="Failed to convert task message"): + convert_a2a_task_to_event(mock_task, "test-author") + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_success(self, mock_convert_part): + """Test successful conversion of A2A message to event.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + from google.genai import types as genai_types + + # Create mock parts and message with valid genai Part + mock_a2a_part = Mock() + mock_genai_part = genai_types.Part(text="test content") + mock_convert_part.return_value = mock_genai_part + + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify conversion was successful + assert result.author == "test-author" + assert result.branch == "test-branch" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 1 + assert result.content.parts[0].text == "test content" + mock_convert_part.assert_called_once_with(mock_a2a_part) + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_with_long_running_tools( + self, mock_convert_part + ): + """Test conversion with long-running tools by mocking the entire flow.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + # Create mock parts and message + mock_a2a_part = Mock() + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + # Mock the part conversion to return None to simulate long-running tool detection logic + mock_convert_part.return_value = None + + # Patch the long-running tool detection since the main logic is in the actual conversion + with patch( + "google.adk.a2a.converters.event_converter.logger" + ) as mock_logger: + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify basic conversion worked + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + # Parts will be empty since conversion returned None, but that's expected for this test + + def test_convert_a2a_message_to_event_empty_parts(self): + """Test conversion with empty parts list.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + mock_message = Mock(spec=Message) + mock_message.parts = [] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify event was created with empty parts + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 0 + + def test_convert_a2a_message_to_event_none_message(self): + """Test converting None message raises ValueError.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + with pytest.raises(ValueError, match="A2A message cannot be None"): + convert_a2a_message_to_event(None) + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_part_conversion_fails( + self, mock_convert_part + ): + """Test handling when part conversion returns None.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + # Setup mock to return None (conversion failure) + mock_a2a_part = Mock() + mock_convert_part.return_value = None + + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify event was created but with no parts + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 0 + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_part_conversion_exception( + self, mock_convert_part + ): + """Test handling when part conversion raises exception.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + from google.genai import types as genai_types + + # Setup mock to raise exception + mock_a2a_part1 = Mock() + mock_a2a_part2 = Mock() + mock_genai_part = genai_types.Part(text="successful conversion") + + mock_convert_part.side_effect = [ + Exception("Conversion failed"), # First part fails + mock_genai_part, # Second part succeeds + ] + + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part1, mock_a2a_part2] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify event was created with only the successfully converted part + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 1 + assert result.content.parts[0].text == "successful conversion" + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_missing_tool_id( + self, mock_convert_part + ): + """Test handling of message conversion when part conversion fails.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + # Create mock parts and message + mock_a2a_part = Mock() + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + # Mock the part conversion to return None + mock_convert_part.return_value = None + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify basic conversion worked + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + # Parts will be empty since conversion returned None + assert len(result.content.parts) == 0 + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + def test_convert_a2a_message_to_event_default_author(self, mock_uuid): + """Test conversion with default author and no invocation context.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + mock_message = Mock(spec=Message) + mock_message.parts = [] + + # Mock UUID generation + mock_uuid.return_value = "generated-uuid" + + result = convert_a2a_message_to_event(mock_message) + + # Verify default author was used and UUID was generated for invocation_id + assert result.author == "a2a agent" + assert result.branch is None + assert result.invocation_id == "generated-uuid" From 3901fade71486a1e9677fe74a120c3f08efe9d9e Mon Sep 17 00:00:00 2001 From: SimonWei <119845914+simonwei97@users.noreply.github.com> Date: Wed, 25 Jun 2025 17:40:11 -0700 Subject: [PATCH 73/79] fix: converts litellm generate config err Merge https://github.com/google/adk-python/pull/1509 Fixs: #1302 Previous PR: https://github.com/google/adk-python/pull/1450 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1509 from simonwei97:fix/litellm-gen-config-converting-err 3120887f29a21789f1b4a7c54af3ed35eb5055e3 PiperOrigin-RevId: 775903671 --- src/google/adk/models/lite_llm.py | 60 +++++++++++++++++++++----- tests/unittests/models/test_litellm.py | 32 ++++++++++++++ 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 624b7adfc..39514d6f4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,6 +23,7 @@ from typing import Dict from typing import Generator from typing import Iterable +from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -485,16 +486,22 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> tuple[Iterable[Message], Iterable[dict]]: - """Converts an LlmRequest to litellm inputs. +) -> Tuple[ + List[Message], + Optional[List[Dict]], + Optional[types.SchemaUnion], + Optional[Dict], +]: + """Converts an LlmRequest to litellm inputs and extracts generation params. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary and response format). + The litellm inputs (message list, tool dictionary, response format and generation params). """ - messages = [] + # 1. Construct messages + messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -511,7 +518,8 @@ def _get_completion_inputs( ), ) - tools = None + # 2. Convert tool declarations + tools: Optional[List[Dict]] = None if ( llm_request.config and llm_request.config.tools @@ -522,12 +530,39 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - response_format = None - - if llm_request.config.response_schema: + # 3. Handle response format + response_format: Optional[types.SchemaUnion] = None + if llm_request.config and llm_request.config.response_schema: response_format = llm_request.config.response_schema - return messages, tools, response_format + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: + config_dict = llm_request.config.model_dump(exclude_none=True) + # Generate LiteLlm parameters here, + # Following https://docs.litellm.ai/docs/completion/input. + generation_params = {} + param_mapping = { + "max_output_tokens": "max_completion_tokens", + "stop_sequences": "stop", + } + for key in ( + "temperature", + "max_output_tokens", + "top_p", + "top_k", + "stop_sequences", + "presence_penalty", + "frequency_penalty", + ): + if key in config_dict: + mapped_key = param_mapping.get(key, key) + generation_params[mapped_key] = config_dict[key] + + if not generation_params: + generation_params = None + + return messages, tools, response_format, generation_params def _build_function_declaration_log( @@ -664,7 +699,9 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format = _get_completion_inputs(llm_request) + messages, tools, response_format, generation_params = ( + _get_completion_inputs(llm_request) + ) if "functions" in self._additional_args: # LiteLLM does not support both tools and functions together. @@ -678,6 +715,9 @@ async def generate_content_async( } completion_args.update(self._additional_args) + if generation_params: + completion_args.update(generation_params) + if stream: text = "" # Track function calls by index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index d058aa44d..b9b1fb409 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1447,3 +1447,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +def test_get_completion_inputs_generation_params(): + # Test that generation_params are extracted and mapped correctly + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]), + ], + config=types.GenerateContentConfig( + temperature=0.33, + max_output_tokens=123, + top_p=0.88, + top_k=7, + stop_sequences=["foo", "bar"], + presence_penalty=0.1, + frequency_penalty=0.2, + ), + ) + from google.adk.models.lite_llm import _get_completion_inputs + + _, _, _, generation_params = _get_completion_inputs(req) + assert generation_params["temperature"] == 0.33 + assert generation_params["max_completion_tokens"] == 123 + assert generation_params["top_p"] == 0.88 + assert generation_params["top_k"] == 7 + assert generation_params["stop"] == ["foo", "bar"] + assert generation_params["presence_penalty"] == 0.1 + assert generation_params["frequency_penalty"] == 0.2 + # Should not include max_output_tokens + assert "max_output_tokens" not in generation_params + assert "stop_sequences" not in generation_params From 04de3e197d7a57935488eb7bfa647c7ab62cd9d9 Mon Sep 17 00:00:00 2001 From: Ankur Sharma Date: Wed, 25 Jun 2025 18:31:25 -0700 Subject: [PATCH 74/79] fix: Adding detailed information on each metric evaluation Additionally, few other small changes. * Updated a test fixture to support the latest eval data schema. Somehow I missed doing that previously. * Updated the `evaluation_generator.py` to use `run_async`, instead of `run`. * Also, raise an informed error when dependencies required eval are not installed. * Also, changed the behavior of AgentEvaluator.evaluate method to run all the evals, instead of failing at the first eval metric failure. PiperOrigin-RevId: 775919127 --- src/google/adk/cli/cli_eval.py | 10 +- src/google/adk/cli/cli_tools_click.py | 2 +- src/google/adk/evaluation/agent_evaluator.py | 109 ++++++++++++-- src/google/adk/evaluation/constants.py | 20 +++ .../adk/evaluation/evaluation_generator.py | 2 +- .../trip_planner_agent/initial.session.json | 13 -- .../trip_planner_agent/trip_inquiry.test.json | 133 +++++++++++++++--- 7 files changed, 240 insertions(+), 49 deletions(-) create mode 100644 src/google/adk/evaluation/constants.py delete mode 100644 tests/integration/fixture/trip_planner_agent/initial.session.json diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index 13e205cb7..01b06135a 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -26,6 +26,7 @@ from ..agents import Agent from ..artifacts.base_artifact_service import BaseArtifactService +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..evaluation.eval_case import EvalCase from ..evaluation.eval_metrics import EvalMetric from ..evaluation.eval_metrics import EvalMetricResult @@ -38,10 +39,6 @@ logger = logging.getLogger("google_adk." + __name__) -MISSING_EVAL_DEPENDENCIES_MESSAGE = ( - "Eval module is not installed, please install via `pip install" - " google-adk[eval]`." -) TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score" RESPONSE_MATCH_SCORE_KEY = "response_match_score" # This evaluation is not very stable. @@ -150,7 +147,7 @@ async def run_evals( artifact_service: The artifact service to use during inferencing. """ try: - from ..evaluation.agent_evaluator import EvaluationGenerator + from ..evaluation.evaluation_generator import EvaluationGenerator except ModuleNotFoundError as e: raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e @@ -252,7 +249,8 @@ async def run_evals( result = "❌ Failed" print(f"Result: {result}\n") - + except ModuleNotFoundError as e: + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e except Exception: # Catching the general exception, so that we don't block other eval # cases. diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index c0935cceb..1bc7d5662 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -31,12 +31,12 @@ from . import cli_create from . import cli_deploy from .. import version +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..sessions.in_memory_session_service import InMemorySessionService from .cli import run_cli -from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE from .fast_api import get_fast_api_app from .utils import envs from .utils import evals diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py index 6ee001f9d..486d01cf1 100644 --- a/src/google/adk/evaluation/agent_evaluator.py +++ b/src/google/adk/evaluation/agent_evaluator.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import logging import os @@ -23,16 +25,16 @@ from typing import Union import uuid +from google.genai import types as genai_types from pydantic import ValidationError +from .constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from .eval_case import IntermediateData from .eval_set import EvalSet -from .evaluation_generator import EvaluationGenerator from .evaluator import EvalStatus from .evaluator import EvaluationResult from .evaluator import Evaluator from .local_eval_sets_manager import convert_eval_set_to_pydanctic_schema -from .response_evaluator import ResponseEvaluator -from .trajectory_evaluator import TrajectoryEvaluator logger = logging.getLogger("google_adk." + __name__) @@ -96,6 +98,7 @@ async def evaluate_eval_set( criteria: dict[str, float], num_runs=NUM_RUNS, agent_name=None, + print_detailed_results: bool = True, ): """Evaluates an agent using the given EvalSet. @@ -109,7 +112,13 @@ async def evaluate_eval_set( num_runs: Number of times all entries in the eval dataset should be assessed. agent_name: The name of the agent. + print_detailed_results: Whether to print detailed results for each metric + evaluation. """ + try: + from .evaluation_generator import EvaluationGenerator + except ModuleNotFoundError as e: + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e eval_case_responses_list = await EvaluationGenerator.generate_responses( eval_set=eval_set, agent_module_path=agent_module, @@ -117,6 +126,8 @@ async def evaluate_eval_set( agent_name=agent_name, ) + failures = [] + for eval_case_responses in eval_case_responses_list: actual_invocations = [ invocation @@ -139,10 +150,25 @@ async def evaluate_eval_set( ) ) - assert evaluation_result.overall_eval_status == EvalStatus.PASSED, ( - f"{metric_name} for {agent_module} Failed. Expected {threshold}," - f" but got {evaluation_result.overall_score}." - ) + if print_detailed_results: + AgentEvaluator._print_details( + evaluation_result=evaluation_result, + metric_name=metric_name, + threshold=threshold, + ) + + # Gather all the failures. + if evaluation_result.overall_eval_status != EvalStatus.PASSED: + failures.append( + f"{metric_name} for {agent_module} Failed. Expected {threshold}," + f" but got {evaluation_result.overall_score}." + ) + + assert not failures, ( + "Following are all the test failures. If you looking to get more" + " details on the failures, then please re-run this test with" + " `print_details` set to `True`.\n{}".format("\n".join(failures)) + ) @staticmethod async def evaluate( @@ -158,9 +184,10 @@ async def evaluate( agent_module: The path to python module that contains the definition of the agent. There is convention in place here, where the code is going to look for 'root_agent' in the loaded module. - eval_dataset_file_path_or_dir: The eval data set. This can be either a string representing - full path to the file containing eval dataset, or a directory that is - recursively explored for all files that have a `.test.json` suffix. + eval_dataset_file_path_or_dir: The eval data set. This can be either a + string representing full path to the file containing eval dataset, or a + directory that is recursively explored for all files that have a + `.test.json` suffix. num_runs: Number of times all entries in the eval dataset should be assessed. agent_name: The name of the agent. @@ -358,6 +385,11 @@ def _validate_input(eval_dataset, criteria): @staticmethod def _get_metric_evaluator(metric_name: str, threshold: float) -> Evaluator: + try: + from .response_evaluator import ResponseEvaluator + from .trajectory_evaluator import TrajectoryEvaluator + except ModuleNotFoundError as e: + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e if metric_name == TOOL_TRAJECTORY_SCORE_KEY: return TrajectoryEvaluator(threshold=threshold) elif ( @@ -367,3 +399,60 @@ def _get_metric_evaluator(metric_name: str, threshold: float) -> Evaluator: return ResponseEvaluator(threshold=threshold, metric_name=metric_name) raise ValueError(f"Unsupported eval metric: {metric_name}") + + @staticmethod + def _print_details( + evaluation_result: EvaluationResult, metric_name: str, threshold: float + ): + try: + from pandas import pandas as pd + from tabulate import tabulate + except ModuleNotFoundError as e: + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e + print( + f"Summary: `{evaluation_result.overall_eval_status}` for Metric:" + f" `{metric_name}`. Expected threshold: `{threshold}`, actual value:" + f" `{evaluation_result.overall_score}`." + ) + + data = [] + for per_invocation_result in evaluation_result.per_invocation_results: + data.append({ + "eval_status": per_invocation_result.eval_status, + "score": per_invocation_result.score, + "threshold": threshold, + "prompt": AgentEvaluator._convert_content_to_text( + per_invocation_result.expected_invocation.user_content + ), + "expected_response": AgentEvaluator._convert_content_to_text( + per_invocation_result.expected_invocation.final_response + ), + "actual_response": AgentEvaluator._convert_content_to_text( + per_invocation_result.actual_invocation.final_response + ), + "expected_tool_calls": AgentEvaluator._convert_tool_calls_to_text( + per_invocation_result.expected_invocation.intermediate_data + ), + "actual_tool_calls": AgentEvaluator._convert_tool_calls_to_text( + per_invocation_result.actual_invocation.intermediate_data + ), + }) + + print(tabulate(pd.DataFrame(data), headers="keys", tablefmt="grid")) + print("\n\n") # Few empty lines for visual clarity + + @staticmethod + def _convert_content_to_text(content: Optional[genai_types.Content]) -> str: + if content and content.parts: + return "\n".join([p.text for p in content.parts if p.text]) + + return "" + + @staticmethod + def _convert_tool_calls_to_text( + intermediate_data: Optional[IntermediateData], + ) -> str: + if intermediate_data and intermediate_data.tool_uses: + return "\n".join([str(t) for t in intermediate_data.tool_uses]) + + return "" diff --git a/src/google/adk/evaluation/constants.py b/src/google/adk/evaluation/constants.py new file mode 100644 index 000000000..74248ed18 --- /dev/null +++ b/src/google/adk/evaluation/constants.py @@ -0,0 +1,20 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +MISSING_EVAL_DEPENDENCIES_MESSAGE = ( + "Eval module is not installed, please install via `pip install" + " google-adk[eval]`." +) diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index fbf6ea8e2..1359967bc 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -182,7 +182,7 @@ async def _generate_inferences_from_root_agent( tool_uses = [] invocation_id = "" - for event in runner.run( + async for event in runner.run_async( user_id=user_id, session_id=session_id, new_message=user_content ): invocation_id = ( diff --git a/tests/integration/fixture/trip_planner_agent/initial.session.json b/tests/integration/fixture/trip_planner_agent/initial.session.json deleted file mode 100644 index b33840cda..000000000 --- a/tests/integration/fixture/trip_planner_agent/initial.session.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "id": "test_id", - "app_name": "trip_planner_agent", - "user_id": "test_user", - "state": { - "origin": "San Francisco", - "interests": "Food, Shopping, Museums", - "range": "1000 miles", - "cities": "" - }, - "events": [], - "last_update_time": 1741218714.258285 -} diff --git a/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json b/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json index c504f68e3..317599c6b 100644 --- a/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json +++ b/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json @@ -1,19 +1,116 @@ -[ - { - "query": "Hi, who are you? What can you do?", - "expected_tool_use": [], - "reference": "I am trip_planner, and my goal is to plan the best trip ever. I can describe why a city was chosen, list its top attractions, and provide a detailed itinerary for each day of the trip.\n" - }, - { - "query": "I want to travel from San Francisco to an European country in fall next year. I am considering London and Paris. What is your advice?", - "expected_tool_use": [ - { - "tool_name": "transfer_to_agent", - "tool_input": { - "agent_name": "indentify_agent" +{ + "eval_set_id": "e7996ccc-16bc-46bf-9a24-0a3ecc3dacd7", + "name": "e7996ccc-16bc-46bf-9a24-0a3ecc3dacd7", + "description": null, + "eval_cases": [ + { + "eval_id": "/google/src/cloud/ankusharma/CS-agent_evaluator-2025-06-17_115009/google3/third_party/py/google/adk/open_source_workspace/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json", + "conversation": [ + { + "invocation_id": "d7ff8ec1-290b-48c5-b3aa-05cb8f27b8ae", + "user_content": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "inline_data": null, + "file_data": null, + "thought_signature": null, + "code_execution_result": null, + "executable_code": null, + "function_call": null, + "function_response": null, + "text": "Hi, who are you? What can you do?" + } + ], + "role": "user" + }, + "final_response": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "inline_data": null, + "file_data": null, + "thought_signature": null, + "code_execution_result": null, + "executable_code": null, + "function_call": null, + "function_response": null, + "text": "I am trip_planner, and my goal is to plan the best trip ever. I can describe why a city was chosen, list its top attractions, and provide a detailed itinerary for each day of the trip.\n" + } + ], + "role": "model" + }, + "intermediate_data": { + "tool_uses": [], + "intermediate_responses": [] + }, + "creation_timestamp": 1750190885.419684 + }, + { + "invocation_id": "f515ff57-ff21-488f-ab92-7d7de5bb76fe", + "user_content": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "inline_data": null, + "file_data": null, + "thought_signature": null, + "code_execution_result": null, + "executable_code": null, + "function_call": null, + "function_response": null, + "text": "I want to travel from San Francisco to an European country in fall next year. I am considering London and Paris. What is your advice?" + } + ], + "role": "user" + }, + "final_response": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "inline_data": null, + "file_data": null, + "thought_signature": null, + "code_execution_result": null, + "executable_code": null, + "function_call": null, + "function_response": null, + "text": "Okay, I can help you analyze London and Paris to determine which city is better for your trip next fall. I will consider weather patterns, seasonal events, travel costs (including flights from San Francisco), and your interests (food, shopping, and museums). After gathering this information, I'll provide a detailed report on my chosen city.\n" + } + ], + "role": "model" + }, + "intermediate_data": { + "tool_uses": [ + { + "id": null, + "args": { + "agent_name": "indentify_agent" + }, + "name": "transfer_to_agent" + } + ], + "intermediate_responses": [] + }, + "creation_timestamp": 1750190885.4197457 } - } - ], - "reference": "Okay, I can help you analyze London and Paris to determine which city is better for your trip next fall. I will consider weather patterns, seasonal events, travel costs (including flights from San Francisco), and your interests (food, shopping, and museums). After gathering this information, I'll provide a detailed report on my chosen city.\n" - } -] + ], + "session_input": { + "app_name": "trip_planner_agent", + "user_id": "test_user", + "state": { + "origin": "San Francisco", + "interests": "Food, Shopping, Museums", + "range": "1000 miles", + "cities": "" + } + }, + "creation_timestamp": 1750190885.4197533 + } + ], + "creation_timestamp": 1750190885.4197605 +} \ No newline at end of file From 77b869f5e35a66682cba35563824fd23a9028d7c Mon Sep 17 00:00:00 2001 From: Ray Iramaneerat Date: Wed, 25 Jun 2025 19:55:20 -0700 Subject: [PATCH 75/79] fix: Update google_search_tool.py to support updated Gemini LIVE model naming Merge https://github.com/google/adk-python/pull/1518 ## Description Fixes [#1512](https://github.com/google/adk-python/issues/1512) by updating google_search_tool.py to support new Gemini LIVE model naming ## Changes - Update the model name checking in google_search_tool.py COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1518 from rayira:rayira-patch-1 4c98f88290af6e0a4690652019ca1d7a08689340 PiperOrigin-RevId: 775941268 --- src/google/adk/tools/google_search_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/tools/google_search_tool.py b/src/google/adk/tools/google_search_tool.py index d116e4b6c..9fe387df3 100644 --- a/src/google/adk/tools/google_search_tool.py +++ b/src/google/adk/tools/google_search_tool.py @@ -54,7 +54,7 @@ async def process_llm_request( llm_request.config.tools.append( types.Tool(google_search_retrieval=types.GoogleSearchRetrieval()) ) - elif llm_request.model and 'gemini-2' in llm_request.model: + elif llm_request.model and 'gemini-' in llm_request.model: llm_request.config.tools.append( types.Tool(google_search=types.GoogleSearch()) ) From 2f55de6ded26c0b55715b78fa081dbe126d968c2 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 22:01:50 -0700 Subject: [PATCH 76/79] chore: Add a2a task result aggregator PiperOrigin-RevId: 775975982 --- src/google/adk/a2a/executor/__init__.py | 13 + .../a2a/executor/task_result_aggregator.py | 71 ++++ tests/unittests/a2a/executor/__init__.py | 13 + .../executor/test_task_result_aggregator.py | 337 ++++++++++++++++++ 4 files changed, 434 insertions(+) create mode 100644 src/google/adk/a2a/executor/__init__.py create mode 100644 src/google/adk/a2a/executor/task_result_aggregator.py create mode 100644 tests/unittests/a2a/executor/__init__.py create mode 100644 tests/unittests/a2a/executor/test_task_result_aggregator.py diff --git a/src/google/adk/a2a/executor/__init__.py b/src/google/adk/a2a/executor/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/a2a/executor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/a2a/executor/task_result_aggregator.py b/src/google/adk/a2a/executor/task_result_aggregator.py new file mode 100644 index 000000000..609de40e9 --- /dev/null +++ b/src/google/adk/a2a/executor/task_result_aggregator.py @@ -0,0 +1,71 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from a2a.server.events import Event +from a2a.types import Message +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent + +from ...utils.feature_decorator import working_in_progress + + +@working_in_progress +class TaskResultAggregator: + """Aggregates the task status updates and provides the final task state.""" + + def __init__(self): + self._task_state = TaskState.working + self._task_status_message = None + + def process_event(self, event: Event): + """Process an event from the agent run and detect signals about the task status. + Priority of task state: + - failed + - auth_required + - input_required + - working + """ + if isinstance(event, TaskStatusUpdateEvent): + if event.status.state == TaskState.failed: + self._task_state = TaskState.failed + self._task_status_message = event.status.message + elif ( + event.status.state == TaskState.auth_required + and self._task_state != TaskState.failed + ): + self._task_state = TaskState.auth_required + self._task_status_message = event.status.message + elif ( + event.status.state == TaskState.input_required + and self._task_state + not in (TaskState.failed, TaskState.auth_required) + ): + self._task_state = TaskState.input_required + self._task_status_message = event.status.message + # final state is already recorded and make sure the intermediate state is + # always working because other state may terminate the event aggregation + # in a2a request handler + elif self._task_state == TaskState.working: + self._task_status_message = event.status.message + event.status.state = TaskState.working + + @property + def task_state(self) -> TaskState: + return self._task_state + + @property + def task_status_message(self) -> Message | None: + return self._task_status_message diff --git a/tests/unittests/a2a/executor/__init__.py b/tests/unittests/a2a/executor/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/executor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/executor/test_task_result_aggregator.py b/tests/unittests/a2a/executor/test_task_result_aggregator.py new file mode 100644 index 000000000..b808cf0cf --- /dev/null +++ b/tests/unittests/a2a/executor/test_task_result_aggregator.py @@ -0,0 +1,337 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import Mock + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a.types import Message + from a2a.types import Part + from a2a.types import Role + from a2a.types import TaskState + from a2a.types import TaskStatus + from a2a.types import TaskStatusUpdateEvent + from a2a.types import TextPart + from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + TaskState = DummyTypes() + TaskStatus = DummyTypes() + TaskStatusUpdateEvent = DummyTypes() + TaskResultAggregator = DummyTypes() + else: + raise e + + +def create_test_message(text: str) -> Message: + """Helper function to create a test Message object.""" + return Message( + messageId="test-msg", + role=Role.agent, + parts=[Part(root=TextPart(text=text))], + ) + + +class TestTaskResultAggregator: + """Test suite for TaskResultAggregator class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.aggregator = TaskResultAggregator() + + def test_initial_state(self): + """Test the initial state of the aggregator.""" + assert self.aggregator.task_state == TaskState.working + assert self.aggregator.task_status_message is None + + def test_process_failed_event(self): + """Test processing a failed event.""" + status_message = create_test_message("Failed to process") + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=status_message), + final=True, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message == status_message + # Verify the event state was modified to working + assert event.status.state == TaskState.working + + def test_process_auth_required_event(self): + """Test processing an auth_required event.""" + status_message = create_test_message("Authentication needed") + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus( + state=TaskState.auth_required, message=status_message + ), + final=False, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.auth_required + assert self.aggregator.task_status_message == status_message + # Verify the event state was modified to working + assert event.status.state == TaskState.working + + def test_process_input_required_event(self): + """Test processing an input_required event.""" + status_message = create_test_message("Input required") + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus( + state=TaskState.input_required, message=status_message + ), + final=False, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.input_required + assert self.aggregator.task_status_message == status_message + # Verify the event state was modified to working + assert event.status.state == TaskState.working + + def test_status_message_with_none_message(self): + """Test that status message handles None message properly.""" + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=None), + final=True, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message is None + + def test_priority_order_failed_over_auth(self): + """Test that failed state takes priority over auth_required.""" + # First set auth_required + auth_message = create_test_message("Auth required") + auth_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.auth_required, message=auth_message), + final=False, + ) + self.aggregator.process_event(auth_event) + assert self.aggregator.task_state == TaskState.auth_required + assert self.aggregator.task_status_message == auth_message + + # Then process failed - should override + failed_message = create_test_message("Failed") + failed_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=failed_message), + final=True, + ) + self.aggregator.process_event(failed_event) + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message == failed_message + + def test_priority_order_auth_over_input(self): + """Test that auth_required state takes priority over input_required.""" + # First set input_required + input_message = create_test_message("Input needed") + input_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus( + state=TaskState.input_required, message=input_message + ), + final=False, + ) + self.aggregator.process_event(input_event) + assert self.aggregator.task_state == TaskState.input_required + assert self.aggregator.task_status_message == input_message + + # Then process auth_required - should override + auth_message = create_test_message("Auth needed") + auth_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.auth_required, message=auth_message), + final=False, + ) + self.aggregator.process_event(auth_event) + assert self.aggregator.task_state == TaskState.auth_required + assert self.aggregator.task_status_message == auth_message + + def test_ignore_non_status_update_events(self): + """Test that non-TaskStatusUpdateEvent events are ignored.""" + mock_event = Mock() + + initial_state = self.aggregator.task_state + initial_message = self.aggregator.task_status_message + self.aggregator.process_event(mock_event) + + # State should remain unchanged + assert self.aggregator.task_state == initial_state + assert self.aggregator.task_status_message == initial_message + + def test_working_state_does_not_override_higher_priority(self): + """Test that working state doesn't override higher priority states.""" + # First set failed state + failed_message = create_test_message("Failure message") + failed_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=failed_message), + final=True, + ) + self.aggregator.process_event(failed_event) + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message == failed_message + + # Then process working - should not override state and should not update message + # because the current task state is not working + working_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working), + final=False, + ) + self.aggregator.process_event(working_event) + assert self.aggregator.task_state == TaskState.failed + # Working events don't update the status message when task state is not working + assert self.aggregator.task_status_message == failed_message + + def test_status_message_priority_ordering(self): + """Test that status messages follow the same priority ordering as states.""" + # Start with input_required + input_message = create_test_message("Input message") + input_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus( + state=TaskState.input_required, message=input_message + ), + final=False, + ) + self.aggregator.process_event(input_event) + assert self.aggregator.task_status_message == input_message + + # Override with auth_required + auth_message = create_test_message("Auth message") + auth_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.auth_required, message=auth_message), + final=False, + ) + self.aggregator.process_event(auth_event) + assert self.aggregator.task_status_message == auth_message + + # Override with failed + failed_message = create_test_message("Failed message") + failed_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.failed, message=failed_message), + final=True, + ) + self.aggregator.process_event(failed_event) + assert self.aggregator.task_status_message == failed_message + + # Working should not override failed message because current task state is failed + working_message = create_test_message("Working message") + working_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working, message=working_message), + final=False, + ) + self.aggregator.process_event(working_event) + # State should still be failed, and message should remain the failed message + # because working events only update message when task state is working + assert self.aggregator.task_state == TaskState.failed + assert self.aggregator.task_status_message == failed_message + + def test_process_working_event_updates_message(self): + """Test that working state events update the status message.""" + working_message = create_test_message("Working on task") + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working, message=working_message), + final=False, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.working + assert self.aggregator.task_status_message == working_message + # Verify the event state was modified to working (should remain working) + assert event.status.state == TaskState.working + + def test_working_event_with_none_message(self): + """Test that working state events handle None message properly.""" + event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working, message=None), + final=False, + ) + + self.aggregator.process_event(event) + assert self.aggregator.task_state == TaskState.working + assert self.aggregator.task_status_message is None + + def test_working_event_updates_message_regardless_of_state(self): + """Test that working events update message only when current task state is working.""" + # First set auth_required state + auth_message = create_test_message("Auth required") + auth_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.auth_required, message=auth_message), + final=False, + ) + self.aggregator.process_event(auth_event) + assert self.aggregator.task_state == TaskState.auth_required + assert self.aggregator.task_status_message == auth_message + + # Then process working - should not update message because task state is not working + working_message = create_test_message("Working on auth") + working_event = TaskStatusUpdateEvent( + taskId="test-task", + contextId="test-context", + status=TaskStatus(state=TaskState.working, message=working_message), + final=False, + ) + self.aggregator.process_event(working_event) + assert ( + self.aggregator.task_state == TaskState.auth_required + ) # State unchanged + assert ( + self.aggregator.task_status_message == auth_message + ) # Message unchanged because task state is not working From 630f1674cb99d4f8e07028871ffd37ef7c6bd455 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 22:31:44 -0700 Subject: [PATCH 77/79] chore: Add a2a agent executor PiperOrigin-RevId: 775983689 --- .../adk/a2a/executor/a2a_agent_executor.py | 249 ++++++ .../a2a/executor/test_a2a_agent_executor.py | 829 ++++++++++++++++++ 2 files changed, 1078 insertions(+) create mode 100644 src/google/adk/a2a/executor/a2a_agent_executor.py create mode 100644 tests/unittests/a2a/executor/test_a2a_agent_executor.py diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py new file mode 100644 index 000000000..953dd703d --- /dev/null +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -0,0 +1,249 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +import inspect +import logging +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Optional +import uuid + +from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Message +from a2a.types import Role +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.runners import Runner +from pydantic import BaseModel +from typing_extensions import override + +from ...utils.feature_decorator import working_in_progress +from ..converters.event_converter import convert_event_to_a2a_events +from ..converters.request_converter import convert_a2a_request_to_adk_run_args +from ..converters.utils import _get_adk_metadata_key +from .task_result_aggregator import TaskResultAggregator + +logger = logging.getLogger('google_adk.' + __name__) + + +@working_in_progress +class A2aAgentExecutorConfig(BaseModel): + """Configuration for the A2aAgentExecutor.""" + + pass + + +@working_in_progress +class A2aAgentExecutor(AgentExecutor): + """An AgentExecutor that runs an ADK Agent against an A2A request and + publishes updates to an event queue. + """ + + def __init__( + self, + *, + runner: Runner | Callable[..., Runner | Awaitable[Runner]], + config: Optional[A2aAgentExecutorConfig] = None, + ): + super().__init__() + self._runner = runner + self._config = config + + async def _resolve_runner(self) -> Runner: + """Resolve the runner, handling cases where it's a callable that returns a Runner.""" + # If already resolved and cached, return it + if isinstance(self._runner, Runner): + return self._runner + if callable(self._runner): + # Call the function to get the runner + result = self._runner() + + # Handle async callables + if inspect.iscoroutine(result): + resolved_runner = await result + else: + resolved_runner = result + + # Cache the resolved runner for future calls + self._runner = resolved_runner + return resolved_runner + + raise TypeError( + 'Runner must be a Runner instance or a callable that returns a' + f' Runner, got {type(self._runner)}' + ) + + @override + async def cancel(self, context: RequestContext, event_queue: EventQueue): + """Cancel the execution.""" + # TODO: Implement proper cancellation logic if needed + raise NotImplementedError('Cancellation is not supported') + + @override + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ): + """Executes an A2A request and publishes updates to the event queue + specified. It runs as following: + * Takes the input from the A2A request + * Convert the input to ADK input content, and runs the ADK agent + * Collects output events of the underlying ADK Agent + * Converts the ADK output events into A2A task updates + * Publishes the updates back to A2A server via event queue + """ + if not context.message: + raise ValueError('A2A request must have a message') + + # for new task, create a task submitted event + if not context.current_task: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + taskId=context.task_id, + status=TaskStatus( + state=TaskState.submitted, + message=context.message, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + contextId=context.context_id, + final=False, + ) + ) + + # Handle the request and publish updates to the event queue + try: + await self._handle_request(context, event_queue) + except Exception as e: + logger.error('Error handling A2A request: %s', e, exc_info=True) + # Publish failure event + try: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + taskId=context.task_id, + status=TaskStatus( + state=TaskState.failed, + timestamp=datetime.now(timezone.utc).isoformat(), + message=Message( + messageId=str(uuid.uuid4()), + role=Role.agent, + parts=[TextPart(text=str(e))], + ), + ), + contextId=context.context_id, + final=True, + ) + ) + except Exception as enqueue_error: + logger.error( + 'Failed to publish failure event: %s', enqueue_error, exc_info=True + ) + + async def _handle_request( + self, + context: RequestContext, + event_queue: EventQueue, + ): + # Resolve the runner instance + runner = await self._resolve_runner() + + # Convert the a2a request to ADK run args + run_args = convert_a2a_request_to_adk_run_args(context) + + # ensure the session exists + session = await self._prepare_session(context, run_args, runner) + + # create invocation context + invocation_context = runner._new_invocation_context( + session=session, + new_message=run_args['new_message'], + run_config=run_args['run_config'], + ) + + # publish the task working event + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + taskId=context.task_id, + status=TaskStatus( + state=TaskState.working, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + contextId=context.context_id, + final=False, + metadata={ + _get_adk_metadata_key('app_name'): runner.app_name, + _get_adk_metadata_key('user_id'): run_args['user_id'], + _get_adk_metadata_key('session_id'): run_args['session_id'], + }, + ) + ) + + task_result_aggregator = TaskResultAggregator() + async for adk_event in runner.run_async(**run_args): + for a2a_event in convert_event_to_a2a_events( + adk_event, invocation_context, context.task_id, context.context_id + ): + task_result_aggregator.process_event(a2a_event) + await event_queue.enqueue_event(a2a_event) + + # publish the task result event - this is final + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + taskId=context.task_id, + status=TaskStatus( + state=( + task_result_aggregator.task_state + if task_result_aggregator.task_state != TaskState.working + else TaskState.completed + ), + timestamp=datetime.now(timezone.utc).isoformat(), + message=task_result_aggregator.task_status_message, + ), + contextId=context.context_id, + final=True, + ) + ) + + async def _prepare_session( + self, context: RequestContext, run_args: dict[str, Any], runner: Runner + ): + + session_id = run_args['session_id'] + # create a new session if not exists + user_id = run_args['user_id'] + session = await runner.session_service.get_session( + app_name=runner.app_name, + user_id=user_id, + session_id=session_id, + ) + if session is None: + session = await runner.session_service.create_session( + app_name=runner.app_name, + user_id=user_id, + state={}, + session_id=session_id, + ) + # Update run_args with the new session_id + run_args['session_id'] = session.id + + return session diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py new file mode 100644 index 000000000..44d592fbc --- /dev/null +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -0,0 +1,829 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a.server.agent_execution.context import RequestContext + from a2a.server.events.event_queue import EventQueue + from a2a.types import Message + from a2a.types import TaskState + from a2a.types import TextPart + from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor + from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig + from google.adk.events.event import Event + from google.adk.runners import Runner +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + RequestContext = DummyTypes() + EventQueue = DummyTypes() + Message = DummyTypes() + Role = DummyTypes() + TaskState = DummyTypes() + TaskStatus = DummyTypes() + TaskStatusUpdateEvent = DummyTypes() + TextPart = DummyTypes() + A2aAgentExecutor = DummyTypes() + A2aAgentExecutorConfig = DummyTypes() + Event = DummyTypes() + Runner = DummyTypes() + else: + raise e + + +class TestA2aAgentExecutor: + """Test suite for A2aAgentExecutor class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_runner = Mock(spec=Runner) + self.mock_runner.app_name = "test-app" + self.mock_runner.session_service = Mock() + self.mock_runner._new_invocation_context = Mock() + self.mock_runner.run_async = AsyncMock() + + self.mock_config = Mock(spec=A2aAgentExecutorConfig) + self.executor = A2aAgentExecutor( + runner=self.mock_runner, config=self.mock_config + ) + + self.mock_context = Mock(spec=RequestContext) + self.mock_context.message = Mock(spec=Message) + self.mock_context.message.parts = [Mock(spec=TextPart)] + self.mock_context.current_task = None + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + self.mock_event_queue = Mock(spec=EventQueue) + + async def _create_async_generator(self, items): + """Helper to create async generator from items.""" + for item in items: + yield item + + @pytest.mark.asyncio + async def test_execute_success_new_task(self): + """Test successful execution of a new task.""" + # Setup + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False + + # Verify working event was enqueued + working_event = self.mock_event_queue.enqueue_event.call_args_list[1][ + 0 + ][0] + assert working_event.status.state == TaskState.working + assert working_event.final == False + + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ + 0 + ] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), so final state should be completed + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.completed + + @pytest.mark.asyncio + async def test_execute_no_message_error(self): + """Test execution fails when no message is provided.""" + self.mock_context.message = None + + with pytest.raises(ValueError, match="A2A request must have a message"): + await self.executor.execute(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_existing_task(self): + """Test execution with existing task (no submitted event).""" + self.mock_context.current_task = Mock() + self.mock_context.task_id = "existing-task-id" + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify no submitted event (first call should be working event) + working_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert working_event.status.state == TaskState.working + assert working_event.final == False + + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ + 0 + ] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), so final state should be completed + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.completed + + @pytest.mark.asyncio + async def test_prepare_session_new_session(self): + """Test session preparation when session doesn't exist.""" + run_args = { + "user_id": "test-user", + "session_id": None, + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + self.mock_runner.session_service.get_session = AsyncMock(return_value=None) + mock_session = Mock() + mock_session.id = "new-session-id" + self.mock_runner.session_service.create_session = AsyncMock( + return_value=mock_session + ) + + # Execute + result = await self.executor._prepare_session( + self.mock_context, run_args, self.mock_runner + ) + + # Verify session was created + assert result == mock_session + assert run_args["session_id"] is not None + self.mock_runner.session_service.create_session.assert_called_once() + + @pytest.mark.asyncio + async def test_prepare_session_existing_session(self): + """Test session preparation when session exists.""" + run_args = { + "user_id": "test-user", + "session_id": "existing-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "existing-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Execute + result = await self.executor._prepare_session( + self.mock_context, run_args, self.mock_runner + ) + + # Verify existing session was returned + assert result == mock_session + self.mock_runner.session_service.create_session.assert_not_called() + + def test_constructor_with_callable_runner(self): + """Test constructor with callable runner.""" + callable_runner = Mock() + executor = A2aAgentExecutor(runner=callable_runner, config=self.mock_config) + + assert executor._runner == callable_runner + assert executor._config == self.mock_config + + @pytest.mark.asyncio + async def test_resolve_runner_direct_instance(self): + """Test _resolve_runner with direct Runner instance.""" + # Setup - already using direct runner instance in setup_method + runner = await self.executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_sync_callable(self): + """Test _resolve_runner with sync callable that returns Runner.""" + + def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_async_callable(self): + """Test _resolve_runner with async callable that returns Runner.""" + + async def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_invalid_type(self): + """Test _resolve_runner with invalid runner type.""" + executor = A2aAgentExecutor(runner="invalid", config=self.mock_config) + + with pytest.raises( + TypeError, match="Runner must be a Runner instance or a callable" + ): + await executor._resolve_runner() + + @pytest.mark.asyncio + async def test_resolve_runner_callable_with_parameters(self): + """Test _resolve_runner with callable that normally takes parameters.""" + + def create_runner(*args, **kwargs): + # In real usage, this might use the args/kwargs to configure the runner + # For testing, we'll just return the mock runner + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_caching(self): + """Test that _resolve_runner caches the result and doesn't call the callable multiple times.""" + call_count = 0 + + def create_runner(): + nonlocal call_count + call_count += 1 + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + + # First call should invoke the callable + runner1 = await executor._resolve_runner() + assert runner1 == self.mock_runner + assert call_count == 1 + + # Second call should return cached result, not invoke callable again + runner2 = await executor._resolve_runner() + assert runner2 == self.mock_runner + assert runner1 is runner2 # Same instance + assert call_count == 1 # Callable was not called again + + # Verify that self._runner is now the resolved Runner instance + assert executor._runner is self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_async_caching(self): + """Test that _resolve_runner caches async callable results correctly.""" + call_count = 0 + + async def create_runner(): + nonlocal call_count + call_count += 1 + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + + # First call should invoke the async callable + runner1 = await executor._resolve_runner() + assert runner1 == self.mock_runner + assert call_count == 1 + + # Second call should return cached result, not invoke callable again + runner2 = await executor._resolve_runner() + assert runner2 == self.mock_runner + assert runner1 is runner2 # Same instance + assert call_count == 1 # Async callable was not called again + + # Verify that self._runner is now the resolved Runner instance + assert executor._runner is self.mock_runner + + @pytest.mark.asyncio + async def test_execute_with_sync_callable_runner(self): + """Test execution with sync callable runner.""" + + def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [] + + # Execute + await executor.execute(self.mock_context, self.mock_event_queue) + + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False + + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ + 0 + ] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), so final state should be completed + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.completed + + @pytest.mark.asyncio + async def test_execute_with_async_callable_runner(self): + """Test execution with async callable runner.""" + + async def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [] + + # Execute + await executor.execute(self.mock_context, self.mock_event_queue) + + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False + + # Verify final event was enqueued with proper message field + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ + 0 + ] + assert final_event.final == True + # The TaskResultAggregator is created with default state (working), so final state should be completed + assert hasattr(final_event.status, "message") + assert final_event.status.state == TaskState.completed + + @pytest.mark.asyncio + async def test_handle_request_integration(self): + """Test the complete request handling flow.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + + # Setup detailed mocks + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [Mock()] + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + mock_aggregator.task_state = TaskState.working + # Mock the task_status_message property to return None by default + mock_aggregator.task_status_message = None + mock_aggregator_class.return_value = mock_aggregator + + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue + ) + + # Verify working event was enqueued + working_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "status") + and call[0][0].status.state == TaskState.working + ] + assert len(working_events) >= 1 + + # Verify aggregator processed events + assert mock_aggregator.process_event.call_count == len(mock_events) + + # Verify final event has message field from aggregator and state is completed when aggregator state is working + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert ( + final_event.status.message == mock_aggregator.task_status_message + ) + # When aggregator state is working, final event should be completed + assert final_event.status.state == TaskState.completed + + @pytest.mark.asyncio + async def test_cancel_with_task_id(self): + """Test cancellation with a task ID.""" + self.mock_context.task_id = "test-task-id" + + # The current implementation raises NotImplementedError + with pytest.raises( + NotImplementedError, match="Cancellation is not supported" + ): + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_cancel_without_task_id(self): + """Test cancellation without a task ID.""" + self.mock_context.task_id = None + + # The current implementation raises NotImplementedError regardless of task_id + with pytest.raises( + NotImplementedError, match="Cancellation is not supported" + ): + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_with_exception_handling(self): + """Test execution with exception handling.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.current_task = ( + None # Make sure it goes through submitted event creation + ) + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.side_effect = Exception("Test error") + + # Execute (should not raise since we catch the exception) + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify both submitted and failure events were enqueued + # First call should be submitted event, last should be failure event + assert self.mock_event_queue.enqueue_event.call_count >= 2 + + # Check submitted event (first) + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False + + # Check failure event (last) + failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ + 0 + ] + assert failure_event.status.state == TaskState.failed + assert failure_event.final == True + + @pytest.mark.asyncio + async def test_handle_request_with_aggregator_message(self): + """Test that the final task status event includes message from aggregator.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + + # Create a test message to be returned by the aggregator + from a2a.types import Message + from a2a.types import Role + from a2a.types import TextPart + + test_message = Mock(spec=Message) + test_message.messageId = "test-message-id" + test_message.role = Role.agent + test_message.parts = [Mock(spec=TextPart)] + + # Setup detailed mocks + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [Mock()] + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + mock_aggregator.task_state = TaskState.completed + # Mock the task_status_message property to return a test message + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator + + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue + ) + + # Verify final event has message field from aggregator + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.message == test_message + assert final_event.status.state == TaskState.completed + + @pytest.mark.asyncio + async def test_handle_request_with_non_working_aggregator_state(self): + """Test that when aggregator state is not working, it preserves the original state.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + + # Create a test message to be returned by the aggregator + from a2a.types import Message + from a2a.types import Role + from a2a.types import TextPart + + test_message = Mock(spec=Message) + test_message.messageId = "test-message-id" + test_message.role = Role.agent + test_message.parts = [Mock(spec=TextPart)] + + # Setup detailed mocks + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [Mock()] + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + # Test with failed state - should preserve failed state + mock_aggregator.task_state = TaskState.failed + mock_aggregator.task_status_message = test_message + mock_aggregator_class.return_value = mock_aggregator + + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue + ) + + # Verify final event preserves the non-working state + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] # Get the last final event + assert final_event.status.message == test_message + # When aggregator state is failed (not working), final event should keep failed state + assert final_event.status.state == TaskState.failed From ed09cd840f5a294db9ae180cbec30e2503e782c7 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 23:04:30 -0700 Subject: [PATCH 78/79] chore: Add enable_a2a option to adk command line PiperOrigin-RevId: 775991652 --- src/google/adk/cli/cli_deploy.py | 5 ++++- src/google/adk/cli/cli_tools_click.py | 17 +++++++++++++++++ src/google/adk/cli/fast_api.py | 3 +++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 44d4a900d..0dedae6de 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -55,7 +55,7 @@ EXPOSE {port} -CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} {allow_origins_option} "/app/agents" +CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} {allow_origins_option} {a2a_option}"/app/agents" """ _AGENT_ENGINE_APP_TEMPLATE = """ @@ -128,6 +128,7 @@ def to_cloud_run( session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, + a2a: bool = False, ): """Deploys an agent to Google Cloud Run. @@ -189,6 +190,7 @@ def to_cloud_run( allow_origins_option = ( f'--allow_origins={",".join(allow_origins)}' if allow_origins else '' ) + a2a_option = '--a2a' if a2a else '' dockerfile_content = _DOCKERFILE_TEMPLATE.format( gcp_project_id=project, gcp_region=region, @@ -206,6 +208,7 @@ def to_cloud_run( allow_origins_option=allow_origins_option, adk_version=adk_version, host_option=host_option, + a2a_option=a2a_option, ) dockerfile_path = os.path.join(temp_folder, 'Dockerfile') os.makedirs(temp_folder, exist_ok=True) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 1bc7d5662..f3095c2ff 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -576,6 +576,13 @@ def decorator(func): " for Cloud Run." ), ) + @click.option( + "--a2a", + is_flag=True, + show_default=True, + default=False, + help="Optional. Whether to enable A2A endpoint.", + ) @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -617,6 +624,7 @@ def cli_web( memory_service_uri: Optional[str] = None, session_db_url: Optional[str] = None, # Deprecated artifact_storage_uri: Optional[str] = None, # Deprecated + a2a: bool = False, ): """Starts a FastAPI server with Web UI for agents. @@ -663,6 +671,9 @@ async def _lifespan(app: FastAPI): web=True, trace_to_cloud=trace_to_cloud, lifespan=_lifespan, + a2a=a2a, + host=host, + port=port, ) config = uvicorn.Config( app, @@ -709,6 +720,7 @@ def cli_api_server( memory_service_uri: Optional[str] = None, session_db_url: Optional[str] = None, # Deprecated artifact_storage_uri: Optional[str] = None, # Deprecated + a2a: bool = False, ): """Starts a FastAPI server for agents. @@ -733,6 +745,9 @@ def cli_api_server( allow_origins=allow_origins, web=False, trace_to_cloud=trace_to_cloud, + a2a=a2a, + host=host, + port=port, ), host=host, port=port, @@ -854,6 +869,7 @@ def cli_deploy_cloud_run( eval_storage_uri: Optional[str] = None, session_db_url: Optional[str] = None, # Deprecated artifact_storage_uri: Optional[str] = None, # Deprecated + a2a: bool = False, ): """Deploys an agent to Cloud Run. @@ -884,6 +900,7 @@ def cli_deploy_cloud_run( session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, memory_service_uri=memory_service_uri, + a2a=a2a, ) except Exception as e: click.secho(f"Deploy failed: {e}", fg="red", err=True) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index abe1961e7..69d7c3a0e 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -204,6 +204,9 @@ def get_fast_api_app( eval_storage_uri: Optional[str] = None, allow_origins: Optional[list[str]] = None, web: bool, + a2a: bool = False, + host: str = "127.0.0.1", + port: int = 8000, trace_to_cloud: bool = False, lifespan: Optional[Lifespan[FastAPI]] = None, ) -> FastAPI: From 5356f20eadb1655e3c35d0bf77445d847d090e07 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 26 Jun 2025 01:11:49 -0700 Subject: [PATCH 79/79] chore: Add a2a log utils for formatting a2a reqeust/response logs PiperOrigin-RevId: 776026554 --- src/google/adk/a2a/logs/__init__.py | 13 + src/google/adk/a2a/logs/log_utils.py | 349 ++++++++++++++ tests/unittests/a2a/logs/__init__.py | 13 + tests/unittests/a2a/logs/test_log_utils.py | 505 +++++++++++++++++++++ 4 files changed, 880 insertions(+) create mode 100644 src/google/adk/a2a/logs/__init__.py create mode 100644 src/google/adk/a2a/logs/log_utils.py create mode 100644 tests/unittests/a2a/logs/__init__.py create mode 100644 tests/unittests/a2a/logs/test_log_utils.py diff --git a/src/google/adk/a2a/logs/__init__.py b/src/google/adk/a2a/logs/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/a2a/logs/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/a2a/logs/log_utils.py b/src/google/adk/a2a/logs/log_utils.py new file mode 100644 index 000000000..b3891514c --- /dev/null +++ b/src/google/adk/a2a/logs/log_utils.py @@ -0,0 +1,349 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for structured A2A request and response logging.""" + +from __future__ import annotations + +import json +import sys + +try: + from a2a.types import DataPart as A2ADataPart + from a2a.types import Message as A2AMessage + from a2a.types import Part as A2APart + from a2a.types import SendMessageRequest + from a2a.types import SendMessageResponse + from a2a.types import Task as A2ATask + from a2a.types import TextPart as A2ATextPart +except ImportError as e: + if sys.version_info < (3, 10): + raise ImportError( + "A2A Tool requires Python 3.10 or above. Please upgrade your Python" + " version." + ) from e + else: + raise e + + +# Constants +_NEW_LINE = "\n" +_EXCLUDED_PART_FIELD = {"file": {"bytes"}} + + +def _is_a2a_task(obj) -> bool: + """Check if an object is an A2A Task, with fallback for isinstance issues.""" + try: + return isinstance(obj, A2ATask) + except (TypeError, AttributeError): + return type(obj).__name__ == "Task" and hasattr(obj, "status") + + +def _is_a2a_message(obj) -> bool: + """Check if an object is an A2A Message, with fallback for isinstance issues.""" + try: + return isinstance(obj, A2AMessage) + except (TypeError, AttributeError): + return type(obj).__name__ == "Message" and hasattr(obj, "role") + + +def _is_a2a_text_part(obj) -> bool: + """Check if an object is an A2A TextPart, with fallback for isinstance issues.""" + try: + return isinstance(obj, A2ATextPart) + except (TypeError, AttributeError): + return type(obj).__name__ == "TextPart" and hasattr(obj, "text") + + +def _is_a2a_data_part(obj) -> bool: + """Check if an object is an A2A DataPart, with fallback for isinstance issues.""" + try: + return isinstance(obj, A2ADataPart) + except (TypeError, AttributeError): + return type(obj).__name__ == "DataPart" and hasattr(obj, "data") + + +def build_message_part_log(part: A2APart) -> str: + """Builds a log representation of an A2A message part. + + Args: + part: The A2A message part to log. + + Returns: + A string representation of the part. + """ + part_content = "" + if _is_a2a_text_part(part.root): + part_content = f"TextPart: {part.root.text[:100]}" + ( + "..." if len(part.root.text) > 100 else "" + ) + elif _is_a2a_data_part(part.root): + # For data parts, show the data keys but exclude large values + data_summary = { + k: ( + f"<{type(v).__name__}>" + if isinstance(v, (dict, list)) and len(str(v)) > 100 + else v + ) + for k, v in part.root.data.items() + } + part_content = f"DataPart: {json.dumps(data_summary, indent=2)}" + else: + part_content = ( + f"{type(part.root).__name__}:" + f" {part.model_dump_json(exclude_none=True, exclude=_EXCLUDED_PART_FIELD)}" + ) + + # Add part metadata if it exists + if hasattr(part.root, "metadata") and part.root.metadata: + metadata_str = json.dumps(part.root.metadata, indent=2).replace( + "\n", "\n " + ) + part_content += f"\n Part Metadata: {metadata_str}" + + return part_content + + +def build_a2a_request_log(req: SendMessageRequest) -> str: + """Builds a structured log representation of an A2A request. + + Args: + req: The A2A SendMessageRequest to log. + + Returns: + A formatted string representation of the request. + """ + # Message parts logs + message_parts_logs = [] + if req.params.message.parts: + for i, part in enumerate(req.params.message.parts): + part_log = build_message_part_log(part) + # Replace any internal newlines with indented newlines to maintain formatting + part_log_formatted = part_log.replace("\n", "\n ") + message_parts_logs.append(f"Part {i}: {part_log_formatted}") + + # Configuration logs + config_log = "None" + if req.params.configuration: + config_data = { + "acceptedOutputModes": req.params.configuration.acceptedOutputModes, + "blocking": req.params.configuration.blocking, + "historyLength": req.params.configuration.historyLength, + "pushNotificationConfig": bool( + req.params.configuration.pushNotificationConfig + ), + } + config_log = json.dumps(config_data, indent=2) + + # Build message metadata section + message_metadata_section = "" + if req.params.message.metadata: + message_metadata_section = f""" + Metadata: + {json.dumps(req.params.message.metadata, indent=2).replace(chr(10), chr(10) + ' ')}""" + + # Build optional sections + optional_sections = [] + + if req.params.metadata: + optional_sections.append( + f"""----------------------------------------------------------- +Metadata: +{json.dumps(req.params.metadata, indent=2)}""" + ) + + optional_sections_str = _NEW_LINE.join(optional_sections) + + return f""" +A2A Request: +----------------------------------------------------------- +Request ID: {req.id} +Method: {req.method} +JSON-RPC: {req.jsonrpc} +----------------------------------------------------------- +Message: + ID: {req.params.message.messageId} + Role: {req.params.message.role} + Task ID: {req.params.message.taskId} + Context ID: {req.params.message.contextId}{message_metadata_section} +----------------------------------------------------------- +Message Parts: +{_NEW_LINE.join(message_parts_logs) if message_parts_logs else "No parts"} +----------------------------------------------------------- +Configuration: +{config_log} +{optional_sections_str} +----------------------------------------------------------- +""" + + +def build_a2a_response_log(resp: SendMessageResponse) -> str: + """Builds a structured log representation of an A2A response. + + Args: + resp: The A2A SendMessageResponse to log. + + Returns: + A formatted string representation of the response. + """ + # Handle error responses + if hasattr(resp.root, "error"): + return f""" +A2A Response: +----------------------------------------------------------- +Type: ERROR +Error Code: {resp.root.error.code} +Error Message: {resp.root.error.message} +Error Data: {json.dumps(resp.root.error.data, indent=2) if resp.root.error.data else "None"} +----------------------------------------------------------- +Response ID: {resp.root.id} +JSON-RPC: {resp.root.jsonrpc} +----------------------------------------------------------- +""" + + # Handle success responses + result = resp.root.result + result_type = type(result).__name__ + + # Build result details based on type + result_details = [] + + if _is_a2a_task(result): + result_details.extend([ + f"Task ID: {result.id}", + f"Context ID: {result.contextId}", + f"Status State: {result.status.state}", + f"Status Timestamp: {result.status.timestamp}", + f"History Length: {len(result.history) if result.history else 0}", + f"Artifacts Count: {len(result.artifacts) if result.artifacts else 0}", + ]) + + # Add task metadata if it exists + if result.metadata: + result_details.append("Task Metadata:") + metadata_formatted = json.dumps(result.metadata, indent=2).replace( + "\n", "\n " + ) + result_details.append(f" {metadata_formatted}") + + elif _is_a2a_message(result): + result_details.extend([ + f"Message ID: {result.messageId}", + f"Role: {result.role}", + f"Task ID: {result.taskId}", + f"Context ID: {result.contextId}", + ]) + + # Add message parts + if result.parts: + result_details.append("Message Parts:") + for i, part in enumerate(result.parts): + part_log = build_message_part_log(part) + # Replace any internal newlines with indented newlines to maintain formatting + part_log_formatted = part_log.replace("\n", "\n ") + result_details.append(f" Part {i}: {part_log_formatted}") + + # Add metadata if it exists + if result.metadata: + result_details.append("Metadata:") + metadata_formatted = json.dumps(result.metadata, indent=2).replace( + "\n", "\n " + ) + result_details.append(f" {metadata_formatted}") + + else: + # Handle other result types by showing their JSON representation + if hasattr(result, "model_dump_json"): + try: + result_json = result.model_dump_json() + result_details.append(f"JSON Data: {result_json}") + except Exception: + result_details.append("JSON Data: ") + + # Build status message section + status_message_section = "None" + if _is_a2a_task(result) and result.status.message: + status_parts_logs = [] + if result.status.message.parts: + for i, part in enumerate(result.status.message.parts): + part_log = build_message_part_log(part) + # Replace any internal newlines with indented newlines to maintain formatting + part_log_formatted = part_log.replace("\n", "\n ") + status_parts_logs.append(f"Part {i}: {part_log_formatted}") + + # Build status message metadata section + status_metadata_section = "" + if result.status.message.metadata: + status_metadata_section = f""" +Metadata: +{json.dumps(result.status.message.metadata, indent=2)}""" + + status_message_section = f"""ID: {result.status.message.messageId} +Role: {result.status.message.role} +Task ID: {result.status.message.taskId} +Context ID: {result.status.message.contextId} +Message Parts: +{_NEW_LINE.join(status_parts_logs) if status_parts_logs else "No parts"}{status_metadata_section}""" + + # Build history section + history_section = "No history" + if _is_a2a_task(result) and result.history: + history_logs = [] + for i, message in enumerate(result.history): + message_parts_logs = [] + if message.parts: + for j, part in enumerate(message.parts): + part_log = build_message_part_log(part) + # Replace any internal newlines with indented newlines to maintain formatting + part_log_formatted = part_log.replace("\n", "\n ") + message_parts_logs.append(f" Part {j}: {part_log_formatted}") + + # Build message metadata section + message_metadata_section = "" + if message.metadata: + message_metadata_section = f""" + Metadata: + {json.dumps(message.metadata, indent=2).replace(chr(10), chr(10) + ' ')}""" + + history_logs.append( + f"""Message {i + 1}: + ID: {message.messageId} + Role: {message.role} + Task ID: {message.taskId} + Context ID: {message.contextId} + Message Parts: +{_NEW_LINE.join(message_parts_logs) if message_parts_logs else " No parts"}{message_metadata_section}""" + ) + + history_section = _NEW_LINE.join(history_logs) + + return f""" +A2A Response: +----------------------------------------------------------- +Type: SUCCESS +Result Type: {result_type} +----------------------------------------------------------- +Result Details: +{_NEW_LINE.join(result_details)} +----------------------------------------------------------- +Status Message: +{status_message_section} +----------------------------------------------------------- +History: +{history_section} +----------------------------------------------------------- +Response ID: {resp.root.id} +JSON-RPC: {resp.root.jsonrpc} +----------------------------------------------------------- +""" diff --git a/tests/unittests/a2a/logs/__init__.py b/tests/unittests/a2a/logs/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/logs/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/logs/test_log_utils.py b/tests/unittests/a2a/logs/test_log_utils.py new file mode 100644 index 000000000..4a02a137f --- /dev/null +++ b/tests/unittests/a2a/logs/test_log_utils.py @@ -0,0 +1,505 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for log_utils module.""" + +import json +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +# Import the actual A2A types that we need to mock +try: + from a2a.types import DataPart as A2ADataPart + from a2a.types import Message as A2AMessage + from a2a.types import Part as A2APart + from a2a.types import Role + from a2a.types import Task as A2ATask + from a2a.types import TaskState + from a2a.types import TaskStatus + from a2a.types import TextPart as A2ATextPart + + A2A_AVAILABLE = True +except ImportError: + A2A_AVAILABLE = False + + +class TestBuildMessagePartLog: + """Test suite for build_message_part_log function.""" + + @pytest.mark.skipif(not A2A_AVAILABLE, reason="A2A types not available") + def test_text_part_short_text(self): + """Test TextPart with short text.""" + # Import here to avoid import issues at module level + from google.adk.a2a.logs.log_utils import build_message_part_log + + # Create real A2A objects + text_part = A2ATextPart(text="Hello, world!") + part = A2APart(root=text_part) + + result = build_message_part_log(part) + + assert result == "TextPart: Hello, world!" + + @pytest.mark.skipif(not A2A_AVAILABLE, reason="A2A types not available") + def test_text_part_long_text(self): + """Test TextPart with long text that gets truncated.""" + from google.adk.a2a.logs.log_utils import build_message_part_log + + long_text = "x" * 150 # Long text that should be truncated + text_part = A2ATextPart(text=long_text) + part = A2APart(root=text_part) + + result = build_message_part_log(part) + + expected = f"TextPart: {'x' * 100}..." + assert result == expected + + @pytest.mark.skipif(not A2A_AVAILABLE, reason="A2A types not available") + def test_data_part_simple_data(self): + """Test DataPart with simple data.""" + from google.adk.a2a.logs.log_utils import build_message_part_log + + data_part = A2ADataPart(data={"key1": "value1", "key2": 42}) + part = A2APart(root=data_part) + + result = build_message_part_log(part) + + expected_data = {"key1": "value1", "key2": 42} + expected = f"DataPart: {json.dumps(expected_data, indent=2)}" + assert result == expected + + @pytest.mark.skipif(not A2A_AVAILABLE, reason="A2A types not available") + def test_data_part_large_values(self): + """Test DataPart with large values that get summarized.""" + from google.adk.a2a.logs.log_utils import build_message_part_log + + large_dict = {f"key{i}": f"value{i}" for i in range(50)} + large_list = list(range(100)) + + data_part = A2ADataPart( + data={ + "small_value": "hello", + "large_dict": large_dict, + "large_list": large_list, + "normal_int": 42, + } + ) + part = A2APart(root=data_part) + + result = build_message_part_log(part) + + # Large values should be replaced with type names + assert "small_value" in result + assert "hello" in result + assert "" in result + assert "" in result + assert "normal_int" in result + assert "42" in result + + def test_other_part_type(self): + """Test handling of other part types (not Text or Data).""" + from google.adk.a2a.logs.log_utils import build_message_part_log + + # Create a mock part that will fall through to the else case + mock_root = Mock() + mock_root.__class__.__name__ = "MockOtherPart" + # Ensure metadata attribute doesn't exist or returns None to avoid JSON serialization issues + mock_root.metadata = None + + mock_part = Mock() + mock_part.root = mock_root + mock_part.model_dump_json.return_value = '{"some": "data"}' + + result = build_message_part_log(mock_part) + + expected = 'MockOtherPart: {"some": "data"}' + assert result == expected + + +class TestBuildA2ARequestLog: + """Test suite for build_a2a_request_log function.""" + + def test_request_with_parts_and_config(self): + """Test request logging with message parts and configuration.""" + from google.adk.a2a.logs.log_utils import build_a2a_request_log + + # Create mock request with all components + req = Mock() + req.id = "req-123" + req.method = "sendMessage" + req.jsonrpc = "2.0" + + # Mock message + req.params.message.messageId = "msg-456" + req.params.message.role = "user" + req.params.message.taskId = "task-789" + req.params.message.contextId = "ctx-101" + + # Mock message parts - use simple mocks since the function will call build_message_part_log + part1 = Mock() + part2 = Mock() + req.params.message.parts = [part1, part2] + + # Mock configuration + req.params.configuration.acceptedOutputModes = ["text", "image"] + req.params.configuration.blocking = True + req.params.configuration.historyLength = 10 + req.params.configuration.pushNotificationConfig = Mock() # Non-None + + # Mock metadata + req.params.metadata = {"key1": "value1"} + # Mock message metadata to avoid JSON serialization issues + req.params.message.metadata = {"msg_key": "msg_value"} + + with patch( + "google.adk.a2a.logs.log_utils.build_message_part_log" + ) as mock_build_part: + mock_build_part.side_effect = lambda part: f"Mock part: {id(part)}" + + result = build_a2a_request_log(req) + + # Verify all components are present + assert "req-123" in result + assert "sendMessage" in result + assert "2.0" in result + assert "msg-456" in result + assert "user" in result + assert "task-789" in result + assert "ctx-101" in result + assert "Part 0:" in result + assert "Part 1:" in result + assert '"blocking": true' in result + assert '"historyLength": 10' in result + assert '"key1": "value1"' in result + + def test_request_without_parts(self): + """Test request logging without message parts.""" + from google.adk.a2a.logs.log_utils import build_a2a_request_log + + req = Mock() + req.id = "req-123" + req.method = "sendMessage" + req.jsonrpc = "2.0" + + req.params.message.messageId = "msg-456" + req.params.message.role = "user" + req.params.message.taskId = "task-789" + req.params.message.contextId = "ctx-101" + req.params.message.parts = None # No parts + req.params.message.metadata = None # No message metadata + + req.params.configuration = None # No configuration + req.params.metadata = None # No metadata + + result = build_a2a_request_log(req) + + assert "No parts" in result + assert "Configuration:\nNone" in result + # When metadata is None, it's not included in the output + assert "Metadata:" not in result + + def test_request_with_empty_parts_list(self): + """Test request logging with empty parts list.""" + from google.adk.a2a.logs.log_utils import build_a2a_request_log + + req = Mock() + req.id = "req-123" + req.method = "sendMessage" + req.jsonrpc = "2.0" + + req.params.message.messageId = "msg-456" + req.params.message.role = "user" + req.params.message.taskId = "task-789" + req.params.message.contextId = "ctx-101" + req.params.message.parts = [] # Empty parts list + req.params.message.metadata = None # No message metadata + + req.params.configuration = None + req.params.metadata = None + + result = build_a2a_request_log(req) + + assert "No parts" in result + + +class TestBuildA2AResponseLog: + """Test suite for build_a2a_response_log function.""" + + def test_error_response(self): + """Test error response logging.""" + from google.adk.a2a.logs.log_utils import build_a2a_response_log + + resp = Mock() + resp.root.error.code = 500 + resp.root.error.message = "Internal Server Error" + resp.root.error.data = {"details": "Something went wrong"} + resp.root.id = "resp-error" + resp.root.jsonrpc = "2.0" + + result = build_a2a_response_log(resp) + + assert "Type: ERROR" in result + assert "Error Code: 500" in result + assert "Internal Server Error" in result + assert '"details": "Something went wrong"' in result + assert "resp-error" in result + assert "2.0" in result + + def test_error_response_no_data(self): + """Test error response logging without error data.""" + from google.adk.a2a.logs.log_utils import build_a2a_response_log + + resp = Mock() + resp.root.error.code = 404 + resp.root.error.message = "Not Found" + resp.root.error.data = None + resp.root.id = "resp-404" + resp.root.jsonrpc = "2.0" + + result = build_a2a_response_log(resp) + + assert "Type: ERROR" in result + assert "Error Code: 404" in result + assert "Not Found" in result + assert "Error Data: None" in result + + @pytest.mark.skipif(not A2A_AVAILABLE, reason="A2A types not available") + def test_success_response_with_task(self): + """Test success response logging with Task result.""" + # Use module-level imported types consistently + from google.adk.a2a.logs.log_utils import build_a2a_response_log + + task_status = TaskStatus(state=TaskState.working) + task = A2ATask(id="task-123", contextId="ctx-456", status=task_status) + + resp = Mock() + resp.root.result = task + resp.root.id = "resp-789" + resp.root.jsonrpc = "2.0" + + # Remove error attribute to ensure success path + delattr(resp.root, "error") + + result = build_a2a_response_log(resp) + + assert "Type: SUCCESS" in result + assert "Result Type: Task" in result + assert "Task ID: task-123" in result + assert "Context ID: ctx-456" in result + # Handle both structured format and JSON fallback due to potential isinstance failures + assert ( + "Status State: TaskState.working" in result + or "Status State: working" in result + or '"state":"working"' in result + or '"state": "working"' in result + ) + + @pytest.mark.skipif(not A2A_AVAILABLE, reason="A2A types not available") + def test_success_response_with_task_and_status_message(self): + """Test success response with Task that has status message.""" + from google.adk.a2a.logs.log_utils import build_a2a_response_log + + # Create status message using module-level imported types + status_message = A2AMessage( + messageId="status-msg-123", + role=Role.agent, + parts=[ + A2APart(root=A2ATextPart(text="Status part 1")), + A2APart(root=A2ATextPart(text="Status part 2")), + ], + ) + + task_status = TaskStatus(state=TaskState.working, message=status_message) + task = A2ATask( + id="task-123", + contextId="ctx-456", + status=task_status, + history=[], + artifacts=None, + ) + + resp = Mock() + resp.root.result = task + resp.root.id = "resp-789" + resp.root.jsonrpc = "2.0" + + # Remove error attribute to ensure success path + delattr(resp.root, "error") + + result = build_a2a_response_log(resp) + + assert "ID: status-msg-123" in result + # Handle both structured format and JSON fallback + assert ( + "Role: Role.agent" in result + or "Role: agent" in result + or '"role":"agent"' in result + or '"role": "agent"' in result + ) + assert "Message Parts:" in result + + @pytest.mark.skipif(not A2A_AVAILABLE, reason="A2A types not available") + def test_success_response_with_message(self): + """Test success response logging with Message result.""" + from google.adk.a2a.logs.log_utils import build_a2a_response_log + + # Use module-level imported types consistently + message = A2AMessage( + messageId="msg-123", + role=Role.agent, + taskId="task-456", + contextId="ctx-789", + parts=[A2APart(root=A2ATextPart(text="Message part 1"))], + ) + + resp = Mock() + resp.root.result = message + resp.root.id = "resp-101" + resp.root.jsonrpc = "2.0" + + # Remove error attribute to ensure success path + delattr(resp.root, "error") + + result = build_a2a_response_log(resp) + + assert "Type: SUCCESS" in result + assert "Result Type: Message" in result + assert "Message ID: msg-123" in result + # Handle both structured format and JSON fallback + assert ( + "Role: Role.agent" in result + or "Role: agent" in result + or '"role":"agent"' in result + or '"role": "agent"' in result + ) + assert "Task ID: task-456" in result + assert "Context ID: ctx-789" in result + + def test_success_response_with_message_no_parts(self): + """Test success response with Message that has no parts.""" + from google.adk.a2a.logs.log_utils import build_a2a_response_log + + # Use mock for this case since we want to test empty parts handling + message = Mock() + message.__class__.__name__ = "Message" + message.messageId = "msg-empty" + message.role = "agent" + message.taskId = "task-empty" + message.contextId = "ctx-empty" + message.parts = None # No parts + message.model_dump_json.return_value = '{"message": "empty"}' + + resp = Mock() + resp.root.result = message + resp.root.id = "resp-empty" + resp.root.jsonrpc = "2.0" + + # Remove error attribute to ensure success path + delattr(resp.root, "error") + + result = build_a2a_response_log(resp) + + assert "Type: SUCCESS" in result + assert "Result Type: Message" in result + + def test_success_response_with_other_result_type(self): + """Test success response with result type that's not Task or Message.""" + from google.adk.a2a.logs.log_utils import build_a2a_response_log + + other_result = Mock() + other_result.__class__.__name__ = "OtherResult" + other_result.model_dump_json.return_value = '{"other": "data"}' + + resp = Mock() + resp.root.result = other_result + resp.root.id = "resp-other" + resp.root.jsonrpc = "2.0" + + # Remove error attribute to ensure success path + delattr(resp.root, "error") + + result = build_a2a_response_log(resp) + + assert "Type: SUCCESS" in result + assert "Result Type: OtherResult" in result + assert "JSON Data:" in result + assert '"other": "data"' in result + + def test_success_response_without_model_dump_json(self): + """Test success response with result that doesn't have model_dump_json.""" + from google.adk.a2a.logs.log_utils import build_a2a_response_log + + other_result = Mock() + other_result.__class__.__name__ = "SimpleResult" + # Don't add model_dump_json method + del other_result.model_dump_json + + resp = Mock() + resp.root.result = other_result + resp.root.id = "resp-simple" + resp.root.jsonrpc = "2.0" + + # Remove error attribute to ensure success path + delattr(resp.root, "error") + + result = build_a2a_response_log(resp) + + assert "Type: SUCCESS" in result + assert "Result Type: SimpleResult" in result + + def test_build_message_part_log_with_metadata(self): + """Test build_message_part_log with metadata in the part.""" + from google.adk.a2a.logs.log_utils import build_message_part_log + + mock_root = Mock() + mock_root.__class__.__name__ = "MockPartWithMetadata" + mock_root.metadata = {"key": "value", "nested": {"data": "test"}} + + mock_part = Mock() + mock_part.root = mock_root + mock_part.model_dump_json.return_value = '{"content": "test"}' + + result = build_message_part_log(mock_part) + + assert "MockPartWithMetadata:" in result + assert "Part Metadata:" in result + assert '"key": "value"' in result + assert '"nested"' in result + + def test_build_a2a_request_log_with_message_metadata(self): + """Test request logging with message metadata.""" + from google.adk.a2a.logs.log_utils import build_a2a_request_log + + req = Mock() + req.id = "req-with-metadata" + req.method = "sendMessage" + req.jsonrpc = "2.0" + + req.params.message.messageId = "msg-with-metadata" + req.params.message.role = "user" + req.params.message.taskId = "task-metadata" + req.params.message.contextId = "ctx-metadata" + req.params.message.parts = [] + req.params.message.metadata = {"msg_type": "test", "priority": "high"} + + req.params.configuration = None + req.params.metadata = None + + result = build_a2a_request_log(req) + + assert "Metadata:" in result + assert '"msg_type": "test"' in result + assert '"priority": "high"' in result