From 3b1d9a8a3e631ca2d86d30f09640497f1728986c Mon Sep 17 00:00:00 2001 From: bck-ob-gh Date: Mon, 23 Jun 2025 09:24:00 -0700 Subject: [PATCH 01/28] fix: Use starred tuple unpacking on GCS artifact blob names Merges https://github.com/google/adk-python/pull/1471 Fixes google#1436 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1471 from bck-ob-gh:main 4c4f2b66ab1e6fde8b1a9d2b914dcb24040db144 PiperOrigin-RevId: 774809270 --- src/google/adk/artifacts/gcs_artifact_service.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index e4af21e15..35aa88622 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -13,6 +13,7 @@ # limitations under the License. """An artifact service implementation using Google Cloud Storage (GCS).""" +from __future__ import annotations import logging from typing import Optional @@ -151,7 +152,7 @@ async def list_artifact_keys( self.bucket, prefix=session_prefix ) for blob in session_blobs: - _, _, _, filename, _ = blob.name.split("/") + *_, filename, _ = blob.name.split("/") filenames.add(filename) user_namespace_prefix = f"{app_name}/{user_id}/user/" @@ -159,7 +160,7 @@ async def list_artifact_keys( self.bucket, prefix=user_namespace_prefix ) for blob in user_namespace_blobs: - _, _, _, filename, _ = blob.name.split("/") + *_, filename, _ = blob.name.split("/") filenames.add(filename) return sorted(list(filenames)) From f033e405c10ff8d86550d1419a9d63c0099182f9 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Mon, 23 Jun 2025 10:11:47 -0700 Subject: [PATCH 02/28] chore: Clarify the behavior of Event.invocation_id PiperOrigin-RevId: 774827874 --- src/google/adk/events/event.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/google/adk/events/event.py b/src/google/adk/events/event.py index c3b8b8699..6dd617fff 100644 --- a/src/google/adk/events/event.py +++ b/src/google/adk/events/event.py @@ -34,9 +34,10 @@ class Event(LlmResponse): taken by the agents like function calls, etc. Attributes: - invocation_id: The invocation ID of the event. - author: "user" or the name of the agent, indicating who appended the event - to the session. + invocation_id: Required. The invocation ID of the event. Should be non-empty + before appending to a session. + author: Required. "user" or the name of the agent, indicating who appended + the event to the session. actions: The actions taken by the agent. long_running_tool_ids: The ids of the long running function calls. branch: The branch of the event. @@ -55,9 +56,8 @@ class Event(LlmResponse): ) """The pydantic model config.""" - # TODO: revert to be required after spark migration invocation_id: str = '' - """The invocation ID of the event.""" + """The invocation ID of the event. Should be non-empty before appending to a session.""" author: str """'user' or the name of the agent, indicating who appended the event to the session.""" From ea69c9093a16489afdf72657136c96f61c69cafd Mon Sep 17 00:00:00 2001 From: Keisuke Oohashi Date: Mon, 23 Jun 2025 10:27:41 -0700 Subject: [PATCH 03/28] feat: add usage span attributes to telemetry (#356) Merge https://github.com/google/adk-python/pull/1079 Fixes part of #356 Add usage attributes to span. Note: Since the handling of GenAI event bodies in OpenTelemetry has not yet been determined, I have temporarily added only attributes related to usage. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1079 from soundTricker:feature/356-support-more-opentelemetry-semantics 99a9d0352b4bca165baa645440e39ce7199f072b PiperOrigin-RevId: 774834279 --- src/google/adk/telemetry.py | 10 ++++++++++ tests/unittests/test_telemetry.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/google/adk/telemetry.py b/src/google/adk/telemetry.py index badaec46d..a09c2f55b 100644 --- a/src/google/adk/telemetry.py +++ b/src/google/adk/telemetry.py @@ -195,6 +195,16 @@ def trace_call_llm( llm_response_json, ) + if llm_response.usage_metadata is not None: + span.set_attribute( + 'gen_ai.usage.input_tokens', + llm_response.usage_metadata.prompt_token_count, + ) + span.set_attribute( + 'gen_ai.usage.output_tokens', + llm_response.usage_metadata.total_token_count, + ) + def trace_send_data( invocation_context: InvocationContext, diff --git a/tests/unittests/test_telemetry.py b/tests/unittests/test_telemetry.py index 1b8ee1b16..debdc802e 100644 --- a/tests/unittests/test_telemetry.py +++ b/tests/unittests/test_telemetry.py @@ -141,6 +141,36 @@ async def test_trace_call_llm_function_response_includes_part_from_bytes( assert llm_request_json_str.count('') == 2 +@pytest.mark.asyncio +async def test_trace_call_llm_usage_metadata(monkeypatch, mock_span_fixture): + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + agent = LlmAgent(name='test_agent') + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest( + config=types.GenerateContentConfig(system_instruction=''), + ) + llm_response = LlmResponse( + turn_complete=True, + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=100, prompt_token_count=50 + ), + ) + trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response) + + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + mock.call('gen_ai.usage.input_tokens', 50), + mock.call('gen_ai.usage.output_tokens', 100), + ] + assert mock_span_fixture.set_attribute.call_count == 9 + mock_span_fixture.set_attribute.assert_has_calls( + expected_calls, any_order=True + ) + + def test_trace_tool_call_with_scalar_response( monkeypatch, mock_span_fixture, mock_tool_fixture, mock_event_fixture ): From bd67e8480f6e8b4b0f8c22b94f15a8cda1336339 Mon Sep 17 00:00:00 2001 From: avidelatm Date: Mon, 23 Jun 2025 10:29:43 -0700 Subject: [PATCH 04/28] fix: make LiteLLM streaming truly asynchronous Merge https://github.com/google/adk-python/pull/1451 ## Description Fixes https://github.com/google/adk-python/issues/1306 by using `async for` with `await self.llm_client.acompletion()` instead of synchronous `for` loop. ## Changes - Updated test mocks to properly handle async streaming by creating an async generator - Ensured proper parameter handling to avoid duplicate stream parameter ## Testing Plan - All unit tests now pass with the async streaming implementation - Verified with `pytest tests/unittests/models/test_litellm.py` that all streaming tests pass - Manually tested with a sample agent using LiteLLM to confirm streaming works properly # Test Evidence: https://youtu.be/hSp3otI79DM Let me know if you need anything else from me for this PR COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1451 from avidelatm:fix/litellm-async-streaming d35b9dc90b2fd6fad44c3869de0fda2514e50055 PiperOrigin-RevId: 774835130 --- src/google/adk/models/lite_llm.py | 2 +- tests/unittests/models/test_litellm.py | 23 ++++++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index dce5ed7c4..acc88ed19 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -679,7 +679,7 @@ async def generate_content_async( aggregated_llm_response_with_tool_call = None usage_metadata = None fallback_index = 0 - for part in self.llm_client.completion(**completion_args): + async for part in await self.llm_client.acompletion(**completion_args): for chunk, finish_reason in _model_response_to_chunk(part): if isinstance(chunk, FunctionChunk): index = chunk.index or fallback_index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 8b43cc48b..d058aa44d 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -416,9 +416,26 @@ def __init__(self, acompletion_mock, completion_mock): self.completion_mock = completion_mock async def acompletion(self, model, messages, tools, **kwargs): - return await self.acompletion_mock( - model=model, messages=messages, tools=tools, **kwargs - ) + if kwargs.get("stream", False): + kwargs_copy = dict(kwargs) + kwargs_copy.pop("stream", None) + + async def stream_generator(): + stream_data = self.completion_mock( + model=model, + messages=messages, + tools=tools, + stream=True, + **kwargs_copy, + ) + for item in stream_data: + yield item + + return stream_generator() + else: + return await self.acompletion_mock( + model=model, messages=messages, tools=tools, **kwargs + ) def completion(self, model, messages, tools, stream, **kwargs): return self.completion_mock( From 29cd183aa1b47dc4f5d8afe22f410f8546634abc Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 23 Jun 2025 12:15:26 -0700 Subject: [PATCH 05/28] chore: Add credential service backed by session state PiperOrigin-RevId: 774878336 --- .../session_state_credential_service.py | 83 ++++ tests/unittests/auth/__init__.py | 13 + .../auth/credential_service/__init__.py | 13 + .../test_session_state_credential_service.py | 355 ++++++++++++++++++ 4 files changed, 464 insertions(+) create mode 100644 src/google/adk/auth/credential_service/session_state_credential_service.py create mode 100644 tests/unittests/auth/__init__.py create mode 100644 tests/unittests/auth/credential_service/__init__.py create mode 100644 tests/unittests/auth/credential_service/test_session_state_credential_service.py diff --git a/src/google/adk/auth/credential_service/session_state_credential_service.py b/src/google/adk/auth/credential_service/session_state_credential_service.py new file mode 100644 index 000000000..e2ff7e07d --- /dev/null +++ b/src/google/adk/auth/credential_service/session_state_credential_service.py @@ -0,0 +1,83 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from typing_extensions import override + +from ...tools.tool_context import ToolContext +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_tool import AuthConfig +from .base_credential_service import BaseCredentialService + + +@experimental +class SessionStateCredentialService(BaseCredentialService): + """Class for implementation of credential service using session state as the + store. + Note: store credential in session may not be secure, use at your own risk. + """ + + @override + async def load_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> Optional[AuthCredential]: + """ + Loads the credential by auth config and current tool context from the + backend credential store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to load the credential. + + tool_context: The context of the current invocation when the tool is + trying to load the credential. + + Returns: + Optional[AuthCredential]: the credential saved in the store. + + """ + return tool_context.state.get(auth_config.credential_key) + + @override + async def save_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> None: + """ + Saves the exchanged_auth_credential in auth config to the backend credential + store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to save the credential. + + tool_context: The context of the current invocation when the tool is + trying to save the credential. + + Returns: + None + """ + + tool_context.state[auth_config.credential_key] = ( + auth_config.exchanged_auth_credential + ) diff --git a/tests/unittests/auth/__init__.py b/tests/unittests/auth/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/credential_service/__init__.py b/tests/unittests/auth/credential_service/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/credential_service/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/credential_service/test_session_state_credential_service.py b/tests/unittests/auth/credential_service/test_session_state_credential_service.py new file mode 100644 index 000000000..610a9d3d1 --- /dev/null +++ b/tests/unittests/auth/credential_service/test_session_state_credential_service.py @@ -0,0 +1,355 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_service.session_state_credential_service import SessionStateCredentialService +from google.adk.tools.tool_context import ToolContext +import pytest + + +class TestSessionStateCredentialService: + """Tests for the SessionStateCredentialService class.""" + + @pytest.fixture + def credential_service(self): + """Create a SessionStateCredentialService instance for testing.""" + return SessionStateCredentialService() + + @pytest.fixture + def oauth2_auth_scheme(self): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + @pytest.fixture + def oauth2_credentials(self): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + @pytest.fixture + def auth_config(self, oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + exchanged_credential = oauth2_credentials.model_copy(deep=True) + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged_credential, + ) + + @pytest.fixture + def tool_context(self): + """Create a mock ToolContext for testing.""" + mock_context = Mock(spec=ToolContext) + # Create a state dictionary that behaves like session state + mock_context.state = {} + return mock_context + + @pytest.fixture + def another_tool_context(self): + """Create another mock ToolContext with different state for testing isolation.""" + mock_context = Mock(spec=ToolContext) + # Create a separate state dictionary to simulate different session + mock_context.state = {} + return mock_context + + @pytest.mark.asyncio + async def test_load_credential_not_found( + self, credential_service, auth_config, tool_context + ): + """Test loading a credential that doesn't exist returns None.""" + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_save_and_load_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving and then loading a credential.""" + # Save the credential + await credential_service.save_credential(auth_config, tool_context) + + # Load the credential + result = await credential_service.load_credential(auth_config, tool_context) + + # Verify the credential was saved and loaded correctly + assert result is not None + assert result == auth_config.exchanged_auth_credential + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.client_id == "mock_client_id" + + @pytest.mark.asyncio + async def test_save_credential_updates_existing( + self, credential_service, auth_config, tool_context, oauth2_credentials + ): + """Test that saving a credential updates an existing one.""" + # Save initial credential + await credential_service.save_credential(auth_config, tool_context) + + # Create a new credential and update the auth_config + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="updated_client_id", + client_secret="updated_client_secret", + redirect_uri="https://updated.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + + # Save the updated credential + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify the credential was updated + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "updated_client_id" + assert result.oauth2.client_secret == "updated_client_secret" + + @pytest.mark.asyncio + async def test_credentials_isolated_by_context( + self, credential_service, auth_config, tool_context, another_tool_context + ): + """Test that credentials are isolated between different tool contexts.""" + # Save credential in first context + await credential_service.save_credential(auth_config, tool_context) + + # Try to load from another context (should not find it) + result = await credential_service.load_credential( + auth_config, another_tool_context + ) + assert result is None + + # Verify original context still has the credential + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + + @pytest.mark.asyncio + async def test_multiple_credentials_same_context( + self, credential_service, tool_context, oauth2_auth_scheme + ): + """Test storing multiple credentials in the same context with different keys.""" + # Create two different auth configs with different credential keys + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client1", + client_secret="secret1", + redirect_uri="https://example1.com/callback", + ), + ) + + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client2", + client_secret="secret2", + redirect_uri="https://example2.com/callback", + ), + ) + + auth_config1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred1, + exchanged_auth_credential=cred1, + credential_key="key1", + ) + + auth_config2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred2, + exchanged_auth_credential=cred2, + credential_key="key2", + ) + + # Save both credentials + await credential_service.save_credential(auth_config1, tool_context) + await credential_service.save_credential(auth_config2, tool_context) + + # Load and verify both credentials + result1 = await credential_service.load_credential( + auth_config1, tool_context + ) + result2 = await credential_service.load_credential( + auth_config2, tool_context + ) + + assert result1 is not None + assert result2 is not None + assert result1.oauth2.client_id == "client1" + assert result2.oauth2.client_id == "client2" + + @pytest.mark.asyncio + async def test_save_credential_with_none_exchanged_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving when exchanged_auth_credential is None.""" + # Set exchanged credential to None + auth_config.exchanged_auth_credential = None + + # Save the credential (should save None) + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify None was saved + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_load_credential_with_empty_credential_key( + self, credential_service, auth_config, tool_context + ): + """Test loading credential with empty credential key.""" + # Set credential key to empty string + auth_config.credential_key = "" + + # Save first to have something to load + await credential_service.save_credential(auth_config, tool_context) + + # Load should work with empty key + result = await credential_service.load_credential(auth_config, tool_context) + assert result == auth_config.exchanged_auth_credential + + @pytest.mark.asyncio + async def test_state_persistence_across_operations( + self, credential_service, auth_config, tool_context + ): + """Test that state persists correctly across multiple operations.""" + # Initially, no credential should exist + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + # Save a credential + await credential_service.save_credential(auth_config, tool_context) + + # Verify it was saved + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result == auth_config.exchanged_auth_credential + + # Update and save again + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="new_client_id", + client_secret="new_client_secret", + redirect_uri="https://new.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + await credential_service.save_credential(auth_config, tool_context) + + # Verify the update persisted + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "new_client_id" + + @pytest.mark.asyncio + async def test_credential_key_uniqueness( + self, credential_service, oauth2_auth_scheme, tool_context + ): + """Test that different credential keys create separate storage slots.""" + # Create credentials with same content but different keys + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="same_client", + client_secret="same_secret", + redirect_uri="https://same.com/callback", + ), + ) + + config_key1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=credential, + exchanged_auth_credential=credential, + credential_key="unique_key_1", + ) + + config_key2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=credential, + exchanged_auth_credential=credential, + credential_key="unique_key_2", + ) + + # Save credential with first key + await credential_service.save_credential(config_key1, tool_context) + + # Verify it's stored under first key + result1 = await credential_service.load_credential( + config_key1, tool_context + ) + assert result1 is not None + + # Verify it's not accessible under second key + result2 = await credential_service.load_credential( + config_key2, tool_context + ) + assert result2 is None + + # Save under second key + await credential_service.save_credential(config_key2, tool_context) + + # Now both should be accessible + result1 = await credential_service.load_credential( + config_key1, tool_context + ) + result2 = await credential_service.load_credential( + config_key2, tool_context + ) + assert result1 is not None + assert result2 is not None + assert result1 == result2 # Same credential content + + def test_direct_state_access( + self, credential_service, auth_config, tool_context + ): + """Test that the service correctly uses tool_context.state for storage.""" + # Verify that the state starts empty + assert len(tool_context.state) == 0 + + # Save a credential (this is async but we're testing the state directly) + credential_key = auth_config.credential_key + test_credential = auth_config.exchanged_auth_credential + + # Directly set the state to simulate save_credential behavior + tool_context.state[credential_key] = test_credential + + # Verify the credential is in the state + assert credential_key in tool_context.state + assert tool_context.state[credential_key] == test_credential + + # Verify we can retrieve it using the get method (simulating load_credential) + retrieved = tool_context.state.get(credential_key) + assert retrieved == test_credential From 120cbabeb23c16d9ce4be511e768885f19a8c2d2 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 23 Jun 2025 12:22:53 -0700 Subject: [PATCH 06/28] refactor: Rename long util function name in runner.py and move it to functions.py PiperOrigin-RevId: 774880990 --- src/google/adk/flows/llm_flows/functions.py | 32 ++++ src/google/adk/runners.py | 37 +--- .../flows/llm_flows/test_functions_simple.py | 136 ++++++++++++++ tests/unittests/test_runners.py | 171 ------------------ 4 files changed, 170 insertions(+), 206 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 2772550c2..5c690f1fd 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -519,3 +519,35 @@ def merge_parallel_function_response_events( # Use the base_event as the timestamp merged_event.timestamp = base_event.timestamp return merged_event + + +def find_matching_function_call( + events: list[Event], +) -> Optional[Event]: + """Finds the function call event that matches the function response id of the last event.""" + if not events: + return None + + last_event = events[-1] + if ( + last_event.content + and last_event.content.parts + and any(part.function_response for part in last_event.content.parts) + ): + + function_call_id = next( + part.function_response.id + for part in last_event.content.parts + if part.function_response + ) + for i in range(len(events) - 2, -1, -1): + event = events[i] + # looking for the system long running request euc function call + function_calls = event.get_function_calls() + if not function_calls: + continue + + for function_call in function_calls: + if function_call.id == function_call_id: + return event + return None diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 936bc5205..017997bb3 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -36,6 +36,7 @@ from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .events.event import Event +from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService from .memory.in_memory_memory_service import InMemoryMemoryService from .platform.thread import create_thread @@ -354,9 +355,7 @@ def _find_agent_to_run( # the agent that returned the corressponding function call regardless the # type of the agent. e.g. a remote a2a agent may surface a credential # request as a special long running function tool call. - event = _find_function_call_event_if_last_event_is_function_response( - session - ) + event = find_matching_function_call(session.events) if event and event.author: return root_agent.find_agent(event.author) for event in filter(lambda e: e.author != 'user', reversed(session.events)): @@ -538,35 +537,3 @@ def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'): session_service=self._in_memory_session_service, memory_service=InMemoryMemoryService(), ) - - -def _find_function_call_event_if_last_event_is_function_response( - session: Session, -) -> Optional[Event]: - events = session.events - if not events: - return None - - last_event = events[-1] - if ( - last_event.content - and last_event.content.parts - and any(part.function_response for part in last_event.content.parts) - ): - - function_call_id = next( - part.function_response.id - for part in last_event.content.parts - if part.function_response - ) - for i in range(len(events) - 2, -1, -1): - event = events[i] - # looking for the system long running request euc function call - function_calls = event.get_function_calls() - if not function_calls: - continue - - for function_call in function_calls: - if function_call.id == function_call_id: - return event - return None diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index 2c5ef9bce..720af516d 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -17,6 +17,9 @@ from typing import Callable from google.adk.agents import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.functions import find_matching_function_call +from google.adk.sessions.session import Session from google.adk.tools import ToolContext from google.adk.tools.function_tool import FunctionTool from google.genai import types @@ -256,3 +259,136 @@ def increase_by_one(x: int) -> int: assert part.function_response.id is None assert events[0].content.parts[0].function_call.id.startswith('adk-') assert events[1].content.parts[0].function_response.id.startswith('adk-') + + +def test_find_function_call_event_no_function_response_in_last_event(): + """Test when last event has no function response.""" + events = [ + Event( + invocation_id='inv1', + author='user', + content=types.Content(role='user', parts=[types.Part(text='Hello')]), + ) + ] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_empty_session_events(): + """Test when session has no events.""" + events = [] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_function_response_but_no_matching_call(): + """Test when last event has function response but no matching call found.""" + # Create a function response + function_response = types.FunctionResponse( + id='func_123', name='test_func', response={} + ) + + events = [ + Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', + parts=[types.Part(text='Some other response')], + ), + ), + Event( + invocation_id='inv2', + author='user', + content=types.Content( + role='user', + parts=[types.Part(function_response=function_response)], + ), + ), + ] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_function_response_with_matching_call(): + """Test when last event has function response with matching function call.""" + # Create a function call + function_call = types.FunctionCall(id='func_123', name='test_func', args={}) + + # Create a function response with matching ID + function_response = types.FunctionResponse( + id='func_123', name='test_func', response={} + ) + + call_event = Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call)] + ), + ) + + response_event = Event( + invocation_id='inv2', + author='user', + content=types.Content( + role='user', parts=[types.Part(function_response=function_response)] + ), + ) + + events = [call_event, response_event] + + result = find_matching_function_call(events) + assert result == call_event + + +def test_find_function_call_event_multiple_function_responses(): + """Test when last event has multiple function responses.""" + # Create function calls + function_call1 = types.FunctionCall(id='func_123', name='test_func1', args={}) + function_call2 = types.FunctionCall(id='func_456', name='test_func2', args={}) + + # Create function responses + function_response1 = types.FunctionResponse( + id='func_123', name='test_func1', response={} + ) + function_response2 = types.FunctionResponse( + id='func_456', name='test_func2', response={} + ) + + call_event1 = Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call1)] + ), + ) + + call_event2 = Event( + invocation_id='inv2', + author='agent2', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call2)] + ), + ) + + response_event = Event( + invocation_id='inv3', + author='user', + content=types.Content( + role='user', + parts=[ + types.Part(function_response=function_response1), + types.Part(function_response=function_response2), + ], + ), + ) + + events = [call_event1, call_event2, response_event] + + # Should return the first matching function call event found + result = find_matching_function_call(events) + assert result == call_event1 # First match (func_123) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 56d7667ab..8d5bd2418 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -18,7 +18,6 @@ from google.adk.agents.llm_agent import LlmAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.events.event import Event -from google.adk.runners import _find_function_call_event_if_last_event_is_function_response from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.session import Session @@ -73,176 +72,6 @@ async def _run_async_impl(self, invocation_context): ) -class TestFindFunctionCallEventIfLastEventIsFunctionResponse: - """Tests for _find_function_call_event_if_last_event_is_function_response function.""" - - def test_no_function_response_in_last_event(self): - """Test when last event has no function response.""" - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[ - Event( - invocation_id="inv1", - author="user", - content=types.Content( - role="user", parts=[types.Part(text="Hello")] - ), - ) - ], - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result is None - - def test_empty_session_events(self): - """Test when session has no events.""" - session = Session( - id="test_session", user_id="test_user", app_name="test_app", events=[] - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result is None - - def test_last_event_has_function_response_but_no_matching_call(self): - """Test when last event has function response but no matching call found.""" - # Create a function response - function_response = types.FunctionResponse( - id="func_123", name="test_func", response={} - ) - - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[ - Event( - invocation_id="inv1", - author="agent1", - content=types.Content( - role="model", - parts=[types.Part(text="Some other response")], - ), - ), - Event( - invocation_id="inv2", - author="user", - content=types.Content( - role="user", - parts=[types.Part(function_response=function_response)], - ), - ), - ], - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result is None - - def test_last_event_has_function_response_with_matching_call(self): - """Test when last event has function response with matching function call.""" - # Create a function call - function_call = types.FunctionCall(id="func_123", name="test_func", args={}) - - # Create a function response with matching ID - function_response = types.FunctionResponse( - id="func_123", name="test_func", response={} - ) - - call_event = Event( - invocation_id="inv1", - author="agent1", - content=types.Content( - role="model", parts=[types.Part(function_call=function_call)] - ), - ) - - response_event = Event( - invocation_id="inv2", - author="user", - content=types.Content( - role="user", parts=[types.Part(function_response=function_response)] - ), - ) - - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[call_event, response_event], - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result == call_event - - def test_last_event_has_multiple_function_responses(self): - """Test when last event has multiple function responses.""" - # Create function calls - function_call1 = types.FunctionCall( - id="func_123", name="test_func1", args={} - ) - function_call2 = types.FunctionCall( - id="func_456", name="test_func2", args={} - ) - - # Create function responses - function_response1 = types.FunctionResponse( - id="func_123", name="test_func1", response={} - ) - function_response2 = types.FunctionResponse( - id="func_456", name="test_func2", response={} - ) - - call_event1 = Event( - invocation_id="inv1", - author="agent1", - content=types.Content( - role="model", parts=[types.Part(function_call=function_call1)] - ), - ) - - call_event2 = Event( - invocation_id="inv2", - author="agent2", - content=types.Content( - role="model", parts=[types.Part(function_call=function_call2)] - ), - ) - - response_event = Event( - invocation_id="inv3", - author="user", - content=types.Content( - role="user", - parts=[ - types.Part(function_response=function_response1), - types.Part(function_response=function_response2), - ], - ), - ) - - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[call_event1, call_event2, response_event], - ) - - # Should return the first matching function call event found - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result == call_event1 # First match (func_123) - - class TestRunnerFindAgentToRun: """Tests for Runner._find_agent_to_run method.""" From fa025d755978e1506fa0da1fecc49775bebc1045 Mon Sep 17 00:00:00 2001 From: Joseph Pagadora Date: Mon, 23 Jun 2025 15:24:15 -0700 Subject: [PATCH 07/28] feat: Add a new option `eval_storage_uri` in adk web & adk eval to specify GCS bucket to store eval data PiperOrigin-RevId: 774947795 --- src/google/adk/cli/cli_tools_click.py | 66 +++++++++++++++++-- src/google/adk/cli/fast_api.py | 17 ++++- src/google/adk/cli/utils/evals.py | 53 +++++++++++++++ .../adk/evaluation/gcs_eval_sets_manager.py | 13 ++-- tests/unittests/cli/test_fast_api.py | 5 +- 5 files changed, 139 insertions(+), 15 deletions(-) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 49ecee482..9923b46c2 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -31,12 +31,15 @@ from . import cli_create from . import cli_deploy from .. import version +from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..sessions.in_memory_session_service import InMemorySessionService from .cli import run_cli from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE from .fast_api import get_fast_api_app from .utils import envs +from .utils import evals from .utils import logs LOG_LEVELS = click.Choice( @@ -282,11 +285,21 @@ def cli_run( default=False, help="Optional. Whether to print detailed results on console or not.", ) +@click.option( + "--eval_storage_uri", + type=str, + help=( + "Optional. The evals storage URI to store agent evals," + " supported URIs: gs://." + ), + default=None, +) def cli_eval( agent_module_file_path: str, - eval_set_file_path: tuple[str], + eval_set_file_path: list[str], config_file_path: str, print_detailed_results: bool, + eval_storage_uri: Optional[str] = None, ): """Evaluates an agent given the eval sets. @@ -338,12 +351,33 @@ def cli_eval( root_agent = get_root_agent(agent_module_file_path) reset_func = try_get_reset_func(agent_module_file_path) + gcs_eval_sets_manager = None + eval_set_results_manager = None + if eval_storage_uri: + gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( + eval_storage_uri + ) + gcs_eval_sets_manager = gcs_eval_managers.eval_sets_manager + eval_set_results_manager = gcs_eval_managers.eval_set_results_manager + else: + eval_set_results_manager = LocalEvalSetResultsManager( + agents_dir=os.path.dirname(agent_module_file_path) + ) eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path) eval_set_id_to_eval_cases = {} # Read the eval_set files and get the cases. for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items(): - eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) + if gcs_eval_sets_manager: + eval_set = gcs_eval_sets_manager._load_eval_set_from_blob( + eval_set_file_path + ) + if not eval_set: + raise click.ClickException( + f"Eval set {eval_set_file_path} not found in GCS." + ) + else: + eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) eval_cases = eval_set.eval_cases if eval_case_ids: @@ -378,16 +412,13 @@ async def _collect_eval_results() -> list[EvalCaseResult]: raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) # Write eval set results. - local_eval_set_results_manager = LocalEvalSetResultsManager( - agents_dir=os.path.dirname(agent_module_file_path) - ) eval_set_id_to_eval_results = collections.defaultdict(list) for eval_case_result in eval_results: eval_set_id = eval_case_result.eval_set_id eval_set_id_to_eval_results[eval_set_id].append(eval_case_result) for eval_set_id, eval_case_results in eval_set_id_to_eval_results.items(): - local_eval_set_results_manager.save_eval_set_result( + eval_set_results_manager.save_eval_set_result( app_name=os.path.basename(agent_module_file_path), eval_set_id=eval_set_id, eval_case_results=eval_case_results, @@ -444,6 +475,15 @@ def decorator(func): ), default=None, ) + @click.option( + "--eval_storage_uri", + type=str, + help=( + "Optional. The evals storage URI to store agent evals," + " supported URIs: gs://." + ), + default=None, + ) @click.option( "--memory_service_uri", type=str, @@ -564,6 +604,7 @@ def wrapper(*args, **kwargs): ) def cli_web( agents_dir: str, + eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, host: str = "127.0.0.1", @@ -616,6 +657,7 @@ async def _lifespan(app: FastAPI): session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, memory_service_uri=memory_service_uri, + eval_storage_uri=eval_storage_uri, allow_origins=allow_origins, web=True, trace_to_cloud=trace_to_cloud, @@ -654,6 +696,7 @@ async def _lifespan(app: FastAPI): ) def cli_api_server( agents_dir: str, + eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, host: str = "127.0.0.1", @@ -685,6 +728,7 @@ def cli_api_server( session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, memory_service_uri=memory_service_uri, + eval_storage_uri=eval_storage_uri, allow_origins=allow_origins, web=False, trace_to_cloud=trace_to_cloud, @@ -771,6 +815,15 @@ def cli_api_server( " version in the dev environment)" ), ) +@click.option( + "--eval_storage_uri", + type=str, + help=( + "Optional. The evals storage URI to store agent evals," + " supported URIs: gs://." + ), + default=None, +) @adk_services_options() @deprecated_adk_services_options() @click.argument( @@ -797,6 +850,7 @@ def cli_deploy_cloud_run( session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, + eval_storage_uri: Optional[str] = None, session_db_url: Optional[str] = None, # Deprecated artifact_storage_uri: Optional[str] = None, # Deprecated ): diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 46e008655..4b2ed6c2e 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -65,6 +65,8 @@ from ..evaluation.eval_metrics import EvalMetricResult from ..evaluation.eval_metrics import EvalMetricResultPerInvocation from ..evaluation.eval_result import EvalSetResult +from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..events.event import Event @@ -198,6 +200,7 @@ def get_fast_api_app( session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, + eval_storage_uri: Optional[str] = None, allow_origins: Optional[list[str]] = None, web: bool, trace_to_cloud: bool = False, @@ -256,8 +259,18 @@ async def internal_lifespan(app: FastAPI): runner_dict = {} - eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) - eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) + # Set up eval managers. + eval_sets_manager = None + eval_set_results_manager = None + if eval_storage_uri: + gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( + eval_storage_uri + ) + eval_sets_manager = gcs_eval_managers.eval_sets_manager + eval_set_results_manager = gcs_eval_managers.eval_set_results_manager + else: + eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) + eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) # Build the Memory service if memory_service_uri: diff --git a/src/google/adk/cli/utils/evals.py b/src/google/adk/cli/utils/evals.py index c8d1a3296..305d47544 100644 --- a/src/google/adk/cli/utils/evals.py +++ b/src/google/adk/cli/utils/evals.py @@ -14,17 +14,36 @@ from __future__ import annotations +import dataclasses +import os from typing import Any from typing import Tuple from google.genai import types as genai_types +from pydantic import alias_generators +from pydantic import BaseModel +from pydantic import ConfigDict from typing_extensions import deprecated from ...evaluation.eval_case import IntermediateData from ...evaluation.eval_case import Invocation +from ...evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from ...evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ...sessions.session import Session +class GcsEvalManagers(BaseModel): + model_config = ConfigDict( + alias_generator=alias_generators.to_camel, + populate_by_name=True, + arbitrary_types_allowed=True, + ) + + eval_sets_manager: GcsEvalSetsManager + + eval_set_results_manager: GcsEvalSetResultsManager + + @deprecated('Use convert_session_to_eval_invocations instead.') def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]: """Converts a session data into eval format. @@ -176,3 +195,37 @@ def convert_session_to_eval_invocations(session: Session) -> list[Invocation]: ) return invocations + + +def create_gcs_eval_managers_from_uri( + eval_storage_uri: str, +) -> GcsEvalManagers: + """Creates GcsEvalManagers from eval_storage_uri. + + Args: + eval_storage_uri: The evals storage URI to use. Supported URIs: + gs://. If a path is provided, the bucket will be extracted. + + Returns: + GcsEvalManagers: The GcsEvalManagers object. + + Raises: + ValueError: If the eval_storage_uri is not supported. + """ + if eval_storage_uri.startswith('gs://'): + gcs_bucket = eval_storage_uri.split('://')[1] + eval_sets_manager = GcsEvalSetsManager( + bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT'] + ) + eval_set_results_manager = GcsEvalSetResultsManager( + bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT'] + ) + return GcsEvalManagers( + eval_sets_manager=eval_sets_manager, + eval_set_results_manager=eval_set_results_manager, + ) + else: + raise ValueError( + f'Unsupported evals storage URI: {eval_storage_uri}. Supported URIs:' + ' gs://' + ) diff --git a/src/google/adk/evaluation/gcs_eval_sets_manager.py b/src/google/adk/evaluation/gcs_eval_sets_manager.py index fe5d8c9b5..c253e4cd5 100644 --- a/src/google/adk/evaluation/gcs_eval_sets_manager.py +++ b/src/google/adk/evaluation/gcs_eval_sets_manager.py @@ -72,6 +72,13 @@ def _validate_id(self, id_name: str, id_value: str): f"Invalid {id_name}. {id_name} should have the `{pattern}` format", ) + def _load_eval_set_from_blob(self, blob_name: str) -> Optional[EvalSet]: + blob = self.bucket.blob(blob_name) + if not blob.exists(): + return None + eval_set_data = blob.download_as_text() + return EvalSet.model_validate_json(eval_set_data) + def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet): """Writes an EvalSet to GCS.""" blob = self.bucket.blob(blob_name) @@ -88,11 +95,7 @@ def _save_eval_set(self, app_name: str, eval_set_id: str, eval_set: EvalSet): def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]: """Returns an EvalSet identified by an app_name and eval_set_id.""" eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id) - blob = self.bucket.blob(eval_set_blob_name) - if not blob.exists(): - return None - eval_set_data = blob.download_as_text() - return EvalSet.model_validate_json(eval_set_data) + return self._load_eval_set_from_blob(eval_set_blob_name) @override def create_eval_set(self, app_name: str, eval_set_id: str): diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 65c1eee3b..aec7a020b 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -40,7 +40,7 @@ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) -logger = logging.getLogger(__name__) +logger = logging.getLogger("google_adk." + __name__) # Here we create a dummy agent module that get_fast_api_app expects @@ -138,6 +138,7 @@ async def mock_run_evals_for_fast_api(*args, **kwargs): final_eval_status=1, # Matches expected (assuming 1 is PASSED) user_id="test_user", # Placeholder, adapt if needed session_id="test_session_for_eval_case", # Placeholder + eval_set_file="test_eval_set_file", # Placeholder overall_eval_metric_results=[{ # Matches expected "metricName": "tool_trajectory_avg_score", "threshold": 0.5, @@ -372,7 +373,7 @@ def add_eval_case(self, app_name, eval_set_id, eval_case): @pytest.fixture def mock_eval_set_results_manager(): - """Create a mock eval set results manager.""" + """Create a mock local eval set results manager.""" # Storage for eval set results. eval_set_results = {} From 9597a446fdec63ad9e4c2692d6966b14f80ff8e2 Mon Sep 17 00:00:00 2001 From: Joseph Pagadora Date: Mon, 23 Jun 2025 15:30:16 -0700 Subject: [PATCH 08/28] feat: Add rouge_score library to ADK eval dependencies, and implement RougeEvaluator that is computes ROUGE-1 for "response_match_score" metric PiperOrigin-RevId: 774949712 --- pyproject.toml | 1 + .../adk/evaluation/final_response_match_v1.py | 110 ++++++++++++++ .../adk/evaluation/response_evaluator.py | 13 +- .../test_final_response_match_v1.py | 140 ++++++++++++++++++ .../evaluation/test_response_evaluator.py | 39 ++++- 5 files changed, 301 insertions(+), 2 deletions(-) create mode 100644 src/google/adk/evaluation/final_response_match_v1.py create mode 100644 tests/unittests/evaluation/test_final_response_match_v1.py diff --git a/pyproject.toml b/pyproject.toml index 8ece4db81..23dbcb537 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ eval = [ "google-cloud-aiplatform[evaluation]>=1.87.0", "pandas>=2.2.3", "tabulate>=0.9.0", + "rouge-score>=0.1.2", # go/keep-sorted end ] diff --git a/src/google/adk/evaluation/final_response_match_v1.py b/src/google/adk/evaluation/final_response_match_v1.py new file mode 100644 index 000000000..a034b470f --- /dev/null +++ b/src/google/adk/evaluation/final_response_match_v1.py @@ -0,0 +1,110 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from google.genai import types as genai_types +from rouge_score import rouge_scorer +from typing_extensions import override + +from .eval_case import Invocation +from .eval_metrics import EvalMetric +from .evaluator import EvalStatus +from .evaluator import EvaluationResult +from .evaluator import Evaluator +from .evaluator import PerInvocationResult + + +class RougeEvaluator(Evaluator): + """Calculates the ROUGE-1 metric to compare responses.""" + + def __init__(self, eval_metric: EvalMetric): + self._eval_metric = eval_metric + + @override + def evaluate_invocations( + self, + actual_invocations: list[Invocation], + expected_invocations: list[Invocation], + ) -> EvaluationResult: + total_score = 0.0 + num_invocations = 0 + per_invocation_results = [] + for actual, expected in zip(actual_invocations, expected_invocations): + reference = _get_text_from_content(expected.final_response) + response = _get_text_from_content(actual.final_response) + rouge_1_scores = _calculate_rouge_1_scores(response, reference) + score = rouge_1_scores.fmeasure + per_invocation_results.append( + PerInvocationResult( + actual_invocation=actual, + expected_invocation=expected, + score=score, + eval_status=_get_eval_status(score, self._eval_metric.threshold), + ) + ) + total_score += score + num_invocations += 1 + + if per_invocation_results: + overall_score = total_score / num_invocations + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=_get_eval_status( + overall_score, self._eval_metric.threshold + ), + per_invocation_results=per_invocation_results, + ) + + return EvaluationResult() + + +def _get_text_from_content(content: Optional[genai_types.Content]) -> str: + if content and content.parts: + return "\n".join([part.text for part in content.parts if part.text]) + + return "" + + +def _get_eval_status(score: float, threshold: float): + return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED + + +def _calculate_rouge_1_scores(candidate: str, reference: str): + """Calculates the ROUGE-1 score between a candidate and reference text. + + ROUGE-1 measures the overlap of unigrams (single words) between the + candidate and reference texts. The score is broken down into: + - Precision: The proportion of unigrams in the candidate that are also in the + reference. + - Recall: The proportion of unigrams in the reference that are also in the + candidate. + - F-measure: The harmonic mean of precision and recall. + + Args: + candidate: The generated text to be evaluated. + reference: The ground-truth text to compare against. + + Returns: + A dictionary containing the ROUGE-1 precision, recall, and f-measure. + """ + scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True) + + # The score method returns a dictionary where keys are the ROUGE types + # and values are Score objects (tuples) with precision, recall, and fmeasure. + scores = scorer.score(reference, candidate) + + return scores["rouge1"] diff --git a/src/google/adk/evaluation/response_evaluator.py b/src/google/adk/evaluation/response_evaluator.py index 52ab50c74..0826f8796 100644 --- a/src/google/adk/evaluation/response_evaluator.py +++ b/src/google/adk/evaluation/response_evaluator.py @@ -27,10 +27,12 @@ from .eval_case import IntermediateData from .eval_case import Invocation +from .eval_metrics import EvalMetric from .evaluator import EvalStatus from .evaluator import EvaluationResult from .evaluator import Evaluator from .evaluator import PerInvocationResult +from .final_response_match_v1 import RougeEvaluator class ResponseEvaluator(Evaluator): @@ -40,7 +42,7 @@ def __init__(self, threshold: float, metric_name: str): if "response_evaluation_score" == metric_name: self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE elif "response_match_score" == metric_name: - self._metric_name = "rouge_1" + self._metric_name = "response_match_score" else: raise ValueError(f"`{metric_name}` is not supported.") @@ -52,6 +54,15 @@ def evaluate_invocations( actual_invocations: list[Invocation], expected_invocations: list[Invocation], ) -> EvaluationResult: + # If the metric is response_match_score, just use the RougeEvaluator. + if self._metric_name == "response_match_score": + rouge_evaluator = RougeEvaluator( + EvalMetric(metric_name=self._metric_name, threshold=self._threshold) + ) + return rouge_evaluator.evaluate_invocations( + actual_invocations, expected_invocations + ) + total_score = 0.0 num_invocations = 0 per_invocation_results = [] diff --git a/tests/unittests/evaluation/test_final_response_match_v1.py b/tests/unittests/evaluation/test_final_response_match_v1.py new file mode 100644 index 000000000..d5544a5a1 --- /dev/null +++ b/tests/unittests/evaluation/test_final_response_match_v1.py @@ -0,0 +1,140 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_metrics import EvalMetric +from google.adk.evaluation.evaluator import EvalStatus +from google.adk.evaluation.final_response_match_v1 import _calculate_rouge_1_scores +from google.adk.evaluation.final_response_match_v1 import RougeEvaluator +from google.genai import types as genai_types +import pytest + + +def _create_test_rouge_evaluator(threshold: float) -> RougeEvaluator: + return RougeEvaluator( + EvalMetric(metric_name="response_match_score", threshold=threshold) + ) + + +def _create_test_invocations( + candidate: str, reference: str +) -> tuple[Invocation, Invocation]: + """Returns tuple of (actual_invocation, expected_invocation).""" + return Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text=candidate)] + ), + ), Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text=reference)] + ), + ) + + +def test_calculate_rouge_1_scores_empty_candidate_and_reference(): + candidate = "" + reference = "" + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == 0 + assert rouge_1_score.recall == 0 + assert rouge_1_score.fmeasure == 0 + + +def test_calculate_rouge_1_scores_empty_candidate(): + candidate = "" + reference = "This is a test reference." + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == 0 + assert rouge_1_score.recall == 0 + assert rouge_1_score.fmeasure == 0 + + +def test_calculate_rouge_1_scores_empty_reference(): + candidate = "This is a test candidate response." + reference = "" + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == 0 + assert rouge_1_score.recall == 0 + assert rouge_1_score.fmeasure == 0 + + +def test_calculate_rouge_1_scores(): + candidate = "This is a test candidate response." + reference = "This is a test reference." + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == pytest.approx(2 / 3) + assert rouge_1_score.recall == pytest.approx(4 / 5) + assert rouge_1_score.fmeasure == pytest.approx(8 / 11) + + +@pytest.mark.parametrize( + "candidates, references, expected_score, expected_status", + [ + ( + ["The quick brown fox jumps.", "hello world"], + ["The quick brown fox jumps over the lazy dog.", "hello"], + 0.69048, # (5/7 + 2/3) / 2 + EvalStatus.FAILED, + ), + ( + ["This is a test.", "Another test case."], + ["This is a test.", "This is a different test."], + 0.625, # (1 + 1/4) / 2 + EvalStatus.FAILED, + ), + ( + ["No matching words here.", "Second candidate."], + ["Completely different text.", "Another reference."], + 0.0, # (0 + 1/2) / 2 + EvalStatus.FAILED, + ), + ( + ["Same words", "Same words"], + ["Same words", "Same words"], + 1.0, + EvalStatus.PASSED, + ), + ], +) +def test_rouge_evaluator_multiple_invocations( + candidates: list[str], + references: list[str], + expected_score: float, + expected_status: EvalStatus, +): + rouge_evaluator = _create_test_rouge_evaluator(threshold=0.8) + actual_invocations = [] + expected_invocations = [] + for candidate, reference in zip(candidates, references): + actual_invocation, expected_invocation = _create_test_invocations( + candidate, reference + ) + actual_invocations.append(actual_invocation) + expected_invocations.append(expected_invocation) + + evaluation_result = rouge_evaluator.evaluate_invocations( + actual_invocations, expected_invocations + ) + assert evaluation_result.overall_score == pytest.approx( + expected_score, rel=1e-3 + ) + assert evaluation_result.overall_eval_status == expected_status diff --git a/tests/unittests/evaluation/test_response_evaluator.py b/tests/unittests/evaluation/test_response_evaluator.py index bbaa694f2..839b7188a 100644 --- a/tests/unittests/evaluation/test_response_evaluator.py +++ b/tests/unittests/evaluation/test_response_evaluator.py @@ -16,7 +16,10 @@ from unittest.mock import MagicMock from unittest.mock import patch +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.evaluator import EvalStatus from google.adk.evaluation.response_evaluator import ResponseEvaluator +from google.genai import types as genai_types import pandas as pd import pytest from vertexai.preview.evaluation import MetricPromptTemplateExamples @@ -63,7 +66,7 @@ "google.adk.evaluation.response_evaluator.ResponseEvaluator._perform_eval" ) class TestResponseEvaluator: - """A class to help organize "patch" that are applicabple to all tests.""" + """A class to help organize "patch" that are applicable to all tests.""" def test_evaluate_none_dataset_raises_value_error(self, mock_perform_eval): """Test evaluate function raises ValueError for an empty list.""" @@ -77,6 +80,40 @@ def test_evaluate_empty_dataset_raises_value_error(self, mock_perform_eval): ResponseEvaluator.evaluate([], ["response_evaluation_score"]) mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called + def test_evaluate_invocations_rouge_metric(self, mock_perform_eval): + """Test evaluate_invocations function for Rouge metric.""" + actual_invocations = [ + Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[ + genai_types.Part(text="This is a test candidate response.") + ] + ), + ) + ] + expected_invocations = [ + Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text="This is a test reference.")] + ), + ) + ] + evaluator = ResponseEvaluator( + threshold=0.8, metric_name="response_match_score" + ) + evaluation_result = evaluator.evaluate_invocations( + actual_invocations, expected_invocations + ) + assert evaluation_result.overall_score == pytest.approx(8 / 11) + # ROUGE-1 F1 is approx. 0.73 < 0.8 threshold, so eval status is FAILED. + assert evaluation_result.overall_eval_status == EvalStatus.FAILED + def test_evaluate_determines_metrics_correctly_for_perform_eval( self, mock_perform_eval ): From 00cc8cd6433fc45ecfc2dbaa04dbbc1a81213b4d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Jun 2025 11:33:05 -0700 Subject: [PATCH 09/28] feat: Add Vertex Express mode compatibility for VertexAiSessionService PiperOrigin-RevId: 775317848 --- .../adk/sessions/vertex_ai_session_service.py | 71 ++++++++++++++----- 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index bd1345162..06a904c89 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -16,6 +16,7 @@ import asyncio import json import logging +import os import re from typing import Any from typing import Dict @@ -23,6 +24,7 @@ import urllib.parse from dateutil import parser +from google.genai.errors import ClientError from typing_extensions import override from google import genai @@ -95,25 +97,46 @@ async def create_session( operation_id = api_response['name'].split('/')[-1] max_retry_attempt = 5 - lro_response = None - while max_retry_attempt >= 0: - lro_response = await api_client.async_request( - http_method='GET', - path=f'operations/{operation_id}', - request_dict={}, - ) - lro_response = _convert_api_response(lro_response) - if lro_response.get('done', None): - break - - await asyncio.sleep(1) - max_retry_attempt -= 1 - - if lro_response is None or not lro_response.get('done', None): - raise TimeoutError( - f'Timeout waiting for operation {operation_id} to complete.' - ) + if _is_vertex_express_mode(self._project, self._location): + # Express mode doesn't support LRO, so we need to poll + # the session resource. + # TODO: remove this once LRO polling is supported in Express mode. + for i in range(max_retry_attempt): + try: + await api_client.async_request( + http_method='GET', + path=( + f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' + ), + request_dict={}, + ) + break + except ClientError as e: + logger.info('Polling for session %s: %s', session_id, e) + # Add slight exponential backoff to avoid excessive polling. + await asyncio.sleep(1 + 0.5 * i) + else: + raise TimeoutError('Session creation failed.') + else: + lro_response = None + for _ in range(max_retry_attempt): + lro_response = await api_client.async_request( + http_method='GET', + path=f'operations/{operation_id}', + request_dict={}, + ) + lro_response = _convert_api_response(lro_response) + + if lro_response.get('done', None): + break + + await asyncio.sleep(1) + + if lro_response is None or not lro_response.get('done', None): + raise TimeoutError( + f'Timeout waiting for operation {operation_id} to complete.' + ) # Get session resource get_session_api_response = await api_client.async_request( @@ -312,6 +335,18 @@ def _get_api_client(self): return client._api_client +def _is_vertex_express_mode( + project: Optional[str], location: Optional[str] +) -> bool: + """Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode.""" + return ( + os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in ['true', '1'] + and os.environ.get('GOOGLE_API_KEY', None) is not None + and project is None + and location is None + ) + + def _convert_api_response(api_response): """Converts the API response to a JSON object based on the type.""" if hasattr(api_response, 'body'): From abc89d2c811ba00805f81b27a3a07d56bdf55a0b Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 24 Jun 2025 11:56:28 -0700 Subject: [PATCH 10/28] feat: Add implementation of VertexAiMemoryBankService and support in FastAPI endpoint PiperOrigin-RevId: 775327151 --- src/google/adk/cli/cli_tools_click.py | 3 +- src/google/adk/cli/fast_api.py | 11 ++ src/google/adk/memory/__init__.py | 4 +- .../memory/vertex_ai_memory_bank_service.py | 147 ++++++++++++++++ .../test_vertex_ai_memory_bank_service.py | 158 ++++++++++++++++++ 5 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 src/google/adk/memory/vertex_ai_memory_bank_service.py create mode 100644 tests/unittests/memory/test_vertex_ai_memory_bank_service.py diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 9923b46c2..c0935cceb 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -489,7 +489,8 @@ def decorator(func): type=str, help=( """Optional. The URI of the memory service. - - Use 'rag://' to connect to Vertex AI Rag Memory Service.""" + - Use 'rag://' to connect to Vertex AI Rag Memory Service. + - Use 'agentengine://' to connect to Vertex AI Memory Bank Service. e.g. agentengine://12345""" ), default=None, ) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4b2ed6c2e..abe1961e7 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -71,6 +71,7 @@ from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..events.event import Event from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService from ..runners import Runner from ..sessions.database_session_service import DatabaseSessionService @@ -282,6 +283,16 @@ async def internal_lifespan(app: FastAPI): memory_service = VertexAiRagMemoryService( rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}' ) + elif memory_service_uri.startswith("agentengine://"): + agent_engine_id = memory_service_uri.split("://")[1] + if not agent_engine_id: + raise click.ClickException("Agent engine id can not be empty.") + envs.load_dotenv_for_agent("", agents_dir) + memory_service = VertexAiMemoryBankService( + project=os.environ["GOOGLE_CLOUD_PROJECT"], + location=os.environ["GOOGLE_CLOUD_LOCATION"], + agent_engine_id=agent_engine_id, + ) else: raise click.ClickException( "Unsupported memory service URI: %s" % memory_service_uri diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index f2ac4f9b5..915d7e517 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -15,12 +15,14 @@ from .base_memory_service import BaseMemoryService from .in_memory_memory_service import InMemoryMemoryService +from .vertex_ai_memory_bank_service import VertexAiMemoryBankService logger = logging.getLogger('google_adk.' + __name__) __all__ = [ 'BaseMemoryService', 'InMemoryMemoryService', + 'VertexAiMemoryBankService', ] try: @@ -29,7 +31,7 @@ __all__.append('VertexAiRagMemoryService') except ImportError: logger.debug( - 'The Vertex sdk is not installed. If you want to use the' + 'The Vertex SDK is not installed. If you want to use the' ' VertexAiRagMemoryService please install it. If not, you can ignore this' ' warning.' ) diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py new file mode 100644 index 000000000..b5b70ab1c --- /dev/null +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import logging +from typing import Optional +from typing import TYPE_CHECKING + +from typing_extensions import override + +from google import genai + +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..sessions.session import Session + +logger = logging.getLogger('google_adk.' + __name__) + + +class VertexAiMemoryBankService(BaseMemoryService): + """Implementation of the BaseMemoryService using Vertex AI Memory Bank.""" + + def __init__( + self, + project: Optional[str] = None, + location: Optional[str] = None, + agent_engine_id: Optional[str] = None, + ): + """Initializes a VertexAiMemoryBankService. + + Args: + project: The project ID of the Memory Bank to use. + location: The location of the Memory Bank to use. + agent_engine_id: The ID of the agent engine to use for the Memory Bank. + e.g. '456' in + 'projects/my-project/locations/us-central1/reasoningEngines/456'. + """ + self._project = project + self._location = location + self._agent_engine_id = agent_engine_id + + @override + async def add_session_to_memory(self, session: Session): + api_client = self._get_api_client() + + if not self._agent_engine_id: + raise ValueError('Agent Engine ID is required for Memory Bank.') + + events = [] + for event in session.events: + if event.content and event.content.parts: + events.append({ + 'content': event.content.model_dump(exclude_none=True, mode='json') + }) + request_dict = { + 'direct_contents_source': { + 'events': events, + }, + 'scope': { + 'app_name': session.app_name, + 'user_id': session.user_id, + }, + } + + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:generate', + request_dict=request_dict, + ) + logger.info(f'Generate memory response: {api_response}') + + @override + async def search_memory(self, *, app_name: str, user_id: str, query: str): + api_client = self._get_api_client() + + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve', + request_dict={ + 'scope': { + 'app_name': app_name, + 'user_id': user_id, + }, + 'similarity_search_params': { + 'search_query': query, + }, + }, + ) + api_response = _convert_api_response(api_response) + logger.info(f'Search memory response: {api_response}') + + if not api_response or not api_response.get('retrievedMemories', None): + return SearchMemoryResponse() + + memory_events = [] + for memory in api_response.get('retrievedMemories', []): + # TODO: add more complex error handling + memory_events.append( + MemoryEntry( + author='user', + content=genai.types.Content( + parts=[ + genai.types.Part(text=memory.get('memory').get('fact')) + ], + role='user', + ), + timestamp=memory.get('updateTime'), + ) + ) + return SearchMemoryResponse(memories=memory_events) + + def _get_api_client(self): + """Instantiates an API client for the given project and location. + + It needs to be instantiated inside each request so that the event loop + management can be properly propagated. + + Returns: + An API client for the given project and location. + """ + client = genai.Client( + vertexai=True, project=self._project, location=self._location + ) + return client._api_client + + +def _convert_api_response(api_response): + """Converts the API response to a JSON object based on the type.""" + if hasattr(api_response, 'body'): + return json.loads(api_response.body) + return api_response diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py new file mode 100644 index 000000000..27e2bbdd5 --- /dev/null +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -0,0 +1,158 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any +from unittest import mock + +from google.adk.events import Event +from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService +from google.adk.sessions import Session +from google.genai import types +import pytest + +MOCK_APP_NAME = 'test-app' +MOCK_USER_ID = 'test-user' + +MOCK_SESSION = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='333', + last_update_time=22333, + events=[ + Event( + id='444', + invocation_id='123', + author='user', + timestamp=12345, + content=types.Content(parts=[types.Part(text='test_content')]), + ), + # Empty event, should be ignored + Event( + id='555', + invocation_id='456', + author='user', + timestamp=12345, + ), + ], +) + + +RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$' +GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$' + + +class MockApiClient: + """Mocks the API Client.""" + + def __init__(self) -> None: + """Initializes MockClient.""" + self.async_request = mock.AsyncMock() + self.async_request.side_effect = self._mock_async_request + + async def _mock_async_request( + self, http_method: str, path: str, request_dict: dict[str, Any] + ): + """Mocks the API Client request method.""" + if http_method == 'POST': + if re.match(GENERATE_MEMORIES_REGEX, path): + return {} + elif re.match(RETRIEVE_MEMORIES_REGEX, path): + if ( + request_dict.get('scope', None) + and request_dict['scope'].get('app_name', None) == MOCK_APP_NAME + ): + return { + 'retrievedMemories': [ + { + 'memory': { + 'fact': 'test_content', + }, + 'updateTime': '2024-12-12T12:12:12.123456Z', + }, + ], + } + else: + return {'retrievedMemories': []} + else: + raise ValueError(f'Unsupported path: {path}') + else: + raise ValueError(f'Unsupported http method: {http_method}') + + +def mock_vertex_ai_memory_bank_service(): + """Creates a mock Vertex AI Memory Bank service for testing.""" + return VertexAiMemoryBankService( + project='test-project', + location='test-location', + agent_engine_id='123', + ) + + +@pytest.fixture +def mock_get_api_client(): + api_client = MockApiClient() + with mock.patch( + 'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client', + return_value=api_client, + ): + yield api_client + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_add_session_to_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_session_to_memory(MOCK_SESSION) + + mock_get_api_client.async_request.assert_awaited_once_with( + http_method='POST', + path='reasoningEngines/123/memories:generate', + request_dict={ + 'direct_contents_source': { + 'events': [ + { + 'content': { + 'parts': [ + {'text': 'test_content'}, + ], + }, + }, + ], + }, + 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + }, + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_search_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + mock_get_api_client.async_request.assert_awaited_once_with( + http_method='POST', + path='reasoningEngines/123/memories:retrieve', + request_dict={ + 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + 'similarity_search_params': {'search_query': 'query'}, + }, + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'test_content' From f33e0903b21b752168db3006dd034d7d43f7e84d Mon Sep 17 00:00:00 2001 From: Genquan Duan Date: Tue, 24 Jun 2025 13:07:57 -0700 Subject: [PATCH 11/28] feat: Add ADK examples for litellm with add_function_to_prompt Add examples for for https://github.com/google/adk-python/issues/1273 PiperOrigin-RevId: 775352677 --- .../__init__.py | 16 ++++ .../agent.py | 78 ++++++++++++++++++ .../main.py | 81 +++++++++++++++++++ src/google/adk/models/lite_llm.py | 8 ++ .../models/test_litellm_with_function.py | 3 - 5 files changed, 183 insertions(+), 3 deletions(-) create mode 100644 contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py create mode 100644 contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py create mode 100644 contributing/samples/hello_world_litellm_add_function_to_prompt/main.py diff --git a/contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py b/contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py new file mode 100644 index 000000000..7d5bb0b1c --- /dev/null +++ b/contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from . import agent diff --git a/contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py b/contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py new file mode 100644 index 000000000..0f10621ae --- /dev/null +++ b/contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random + +from google.adk import Agent +from google.adk.models.lite_llm import LiteLlm +from langchain_core.utils.function_calling import convert_to_openai_function + + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +def check_prime(number: int) -> str: + """Check if a given number is prime. + + Args: + number: The input number to check. + + Returns: + A str indicating the number is prime or not. + """ + if number <= 1: + return f"{number} is not prime." + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + return f"{number} is prime." + else: + return f"{number} is not prime." + + +root_agent = Agent( + model=LiteLlm( + model="vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas", + # If the model is not trained with functions and you would like to + # enable function calling, you can add functions to the models, and the + # functions will be added to the prompts during inferences. + functions=[ + convert_to_openai_function(roll_die), + convert_to_openai_function(check_prime), + ], + ), + name="data_processing_agent", + description="""You are a helpful assistant.""", + instruction=""" + You are a helpful assistant, and call tools optionally. + If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs. + """, + tools=[ + roll_die, + check_prime, + ], +) diff --git a/contributing/samples/hello_world_litellm_add_function_to_prompt/main.py b/contributing/samples/hello_world_litellm_add_function_to_prompt/main.py new file mode 100644 index 000000000..123ba1368 --- /dev/null +++ b/contributing/samples/hello_world_litellm_add_function_to_prompt/main.py @@ -0,0 +1,81 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import time + +import agent +from dotenv import load_dotenv +from google.adk import Runner +from google.adk.artifacts import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.sessions import InMemorySessionService +from google.adk.sessions import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder() + + +async def main(): + app_name = 'my_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_11 = await session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts: + part = event.content.parts[0] + if part.text: + print(f'** {event.author}: {part.text}') + if part.function_call: + print(f'** {event.author} calls tool: {part.function_call}') + if part.function_response: + print( + f'** {event.author} gets tool response: {part.function_response}' + ) + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi, introduce yourself.') + await run_prompt(session_11, 'Roll a die with 100 sides.') + await run_prompt(session_11, 'Check if it is prime.') + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index acc88ed19..624b7adfc 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -29,6 +29,7 @@ from typing import Union from google.genai import types +import litellm from litellm import acompletion from litellm import ChatCompletionAssistantMessage from litellm import ChatCompletionAssistantToolCall @@ -53,6 +54,9 @@ from .llm_request import LlmRequest from .llm_response import LlmResponse +# This will add functions to prompts if functions are provided. +litellm.add_function_to_prompt = True + logger = logging.getLogger("google_adk." + __name__) _NEW_LINE = "\n" @@ -662,6 +666,10 @@ async def generate_content_async( messages, tools, response_format = _get_completion_inputs(llm_request) + if "functions" in self._additional_args: + # LiteLLM does not support both tools and functions together. + tools = None + completion_args = { "model": self.model, "messages": messages, diff --git a/tests/integration/models/test_litellm_with_function.py b/tests/integration/models/test_litellm_with_function.py index 799c55e5c..e0d2bc991 100644 --- a/tests/integration/models/test_litellm_with_function.py +++ b/tests/integration/models/test_litellm_with_function.py @@ -17,11 +17,8 @@ from google.genai import types from google.genai.types import Content from google.genai.types import Part -import litellm import pytest -litellm.add_function_to_prompt = True - _TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """ From a1e14411159fd9f3e114e15b39b4949d0fd6ecb1 Mon Sep 17 00:00:00 2001 From: Liang Wu <18244712+wuliang229@users.noreply.github.com> Date: Tue, 24 Jun 2025 14:26:45 -0700 Subject: [PATCH 12/28] fix: update contributing links Merge https://github.com/google/adk-python/pull/1528 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1528 from google:doc ec8325e126aba7257de73ab26d8d3a30064859b4 PiperOrigin-RevId: 775383121 --- .gitignore | 1 + CONTRIBUTING.md | 2 +- README.md | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 6fb068d48..6f398cbf9 100644 --- a/.gitignore +++ b/.gitignore @@ -82,6 +82,7 @@ log/ .env.development.local .env.test.local .env.production.local +uv.lock # Google Cloud specific .gcloudignore diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c0f3d0069..0d7b2d67d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -200,7 +200,7 @@ For any changes that impact user-facing documentation (guides, API reference, tu ## Contributing Resources -[Contributing folder](https://github.com/google/adk-python/tree/main/contributing/samples) has resources that is helpful for contributors. +[Contributing folder](https://github.com/google/adk-python/tree/main/contributing) has resources that is helpful for contributors. ## Code reviews diff --git a/README.md b/README.md index 7bd5e7401..874658d07 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ adk eval \ ## 🤝 Contributing We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our -- [General contribution guideline and flow](https://google.github.io/adk-docs/contributing-guide/#questions). +- [General contribution guideline and flow](https://google.github.io/adk-docs/contributing-guide/). - Then if you want to contribute code, please read [Code Contributing Guidelines](./CONTRIBUTING.md) to get started. ## 📄 License From ed7a21e1890466fcdf04f7025775305dc71f603d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Jun 2025 14:57:11 -0700 Subject: [PATCH 13/28] chore: Update google-genai package and related deps to latest PiperOrigin-RevId: 775394737 --- pyproject.toml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 23dbcb537..6cf78ab40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start + "PyYAML>=6.0.2", # For APIHubToolset. "anyio>=4.9.0;python_version>='3.10'", # For MCP Session Manager "authlib>=1.5.1", # For RestAPI Tool "click>=8.1.8", # For CLI tools @@ -34,7 +35,7 @@ dependencies = [ "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool "google-cloud-speech>=2.30.0", # For Audio Transcription "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service - "google-genai>=1.17.0", # Google GenAI SDK + "google-genai>=1.21.1", # Google GenAI SDK "graphviz>=0.20.2", # Graphviz for graph rendering "mcp>=1.8.0;python_version>='3.10'", # For MCP Toolset "opentelemetry-api>=1.31.0", # OpenTelemetry @@ -43,7 +44,6 @@ dependencies = [ "pydantic>=2.0, <3.0.0", # For data validation/models "python-dateutil>=2.9.0.post0", # For Vertext AI Session Service "python-dotenv>=1.0.0", # To manage environment variables - "PyYAML>=6.0.2", # For APIHubToolset. "requests>=2.32.4", "sqlalchemy>=2.0", # SQL database ORM "starlette>=0.46.2", # For FastAPI CLI @@ -70,9 +70,9 @@ dev = [ # go/keep-sorted start "flit>=3.10.0", "isort>=6.0.0", + "mypy>=1.15.0", "pyink>=24.10.0", "pylint>=2.6.0", - "mypy>=1.15.0", # go/keep-sorted end ] @@ -98,7 +98,6 @@ test = [ "langgraph>=0.2.60", # For LangGraphAgent "litellm>=1.71.2", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests - "pytest-asyncio>=0.25.0", "pytest-mock>=3.14.0", "pytest-xdist>=3.6.1", From acbdca0d8400e292ba5525931175e0d6feab15f1 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 24 Jun 2025 15:03:23 -0700 Subject: [PATCH 14/28] fix: Make raw_auth_credential and exchanged_auth_credential optional given their default value is None PiperOrigin-RevId: 775397286 --- src/google/adk/auth/auth_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py index 53c571d42..0316e5258 100644 --- a/src/google/adk/auth/auth_tool.py +++ b/src/google/adk/auth/auth_tool.py @@ -31,12 +31,12 @@ class AuthConfig(BaseModelWithConfig): auth_scheme: AuthScheme """The auth scheme used to collect credentials""" - raw_auth_credential: AuthCredential = None + raw_auth_credential: Optional[AuthCredential] = None """The raw auth credential used to collect credentials. The raw auth credentials are used in some auth scheme that needs to exchange auth credentials. e.g. OAuth2 and OIDC. For other auth scheme, it could be None. """ - exchanged_auth_credential: AuthCredential = None + exchanged_auth_credential: Optional[AuthCredential] = None """The exchanged auth credential used to collect credentials. adk and client will work together to fill it. For those auth scheme that doesn't need to exchange auth credentials, e.g. API key, service account etc. It's filled by From 9e473e0abdded24e710fd857782356c15d04b515 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Jun 2025 15:10:49 -0700 Subject: [PATCH 15/28] fix: Include current turn context when include_contents='none' The intended behavior for include_contents='none' is to: - Exclude conversation history from previous turns - Still include current turn context (user input, tool calls/responses within current turn) https://google.github.io/adk-docs/agents/llm-agents/#managing-context-include_contents This resolves https://github.com/google/adk-python/issues/1124 PiperOrigin-RevId: 775400036 --- src/google/adk/agents/llm_agent.py | 8 +- src/google/adk/flows/llm_flows/contents.py | 53 +++- .../agents/test_llm_agent_include_contents.py | 242 ++++++++++++++++++ 3 files changed, 296 insertions(+), 7 deletions(-) create mode 100644 tests/unittests/agents/test_llm_agent_include_contents.py diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index fe145a60e..64c3628df 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -161,10 +161,12 @@ class LlmAgent(BaseAgent): # LLM-based agent transfer configs - End include_contents: Literal['default', 'none'] = 'default' - """Whether to include contents in the model request. + """Controls content inclusion in model requests. - When set to 'none', the model request will not include any contents, such as - user messages, tool results, etc. + Options: + default: Model receives relevant conversation history + none: Model receives no prior history, operates solely on current + instruction and input """ # Controlled input/output configurations - Start diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index ea418888f..039eaf8c5 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -43,12 +43,20 @@ async def run_async( if not isinstance(agent, LlmAgent): return - if agent.include_contents != 'none': + if agent.include_contents == 'default': + # Include full conversation history llm_request.contents = _get_contents( invocation_context.branch, invocation_context.session.events, agent.name, ) + else: + # Include current turn context only (no conversation history) + llm_request.contents = _get_current_turn_contents( + invocation_context.branch, + invocation_context.session.events, + agent.name, + ) # Maintain async generator behavior if False: # Ensures it behaves as a generator @@ -190,13 +198,15 @@ def _get_contents( ) -> list[types.Content]: """Get the contents for the LLM request. + Applies filtering, rearrangement, and content processing to events. + Args: current_branch: The current branch of the agent. - events: A list of events. + events: Events to process. agent_name: The name of the agent. Returns: - A list of contents. + A list of processed contents. """ filtered_events = [] # Parse the events, leaving the contents and the function calls and @@ -211,12 +221,13 @@ def _get_contents( # Skip events without content, or generated neither by user nor by model # or has empty text. # E.g. events purely for mutating session states. + continue if not _is_event_belongs_to_branch(current_branch, event): # Skip events not belong to current branch. continue if _is_auth_event(event): - # skip auth event + # Skip auth events. continue filtered_events.append( _convert_foreign_event(event) @@ -224,12 +235,15 @@ def _get_contents( else event ) + # Rearrange events for proper function call/response pairing result_events = _rearrange_events_for_latest_function_response( filtered_events ) result_events = _rearrange_events_for_async_function_responses_in_history( result_events ) + + # Convert events to contents contents = [] for event in result_events: content = copy.deepcopy(event.content) @@ -238,6 +252,37 @@ def _get_contents( return contents +def _get_current_turn_contents( + current_branch: Optional[str], events: list[Event], agent_name: str = '' +) -> list[types.Content]: + """Get contents for the current turn only (no conversation history). + + When include_contents='none', we want to include: + - The current user input + - Tool calls and responses from the current turn + But exclude conversation history from previous turns. + + In multi-agent scenarios, the "current turn" for an agent starts from an + actual user or from another agent. + + Args: + current_branch: The current branch of the agent. + events: A list of all session events. + agent_name: The name of the agent. + + Returns: + A list of contents for the current turn only, preserving context needed + for proper tool execution while excluding conversation history. + """ + # Find the latest event that starts the current turn and process from there + for i in range(len(events) - 1, -1, -1): + event = events[i] + if event.author == 'user' or _is_other_agent_reply(agent_name, event): + return _get_contents(current_branch, events[i:], agent_name) + + return [] + + def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool: """Whether the event is a reply from another agent.""" return bool( diff --git a/tests/unittests/agents/test_llm_agent_include_contents.py b/tests/unittests/agents/test_llm_agent_include_contents.py new file mode 100644 index 000000000..d4d76cf4e --- /dev/null +++ b/tests/unittests/agents/test_llm_agent_include_contents.py @@ -0,0 +1,242 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for LlmAgent include_contents field behavior.""" + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.genai import types +import pytest + +from .. import testing_utils + + +@pytest.mark.asyncio +async def test_include_contents_default_behavior(): + """Test that include_contents='default' preserves conversation history including tool interactions.""" + + def simple_tool(message: str) -> dict: + return {"result": f"Tool processed: {message}"} + + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + "First response", + types.Part.from_function_call( + name="simple_tool", args={"message": "second"} + ), + "Second response", + ] + ) + + agent = LlmAgent( + name="test_agent", + model=mock_model, + include_contents="default", + instruction="You are a helpful assistant", + tools=[simple_tool], + ) + + runner = testing_utils.InMemoryRunner(agent) + runner.run("First message") + runner.run("Second message") + + # First turn requests + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "First message") + ] + + assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ] + + # Second turn should include full conversation history + assert testing_utils.simplify_contents(mock_model.requests[2].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ("model", "First response"), + ("user", "Second message"), + ] + + # Second turn with tool should include full history + current tool interaction + assert testing_utils.simplify_contents(mock_model.requests[3].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ("model", "First response"), + ("user", "Second message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "second"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: second"} + ), + ), + ] + + +@pytest.mark.asyncio +async def test_include_contents_none_behavior(): + """Test that include_contents='none' excludes conversation history but includes current input.""" + + def simple_tool(message: str) -> dict: + return {"result": f"Tool processed: {message}"} + + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + "First response", + "Second response", + ] + ) + + agent = LlmAgent( + name="test_agent", + model=mock_model, + include_contents="none", + instruction="You are a helpful assistant", + tools=[simple_tool], + ) + + runner = testing_utils.InMemoryRunner(agent) + runner.run("First message") + runner.run("Second message") + + # First turn behavior + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "First message") + ] + + assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ] + + # Second turn should only have current input, no history + assert testing_utils.simplify_contents(mock_model.requests[2].contents) == [ + ("user", "Second message") + ] + + # System instruction and tools should be preserved + assert ( + "You are a helpful assistant" + in mock_model.requests[0].config.system_instruction + ) + assert len(mock_model.requests[0].config.tools) > 0 + + +@pytest.mark.asyncio +async def test_include_contents_none_sequential_agents(): + """Test include_contents='none' with sequential agents.""" + + agent1_model = testing_utils.MockModel.create( + responses=["Agent1 response: XYZ"] + ) + agent1 = LlmAgent( + name="agent1", + model=agent1_model, + instruction="You are Agent1", + ) + + agent2_model = testing_utils.MockModel.create( + responses=["Agent2 final response"] + ) + agent2 = LlmAgent( + name="agent2", + model=agent2_model, + include_contents="none", + instruction="You are Agent2", + ) + + sequential_agent = SequentialAgent( + name="sequential_test_agent", sub_agents=[agent1, agent2] + ) + + runner = testing_utils.InMemoryRunner(sequential_agent) + events = runner.run("Original user request") + + assert len(events) == 2 + assert events[0].author == "agent1" + assert events[1].author == "agent2" + + # Agent1 sees original user request + agent1_contents = testing_utils.simplify_contents( + agent1_model.requests[0].contents + ) + assert ("user", "Original user request") in agent1_contents + + # Agent2 with include_contents='none' should not see original request + agent2_contents = testing_utils.simplify_contents( + agent2_model.requests[0].contents + ) + + assert not any( + "Original user request" in str(content) for _, content in agent2_contents + ) + assert any( + "Agent1 response" in str(content) for _, content in agent2_contents + ) From 09f1269bf7fa46ab4b9324e7f92b4f70ffc923e5 Mon Sep 17 00:00:00 2001 From: Dave Bunten Date: Tue, 24 Jun 2025 15:17:39 -0700 Subject: [PATCH 16/28] ci(tests): leverage official uv action for install Merge https://github.com/google/adk-python/pull/1547 This PR replaces the `curl`-based installation of `uv` to instead use the [official GitHub Action from Astral](https://github.com/astral-sh/setup-uv). Closes #1545 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1547 from d33bs:use-uv-action 05ab7a138cbb5babee30ea81e83f26064e041529 PiperOrigin-RevId: 775402484 --- .github/workflows/python-unit-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 565ee1dca..52e61b8a3 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -36,8 +36,8 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Install uv - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v6 - name: Install dependencies run: | From 88a4402d142672171d0a8ceae74671f47fa14289 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 24 Jun 2025 16:14:52 -0700 Subject: [PATCH 17/28] chore: Do not send api request when session does not have events PiperOrigin-RevId: 775423356 --- .../adk/memory/vertex_ai_memory_bank_service.py | 15 +++++++++------ .../memory/test_vertex_ai_memory_bank_service.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index b5b70ab1c..083b48e8d 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -78,12 +78,15 @@ async def add_session_to_memory(self, session: Session): }, } - api_response = await api_client.async_request( - http_method='POST', - path=f'reasoningEngines/{self._agent_engine_id}/memories:generate', - request_dict=request_dict, - ) - logger.info(f'Generate memory response: {api_response}') + if events: + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:generate', + request_dict=request_dict, + ) + logger.info(f'Generate memory response: {api_response}') + else: + logger.info('No events to add to memory.') @override async def search_memory(self, *, app_name: str, user_id: str, query: str): diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index 27e2bbdd5..2fbf3291c 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -48,6 +48,13 @@ ], ) +MOCK_SESSION_WITH_EMPTY_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='444', + last_update_time=22333, +) + RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$' GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$' @@ -136,6 +143,15 @@ async def test_add_session_to_memory(mock_get_api_client): ) +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_add_empty_session_to_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS) + + mock_get_api_client.async_request.assert_not_called() + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_search_memory(mock_get_api_client): From ef3c745d655538ebd1ed735671be615f842341a8 Mon Sep 17 00:00:00 2001 From: Aditya Mulik Date: Tue, 24 Jun 2025 16:44:00 -0700 Subject: [PATCH 18/28] fix: typo fix in sample agent instruction Merge https://github.com/google/adk-python/pull/1623 fix: minor typo fix in the agent instruction COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1623 from adityamulik:minor_typo_fix 12ea09ae397b5c5e2a9ada48017cd1ca14add72e PiperOrigin-RevId: 775433411 --- contributing/samples/artifact_save_text/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contributing/samples/artifact_save_text/agent.py b/contributing/samples/artifact_save_text/agent.py index 53a7f300d..3ce43bcd1 100755 --- a/contributing/samples/artifact_save_text/agent.py +++ b/contributing/samples/artifact_save_text/agent.py @@ -31,7 +31,7 @@ async def log_query(tool_context: ToolContext, query: str): model='gemini-2.0-flash', name='log_agent', description='Log user query.', - instruction="""Always log the user query and reploy "kk, I've logged." + instruction="""Always log the user query and reply "kk, I've logged." """, tools=[log_query], generate_content_config=types.GenerateContentConfig( From 917a8a19f794ba33fef08898937a73f0ceb809a2 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 24 Jun 2025 16:45:45 -0700 Subject: [PATCH 19/28] chore: Adapt oauth calendar agent to use authenticated tool PiperOrigin-RevId: 775433950 --- .../samples/oauth_calendar_agent/agent.py | 116 ++++++------------ 1 file changed, 35 insertions(+), 81 deletions(-) diff --git a/contributing/samples/oauth_calendar_agent/agent.py b/contributing/samples/oauth_calendar_agent/agent.py index a1b1dea87..3f966b787 100644 --- a/contributing/samples/oauth_calendar_agent/agent.py +++ b/contributing/samples/oauth_calendar_agent/agent.py @@ -13,7 +13,6 @@ # limitations under the License. from datetime import datetime -import json import os from dotenv import load_dotenv @@ -27,8 +26,8 @@ from google.adk.auth import AuthCredentialTypes from google.adk.auth import OAuth2Auth from google.adk.tools import ToolContext +from google.adk.tools.authenticated_function_tool import AuthenticatedFunctionTool from google.adk.tools.google_api_tool import CalendarToolset -from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from googleapiclient.discovery import build @@ -56,6 +55,7 @@ def list_calendar_events( end_time: str, limit: int, tool_context: ToolContext, + credential: AuthCredential, ) -> list[dict]: """Search for calendar events. @@ -80,84 +80,11 @@ def list_calendar_events( Returns: list[dict]: A list of events that match the search criteria. """ - creds = None - - # Check if the tokes were already in the session state, which means the user - # has already gone through the OAuth flow and successfully authenticated and - # authorized the tool to access their calendar. - if "calendar_tool_tokens" in tool_context.state: - creds = Credentials.from_authorized_user_info( - tool_context.state["calendar_tool_tokens"], SCOPES - ) - if not creds or not creds.valid: - # If the access token is expired, refresh it with the refresh token. - if creds and creds.expired and creds.refresh_token: - creds.refresh(Request()) - else: - auth_scheme = OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://accounts.google.com/o/oauth2/auth", - tokenUrl="https://oauth2.googleapis.com/token", - scopes={ - "https://www.googleapis.com/auth/calendar": ( - "See, edit, share, and permanently delete all the" - " calendars you can access using Google Calendar" - ) - }, - ) - ) - ) - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id=oauth_client_id, client_secret=oauth_client_secret - ), - ) - # If the user has not gone through the OAuth flow before, or the refresh - # token also expired, we need to ask users to go through the OAuth flow. - # First we check whether the user has just gone through the OAuth flow and - # Oauth response is just passed back. - auth_response = tool_context.get_auth_response( - AuthConfig( - auth_scheme=auth_scheme, raw_auth_credential=auth_credential - ) - ) - if auth_response: - # ADK exchanged the access token already for us - access_token = auth_response.oauth2.access_token - refresh_token = auth_response.oauth2.refresh_token - - creds = Credentials( - token=access_token, - refresh_token=refresh_token, - token_uri=auth_scheme.flows.authorizationCode.tokenUrl, - client_id=oauth_client_id, - client_secret=oauth_client_secret, - scopes=list(auth_scheme.flows.authorizationCode.scopes.keys()), - ) - else: - # If there are no auth response which means the user has not gone - # through the OAuth flow yet, we need to ask users to go through the - # OAuth flow. - tool_context.request_credential( - AuthConfig( - auth_scheme=auth_scheme, - raw_auth_credential=auth_credential, - ) - ) - # The return value is optional and could be any dict object. It will be - # wrapped in a dict with key as 'result' and value as the return value - # if the object returned is not a dict. This response will be passed - # to LLM to generate a user friendly message. e.g. LLM will tell user: - # "I need your authorization to access your calendar. Please authorize - # me so I can check your meetings for today." - return "Need User Authorization to access their calendar." - # We store the access token and refresh token in the session state for the - # next runs. This is just an example. On production, a tool should store - # those credentials in some secure store or properly encrypt it before store - # it in the session state. - tool_context.state["calendar_tool_tokens"] = json.loads(creds.to_json()) + + creds = Credentials( + token=credential.oauth2.access_token, + refresh_token=credential.oauth2.refresh_token, + ) service = build("calendar", "v3", credentials=creds) events_result = ( @@ -208,6 +135,33 @@ def update_time(callback_context: CallbackContext): Currnet time: {_time} """, - tools=[list_calendar_events, calendar_toolset], + tools=[ + AuthenticatedFunctionTool( + func=list_calendar_events, + auth_config=AuthConfig( + auth_scheme=OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl=( + "https://accounts.google.com/o/oauth2/auth" + ), + tokenUrl="https://oauth2.googleapis.com/token", + scopes={ + "https://www.googleapis.com/auth/calendar": "", + }, + ) + ) + ), + raw_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=oauth_client_id, + client_secret=oauth_client_secret, + ), + ), + ), + ), + calendar_toolset, + ], before_agent_callback=update_time, ) From 6729edd08e427e3be78d8e4665443f5bbabfd635 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 25 Jun 2025 06:04:49 -0700 Subject: [PATCH 20/28] refactor: Rename the Google API based bigquery sample agent This change renames the sample agent based on the Google API based tools to reflect the larger purpose and avoid confusion with the built-in BigQuery tools. In addition, it also renames the root agent in the BigQuery sample agent to "bigquery_agent" PiperOrigin-RevId: 775655226 --- contributing/samples/bigquery/agent.py | 2 +- .../{bigquery_agent => google_api}/README.md | 33 ++++++++----------- .../__init__.py | 0 .../{bigquery_agent => google_api}/agent.py | 2 +- 4 files changed, 16 insertions(+), 21 deletions(-) rename contributing/samples/{bigquery_agent => google_api}/README.md (50%) rename contributing/samples/{bigquery_agent => google_api}/__init__.py (100%) rename contributing/samples/{bigquery_agent => google_api}/agent.py (98%) diff --git a/contributing/samples/bigquery/agent.py b/contributing/samples/bigquery/agent.py index 3cd1eb997..c1b265c00 100644 --- a/contributing/samples/bigquery/agent.py +++ b/contributing/samples/bigquery/agent.py @@ -60,7 +60,7 @@ # debug CLI root_agent = llm_agent.Agent( model="gemini-2.0-flash", - name="hello_agent", + name="bigquery_agent", description=( "Agent to answer questions about BigQuery data and models and execute" " SQL queries." diff --git a/contributing/samples/bigquery_agent/README.md b/contributing/samples/google_api/README.md similarity index 50% rename from contributing/samples/bigquery_agent/README.md rename to contributing/samples/google_api/README.md index c7dc7fd8b..c1e6e8d4c 100644 --- a/contributing/samples/bigquery_agent/README.md +++ b/contributing/samples/google_api/README.md @@ -1,45 +1,40 @@ -# BigQuery Sample +# Google API Tools Sample ## Introduction -This sample tests and demos the BigQuery support in ADK via two tools: +This sample tests and demos Google API tools available in the +`google.adk.tools.google_api_tool` module. We pick the following BigQuery API +tools for this sample agent: -* 1. bigquery_datasets_list: +1. `bigquery_datasets_list`: List user's datasets. - List user's datasets. +2. `bigquery_datasets_get`: Get a dataset's details. -* 2. bigquery_datasets_get: - Get a dataset's details. +3. `bigquery_datasets_insert`: Create a new dataset. -* 3. bigquery_datasets_insert: - Create a new dataset. +4. `bigquery_tables_list`: List all tables in a dataset. -* 4. bigquery_tables_list: - List all tables in a dataset. +5. `bigquery_tables_get`: Get a table's details. -* 5. bigquery_tables_get: - Get a table's details. - -* 6. bigquery_tables_insert: - Insert a new table into a dataset. +6. `bigquery_tables_insert`: Insert a new table into a dataset. ## How to use -* 1. Follow https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name. to get your client id and client secret. +1. Follow https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name. to get your client id and client secret. Be sure to choose "web" as your client type. -* 2. Configure your `.env` file to add two variables: +2. Configure your `.env` file to add two variables: * OAUTH_CLIENT_ID={your client id} * OAUTH_CLIENT_SECRET={your client secret} Note: don't create a separate `.env` file , instead put it to the same `.env` file that stores your Vertex AI or Dev ML credentials -* 3. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs". +3. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs". Note: localhost here is just a hostname that you use to access the dev ui, replace it with the actual hostname you use to access the dev ui. -* 4. For 1st run, allow popup for localhost in Chrome. +4. For 1st run, allow popup for localhost in Chrome. ## Sample prompt diff --git a/contributing/samples/bigquery_agent/__init__.py b/contributing/samples/google_api/__init__.py similarity index 100% rename from contributing/samples/bigquery_agent/__init__.py rename to contributing/samples/google_api/__init__.py diff --git a/contributing/samples/bigquery_agent/agent.py b/contributing/samples/google_api/agent.py similarity index 98% rename from contributing/samples/bigquery_agent/agent.py rename to contributing/samples/google_api/agent.py index 976cea170..1cdbab9c6 100644 --- a/contributing/samples/bigquery_agent/agent.py +++ b/contributing/samples/google_api/agent.py @@ -40,7 +40,7 @@ root_agent = Agent( model="gemini-2.0-flash", - name="bigquery_agent", + name="google_api_bigquery_agent", instruction=""" You are a helpful Google BigQuery agent that help to manage users' data on Google BigQuery. Use the provided tools to conduct various operations on users' data in Google BigQuery. From f54b9b6ad10220ddb2a69e2b951c0bc57a50a8b6 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 09:05:23 -0700 Subject: [PATCH 21/28] chore: Add unit tests for contents.py PiperOrigin-RevId: 775713101 --- .../flows/llm_flows/test_contents.py | 361 ++++++++++++++++++ 1 file changed, 361 insertions(+) create mode 100644 tests/unittests/flows/llm_flows/test_contents.py diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py new file mode 100644 index 000000000..a330852a1 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -0,0 +1,361 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows import contents +from google.adk.flows.llm_flows.contents import _convert_foreign_event +from google.adk.flows.llm_flows.contents import _get_contents +from google.adk.flows.llm_flows.contents import _merge_function_response_events +from google.adk.flows.llm_flows.contents import _rearrange_events_for_async_function_responses_in_history +from google.adk.flows.llm_flows.contents import _rearrange_events_for_latest_function_response +from google.adk.models import LlmRequest +from google.genai import types +import pytest + +from ... import testing_utils + + +@pytest.mark.asyncio +async def test_content_processor_no_contents(): + """Test ContentLlmRequestProcessor when include_contents is 'none'.""" + agent = Agent(model="gemini-1.5-flash", name="agent", include_contents="none") + llm_request = LlmRequest(model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + # Collect events from async generator + events = [] + async for event in contents.request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + # Should not yield any events + assert len(events) == 0 + # Contents should not be set when include_contents is 'none' + assert llm_request.contents == [] + + +@pytest.mark.asyncio +async def test_content_processor_with_contents(): + """Test ContentLlmRequestProcessor when include_contents is not 'none'.""" + agent = Agent(model="gemini-1.5-flash", name="agent") + llm_request = LlmRequest(model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + # Add some test events to the session + test_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + invocation_context.session.events = [test_event] + + # Collect events from async generator + events = [] + async for event in contents.request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + # Should not yield any events (processor doesn't emit events, just modifies request) + assert len(events) == 0 + # Contents should be set + assert llm_request.contents is not None + assert len(llm_request.contents) == 1 + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts[0].text == "Hello" + + +@pytest.mark.asyncio +async def test_content_processor_non_llm_agent(): + """Test ContentLlmRequestProcessor with non-LLM agent.""" + from google.adk.agents.base_agent import BaseAgent + + # Create a base agent (not LLM agent) + agent = BaseAgent(name="base_agent") + llm_request = LlmRequest(model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + # Collect events from async generator + events = [] + async for event in contents.request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + # Should not yield any events and not modify request + assert len(events) == 0 + assert llm_request.contents == [] + + +def test_get_contents_empty_events(): + """Test _get_contents with empty events list.""" + contents_result = _get_contents(None, [], "test_agent") + assert contents_result == [] + + +def test_get_contents_with_events(): + """Test _get_contents with valid events.""" + test_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + + contents_result = _get_contents(None, [test_event], "test_agent") + assert len(contents_result) == 1 + assert contents_result[0].role == "user" + assert contents_result[0].parts[0].text == "Hello" + + +def test_get_contents_filters_empty_events(): + """Test _get_contents filters out events with empty content.""" + # Event with empty text + empty_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content(role="user", parts=[types.Part.from_text(text="")]), + ) + + # Event without content + no_content_event = Event( + invocation_id="test_inv", + author="user", + ) + + # Valid event + valid_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + + contents_result = _get_contents( + None, [empty_event, no_content_event, valid_event], "test_agent" + ) + assert len(contents_result) == 1 + assert contents_result[0].role == "user" + assert contents_result[0].parts[0].text == "Hello" + + +def test_convert_foreign_event(): + """Test _convert_foreign_event function.""" + agent_event = Event( + invocation_id="test_inv", + author="agent1", + content=types.Content( + role="model", parts=[types.Part.from_text(text="Agent response")] + ), + ) + + converted_event = _convert_foreign_event(agent_event) + + assert converted_event.author == "user" + assert converted_event.content.role == "user" + assert len(converted_event.content.parts) == 2 + assert converted_event.content.parts[0].text == "For context:" + assert ( + "[agent1] said: Agent response" in converted_event.content.parts[1].text + ) + + +def test_convert_event_with_function_call(): + """Test _convert_foreign_event with function call.""" + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + agent_event = Event( + invocation_id="test_inv", + author="agent1", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + converted_event = _convert_foreign_event(agent_event) + + assert converted_event.author == "user" + assert converted_event.content.role == "user" + assert len(converted_event.content.parts) == 2 + assert converted_event.content.parts[0].text == "For context:" + assert ( + "[agent1] called tool `test_function`" + in converted_event.content.parts[1].text + ) + assert "{'param': 'value'}" in converted_event.content.parts[1].text + + +def test_convert_event_with_function_response(): + """Test _convert_foreign_event with function response.""" + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + agent_event = Event( + invocation_id="test_inv", + author="agent1", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + converted_event = _convert_foreign_event(agent_event) + + assert converted_event.author == "user" + assert converted_event.content.role == "user" + assert len(converted_event.content.parts) == 2 + assert converted_event.content.parts[0].text == "For context:" + assert ( + "[agent1] `test_function` tool returned result:" + in converted_event.content.parts[1].text + ) + assert "{'result': 'success'}" in converted_event.content.parts[1].text + + +def test_merge_function_response_events(): + """Test _merge_function_response_events function.""" + # Create initial function response event + function_response1 = types.FunctionResponse( + id="func_123", name="test_function", response={"status": "pending"} + ) + + initial_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response1)] + ), + ) + + # Create final function response event + function_response2 = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + final_event = Event( + invocation_id="test_inv2", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response2)] + ), + ) + + merged_event = _merge_function_response_events([initial_event, final_event]) + + assert ( + merged_event.invocation_id == "test_inv" + ) # Should keep initial event ID + assert len(merged_event.content.parts) == 1 + # The first part should be replaced with the final response + assert merged_event.content.parts[0].function_response.response == { + "result": "success" + } + + +def test_rearrange_events_for_async_function_responses(): + """Test _rearrange_events_for_async_function_responses_in_history function.""" + # Create function call event + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + call_event = Event( + invocation_id="test_inv1", + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + # Create function response event + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + response_event = Event( + invocation_id="test_inv2", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + # Test rearrangement + events = [call_event, response_event] + rearranged = _rearrange_events_for_async_function_responses_in_history(events) + + # Should have both events in correct order + assert len(rearranged) == 2 + assert rearranged[0] == call_event + assert rearranged[1] == response_event + + +def test_rearrange_events_for_latest_function_response(): + """Test _rearrange_events_for_latest_function_response function.""" + # Create function call event + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + call_event = Event( + invocation_id="test_inv1", + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + # Create intermediate event + intermediate_event = Event( + invocation_id="test_inv2", + author="agent", + content=types.Content( + role="model", parts=[types.Part.from_text(text="Processing...")] + ), + ) + + # Create function response event + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + response_event = Event( + invocation_id="test_inv3", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + # Test with matching function call and response + events = [call_event, intermediate_event, response_event] + rearranged = _rearrange_events_for_latest_function_response(events) + + # Should remove intermediate events and merge responses + assert len(rearranged) == 2 + assert rearranged[0] == call_event From a623467299e768be93f516a9afb533c32172fd74 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 09:18:34 -0700 Subject: [PATCH 22/28] chore: Enhance a2a context id parsing and construction logic PiperOrigin-RevId: 775718282 --- src/google/adk/a2a/converters/utils.py | 26 ++- .../a2a/converters/test_request_converter.py | 8 +- tests/unittests/a2a/converters/test_utils.py | 213 ++++++++++++++++++ 3 files changed, 239 insertions(+), 8 deletions(-) create mode 100644 tests/unittests/a2a/converters/test_utils.py diff --git a/src/google/adk/a2a/converters/utils.py b/src/google/adk/a2a/converters/utils.py index ecbff1e10..acb2581d4 100644 --- a/src/google/adk/a2a/converters/utils.py +++ b/src/google/adk/a2a/converters/utils.py @@ -16,6 +16,7 @@ ADK_METADATA_KEY_PREFIX = "adk_" ADK_CONTEXT_ID_PREFIX = "ADK" +ADK_CONTEXT_ID_SEPARATOR = "/" def _get_adk_metadata_key(key: str) -> str: @@ -45,8 +46,17 @@ def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str: Returns: The A2A context id. + + Raises: + ValueError: If any of the input parameters are empty or None. """ - return [ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id].join("$") + if not all([app_name, user_id, session_id]): + raise ValueError( + "All parameters (app_name, user_id, session_id) must be non-empty" + ) + return ADK_CONTEXT_ID_SEPARATOR.join( + [ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id] + ) def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: @@ -64,8 +74,16 @@ def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: if not context_id: return None, None, None - prefix, app_name, user_id, session_id = context_id.split("$") - if prefix == "ADK" and app_name and user_id and session_id: - return app_name, user_id, session_id + try: + parts = context_id.split(ADK_CONTEXT_ID_SEPARATOR) + if len(parts) != 4: + return None, None, None + + prefix, app_name, user_id, session_id = parts + if prefix == ADK_CONTEXT_ID_PREFIX and app_name and user_id and session_id: + return app_name, user_id, session_id + except ValueError: + # Handle any split errors gracefully + pass return None, None, None diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py index 02c6400fc..08266751e 100644 --- a/tests/unittests/a2a/converters/test_request_converter.py +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -244,7 +244,7 @@ def test_convert_a2a_request_basic( request = Mock(spec=RequestContext) request.message = mock_message - request.context_id = "ADK$app$user$session" + request.context_id = "ADK/app/user/session" mock_from_context_id.return_value = ( "app_name", @@ -271,7 +271,7 @@ def test_convert_a2a_request_basic( assert isinstance(result["run_config"], RunConfig) # Verify calls - mock_from_context_id.assert_called_once_with("ADK$app$user$session") + mock_from_context_id.assert_called_once_with("ADK/app/user/session") mock_get_user_id.assert_called_once_with(request, "user_from_context") assert mock_convert_part.call_count == 2 mock_convert_part.assert_any_call(mock_part1) @@ -302,7 +302,7 @@ def test_convert_a2a_request_empty_parts( request = Mock(spec=RequestContext) request.message = mock_message - request.context_id = "ADK$app$user$session" + request.context_id = "ADK/app/user/session" mock_from_context_id.return_value = ( "app_name", @@ -431,7 +431,7 @@ def test_end_to_end_conversion_with_auth_user(self, mock_convert_part): request = Mock(spec=RequestContext) request.call_context = mock_call_context request.message = mock_message - request.context_id = "ADK$myapp$context_user$mysession" + request.context_id = "ADK/myapp/context_user/mysession" request.current_task = None request.task_id = "task123" diff --git a/tests/unittests/a2a/converters/test_utils.py b/tests/unittests/a2a/converters/test_utils.py new file mode 100644 index 000000000..f919cbd00 --- /dev/null +++ b/tests/unittests/a2a/converters/test_utils.py @@ -0,0 +1,213 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) + +from google.adk.a2a.converters.utils import _from_a2a_context_id +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.converters.utils import _to_a2a_context_id +from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX +import pytest + + +class TestUtilsFunctions: + """Test suite for utils module functions.""" + + def test_get_adk_metadata_key_success(self): + """Test successful metadata key generation.""" + key = "test_key" + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_get_adk_metadata_key_empty_string(self): + """Test metadata key generation with empty string.""" + with pytest.raises( + ValueError, match="Metadata key cannot be empty or None" + ): + _get_adk_metadata_key("") + + def test_get_adk_metadata_key_none(self): + """Test metadata key generation with None.""" + with pytest.raises( + ValueError, match="Metadata key cannot be empty or None" + ): + _get_adk_metadata_key(None) + + def test_get_adk_metadata_key_whitespace(self): + """Test metadata key generation with whitespace string.""" + key = " " + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_to_a2a_context_id_success(self): + """Test successful context ID generation.""" + app_name = "test-app" + user_id = "test-user" + session_id = "test-session" + + result = _to_a2a_context_id(app_name, user_id, session_id) + + expected = f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user/test-session" + assert result == expected + + def test_to_a2a_context_id_empty_app_name(self): + """Test context ID generation with empty app name.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("", "user", "session") + + def test_to_a2a_context_id_empty_user_id(self): + """Test context ID generation with empty user ID.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("app", "", "session") + + def test_to_a2a_context_id_empty_session_id(self): + """Test context ID generation with empty session ID.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("app", "user", "") + + def test_to_a2a_context_id_none_values(self): + """Test context ID generation with None values.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id(None, "user", "session") + + def test_to_a2a_context_id_special_characters(self): + """Test context ID generation with special characters.""" + app_name = "test-app@2024" + user_id = "user_123" + session_id = "session-456" + + result = _to_a2a_context_id(app_name, user_id, session_id) + + expected = f"{ADK_CONTEXT_ID_PREFIX}/test-app@2024/user_123/session-456" + assert result == expected + + def test_from_a2a_context_id_success(self): + """Test successful context ID parsing.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user/test-session" + + app_name, user_id, session_id = _from_a2a_context_id(context_id) + + assert app_name == "test-app" + assert user_id == "test-user" + assert session_id == "test-session" + + def test_from_a2a_context_id_none_input(self): + """Test context ID parsing with None input.""" + result = _from_a2a_context_id(None) + assert result == (None, None, None) + + def test_from_a2a_context_id_empty_string(self): + """Test context ID parsing with empty string.""" + result = _from_a2a_context_id("") + assert result == (None, None, None) + + def test_from_a2a_context_id_invalid_prefix(self): + """Test context ID parsing with invalid prefix.""" + context_id = "INVALID/test-app/test-user/test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_too_few_parts(self): + """Test context ID parsing with too few parts.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_too_many_parts(self): + """Test context ID parsing with too many parts.""" + context_id = ( + f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user/test-session/extra" + ) + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_empty_components(self): + """Test context ID parsing with empty components.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}//test-user/test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_no_dollar_separator(self): + """Test context ID parsing without dollar separators.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}-test-app-test-user-test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_roundtrip_context_id(self): + """Test roundtrip conversion: to -> from.""" + app_name = "test-app" + user_id = "test-user" + session_id = "test-session" + + # Convert to context ID + context_id = _to_a2a_context_id(app_name, user_id, session_id) + + # Convert back + parsed_app, parsed_user, parsed_session = _from_a2a_context_id(context_id) + + assert parsed_app == app_name + assert parsed_user == user_id + assert parsed_session == session_id + + def test_from_a2a_context_id_special_characters(self): + """Test context ID parsing with special characters.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}/test-app@2024/user_123/session-456" + + app_name, user_id, session_id = _from_a2a_context_id(context_id) + + assert app_name == "test-app@2024" + assert user_id == "user_123" + assert session_id == "session-456" From 5306ddad4dde29748fe9c75e01511fc59e28a8d1 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Wed, 25 Jun 2025 10:18:30 -0700 Subject: [PATCH 23/28] chore: Release 1.5.0 PiperOrigin-RevId: 775742049 --- CHANGELOG.md | 35 +++++++++++++++++++++++++++++++++++ src/google/adk/version.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce36dcdcf..b6bba2692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,40 @@ # Changelog +## [1.5.0](https://github.com/google/adk-python/compare/v1.4.2...v1.5.0) (2025-06-25) + + +### Features + +* Add a new option `eval_storage_uri` in adk web & adk eval to specify GCS bucket to store eval data ([fa025d7](https://github.com/google/adk-python/commit/fa025d755978e1506fa0da1fecc49775bebc1045)) +* Add ADK examples for litellm with add_function_to_prompt ([f33e090](https://github.com/google/adk-python/commit/f33e0903b21b752168db3006dd034d7d43f7e84d)) +* Add implementation of VertexAiMemoryBankService and support in FastAPI endpoint ([abc89d2](https://github.com/google/adk-python/commit/abc89d2c811ba00805f81b27a3a07d56bdf55a0b)) +* Add rouge_score library to ADK eval dependencies, and implement RougeEvaluator that is computes ROUGE-1 for "response_match_score" metric ([9597a44](https://github.com/google/adk-python/commit/9597a446fdec63ad9e4c2692d6966b14f80ff8e2)) +* Add usage span attributes to telemetry ([#356](https://github.com/google/adk-python/issues/356)) ([ea69c90](https://github.com/google/adk-python/commit/ea69c9093a16489afdf72657136c96f61c69cafd)) +* Add Vertex Express mode compatibility for VertexAiSessionService ([00cc8cd](https://github.com/google/adk-python/commit/00cc8cd6433fc45ecfc2dbaa04dbbc1a81213b4d)) + + +### Bug Fixes + +* Include current turn context when include_contents='none' ([9e473e0](https://github.com/google/adk-python/commit/9e473e0abdded24e710fd857782356c15d04b515)) +* Make LiteLLM streaming truly asynchronous ([bd67e84](https://github.com/google/adk-python/commit/bd67e8480f6e8b4b0f8c22b94f15a8cda1336339)) +* Make raw_auth_credential and exchanged_auth_credential optional given their default value is None ([acbdca0](https://github.com/google/adk-python/commit/acbdca0d8400e292ba5525931175e0d6feab15f1)) +* Minor typo fix in the agent instruction ([ef3c745](https://github.com/google/adk-python/commit/ef3c745d655538ebd1ed735671be615f842341a8)) +* Typo fix in sample agent instruction ([ef3c745](https://github.com/google/adk-python/commit/ef3c745d655538ebd1ed735671be615f842341a8)) +* Update contributing links ([a1e1441](https://github.com/google/adk-python/commit/a1e14411159fd9f3e114e15b39b4949d0fd6ecb1)) +* Use starred tuple unpacking on GCS artifact blob names ([3b1d9a8](https://github.com/google/adk-python/commit/3b1d9a8a3e631ca2d86d30f09640497f1728986c)) + + +### Chore + +* Do not send api request when session does not have events ([88a4402](https://github.com/google/adk-python/commit/88a4402d142672171d0a8ceae74671f47fa14289)) +* Leverage official uv action for install([09f1269](https://github.com/google/adk-python/commit/09f1269bf7fa46ab4b9324e7f92b4f70ffc923e5)) +* Update google-genai package and related deps to latest([ed7a21e](https://github.com/google/adk-python/commit/ed7a21e1890466fcdf04f7025775305dc71f603d)) +* Add credential service backed by session state([29cd183](https://github.com/google/adk-python/commit/29cd183aa1b47dc4f5d8afe22f410f8546634abc)) +* Clarify the behavior of Event.invocation_id([f033e40](https://github.com/google/adk-python/commit/f033e405c10ff8d86550d1419a9d63c0099182f9)) +* Send user message to the agent that returned a corresponding function call if user message is a function response([7c670f6](https://github.com/google/adk-python/commit/7c670f638bc17374ceb08740bdd057e55c9c2e12)) +* Add request converter to convert a2a request to ADK request([fb13963](https://github.com/google/adk-python/commit/fb13963deda0ff0650ac27771711ea0411474bf5)) +* Support allow_origins in cloud_run deployment ([2fd8feb](https://github.com/google/adk-python/commit/2fd8feb65d6ae59732fb3ec0652d5650f47132cc)) + ## [1.4.2](https://github.com/google/adk-python/compare/v1.4.1...v1.4.2) (2025-06-20) diff --git a/src/google/adk/version.py b/src/google/adk/version.py index 9accc1025..1c061dd03 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.4.2" +__version__ = "1.5.0" From 738d1a8b84f388dfc761f5818da9467aedc160cf Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Wed, 25 Jun 2025 10:19:06 -0700 Subject: [PATCH 24/28] chore: create an agent to check issue format and content for bugs and feature requests This agent will pose a comment to ask for more information according to the template if necessary. PiperOrigin-RevId: 775742256 --- .github/ISSUE_TEMPLATE/bug_report.md | 3 + .../adk_issue_formatting_agent/__init__.py | 15 ++ .../adk_issue_formatting_agent/agent.py | 241 ++++++++++++++++++ .../adk_issue_formatting_agent/settings.py | 33 +++ .../adk_issue_formatting_agent/utils.py | 53 ++++ 5 files changed, 345 insertions(+) create mode 100644 contributing/samples/adk_issue_formatting_agent/__init__.py create mode 100644 contributing/samples/adk_issue_formatting_agent/agent.py create mode 100644 contributing/samples/adk_issue_formatting_agent/settings.py create mode 100644 contributing/samples/adk_issue_formatting_agent/utils.py diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 7c2ffdd95..f04f3f039 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -31,5 +31,8 @@ If applicable, add screenshots to help explain your problem. - Python version(python -V): - ADK version(pip show google-adk): + **Model Information:** + For example, which model is being used. + **Additional context** Add any other context about the problem here. diff --git a/contributing/samples/adk_issue_formatting_agent/__init__.py b/contributing/samples/adk_issue_formatting_agent/__init__.py new file mode 100644 index 000000000..c48963cdc --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/adk_issue_formatting_agent/agent.py b/contributing/samples/adk_issue_formatting_agent/agent.py new file mode 100644 index 000000000..78add9b83 --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/agent.py @@ -0,0 +1,241 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any + +from adk_issue_formatting_agent.settings import GITHUB_BASE_URL +from adk_issue_formatting_agent.settings import IS_INTERACTIVE +from adk_issue_formatting_agent.settings import OWNER +from adk_issue_formatting_agent.settings import REPO +from adk_issue_formatting_agent.utils import error_response +from adk_issue_formatting_agent.utils import get_request +from adk_issue_formatting_agent.utils import post_request +from adk_issue_formatting_agent.utils import read_file +from google.adk import Agent +import requests + +BUG_REPORT_TEMPLATE = read_file( + Path(__file__).parent / "../../../../.github/ISSUE_TEMPLATE/bug_report.md" +) +FREATURE_REQUEST_TEMPLATE = read_file( + Path(__file__).parent + / "../../../../.github/ISSUE_TEMPLATE/feature_request.md" +) + +APPROVAL_INSTRUCTION = ( + "**Do not** wait or ask for user approval or confirmation for adding the" + " comment." +) +if IS_INTERACTIVE: + APPROVAL_INSTRUCTION = ( + "Ask for user approval or confirmation for adding the comment." + ) + + +def list_open_issues(issue_count: int) -> dict[str, Any]: + """List most recent `issue_count` numer of open issues in the repo. + + Args: + issue_count: number of issues to return + + Returns: + The status of this request, with a list of issues when successful. + """ + url = f"{GITHUB_BASE_URL}/search/issues" + query = f"repo:{OWNER}/{REPO} is:open is:issue" + params = { + "q": query, + "sort": "created", + "order": "desc", + "per_page": issue_count, + "page": 1, + } + + try: + response = get_request(url, params) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + issues = response.get("items", None) + return {"status": "success", "issues": issues} + + +def get_issue(issue_number: int) -> dict[str, Any]: + """Get the details of the specified issue number. + + Args: + issue_number: issue number of the Github issue. + + Returns: + The status of this request, with the issue details when successful. + """ + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}" + try: + response = get_request(url) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + return {"status": "success", "issue": response} + + +def add_comment_to_issue(issue_number: int, comment: str) -> dict[str, any]: + """Add the specified comment to the given issue number. + + Args: + issue_number: issue number of the Github issue + comment: comment to add + + Returns: + The the status of this request, with the applied comment when successful. + """ + print(f"Attempting to add comment '{comment}' to issue #{issue_number}") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}/comments" + payload = {"body": comment} + + try: + response = post_request(url, payload) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + return { + "status": "success", + "added_comment": response, + } + + +def list_comments_on_issue(issue_number: int) -> dict[str, any]: + """List all comments on the given issue number. + + Args: + issue_number: issue number of the Github issue + + Returns: + The the status of this request, with the list of comments when successful. + """ + print(f"Attempting to list comments on issue #{issue_number}") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}/comments" + + try: + response = get_request(url) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + return {"status": "success", "comments": response} + + +root_agent = Agent( + model="gemini-2.5-pro", + name="adk_issue_formatting_assistant", + description="Check ADK issue format and content.", + instruction=f""" + # 1. IDENTITY + You are an AI assistant designed to help maintain the quality and consistency of issues in our GitHub repository. + Your primary role is to act as a "GitHub Issue Format Validator." You will analyze new and existing **open** issues + to ensure they contain all the necessary information as required by our templates. You are helpful, polite, + and precise in your feedback. + + # 2. CONTEXT & RESOURCES + * **Repository:** You are operating on the GitHub repository `{OWNER}/{REPO}`. + * **Bug Report Template:** (`{BUG_REPORT_TEMPLATE}`) + * **Feature Request Template:** (`{FREATURE_REQUEST_TEMPLATE}`) + + # 3. CORE MISSION + Your goal is to check if a GitHub issue, identified as either a "bug" or a "feature request," + contains all the information required by the corresponding template. If it does not, your job is + to post a single, helpful comment asking the original author to provide the missing information. + {APPROVAL_INSTRUCTION} + + **IMPORTANT NOTE:** + * You add one comment at most each time you are invoked. + * Don't proceed to other issues which are not the target issues. + * Don't take any action on closed issues. + + # 4. BEHAVIORAL RULES & LOGIC + + ## Step 1: Identify Issue Type & Applicability + + Your first task is to determine if the issue is a valid target for validation. + + 1. **Assess Content Intent:** You must perform a quick semantic check of the issue's title, body, and comments. + If you determine the issue's content is fundamentally *not* a bug report or a feature request + (for example, it is a general question, a request for help, or a discussion prompt), then you must ignore it. + 2. **Exit Condition:** If the issue does not clearly fall into the categories of "bug" or "feature request" + based on both its labels and its content, **take no action**. + + ## Step 2: Analyze the Issue Content + + If you have determined the issue is a valid bug or feature request, your analysis depends on whether it has comments. + + **Scenario A: Issue has NO comments** + 1. Read the main body of the issue. + 2. Compare the content of the issue body against the required headings/sections in the relevant template (Bug or Feature). + 3. Check for the presence of content under each heading. A heading with no content below it is considered incomplete. + 4. If one or more sections are missing or empty, proceed to Step 3. + 5. If all sections are filled out, your task is complete. Do nothing. + + **Scenario B: Issue HAS one or more comments** + 1. First, analyze the main issue body to see which sections of the template are filled out. + 2. Next, read through **all** the comments in chronological order. + 3. As you read the comments, check if the information provided in them satisfies any of the template sections that were missing from the original issue body. + 4. After analyzing the body and all comments, determine if any required sections from the template *still* remain unaddressed. + 5. If one or more sections are still missing information, proceed to Step 3. + 6. If the issue body and comments *collectively* provide all the required information, your task is complete. Do nothing. + + ## Step 3: Formulate and Post a Comment (If Necessary) + + If you determined in Step 2 that information is missing, you must post a **single comment** on the issue. + + Please include a bolded note in your comment that this comment was added by an ADK agent. + + **Comment Guidelines:** + * **Be Polite and Helpful:** Start with a friendly tone. + * **Be Specific:** Clearly list only the sections from the template that are still missing. Do not list sections that have already been filled out. + * **Address the Author:** Mention the issue author by their username (e.g., `@username`). + * **Provide Context:** Explain *why* the information is needed (e.g., "to help us reproduce the bug" or "to better understand your request"). + * **Do not be repetitive:** If you have already commented on an issue asking for information, do not comment again unless new information has been added and it's still incomplete. + + **Example Comment for a Bug Report:** + > **Response from ADK Agent** + > + > Hello @[issue-author-username], thank you for submitting this issue! + > + > To help us investigate and resolve this bug effectively, could you please provide the missing details for the following sections of our bug report template: + > + > * **To Reproduce:** (Please provide the specific steps required to reproduce the behavior) + > * **Desktop (please complete the following information):** (Please provide OS, Python version, and ADK version) + > + > This information will give us the context we need to move forward. Thanks! + + **Example Comment for a Feature Request:** + > **Response from ADK Agent** + > + > Hi @[issue-author-username], thanks for this great suggestion! + > + > To help our team better understand and evaluate your feature request, could you please provide a bit more information on the following section: + > + > * **Is your feature request related to a problem? Please describe.** + > + > We look forward to hearing more about your idea! + + # 5. FINAL INSTRUCTION + + Execute this process for the given GitHub issue. Your final output should either be **[NO ACTION]** + if the issue is complete or invalid, or **[POST COMMENT]** followed by the exact text of the comment you will post. + + Please include your justification for your decision in your output. + """, + tools={ + list_open_issues, + get_issue, + add_comment_to_issue, + list_comments_on_issue, + }, +) diff --git a/contributing/samples/adk_issue_formatting_agent/settings.py b/contributing/samples/adk_issue_formatting_agent/settings.py new file mode 100644 index 000000000..d29bda9b7 --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/settings.py @@ -0,0 +1,33 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dotenv import load_dotenv + +load_dotenv(override=True) + +GITHUB_BASE_URL = "https://api.github.com" + +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +if not GITHUB_TOKEN: + raise ValueError("GITHUB_TOKEN environment variable not set") + +OWNER = os.getenv("OWNER", "google") +REPO = os.getenv("REPO", "adk-python") +EVENT_NAME = os.getenv("EVENT_NAME") +ISSUE_NUMBER = os.getenv("ISSUE_NUMBER") +ISSUE_COUNT_TO_PROCESS = os.getenv("ISSUE_COUNT_TO_PROCESS") + +IS_INTERACTIVE = os.environ.get("INTERACTIVE", "1").lower() in ["true", "1"] diff --git a/contributing/samples/adk_issue_formatting_agent/utils.py b/contributing/samples/adk_issue_formatting_agent/utils.py new file mode 100644 index 000000000..2ee735d3d --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/utils.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from adk_issue_formatting_agent.settings import GITHUB_TOKEN +import requests + +headers = { + "Authorization": f"token {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +} + + +def get_request( + url: str, params: dict[str, Any] | None = None +) -> dict[str, Any]: + if params is None: + params = {} + response = requests.get(url, headers=headers, params=params, timeout=60) + response.raise_for_status() + return response.json() + + +def post_request(url: str, payload: Any) -> dict[str, Any]: + response = requests.post(url, headers=headers, json=payload, timeout=60) + response.raise_for_status() + return response.json() + + +def error_response(error_message: str) -> dict[str, Any]: + return {"status": "error", "message": error_message} + + +def read_file(file_path: str) -> str: + """Read the content of the given file.""" + try: + with open(file_path, "r") as f: + return f.read() + except FileNotFoundError: + print(f"Error: File not found: {file_path}.") + return "" From 832a6333512eb96fccc3d394939fd947a38fe409 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 13:58:14 -0700 Subject: [PATCH 25/28] chore: Enhance a2a part converters a. fix binary data conversion b. support thoughts, code execution result, executable codes conversion PiperOrigin-RevId: 775827259 --- .../adk/a2a/converters/part_converter.py | 106 +++++- .../a2a/converters/test_part_converter.py | 342 ++++++++++++++++-- 2 files changed, 403 insertions(+), 45 deletions(-) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index c47ac7276..8dab1097d 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -18,6 +18,7 @@ from __future__ import annotations +import base64 import json import logging import sys @@ -43,8 +44,11 @@ logger = logging.getLogger('google_adk.' + __name__) A2A_DATA_PART_METADATA_TYPE_KEY = 'type' +A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY = 'is_long_running' A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = 'function_call' A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response' +A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = 'code_execution_result' +A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = 'executable_code' @working_in_progress @@ -67,7 +71,8 @@ def convert_a2a_part_to_genai_part( elif isinstance(part.file, a2a_types.FileWithBytes): return genai_types.Part( inline_data=genai_types.Blob( - data=part.file.bytes.encode('utf-8'), mime_type=part.file.mimeType + data=base64.b64decode(part.file.bytes), + mime_type=part.file.mimeType, ) ) else: @@ -84,7 +89,11 @@ def convert_a2a_part_to_genai_part( # response. # TODO once A2A defined how to suervice such information, migrate below # logic accordinlgy - if part.metadata and A2A_DATA_PART_METADATA_TYPE_KEY in part.metadata: + if ( + part.metadata + and _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + in part.metadata + ): if ( part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL @@ -103,6 +112,24 @@ def convert_a2a_part_to_genai_part( part.data, by_alias=True ) ) + if ( + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + ): + return genai_types.Part( + code_execution_result=genai_types.CodeExecutionResult.model_validate( + part.data, by_alias=True + ) + ) + if ( + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + ): + return genai_types.Part( + executable_code=genai_types.ExecutableCode.model_validate( + part.data, by_alias=True + ) + ) return genai_types.Part(text=json.dumps(part.data)) logger.warning( @@ -118,27 +145,40 @@ def convert_genai_part_to_a2a_part( part: genai_types.Part, ) -> Optional[a2a_types.Part]: """Convert a Google GenAI Part to an A2A Part.""" + if part.text: - return a2a_types.TextPart(text=part.text) + a2a_part = a2a_types.TextPart(text=part.text) + if part.thought is not None: + a2a_part.metadata = {_get_adk_metadata_key('thought'): part.thought} + return a2a_types.Part(root=a2a_part) if part.file_data: - return a2a_types.FilePart( - file=a2a_types.FileWithUri( - uri=part.file_data.file_uri, - mimeType=part.file_data.mime_type, + return a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=part.file_data.file_uri, + mimeType=part.file_data.mime_type, + ) ) ) if part.inline_data: - return a2a_types.Part( - root=a2a_types.FilePart( - file=a2a_types.FileWithBytes( - bytes=part.inline_data.data, - mimeType=part.inline_data.mime_type, - ) + a2a_part = a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=base64.b64encode(part.inline_data.data).decode('utf-8'), + mimeType=part.inline_data.mime_type, ) ) + if part.video_metadata: + a2a_part.metadata = { + _get_adk_metadata_key( + 'video_metadata' + ): part.video_metadata.model_dump(by_alias=True, exclude_none=True) + } + + return a2a_types.Part(root=a2a_part) + # Conver the funcall and function reponse to A2A DataPart. # This is mainly for converting human in the loop and auth request and # response. @@ -151,9 +191,9 @@ def convert_genai_part_to_a2a_part( by_alias=True, exclude_none=True ), metadata={ - A2A_DATA_PART_METADATA_TYPE_KEY: ( - A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - ) + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL }, ) ) @@ -165,9 +205,37 @@ def convert_genai_part_to_a2a_part( by_alias=True, exclude_none=True ), metadata={ - A2A_DATA_PART_METADATA_TYPE_KEY: ( - A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - ) + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + }, + ) + ) + + if part.code_execution_result: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.code_execution_result.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + }, + ) + ) + + if part.executable_code: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.executable_code.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE }, ) ) diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 4b9bd47cf..1e8f0d4a3 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -21,17 +21,20 @@ # Skip all tests in this module if Python version is less than 3.10 pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" ) # Import dependencies with version checking try: from a2a import types as a2a_types + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part + from google.adk.a2a.converters.utils import _get_adk_metadata_key from google.genai import types as genai_types except ImportError as e: if sys.version_info < (3, 10): @@ -44,9 +47,12 @@ class DummyTypes: genai_types = DummyTypes() A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = "function_call" A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = "function_response" + A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = "code_execution_result" + A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = "executable_code" A2A_DATA_PART_METADATA_TYPE_KEY = "type" convert_a2a_part_to_genai_part = lambda x: None convert_genai_part_to_a2a_part = lambda x: None + _get_adk_metadata_key = lambda x: f"adk_{x}" else: raise e @@ -92,11 +98,14 @@ def test_convert_file_part_with_bytes(self): """Test conversion of A2A FilePart with bytes to GenAI Part.""" # Arrange test_bytes = b"test file content" - # Note: A2A FileWithBytes converts bytes to string automatically + # A2A FileWithBytes expects base64-encoded string + import base64 + + base64_encoded = base64.b64encode(test_bytes).decode("utf-8") a2a_part = a2a_types.Part( root=a2a_types.FilePart( file=a2a_types.FileWithBytes( - bytes=test_bytes, mimeType="text/plain" + bytes=base64_encoded, mimeType="text/plain" ) ) ) @@ -108,7 +117,7 @@ def test_convert_file_part_with_bytes(self): assert result is not None assert isinstance(result, genai_types.Part) assert result.inline_data is not None - # Source code now properly converts A2A string back to bytes for GenAI Blob + # The converter decodes base64 back to original bytes assert result.inline_data.data == test_bytes assert result.inline_data.mime_type == "text/plain" @@ -123,9 +132,9 @@ def test_convert_data_part_function_call(self): root=a2a_types.DataPart( data=function_call_data, metadata={ - A2A_DATA_PART_METADATA_TYPE_KEY: ( - A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - ), + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, }, ) @@ -152,9 +161,9 @@ def test_convert_data_part_function_response(self): root=a2a_types.DataPart( data=function_response_data, metadata={ - A2A_DATA_PART_METADATA_TYPE_KEY: ( - A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - ), + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, }, ) @@ -260,8 +269,25 @@ def test_convert_text_part(self): # Assert assert result is not None - assert isinstance(result, a2a_types.TextPart) - assert result.text == "Hello, world!" + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.TextPart) + assert result.root.text == "Hello, world!" + + def test_convert_text_part_with_thought(self): + """Test conversion of GenAI text Part with thought to A2A Part.""" + # Arrange - thought is a boolean field in genai_types.Part + genai_part = genai_types.Part(text="Hello, world!", thought=True) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.TextPart) + assert result.root.text == "Hello, world!" + assert result.root.metadata is not None + assert result.root.metadata[_get_adk_metadata_key("thought")] == True def test_convert_file_data_part(self): """Test conversion of GenAI file_data Part to A2A Part.""" @@ -277,10 +303,11 @@ def test_convert_file_data_part(self): # Assert assert result is not None - assert isinstance(result, a2a_types.FilePart) - assert isinstance(result.file, a2a_types.FileWithUri) - assert result.file.uri == "gs://bucket/file.txt" - assert result.file.mimeType == "text/plain" + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithUri) + assert result.root.file.uri == "gs://bucket/file.txt" + assert result.root.file.mimeType == "text/plain" def test_convert_inline_data_part(self): """Test conversion of GenAI inline_data Part to A2A Part.""" @@ -298,10 +325,34 @@ def test_convert_inline_data_part(self): assert isinstance(result, a2a_types.Part) assert isinstance(result.root, a2a_types.FilePart) assert isinstance(result.root.file, a2a_types.FileWithBytes) - # A2A FileWithBytes stores bytes as strings - assert result.root.file.bytes == test_bytes.decode("utf-8") + # A2A FileWithBytes now stores base64-encoded bytes to ensure round-trip compatibility + import base64 + + expected_base64 = base64.b64encode(test_bytes).decode("utf-8") + assert result.root.file.bytes == expected_base64 assert result.root.file.mimeType == "text/plain" + def test_convert_inline_data_part_with_video_metadata(self): + """Test conversion of GenAI inline_data Part with video metadata to A2A Part.""" + # Arrange + test_bytes = b"test video content" + video_metadata = genai_types.VideoMetadata(fps=30.0) + genai_part = genai_types.Part( + inline_data=genai_types.Blob(data=test_bytes, mime_type="video/mp4"), + video_metadata=video_metadata, + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithBytes) + assert result.root.metadata is not None + assert _get_adk_metadata_key("video_metadata") in result.root.metadata + def test_convert_function_call_part(self): """Test conversion of GenAI function_call Part to A2A Part.""" # Arrange @@ -320,7 +371,9 @@ def test_convert_function_call_part(self): expected_data = function_call.model_dump(by_alias=True, exclude_none=True) assert result.root.data == expected_data assert ( - result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL ) @@ -344,10 +397,62 @@ def test_convert_function_response_part(self): ) assert result.root.data == expected_data assert ( - result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE ) + def test_convert_code_execution_result_part(self): + """Test conversion of GenAI code_execution_result Part to A2A Part.""" + # Arrange + code_execution_result = genai_types.CodeExecutionResult( + outcome=genai_types.Outcome.OUTCOME_OK, output="Hello, World!" + ) + genai_part = genai_types.Part(code_execution_result=code_execution_result) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = code_execution_result.model_dump( + by_alias=True, exclude_none=True + ) + assert result.root.data == expected_data + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + ) + + def test_convert_executable_code_part(self): + """Test conversion of GenAI executable_code Part to A2A Part.""" + # Arrange + executable_code = genai_types.ExecutableCode( + language=genai_types.Language.PYTHON, code="print('Hello, World!')" + ) + genai_part = genai_types.Part(executable_code=executable_code) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = executable_code.model_dump(by_alias=True, exclude_none=True) + assert result.root.data == expected_data + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + ) + def test_convert_unsupported_part(self): """Test handling of unsupported GenAI Part types.""" # Arrange - Create a GenAI Part with no recognized fields @@ -379,8 +484,9 @@ def test_text_part_round_trip(self): # Assert assert result_a2a_part is not None - assert isinstance(result_a2a_part, a2a_types.TextPart) - assert result_a2a_part.text == original_text + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.TextPart) + assert result_a2a_part.root.text == original_text def test_file_uri_round_trip(self): """Test round-trip conversion for file parts with URI.""" @@ -401,10 +507,122 @@ def test_file_uri_round_trip(self): # Assert assert result_a2a_part is not None - assert isinstance(result_a2a_part, a2a_types.FilePart) - assert isinstance(result_a2a_part.file, a2a_types.FileWithUri) - assert result_a2a_part.file.uri == original_uri - assert result_a2a_part.file.mimeType == original_mime_type + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.FilePart) + assert isinstance(result_a2a_part.root.file, a2a_types.FileWithUri) + assert result_a2a_part.root.file.uri == original_uri + assert result_a2a_part.root.file.mimeType == original_mime_type + + def test_file_bytes_round_trip(self): + """Test round-trip conversion for file parts with bytes.""" + # Arrange + original_bytes = b"test file content for round trip" + original_mime_type = "application/octet-stream" + + # Start with GenAI part (the more common starting point) + genai_part = genai_types.Part( + inline_data=genai_types.Blob( + data=original_bytes, mime_type=original_mime_type + ) + ) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.inline_data is not None + assert result_genai_part.inline_data.data == original_bytes + assert result_genai_part.inline_data.mime_type == original_mime_type + + def test_function_call_round_trip(self): + """Test round-trip conversion for function call parts.""" + # Arrange + function_call = genai_types.FunctionCall( + name="test_function", args={"param1": "value1", "param2": 42} + ) + genai_part = genai_types.Part(function_call=function_call) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.function_call is not None + assert result_genai_part.function_call.name == function_call.name + assert result_genai_part.function_call.args == function_call.args + + def test_function_response_round_trip(self): + """Test round-trip conversion for function response parts.""" + # Arrange + function_response = genai_types.FunctionResponse( + name="test_function", response={"result": "success", "data": [1, 2, 3]} + ) + genai_part = genai_types.Part(function_response=function_response) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.function_response is not None + assert result_genai_part.function_response.name == function_response.name + assert ( + result_genai_part.function_response.response + == function_response.response + ) + + def test_code_execution_result_round_trip(self): + """Test round-trip conversion for code execution result parts.""" + # Arrange + code_execution_result = genai_types.CodeExecutionResult( + outcome=genai_types.Outcome.OUTCOME_OK, output="Hello, World!" + ) + genai_part = genai_types.Part(code_execution_result=code_execution_result) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.code_execution_result is not None + assert ( + result_genai_part.code_execution_result.outcome + == code_execution_result.outcome + ) + assert ( + result_genai_part.code_execution_result.output + == code_execution_result.output + ) + + def test_executable_code_round_trip(self): + """Test round-trip conversion for executable code parts.""" + # Arrange + executable_code = genai_types.ExecutableCode( + language=genai_types.Language.PYTHON, code="print('Hello, World!')" + ) + genai_part = genai_types.Part(executable_code=executable_code) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.executable_code is not None + assert ( + result_genai_part.executable_code.language == executable_code.language + ) + assert result_genai_part.executable_code.code == executable_code.code class TestEdgeCases: @@ -468,3 +686,75 @@ def test_data_part_with_empty_metadata(self): # Assert assert result is not None assert result.text == json.dumps(data) + + +class TestNewConstants: + """Test cases for new constants and functionality.""" + + def test_new_constants_exist(self): + """Test that new constants are defined.""" + assert ( + A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + == "code_execution_result" + ) + assert A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE == "executable_code" + + def test_convert_a2a_data_part_with_code_execution_result_metadata(self): + """Test conversion of A2A DataPart with code execution result metadata.""" + # Arrange + code_execution_result_data = { + "outcome": "OUTCOME_OK", + "output": "Hello, World!", + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=code_execution_result_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + # Now it should convert back to a proper CodeExecutionResult + assert result.code_execution_result is not None + assert ( + result.code_execution_result.outcome == genai_types.Outcome.OUTCOME_OK + ) + assert result.code_execution_result.output == "Hello, World!" + + def test_convert_a2a_data_part_with_executable_code_metadata(self): + """Test conversion of A2A DataPart with executable code metadata.""" + # Arrange + executable_code_data = { + "language": "PYTHON", + "code": "print('Hello, World!')", + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=executable_code_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + # Now it should convert back to a proper ExecutableCode + assert result.executable_code is not None + assert result.executable_code.language == genai_types.Language.PYTHON + assert result.executable_code.code == "print('Hello, World!')" From a71dbdf9e24a44f18b23be082ef9cb14d9b03692 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 25 Jun 2025 15:31:42 -0700 Subject: [PATCH 26/28] chore: Enhance a2a event converter a. fix function call long running id matching logic b. fix error code conversion logic c. add input required and auth required status conversion logic d. add a2a Task/Message to ADK event converter f. set task id and context id from input argument PiperOrigin-RevId: 775860563 --- .../adk/a2a/converters/event_converter.py | 288 ++++++- .../a2a/converters/test_event_converter.py | 701 +++++++++++++++++- 2 files changed, 918 insertions(+), 71 deletions(-) diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py index 5594c0e63..25183f6be 100644 --- a/src/google/adk/a2a/converters/event_converter.py +++ b/src/google/adk/a2a/converters/event_converter.py @@ -14,7 +14,8 @@ from __future__ import annotations -import datetime +from datetime import datetime +from datetime import timezone import logging from typing import Any from typing import Dict @@ -26,18 +27,24 @@ from a2a.types import Artifact from a2a.types import DataPart from a2a.types import Message +from a2a.types import Part as A2APart from a2a.types import Role +from a2a.types import Task from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart +from google.genai import types as genai_types from ...agents.invocation_context import InvocationContext from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME from ...utils.feature_decorator import working_in_progress +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import convert_a2a_part_to_genai_part from .part_converter import convert_genai_part_to_a2a_part from .utils import _get_adk_metadata_key @@ -143,6 +150,8 @@ def _convert_artifact_to_a2a_events( invocation_context: InvocationContext, filename: str, version: int, + task_id: Optional[str] = None, + context_id: Optional[str] = None, ) -> TaskArtifactUpdateEvent: """Converts a new artifact version to an A2A TaskArtifactUpdateEvent. @@ -151,6 +160,7 @@ def _convert_artifact_to_a2a_events( invocation_context: The invocation context. filename: The name of the artifact file. version: The version number of the artifact. + task_id: Optional task ID to use for generated events. If not provided, new UUIDs will be generated. Returns: A TaskArtifactUpdateEvent representing the artifact update. @@ -186,9 +196,9 @@ def _convert_artifact_to_a2a_events( ) return TaskArtifactUpdateEvent( - taskId=str(uuid.uuid4()), + taskId=task_id, append=False, - contextId=invocation_context.session.id, + contextId=context_id, lastChunk=True, artifact=Artifact( artifactId=artifact_id, @@ -210,7 +220,7 @@ def _convert_artifact_to_a2a_events( raise RuntimeError(f"Artifact conversion failed: {e}") from e -def _process_long_running_tool(a2a_part, event: Event) -> None: +def _process_long_running_tool(a2a_part: A2APart, event: Event) -> None: """Processes long-running tool metadata for an A2A part. Args: @@ -220,18 +230,173 @@ def _process_long_running_tool(a2a_part, event: Event) -> None: if ( isinstance(a2a_part.root, DataPart) and event.long_running_tool_ids + and a2a_part.root.metadata and a2a_part.root.metadata.get( _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) ) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - and a2a_part.root.metadata.get("id") in event.long_running_tool_ids + and a2a_part.root.data.get("id") in event.long_running_tool_ids ): - a2a_part.root.metadata[_get_adk_metadata_key("is_long_running")] = True + a2a_part.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ] = True @working_in_progress -def convert_event_to_a2a_status_message( - event: Event, invocation_context: InvocationContext +def convert_a2a_task_to_event( + a2a_task: Task, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, +) -> Event: + """Converts an A2A task to an ADK event. + + Args: + a2a_task: The A2A task to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + + Returns: + An ADK Event object representing the converted task. + + Raises: + ValueError: If a2a_task is None. + RuntimeError: If conversion of the underlying message fails. + """ + if a2a_task is None: + raise ValueError("A2A task cannot be None") + + try: + # Extract message from task status or history + message = None + if a2a_task.status and a2a_task.status.message: + message = a2a_task.status.message + elif a2a_task.history: + message = a2a_task.history[-1] + + # Convert message if available + if message: + try: + return convert_a2a_message_to_event(message, author, invocation_context) + except Exception as e: + logger.error("Failed to convert A2A task message to event: %s", e) + raise RuntimeError(f"Failed to convert task message: {e}") from e + + # Create minimal event if no message is available + return Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + ) + + except Exception as e: + logger.error("Failed to convert A2A task to event: %s", e) + raise + + +@working_in_progress +def convert_a2a_message_to_event( + a2a_message: Message, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, +) -> Event: + """Converts an A2A message to an ADK event. + + Args: + a2a_message: The A2A message to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + + Returns: + An ADK Event object with converted content and long-running tool metadata. + + Raises: + ValueError: If a2a_message is None. + RuntimeError: If conversion of message parts fails. + """ + if a2a_message is None: + raise ValueError("A2A message cannot be None") + + if not a2a_message.parts: + logger.warning( + "A2A message has no parts, creating event with empty content" + ) + return Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + content=genai_types.Content(role="model", parts=[]), + ) + + try: + parts = [] + long_running_tool_ids = set() + + for a2a_part in a2a_message.parts: + try: + part = convert_a2a_part_to_genai_part(a2a_part) + if part is None: + logger.warning("Failed to convert A2A part, skipping: %s", a2a_part) + continue + + # Check for long-running tools + if ( + a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY + ) + ) + is True + ): + long_running_tool_ids.add(part.function_call.id) + + parts.append(part) + + except Exception as e: + logger.error("Failed to convert A2A part: %s, error: %s", a2a_part, e) + # Continue processing other parts instead of failing completely + continue + + if not parts: + logger.warning( + "No parts could be converted from A2A message %s", a2a_message + ) + + return Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + long_running_tool_ids=long_running_tool_ids + if long_running_tool_ids + else None, + content=genai_types.Content( + role="model", + parts=parts, + ), + ) + + except Exception as e: + logger.error("Failed to convert A2A message to event: %s", e) + raise RuntimeError(f"Failed to convert message: {e}") from e + + +@working_in_progress +def convert_event_to_a2a_message( + event: Event, invocation_context: InvocationContext, role: Role = Role.agent ) -> Optional[Message]: """Converts an ADK event to an A2A message. @@ -262,9 +427,7 @@ def convert_event_to_a2a_status_message( _process_long_running_tool(a2a_part, event) if a2a_parts: - return Message( - messageId=str(uuid.uuid4()), role=Role.agent, parts=a2a_parts - ) + return Message(messageId=str(uuid.uuid4()), role=role, parts=a2a_parts) except Exception as e: logger.error("Failed to convert event to status message: %s", e) @@ -274,38 +437,57 @@ def convert_event_to_a2a_status_message( def _create_error_status_event( - event: Event, invocation_context: InvocationContext + event: Event, + invocation_context: InvocationContext, + task_id: Optional[str] = None, + context_id: Optional[str] = None, ) -> TaskStatusUpdateEvent: """Creates a TaskStatusUpdateEvent for error scenarios. Args: event: The ADK event containing error information. invocation_context: The invocation context. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. Returns: A TaskStatusUpdateEvent with FAILED state. """ error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE + # Get context metadata and add error code + event_metadata = _get_context_metadata(event, invocation_context) + if event.error_code: + event_metadata[_get_adk_metadata_key("error_code")] = str(event.error_code) + return TaskStatusUpdateEvent( - taskId=str(uuid.uuid4()), - contextId=invocation_context.session.id, - final=False, - metadata=_get_context_metadata(event, invocation_context), + taskId=task_id, + contextId=context_id, + metadata=event_metadata, status=TaskStatus( state=TaskState.failed, message=Message( messageId=str(uuid.uuid4()), role=Role.agent, parts=[TextPart(text=error_message)], + metadata={ + _get_adk_metadata_key("error_code"): str(event.error_code) + } + if event.error_code + else {}, ), - timestamp=datetime.datetime.now().isoformat(), + timestamp=datetime.now(timezone.utc).isoformat(), ), + final=False, ) -def _create_running_status_event( - message: Message, invocation_context: InvocationContext, event: Event +def _create_status_update_event( + message: Message, + invocation_context: InvocationContext, + event: Event, + task_id: Optional[str] = None, + context_id: Optional[str] = None, ) -> TaskStatusUpdateEvent: """Creates a TaskStatusUpdateEvent for running scenarios. @@ -313,32 +495,70 @@ def _create_running_status_event( message: The A2A message to include. invocation_context: The invocation context. event: The ADK event. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + Returns: A TaskStatusUpdateEvent with RUNNING state. """ + status = TaskStatus( + state=TaskState.working, + message=message, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + if any( + part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ) + is True + and part.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME + for part in message.parts + if part.root.metadata + ): + status.state = TaskState.auth_required + elif any( + part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ) + is True + for part in message.parts + if part.root.metadata + ): + status.state = TaskState.input_required + return TaskStatusUpdateEvent( - taskId=str(uuid.uuid4()), - contextId=invocation_context.session.id, - final=False, - status=TaskStatus( - state=TaskState.working, - message=message, - timestamp=datetime.datetime.now().isoformat(), - ), + taskId=task_id, + contextId=context_id, + status=status, metadata=_get_context_metadata(event, invocation_context), + final=False, ) @working_in_progress def convert_event_to_a2a_events( - event: Event, invocation_context: InvocationContext + event: Event, + invocation_context: InvocationContext, + task_id: Optional[str] = None, + context_id: Optional[str] = None, ) -> List[A2AEvent]: """Converts a GenAI event to a list of A2A events. Args: event: The ADK event to convert. invocation_context: The invocation context. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. Returns: A list of A2A events representing the converted ADK event. @@ -358,20 +578,22 @@ def convert_event_to_a2a_events( if event.actions.artifact_delta: for filename, version in event.actions.artifact_delta.items(): artifact_event = _convert_artifact_to_a2a_events( - event, invocation_context, filename, version + event, invocation_context, filename, version, task_id, context_id ) a2a_events.append(artifact_event) # Handle error scenarios if event.error_code: - error_event = _create_error_status_event(event, invocation_context) + error_event = _create_error_status_event( + event, invocation_context, task_id, context_id + ) a2a_events.append(error_event) # Handle regular message content - message = convert_event_to_a2a_status_message(event, invocation_context) + message = convert_event_to_a2a_message(event, invocation_context) if message: - running_event = _create_running_status_event( - message, invocation_context, event + running_event = _create_status_update_event( + message, invocation_context, event, task_id, context_id ) a2a_events.append(running_event) diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py index 311ffc954..2ba8e26b4 100644 --- a/tests/unittests/a2a/converters/test_event_converter.py +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -20,7 +20,7 @@ # Skip all tests in this module if Python version is less than 3.10 pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" ) # Import dependencies with version checking @@ -28,20 +28,21 @@ from a2a.types import DataPart from a2a.types import Message from a2a.types import Role + from a2a.types import Task from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatusUpdateEvent from google.adk.a2a.converters.event_converter import _convert_artifact_to_a2a_events from google.adk.a2a.converters.event_converter import _create_artifact_id from google.adk.a2a.converters.event_converter import _create_error_status_event - from google.adk.a2a.converters.event_converter import _create_running_status_event + from google.adk.a2a.converters.event_converter import _create_status_update_event from google.adk.a2a.converters.event_converter import _get_adk_metadata_key from google.adk.a2a.converters.event_converter import _get_context_metadata from google.adk.a2a.converters.event_converter import _process_long_running_tool from google.adk.a2a.converters.event_converter import _serialize_metadata_value from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events - from google.adk.a2a.converters.event_converter import convert_event_to_a2a_status_message + from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX from google.adk.agents.invocation_context import InvocationContext @@ -57,13 +58,14 @@ class DummyTypes: DataPart = DummyTypes() Message = DummyTypes() Role = DummyTypes() + Task = DummyTypes() TaskArtifactUpdateEvent = DummyTypes() TaskState = DummyTypes() TaskStatusUpdateEvent = DummyTypes() _convert_artifact_to_a2a_events = lambda *args: None _create_artifact_id = lambda *args: None _create_error_status_event = lambda *args: None - _create_running_status_event = lambda *args: None + _create_status_update_event = lambda *args: None _get_adk_metadata_key = lambda *args: None _get_context_metadata = lambda *args: None _process_long_running_tool = lambda *args: None @@ -71,7 +73,7 @@ class DummyTypes: ADK_METADATA_KEY_PREFIX = "adk_" ARTIFACT_ID_SEPARATOR = "_" convert_event_to_a2a_events = lambda *args: None - convert_event_to_a2a_status_message = lambda *args: None + convert_event_to_a2a_message = lambda *args: None DEFAULT_ERROR_MESSAGE = "error" InvocationContext = DummyTypes() Event = DummyTypes() @@ -233,6 +235,8 @@ def test_convert_artifact_to_a2a_events_success(self, mock_convert_part): """Test successful artifact delta conversion.""" filename = "test.txt" version = 1 + task_id = "test-task-id" + context_id = "test-context-id" mock_artifact_part = Mock() # Create a proper Part that Pydantic will accept @@ -246,11 +250,17 @@ def test_convert_artifact_to_a2a_events_success(self, mock_convert_part): mock_convert_part.return_value = mock_converted_part result = _convert_artifact_to_a2a_events( - self.mock_event, self.mock_invocation_context, filename, version + self.mock_event, + self.mock_invocation_context, + filename, + version, + task_id, + context_id, ) assert isinstance(result, TaskArtifactUpdateEvent) - assert result.contextId == self.mock_invocation_context.session.id + assert result.taskId == task_id + assert result.contextId == context_id assert result.append is False assert result.lastChunk is True @@ -265,7 +275,7 @@ def test_convert_artifact_to_a2a_events_empty_filename(self): """Test artifact delta conversion with empty filename.""" with pytest.raises(ValueError) as exc_info: _convert_artifact_to_a2a_events( - self.mock_event, self.mock_invocation_context, "", 1 + self.mock_event, self.mock_invocation_context, "", 1, "", "" ) assert "Filename cannot be empty" in str(exc_info.value) @@ -273,7 +283,7 @@ def test_convert_artifact_to_a2a_events_negative_version(self): """Test artifact delta conversion with negative version.""" with pytest.raises(ValueError) as exc_info: _convert_artifact_to_a2a_events( - self.mock_event, self.mock_invocation_context, "test.txt", -1 + self.mock_event, self.mock_invocation_context, "test.txt", -1, "", "" ) assert "Version must be non-negative" in str(exc_info.value) @@ -293,7 +303,12 @@ def test_convert_artifact_to_a2a_events_conversion_failure( with pytest.raises(RuntimeError) as exc_info: _convert_artifact_to_a2a_events( - self.mock_event, self.mock_invocation_context, filename, version + self.mock_event, + self.mock_invocation_context, + filename, + version, + "", + "", ) assert "Failed to convert artifact part" in str(exc_info.value) @@ -302,6 +317,8 @@ def test_process_long_running_tool_marks_tool(self): mock_a2a_part = Mock() mock_data_part = Mock(spec=DataPart) mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-123"} + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="tool-123") mock_a2a_part.root = mock_data_part self.mock_event.long_running_tool_ids = {"tool-123"} @@ -315,7 +332,11 @@ def test_process_long_running_tool_marks_tool(self): "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", "function_call", ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, ): + mock_get_key.side_effect = lambda key: f"adk_{key}" _process_long_running_tool(mock_a2a_part, self.mock_event) @@ -327,6 +348,8 @@ def test_process_long_running_tool_no_marking(self): mock_a2a_part = Mock() mock_data_part = Mock(spec=DataPart) mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-456"} + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="tool-456") mock_a2a_part.root = mock_data_part self.mock_event.long_running_tool_ids = {"tool-123"} # Different ID @@ -340,7 +363,11 @@ def test_process_long_running_tool_no_marking(self): "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", "function_call", ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, ): + mock_get_key.side_effect = lambda key: f"adk_{key}" _process_long_running_tool(mock_a2a_part, self.mock_event) @@ -368,7 +395,7 @@ def test_convert_event_to_message_success(self, mock_uuid, mock_convert_part): mock_content.parts = [mock_part] self.mock_event.content = mock_content - result = convert_event_to_a2a_status_message( + result = convert_event_to_a2a_message( self.mock_event, self.mock_invocation_context ) @@ -382,7 +409,7 @@ def test_convert_event_to_message_no_content(self): """Test event to message conversion with no content.""" self.mock_event.content = None - result = convert_event_to_a2a_status_message( + result = convert_event_to_a2a_message( self.mock_event, self.mock_invocation_context ) @@ -394,7 +421,7 @@ def test_convert_event_to_message_empty_parts(self): mock_content.parts = [] self.mock_event.content = mock_content - result = convert_event_to_a2a_status_message( + result = convert_event_to_a2a_message( self.mock_event, self.mock_invocation_context ) @@ -403,17 +430,50 @@ def test_convert_event_to_message_empty_parts(self): def test_convert_event_to_message_none_event(self): """Test event to message conversion with None event.""" with pytest.raises(ValueError) as exc_info: - convert_event_to_a2a_status_message(None, self.mock_invocation_context) + convert_event_to_a2a_message(None, self.mock_invocation_context) assert "Event cannot be None" in str(exc_info.value) def test_convert_event_to_message_none_context(self): """Test event to message conversion with None context.""" with pytest.raises(ValueError) as exc_info: - convert_event_to_a2a_status_message(self.mock_event, None) + convert_event_to_a2a_message(self.mock_event, None) assert "Invocation context cannot be None" in str(exc_info.value) + @patch( + "google.adk.a2a.converters.event_converter.convert_genai_part_to_a2a_part" + ) @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") - @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + def test_convert_event_to_message_with_custom_role( + self, mock_uuid, mock_convert_part + ): + """Test event to message conversion with custom role.""" + mock_uuid.return_value = "test-uuid" + + mock_part = Mock() + # Create a proper Part that Pydantic will accept + from a2a.types import Part + from a2a.types import TextPart + + text_part = TextPart(text="test message") + mock_a2a_part = Part(root=text_part) + mock_convert_part.return_value = mock_a2a_part + + mock_content = Mock() + mock_content.parts = [mock_part] + self.mock_event.content = mock_content + + result = convert_event_to_a2a_message( + self.mock_event, self.mock_invocation_context, role=Role.user + ) + + assert isinstance(result, Message) + assert result.messageId == "test-uuid" + assert result.role == Role.user + assert len(result.parts) == 1 + assert result.parts[0].root.text == "test message" + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + @patch("google.adk.a2a.converters.event_converter.datetime") def test_create_error_status_event(self, mock_datetime, mock_uuid): """Test creation of error status event.""" mock_uuid.return_value = "test-uuid" @@ -422,18 +482,21 @@ def test_create_error_status_event(self, mock_datetime, mock_uuid): ) self.mock_event.error_message = "Test error message" + task_id = "test-task-id" + context_id = "test-context-id" result = _create_error_status_event( - self.mock_event, self.mock_invocation_context + self.mock_event, self.mock_invocation_context, task_id, context_id ) assert isinstance(result, TaskStatusUpdateEvent) - assert result.contextId == self.mock_invocation_context.session.id + assert result.taskId == task_id + assert result.contextId == context_id assert result.status.state == TaskState.failed assert result.status.message.parts[0].root.text == "Test error message" @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") - @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + @patch("google.adk.a2a.converters.event_converter.datetime") def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid): """Test creation of error status event without error message.""" mock_uuid.return_value = "test-uuid" @@ -441,13 +504,16 @@ def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid): "2023-01-01T00:00:00" ) + task_id = "test-task-id" + context_id = "test-context-id" + result = _create_error_status_event( - self.mock_event, self.mock_invocation_context + self.mock_event, self.mock_invocation_context, task_id, context_id ) assert result.status.message.parts[0].root.text == DEFAULT_ERROR_MESSAGE - @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + @patch("google.adk.a2a.converters.event_converter.datetime") def test_create_running_status_event(self, mock_datetime): """Test creation of running status event.""" mock_datetime.now.return_value.isoformat.return_value = ( @@ -455,13 +521,21 @@ def test_create_running_status_event(self, mock_datetime): ) mock_message = Mock(spec=Message) - - result = _create_running_status_event( - mock_message, self.mock_invocation_context, self.mock_event + mock_message.parts = [] + task_id = "test-task-id" + context_id = "test-context-id" + + result = _create_status_update_event( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, ) assert isinstance(result, TaskStatusUpdateEvent) - assert result.contextId == self.mock_invocation_context.session.id + assert result.taskId == task_id + assert result.contextId == context_id assert result.status.state == TaskState.working assert result.status.message == mock_message @@ -469,11 +543,11 @@ def test_create_running_status_event(self, mock_datetime): "google.adk.a2a.converters.event_converter._convert_artifact_to_a2a_events" ) @patch( - "google.adk.a2a.converters.event_converter.convert_event_to_a2a_status_message" + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" ) @patch("google.adk.a2a.converters.event_converter._create_error_status_event") @patch( - "google.adk.a2a.converters.event_converter._create_running_status_event" + "google.adk.a2a.converters.event_converter._create_status_update_event" ) def test_convert_event_to_a2a_events_full_scenario( self, @@ -514,14 +588,14 @@ def test_convert_event_to_a2a_events_full_scenario( # Verify artifact delta events assert mock_convert_artifact.call_count == 2 - # Verify error event + # Verify error event - now called with task_id and context_id parameters mock_create_error.assert_called_once_with( - self.mock_event, self.mock_invocation_context + self.mock_event, self.mock_invocation_context, None, None ) - # Verify running event + # Verify running event - now called with task_id and context_id parameters mock_create_running.assert_called_once_with( - mock_message, self.mock_invocation_context, self.mock_event + mock_message, self.mock_invocation_context, self.mock_event, None, None ) # Verify result contains all events @@ -552,7 +626,7 @@ def test_convert_event_to_a2a_events_none_context(self): assert "Invocation context cannot be None" in str(exc_info.value) @patch( - "google.adk.a2a.converters.event_converter.convert_event_to_a2a_status_message" + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" ) def test_convert_event_to_a2a_events_message_only(self, mock_convert_message): """Test event to A2A events conversion with message only.""" @@ -560,7 +634,7 @@ def test_convert_event_to_a2a_events_message_only(self, mock_convert_message): mock_convert_message.return_value = mock_message with patch( - "google.adk.a2a.converters.event_converter._create_running_status_event" + "google.adk.a2a.converters.event_converter._create_status_update_event" ) as mock_create_running: mock_running_event = Mock() mock_create_running.return_value = mock_running_event @@ -571,15 +645,23 @@ def test_convert_event_to_a2a_events_message_only(self, mock_convert_message): assert len(result) == 1 assert result[0] == mock_running_event + # Verify the function is called with task_id and context_id parameters + mock_create_running.assert_called_once_with( + mock_message, + self.mock_invocation_context, + self.mock_event, + None, + None, + ) @patch("google.adk.a2a.converters.event_converter.logger") def test_convert_event_to_a2a_events_exception_handling(self, mock_logger): - """Test exception handling in event to A2A events conversion.""" - # Make convert_event_to_a2a_status_message raise an exception + """Test exception handling in convert_event_to_a2a_events.""" + # Make convert_event_to_a2a_message raise an exception with patch( - "google.adk.a2a.converters.event_converter.convert_event_to_a2a_status_message" - ) as mock_convert: - mock_convert.side_effect = Exception("Conversion failed") + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.side_effect = Exception("Test exception") with pytest.raises(Exception): convert_event_to_a2a_events( @@ -587,3 +669,546 @@ def test_convert_event_to_a2a_events_exception_handling(self, mock_logger): ) mock_logger.error.assert_called_once() + + def test_convert_event_to_a2a_events_with_task_id_and_context_id(self): + """Test event to A2A events conversion with specific task_id and context_id.""" + # Setup message + mock_message = Mock(spec=Message) + mock_message.parts = [] + + with patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.return_value = mock_message + + with patch( + "google.adk.a2a.converters.event_converter._create_status_update_event" + ) as mock_create_running: + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + task_id = "custom-task-id" + context_id = "custom-context-id" + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context, task_id, context_id + ) + + assert len(result) == 1 + assert result[0] == mock_running_event + + # Verify the function is called with the specific task_id and context_id + mock_create_running.assert_called_once_with( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + def test_convert_event_to_a2a_events_with_artifacts_and_custom_ids(self): + """Test event to A2A events conversion with artifacts and custom IDs.""" + # Setup artifact delta + self.mock_event.actions.artifact_delta = {"file1.txt": 1} + + # Setup message + mock_message = Mock(spec=Message) + mock_message.parts = [] + + with patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.return_value = mock_message + + with patch( + "google.adk.a2a.converters.event_converter._convert_artifact_to_a2a_events" + ) as mock_convert_artifact: + mock_artifact_event = Mock() + mock_convert_artifact.return_value = mock_artifact_event + + with patch( + "google.adk.a2a.converters.event_converter._create_status_update_event" + ) as mock_create_running: + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + task_id = "custom-task-id" + context_id = "custom-context-id" + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context, task_id, context_id + ) + + assert len(result) == 2 # 1 artifact + 1 status + assert mock_artifact_event in result + assert mock_running_event in result + + # Verify artifact conversion is called with custom IDs + mock_convert_artifact.assert_called_once_with( + self.mock_event, + self.mock_invocation_context, + "file1.txt", + 1, + task_id, + context_id, + ) + + # Verify status update is called with custom IDs + mock_create_running.assert_called_once_with( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + def test_create_status_update_event_with_auth_required_state(self): + """Test creation of status update event with auth_required state.""" + from a2a.types import DataPart + from a2a.types import Part + + # Create a mock message with a part that triggers auth_required state + mock_message = Mock(spec=Message) + mock_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = { + "adk_type": "function_call", + "adk_is_long_running": True, + } + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="request_euc") + mock_part.root = mock_data_part + mock_message.parts = [mock_part] + + task_id = "test-task-id" + context_id = "test-context-id" + + with patch( + "google.adk.a2a.converters.event_converter.datetime" + ) as mock_datetime: + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY", + "is_long_running", + ), + patch( + "google.adk.a2a.converters.event_converter.REQUEST_EUC_FUNCTION_CALL_NAME", + "request_euc", + ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, + ): + mock_get_key.side_effect = lambda key: f"adk_{key}" + + result = _create_status_update_event( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.taskId == task_id + assert result.contextId == context_id + assert result.status.state == TaskState.auth_required + + def test_create_status_update_event_with_input_required_state(self): + """Test creation of status update event with input_required state.""" + from a2a.types import DataPart + from a2a.types import Part + + # Create a mock message with a part that triggers input_required state + mock_message = Mock(spec=Message) + mock_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = { + "adk_type": "function_call", + "adk_is_long_running": True, + } + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="some_other_function") + mock_part.root = mock_data_part + mock_message.parts = [mock_part] + + task_id = "test-task-id" + context_id = "test-context-id" + + with patch( + "google.adk.a2a.converters.event_converter.datetime" + ) as mock_datetime: + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY", + "is_long_running", + ), + patch( + "google.adk.a2a.converters.event_converter.REQUEST_EUC_FUNCTION_CALL_NAME", + "request_euc", + ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, + ): + mock_get_key.side_effect = lambda key: f"adk_{key}" + + result = _create_status_update_event( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.taskId == task_id + assert result.contextId == context_id + assert result.status.state == TaskState.input_required + + +class TestA2AToEventConverters: + """Test suite for A2A to Event conversion functions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_invocation_context = Mock(spec=InvocationContext) + self.mock_invocation_context.branch = "test-branch" + self.mock_invocation_context.invocation_id = "test-invocation-id" + + def test_convert_a2a_task_to_event_with_status_message(self): + """Test converting A2A task with status message.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock message and task + mock_message = Mock(spec=Message) + mock_status = Mock() + mock_status.message = mock_message + mock_task = Mock(spec=Task) + mock_task.status = mock_status + mock_task.history = [] + + # Mock the convert_a2a_message_to_event function + with patch( + "google.adk.a2a.converters.event_converter.convert_a2a_message_to_event" + ) as mock_convert_message: + mock_event = Mock(spec=Event) + mock_event.invocation_id = "test-invocation-id" + mock_convert_message.return_value = mock_event + + result = convert_a2a_task_to_event( + mock_task, "test-author", self.mock_invocation_context + ) + + # Verify the message converter was called with correct parameters + mock_convert_message.assert_called_once_with( + mock_message, "test-author", self.mock_invocation_context + ) + assert result == mock_event + assert result.invocation_id == "test-invocation-id" + + def test_convert_a2a_task_to_event_with_history_message(self): + """Test converting A2A task with history message when no status message.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock message and task + mock_message = Mock(spec=Message) + mock_task = Mock(spec=Task) + mock_task.status = None + mock_task.history = [mock_message] + + # Mock the convert_a2a_message_to_event function + with patch( + "google.adk.a2a.converters.event_converter.convert_a2a_message_to_event" + ) as mock_convert_message: + mock_event = Mock(spec=Event) + mock_event.invocation_id = "test-invocation-id" + mock_convert_message.return_value = mock_event + + result = convert_a2a_task_to_event(mock_task, "test-author") + + # Verify the message converter was called with correct parameters + mock_convert_message.assert_called_once_with( + mock_message, "test-author", None + ) + assert result == mock_event + + def test_convert_a2a_task_to_event_no_message(self): + """Test converting A2A task with no message.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock task with no message + mock_task = Mock(spec=Task) + mock_task.status = None + mock_task.history = [] + + result = convert_a2a_task_to_event( + mock_task, "test-author", self.mock_invocation_context + ) + + # Verify minimal event was created with correct invocation_id + assert result.author == "test-author" + assert result.branch == "test-branch" + assert result.invocation_id == "test-invocation-id" + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + def test_convert_a2a_task_to_event_default_author(self, mock_uuid): + """Test converting A2A task with default author and no invocation context.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock task with no message + mock_task = Mock(spec=Task) + mock_task.status = None + mock_task.history = [] + + # Mock UUID generation + mock_uuid.return_value = "generated-uuid" + + result = convert_a2a_task_to_event(mock_task) + + # Verify default author was used and UUID was generated for invocation_id + assert result.author == "a2a agent" + assert result.branch is None + assert result.invocation_id == "generated-uuid" + + def test_convert_a2a_task_to_event_none_task(self): + """Test converting None task raises ValueError.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + with pytest.raises(ValueError, match="A2A task cannot be None"): + convert_a2a_task_to_event(None) + + def test_convert_a2a_task_to_event_message_conversion_error(self): + """Test error handling when message conversion fails.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock message and task + mock_message = Mock(spec=Message) + mock_status = Mock() + mock_status.message = mock_message + mock_task = Mock(spec=Task) + mock_task.status = mock_status + mock_task.history = [] + + # Mock the convert_a2a_message_to_event function to raise an exception + with patch( + "google.adk.a2a.converters.event_converter.convert_a2a_message_to_event" + ) as mock_convert_message: + mock_convert_message.side_effect = Exception("Conversion failed") + + with pytest.raises(RuntimeError, match="Failed to convert task message"): + convert_a2a_task_to_event(mock_task, "test-author") + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_success(self, mock_convert_part): + """Test successful conversion of A2A message to event.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + from google.genai import types as genai_types + + # Create mock parts and message with valid genai Part + mock_a2a_part = Mock() + mock_genai_part = genai_types.Part(text="test content") + mock_convert_part.return_value = mock_genai_part + + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify conversion was successful + assert result.author == "test-author" + assert result.branch == "test-branch" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 1 + assert result.content.parts[0].text == "test content" + mock_convert_part.assert_called_once_with(mock_a2a_part) + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_with_long_running_tools( + self, mock_convert_part + ): + """Test conversion with long-running tools by mocking the entire flow.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + # Create mock parts and message + mock_a2a_part = Mock() + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + # Mock the part conversion to return None to simulate long-running tool detection logic + mock_convert_part.return_value = None + + # Patch the long-running tool detection since the main logic is in the actual conversion + with patch( + "google.adk.a2a.converters.event_converter.logger" + ) as mock_logger: + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify basic conversion worked + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + # Parts will be empty since conversion returned None, but that's expected for this test + + def test_convert_a2a_message_to_event_empty_parts(self): + """Test conversion with empty parts list.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + mock_message = Mock(spec=Message) + mock_message.parts = [] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify event was created with empty parts + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 0 + + def test_convert_a2a_message_to_event_none_message(self): + """Test converting None message raises ValueError.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + with pytest.raises(ValueError, match="A2A message cannot be None"): + convert_a2a_message_to_event(None) + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_part_conversion_fails( + self, mock_convert_part + ): + """Test handling when part conversion returns None.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + # Setup mock to return None (conversion failure) + mock_a2a_part = Mock() + mock_convert_part.return_value = None + + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify event was created but with no parts + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 0 + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_part_conversion_exception( + self, mock_convert_part + ): + """Test handling when part conversion raises exception.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + from google.genai import types as genai_types + + # Setup mock to raise exception + mock_a2a_part1 = Mock() + mock_a2a_part2 = Mock() + mock_genai_part = genai_types.Part(text="successful conversion") + + mock_convert_part.side_effect = [ + Exception("Conversion failed"), # First part fails + mock_genai_part, # Second part succeeds + ] + + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part1, mock_a2a_part2] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify event was created with only the successfully converted part + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 1 + assert result.content.parts[0].text == "successful conversion" + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_missing_tool_id( + self, mock_convert_part + ): + """Test handling of message conversion when part conversion fails.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + # Create mock parts and message + mock_a2a_part = Mock() + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + # Mock the part conversion to return None + mock_convert_part.return_value = None + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify basic conversion worked + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + # Parts will be empty since conversion returned None + assert len(result.content.parts) == 0 + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + def test_convert_a2a_message_to_event_default_author(self, mock_uuid): + """Test conversion with default author and no invocation context.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + mock_message = Mock(spec=Message) + mock_message.parts = [] + + # Mock UUID generation + mock_uuid.return_value = "generated-uuid" + + result = convert_a2a_message_to_event(mock_message) + + # Verify default author was used and UUID was generated for invocation_id + assert result.author == "a2a agent" + assert result.branch is None + assert result.invocation_id == "generated-uuid" From 3901fade71486a1e9677fe74a120c3f08efe9d9e Mon Sep 17 00:00:00 2001 From: SimonWei <119845914+simonwei97@users.noreply.github.com> Date: Wed, 25 Jun 2025 17:40:11 -0700 Subject: [PATCH 27/28] fix: converts litellm generate config err Merge https://github.com/google/adk-python/pull/1509 Fixs: #1302 Previous PR: https://github.com/google/adk-python/pull/1450 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1509 from simonwei97:fix/litellm-gen-config-converting-err 3120887f29a21789f1b4a7c54af3ed35eb5055e3 PiperOrigin-RevId: 775903671 --- src/google/adk/models/lite_llm.py | 60 +++++++++++++++++++++----- tests/unittests/models/test_litellm.py | 32 ++++++++++++++ 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 624b7adfc..39514d6f4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,6 +23,7 @@ from typing import Dict from typing import Generator from typing import Iterable +from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -485,16 +486,22 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> tuple[Iterable[Message], Iterable[dict]]: - """Converts an LlmRequest to litellm inputs. +) -> Tuple[ + List[Message], + Optional[List[Dict]], + Optional[types.SchemaUnion], + Optional[Dict], +]: + """Converts an LlmRequest to litellm inputs and extracts generation params. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary and response format). + The litellm inputs (message list, tool dictionary, response format and generation params). """ - messages = [] + # 1. Construct messages + messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -511,7 +518,8 @@ def _get_completion_inputs( ), ) - tools = None + # 2. Convert tool declarations + tools: Optional[List[Dict]] = None if ( llm_request.config and llm_request.config.tools @@ -522,12 +530,39 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - response_format = None - - if llm_request.config.response_schema: + # 3. Handle response format + response_format: Optional[types.SchemaUnion] = None + if llm_request.config and llm_request.config.response_schema: response_format = llm_request.config.response_schema - return messages, tools, response_format + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: + config_dict = llm_request.config.model_dump(exclude_none=True) + # Generate LiteLlm parameters here, + # Following https://docs.litellm.ai/docs/completion/input. + generation_params = {} + param_mapping = { + "max_output_tokens": "max_completion_tokens", + "stop_sequences": "stop", + } + for key in ( + "temperature", + "max_output_tokens", + "top_p", + "top_k", + "stop_sequences", + "presence_penalty", + "frequency_penalty", + ): + if key in config_dict: + mapped_key = param_mapping.get(key, key) + generation_params[mapped_key] = config_dict[key] + + if not generation_params: + generation_params = None + + return messages, tools, response_format, generation_params def _build_function_declaration_log( @@ -664,7 +699,9 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format = _get_completion_inputs(llm_request) + messages, tools, response_format, generation_params = ( + _get_completion_inputs(llm_request) + ) if "functions" in self._additional_args: # LiteLLM does not support both tools and functions together. @@ -678,6 +715,9 @@ async def generate_content_async( } completion_args.update(self._additional_args) + if generation_params: + completion_args.update(generation_params) + if stream: text = "" # Track function calls by index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index d058aa44d..b9b1fb409 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1447,3 +1447,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +def test_get_completion_inputs_generation_params(): + # Test that generation_params are extracted and mapped correctly + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]), + ], + config=types.GenerateContentConfig( + temperature=0.33, + max_output_tokens=123, + top_p=0.88, + top_k=7, + stop_sequences=["foo", "bar"], + presence_penalty=0.1, + frequency_penalty=0.2, + ), + ) + from google.adk.models.lite_llm import _get_completion_inputs + + _, _, _, generation_params = _get_completion_inputs(req) + assert generation_params["temperature"] == 0.33 + assert generation_params["max_completion_tokens"] == 123 + assert generation_params["top_p"] == 0.88 + assert generation_params["top_k"] == 7 + assert generation_params["stop"] == ["foo", "bar"] + assert generation_params["presence_penalty"] == 0.1 + assert generation_params["frequency_penalty"] == 0.2 + # Should not include max_output_tokens + assert "max_output_tokens" not in generation_params + assert "stop_sequences" not in generation_params From 04de3e197d7a57935488eb7bfa647c7ab62cd9d9 Mon Sep 17 00:00:00 2001 From: Ankur Sharma Date: Wed, 25 Jun 2025 18:31:25 -0700 Subject: [PATCH 28/28] fix: Adding detailed information on each metric evaluation Additionally, few other small changes. * Updated a test fixture to support the latest eval data schema. Somehow I missed doing that previously. * Updated the `evaluation_generator.py` to use `run_async`, instead of `run`. * Also, raise an informed error when dependencies required eval are not installed. * Also, changed the behavior of AgentEvaluator.evaluate method to run all the evals, instead of failing at the first eval metric failure. PiperOrigin-RevId: 775919127 --- src/google/adk/cli/cli_eval.py | 10 +- src/google/adk/cli/cli_tools_click.py | 2 +- src/google/adk/evaluation/agent_evaluator.py | 109 ++++++++++++-- src/google/adk/evaluation/constants.py | 20 +++ .../adk/evaluation/evaluation_generator.py | 2 +- .../trip_planner_agent/initial.session.json | 13 -- .../trip_planner_agent/trip_inquiry.test.json | 133 +++++++++++++++--- 7 files changed, 240 insertions(+), 49 deletions(-) create mode 100644 src/google/adk/evaluation/constants.py delete mode 100644 tests/integration/fixture/trip_planner_agent/initial.session.json diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index 13e205cb7..01b06135a 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -26,6 +26,7 @@ from ..agents import Agent from ..artifacts.base_artifact_service import BaseArtifactService +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..evaluation.eval_case import EvalCase from ..evaluation.eval_metrics import EvalMetric from ..evaluation.eval_metrics import EvalMetricResult @@ -38,10 +39,6 @@ logger = logging.getLogger("google_adk." + __name__) -MISSING_EVAL_DEPENDENCIES_MESSAGE = ( - "Eval module is not installed, please install via `pip install" - " google-adk[eval]`." -) TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score" RESPONSE_MATCH_SCORE_KEY = "response_match_score" # This evaluation is not very stable. @@ -150,7 +147,7 @@ async def run_evals( artifact_service: The artifact service to use during inferencing. """ try: - from ..evaluation.agent_evaluator import EvaluationGenerator + from ..evaluation.evaluation_generator import EvaluationGenerator except ModuleNotFoundError as e: raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e @@ -252,7 +249,8 @@ async def run_evals( result = "❌ Failed" print(f"Result: {result}\n") - + except ModuleNotFoundError as e: + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e except Exception: # Catching the general exception, so that we don't block other eval # cases. diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index c0935cceb..1bc7d5662 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -31,12 +31,12 @@ from . import cli_create from . import cli_deploy from .. import version +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..sessions.in_memory_session_service import InMemorySessionService from .cli import run_cli -from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE from .fast_api import get_fast_api_app from .utils import envs from .utils import evals diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py index 6ee001f9d..486d01cf1 100644 --- a/src/google/adk/evaluation/agent_evaluator.py +++ b/src/google/adk/evaluation/agent_evaluator.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import logging import os @@ -23,16 +25,16 @@ from typing import Union import uuid +from google.genai import types as genai_types from pydantic import ValidationError +from .constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from .eval_case import IntermediateData from .eval_set import EvalSet -from .evaluation_generator import EvaluationGenerator from .evaluator import EvalStatus from .evaluator import EvaluationResult from .evaluator import Evaluator from .local_eval_sets_manager import convert_eval_set_to_pydanctic_schema -from .response_evaluator import ResponseEvaluator -from .trajectory_evaluator import TrajectoryEvaluator logger = logging.getLogger("google_adk." + __name__) @@ -96,6 +98,7 @@ async def evaluate_eval_set( criteria: dict[str, float], num_runs=NUM_RUNS, agent_name=None, + print_detailed_results: bool = True, ): """Evaluates an agent using the given EvalSet. @@ -109,7 +112,13 @@ async def evaluate_eval_set( num_runs: Number of times all entries in the eval dataset should be assessed. agent_name: The name of the agent. + print_detailed_results: Whether to print detailed results for each metric + evaluation. """ + try: + from .evaluation_generator import EvaluationGenerator + except ModuleNotFoundError as e: + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e eval_case_responses_list = await EvaluationGenerator.generate_responses( eval_set=eval_set, agent_module_path=agent_module, @@ -117,6 +126,8 @@ async def evaluate_eval_set( agent_name=agent_name, ) + failures = [] + for eval_case_responses in eval_case_responses_list: actual_invocations = [ invocation @@ -139,10 +150,25 @@ async def evaluate_eval_set( ) ) - assert evaluation_result.overall_eval_status == EvalStatus.PASSED, ( - f"{metric_name} for {agent_module} Failed. Expected {threshold}," - f" but got {evaluation_result.overall_score}." - ) + if print_detailed_results: + AgentEvaluator._print_details( + evaluation_result=evaluation_result, + metric_name=metric_name, + threshold=threshold, + ) + + # Gather all the failures. + if evaluation_result.overall_eval_status != EvalStatus.PASSED: + failures.append( + f"{metric_name} for {agent_module} Failed. Expected {threshold}," + f" but got {evaluation_result.overall_score}." + ) + + assert not failures, ( + "Following are all the test failures. If you looking to get more" + " details on the failures, then please re-run this test with" + " `print_details` set to `True`.\n{}".format("\n".join(failures)) + ) @staticmethod async def evaluate( @@ -158,9 +184,10 @@ async def evaluate( agent_module: The path to python module that contains the definition of the agent. There is convention in place here, where the code is going to look for 'root_agent' in the loaded module. - eval_dataset_file_path_or_dir: The eval data set. This can be either a string representing - full path to the file containing eval dataset, or a directory that is - recursively explored for all files that have a `.test.json` suffix. + eval_dataset_file_path_or_dir: The eval data set. This can be either a + string representing full path to the file containing eval dataset, or a + directory that is recursively explored for all files that have a + `.test.json` suffix. num_runs: Number of times all entries in the eval dataset should be assessed. agent_name: The name of the agent. @@ -358,6 +385,11 @@ def _validate_input(eval_dataset, criteria): @staticmethod def _get_metric_evaluator(metric_name: str, threshold: float) -> Evaluator: + try: + from .response_evaluator import ResponseEvaluator + from .trajectory_evaluator import TrajectoryEvaluator + except ModuleNotFoundError as e: + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e if metric_name == TOOL_TRAJECTORY_SCORE_KEY: return TrajectoryEvaluator(threshold=threshold) elif ( @@ -367,3 +399,60 @@ def _get_metric_evaluator(metric_name: str, threshold: float) -> Evaluator: return ResponseEvaluator(threshold=threshold, metric_name=metric_name) raise ValueError(f"Unsupported eval metric: {metric_name}") + + @staticmethod + def _print_details( + evaluation_result: EvaluationResult, metric_name: str, threshold: float + ): + try: + from pandas import pandas as pd + from tabulate import tabulate + except ModuleNotFoundError as e: + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e + print( + f"Summary: `{evaluation_result.overall_eval_status}` for Metric:" + f" `{metric_name}`. Expected threshold: `{threshold}`, actual value:" + f" `{evaluation_result.overall_score}`." + ) + + data = [] + for per_invocation_result in evaluation_result.per_invocation_results: + data.append({ + "eval_status": per_invocation_result.eval_status, + "score": per_invocation_result.score, + "threshold": threshold, + "prompt": AgentEvaluator._convert_content_to_text( + per_invocation_result.expected_invocation.user_content + ), + "expected_response": AgentEvaluator._convert_content_to_text( + per_invocation_result.expected_invocation.final_response + ), + "actual_response": AgentEvaluator._convert_content_to_text( + per_invocation_result.actual_invocation.final_response + ), + "expected_tool_calls": AgentEvaluator._convert_tool_calls_to_text( + per_invocation_result.expected_invocation.intermediate_data + ), + "actual_tool_calls": AgentEvaluator._convert_tool_calls_to_text( + per_invocation_result.actual_invocation.intermediate_data + ), + }) + + print(tabulate(pd.DataFrame(data), headers="keys", tablefmt="grid")) + print("\n\n") # Few empty lines for visual clarity + + @staticmethod + def _convert_content_to_text(content: Optional[genai_types.Content]) -> str: + if content and content.parts: + return "\n".join([p.text for p in content.parts if p.text]) + + return "" + + @staticmethod + def _convert_tool_calls_to_text( + intermediate_data: Optional[IntermediateData], + ) -> str: + if intermediate_data and intermediate_data.tool_uses: + return "\n".join([str(t) for t in intermediate_data.tool_uses]) + + return "" diff --git a/src/google/adk/evaluation/constants.py b/src/google/adk/evaluation/constants.py new file mode 100644 index 000000000..74248ed18 --- /dev/null +++ b/src/google/adk/evaluation/constants.py @@ -0,0 +1,20 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +MISSING_EVAL_DEPENDENCIES_MESSAGE = ( + "Eval module is not installed, please install via `pip install" + " google-adk[eval]`." +) diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index fbf6ea8e2..1359967bc 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -182,7 +182,7 @@ async def _generate_inferences_from_root_agent( tool_uses = [] invocation_id = "" - for event in runner.run( + async for event in runner.run_async( user_id=user_id, session_id=session_id, new_message=user_content ): invocation_id = ( diff --git a/tests/integration/fixture/trip_planner_agent/initial.session.json b/tests/integration/fixture/trip_planner_agent/initial.session.json deleted file mode 100644 index b33840cda..000000000 --- a/tests/integration/fixture/trip_planner_agent/initial.session.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "id": "test_id", - "app_name": "trip_planner_agent", - "user_id": "test_user", - "state": { - "origin": "San Francisco", - "interests": "Food, Shopping, Museums", - "range": "1000 miles", - "cities": "" - }, - "events": [], - "last_update_time": 1741218714.258285 -} diff --git a/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json b/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json index c504f68e3..317599c6b 100644 --- a/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json +++ b/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json @@ -1,19 +1,116 @@ -[ - { - "query": "Hi, who are you? What can you do?", - "expected_tool_use": [], - "reference": "I am trip_planner, and my goal is to plan the best trip ever. I can describe why a city was chosen, list its top attractions, and provide a detailed itinerary for each day of the trip.\n" - }, - { - "query": "I want to travel from San Francisco to an European country in fall next year. I am considering London and Paris. What is your advice?", - "expected_tool_use": [ - { - "tool_name": "transfer_to_agent", - "tool_input": { - "agent_name": "indentify_agent" +{ + "eval_set_id": "e7996ccc-16bc-46bf-9a24-0a3ecc3dacd7", + "name": "e7996ccc-16bc-46bf-9a24-0a3ecc3dacd7", + "description": null, + "eval_cases": [ + { + "eval_id": "/google/src/cloud/ankusharma/CS-agent_evaluator-2025-06-17_115009/google3/third_party/py/google/adk/open_source_workspace/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json", + "conversation": [ + { + "invocation_id": "d7ff8ec1-290b-48c5-b3aa-05cb8f27b8ae", + "user_content": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "inline_data": null, + "file_data": null, + "thought_signature": null, + "code_execution_result": null, + "executable_code": null, + "function_call": null, + "function_response": null, + "text": "Hi, who are you? What can you do?" + } + ], + "role": "user" + }, + "final_response": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "inline_data": null, + "file_data": null, + "thought_signature": null, + "code_execution_result": null, + "executable_code": null, + "function_call": null, + "function_response": null, + "text": "I am trip_planner, and my goal is to plan the best trip ever. I can describe why a city was chosen, list its top attractions, and provide a detailed itinerary for each day of the trip.\n" + } + ], + "role": "model" + }, + "intermediate_data": { + "tool_uses": [], + "intermediate_responses": [] + }, + "creation_timestamp": 1750190885.419684 + }, + { + "invocation_id": "f515ff57-ff21-488f-ab92-7d7de5bb76fe", + "user_content": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "inline_data": null, + "file_data": null, + "thought_signature": null, + "code_execution_result": null, + "executable_code": null, + "function_call": null, + "function_response": null, + "text": "I want to travel from San Francisco to an European country in fall next year. I am considering London and Paris. What is your advice?" + } + ], + "role": "user" + }, + "final_response": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "inline_data": null, + "file_data": null, + "thought_signature": null, + "code_execution_result": null, + "executable_code": null, + "function_call": null, + "function_response": null, + "text": "Okay, I can help you analyze London and Paris to determine which city is better for your trip next fall. I will consider weather patterns, seasonal events, travel costs (including flights from San Francisco), and your interests (food, shopping, and museums). After gathering this information, I'll provide a detailed report on my chosen city.\n" + } + ], + "role": "model" + }, + "intermediate_data": { + "tool_uses": [ + { + "id": null, + "args": { + "agent_name": "indentify_agent" + }, + "name": "transfer_to_agent" + } + ], + "intermediate_responses": [] + }, + "creation_timestamp": 1750190885.4197457 } - } - ], - "reference": "Okay, I can help you analyze London and Paris to determine which city is better for your trip next fall. I will consider weather patterns, seasonal events, travel costs (including flights from San Francisco), and your interests (food, shopping, and museums). After gathering this information, I'll provide a detailed report on my chosen city.\n" - } -] + ], + "session_input": { + "app_name": "trip_planner_agent", + "user_id": "test_user", + "state": { + "origin": "San Francisco", + "interests": "Food, Shopping, Museums", + "range": "1000 miles", + "cities": "" + } + }, + "creation_timestamp": 1750190885.4197533 + } + ], + "creation_timestamp": 1750190885.4197605 +} \ No newline at end of file