From a7ea374dfbc125fcb6e8eb2fc447c09b93b2b200 Mon Sep 17 00:00:00 2001 From: "Wei Sun (Jack)" Date: Wed, 11 Jun 2025 18:24:10 -0700 Subject: [PATCH 01/61] chore: Update isort config to prevent vscode flickering PiperOrigin-RevId: 770406033 --- pyproject.toml | 44 +++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 158025c7e..20c0524a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,27 +25,27 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start - "authlib>=1.5.1", # For RestAPI Tool - "click>=8.1.8", # For CLI tools - "fastapi>=0.115.0", # FastAPI framework - "google-api-python-client>=2.157.0", # Google API client discovery - "google-cloud-aiplatform[agent_engines]>=1.95.1", # For VertexAI integrations, e.g. example store. - "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool - "google-cloud-speech>=2.30.0", # For Audio Transcription - "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service - "google-genai>=1.17.0", # Google GenAI SDK - "graphviz>=0.20.2", # Graphviz for graph rendering - "mcp>=1.8.0;python_version>='3.10'", # For MCP Toolset - "opentelemetry-api>=1.31.0", # OpenTelemetry + "authlib>=1.5.1", # For RestAPI Tool + "click>=8.1.8", # For CLI tools + "fastapi>=0.115.0", # FastAPI framework + "google-api-python-client>=2.157.0", # Google API client discovery + "google-cloud-aiplatform[agent_engines]>=1.95.1", # For VertexAI integrations, e.g. example store. + "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool + "google-cloud-speech>=2.30.0", # For Audio Transcription + "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service + "google-genai>=1.17.0", # Google GenAI SDK + "graphviz>=0.20.2", # Graphviz for graph rendering + "mcp>=1.8.0;python_version>='3.10'", # For MCP Toolset + "opentelemetry-api>=1.31.0", # OpenTelemetry "opentelemetry-exporter-gcp-trace>=1.9.0", "opentelemetry-sdk>=1.31.0", - "pydantic>=2.0, <3.0.0", # For data validation/models - "python-dotenv>=1.0.0", # To manage environment variables - "PyYAML>=6.0.2", # For APIHubToolset. - "sqlalchemy>=2.0", # SQL database ORM - "tzlocal>=5.3", # Time zone utilities + "pydantic>=2.0, <3.0.0", # For data validation/models + "python-dotenv>=1.0.0", # To manage environment variables + "PyYAML>=6.0.2", # For APIHubToolset. + "sqlalchemy>=2.0", # SQL database ORM + "tzlocal>=5.3", # Time zone utilities "typing-extensions>=4.5, <5", - "uvicorn>=0.34.0", # ASGI server for FastAPI + "uvicorn>=0.34.0", # ASGI server for FastAPI # go/keep-sorted end ] dynamic = ["version"] @@ -84,7 +84,7 @@ test = [ "anthropic>=0.43.0", # For anthropic model tests "langchain-community>=0.3.17", "langgraph>=0.2.60", # For LangGraphAgent - "litellm>=1.71.2", # For LiteLLM tests + "litellm>=1.71.2", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "pytest-asyncio>=0.25.0", @@ -140,24 +140,30 @@ pyink-annotation-pragmas = [ requires = ["flit_core >=3.8,<4"] build-backend = "flit_core.buildapi" + [tool.flit.sdist] include = ['src/**/*', 'README.md', 'pyproject.toml', 'LICENSE'] exclude = ['src/**/*.sh'] + [tool.flit.module] name = "google.adk" include = ["py.typed"] + [tool.isort] profile = "google" single_line_exclusions = [] +line_length = 200 # Prevent line wrap flickering. known_third_party = ["google.adk"] + [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" asyncio_mode = "auto" + [tool.mypy] python_version = "3.9" exclude = "tests/" From fc65873d7c31be607f6cd6690f142a031631582a Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 11 Jun 2025 19:45:46 -0700 Subject: [PATCH 02/61] chore: Set agent_engine_id in the service constructor, also use the agent_engine_id field instead of overriding app_name in FastAPI endpoint PiperOrigin-RevId: 770427903 --- src/google/adk/cli/fast_api.py | 35 ++----- .../adk/sessions/vertex_ai_session_service.py | 94 +++++++++++-------- .../test_vertex_ai_session_service.py | 18 +++- 3 files changed, 74 insertions(+), 73 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4ef8ae6c2..875c3cb7b 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -276,7 +276,6 @@ async def internal_lifespan(app: FastAPI): memory_service = InMemoryMemoryService() # Build the Session service - agent_engine_id = "" if session_service_uri: if session_service_uri.startswith("agentengine://"): # Create vertex session service @@ -285,8 +284,9 @@ async def internal_lifespan(app: FastAPI): raise click.ClickException("Agent engine id can not be empty.") envs.load_dotenv_for_agent("", agents_dir) session_service = VertexAiSessionService( - os.environ["GOOGLE_CLOUD_PROJECT"], - os.environ["GOOGLE_CLOUD_LOCATION"], + project=os.environ["GOOGLE_CLOUD_PROJECT"], + location=os.environ["GOOGLE_CLOUD_LOCATION"], + agent_engine_id=agent_engine_id, ) else: session_service = DatabaseSessionService(db_url=session_service_uri) @@ -357,8 +357,6 @@ def get_session_trace(session_id: str) -> Any: async def get_session( app_name: str, user_id: str, session_id: str ) -> Session: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -371,8 +369,6 @@ async def get_session( response_model_exclude_none=True, ) async def list_sessions(app_name: str, user_id: str) -> list[Session]: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name list_sessions_response = await session_service.list_sessions( app_name=app_name, user_id=user_id ) @@ -393,8 +389,6 @@ async def create_session_with_id( session_id: str, state: Optional[dict[str, Any]] = None, ) -> Session: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name if ( await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id @@ -419,8 +413,6 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, ) -> Session: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name logger.info("New session created") return await session_service.create_session( app_name=app_name, user_id=user_id, state=state @@ -660,8 +652,6 @@ def list_eval_results(app_name: str) -> list[str]: @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") async def delete_session(app_name: str, user_id: str, session_id: str): - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name await session_service.delete_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -677,7 +667,6 @@ async def load_artifact( artifact_name: str, version: Optional[int] = Query(None), ) -> Optional[types.Part]: - app_name = agent_engine_id if agent_engine_id else app_name artifact = await artifact_service.load_artifact( app_name=app_name, user_id=user_id, @@ -700,7 +689,6 @@ async def load_artifact_version( artifact_name: str, version_id: int, ) -> Optional[types.Part]: - app_name = agent_engine_id if agent_engine_id else app_name artifact = await artifact_service.load_artifact( app_name=app_name, user_id=user_id, @@ -719,7 +707,6 @@ async def load_artifact_version( async def list_artifact_names( app_name: str, user_id: str, session_id: str ) -> list[str]: - app_name = agent_engine_id if agent_engine_id else app_name return await artifact_service.list_artifact_keys( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -731,7 +718,6 @@ async def list_artifact_names( async def list_artifact_versions( app_name: str, user_id: str, session_id: str, artifact_name: str ) -> list[int]: - app_name = agent_engine_id if agent_engine_id else app_name return await artifact_service.list_versions( app_name=app_name, user_id=user_id, @@ -745,7 +731,6 @@ async def list_artifact_versions( async def delete_artifact( app_name: str, user_id: str, session_id: str, artifact_name: str ): - app_name = agent_engine_id if agent_engine_id else app_name await artifact_service.delete_artifact( app_name=app_name, user_id=user_id, @@ -755,10 +740,8 @@ async def delete_artifact( @app.post("/run", response_model_exclude_none=True) async def agent_run(req: AgentRunRequest) -> list[Event]: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else req.app_name session = await session_service.get_session( - app_name=app_name, user_id=req.user_id, session_id=req.session_id + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") @@ -776,11 +759,9 @@ async def agent_run(req: AgentRunRequest) -> list[Event]: @app.post("/run_sse") async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else req.app_name # SSE endpoint session = await session_service.get_session( - app_name=app_name, user_id=req.user_id, session_id=req.session_id + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") @@ -818,8 +799,6 @@ async def event_generator(): async def get_event_graph( app_name: str, user_id: str, session_id: str, event_id: str ): - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -875,8 +854,6 @@ async def agent_live_run( ) -> None: await websocket.accept() - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -940,7 +917,7 @@ async def _get_runner_async(app_name: str) -> Runner: return runner_dict[app_name] root_agent = agent_loader.load_agent(app_name) runner = Runner( - app_name=agent_engine_id if agent_engine_id else app_name, + app_name=app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 5d6bed2e6..258dcd933 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -22,7 +22,6 @@ import urllib.parse from dateutil import parser -from google.genai import types from typing_extensions import override from google import genai @@ -40,15 +39,27 @@ class VertexAiSessionService(BaseSessionService): - """Connects to the managed Vertex AI Session Service.""" + """Connects to the Vertex AI Agent Engine Session Service using GenAI API client. + + https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/sessions/overview + """ def __init__( self, - project: str = None, - location: str = None, + project: Optional[str] = None, + location: Optional[str] = None, + agent_engine_id: Optional[str] = None, ): - self.project = project - self.location = location + """Initializes the VertexAiSessionService. + + Args: + project: The project id of the project to use. + location: The location of the project to use. + agent_engine_id: The resource ID of the agent engine to use. + """ + self._project = project + self._location = location + self._agent_engine_id = agent_engine_id @override async def create_session( @@ -64,14 +75,13 @@ async def create_session( 'User-provided Session id is not supported for' ' VertexAISessionService.' ) - - reasoning_engine_id = _parse_reasoning_engine_id(app_name) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() session_json_dict = {'user_id': user_id} if state: session_json_dict['session_state'] = state - api_client = _get_api_client(self.project, self.location) api_response = await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions', @@ -130,10 +140,10 @@ async def get_session( session_id: str, config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: - reasoning_engine_id = _parse_reasoning_engine_id(app_name) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() # Get session resource - api_client = _get_api_client(self.project, self.location) get_session_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', @@ -203,14 +213,14 @@ async def get_session( async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: - reasoning_engine_id = _parse_reasoning_engine_id(app_name) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() path = f'reasoningEngines/{reasoning_engine_id}/sessions' if user_id: parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='') path = path + f'?filter=user_id={parsed_user_id}' - api_client = _get_api_client(self.project, self.location) api_response = await api_client.async_request( http_method='GET', path=path, @@ -236,8 +246,9 @@ async def list_sessions( async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: - reasoning_engine_id = _parse_reasoning_engine_id(app_name) - api_client = _get_api_client(self.project, self.location) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() + try: await api_client.async_request( http_method='DELETE', @@ -253,8 +264,8 @@ async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. await super().append_event(session=session, event=event) - reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) - api_client = _get_api_client(self.project, self.location) + reasoning_engine_id = self._get_reasoning_engine_id(session.app_name) + api_client = self._get_api_client() await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', @@ -262,15 +273,34 @@ async def append_event(self, session: Session, event: Event) -> Event: ) return event + def _get_reasoning_engine_id(self, app_name: str): + if self._agent_engine_id: + return self._agent_engine_id -def _get_api_client(project: str, location: str): - """Instantiates an API client for the given project and location. + if app_name.isdigit(): + return app_name - It needs to be instantiated inside each request so that the event loop - management. - """ - client = genai.Client(vertexai=True, project=project, location=location) - return client._api_client + pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' + match = re.fullmatch(pattern, app_name) + + if not bool(match): + raise ValueError( + f'App name {app_name} is not valid. It should either be the full' + ' ReasoningEngine resource name, or the reasoning engine id.' + ) + + return match.groups()[-1] + + def _get_api_client(self): + """Instantiates an API client for the given project and location. + + It needs to be instantiated inside each request so that the event loop + management can be properly propagated. + """ + client = genai.Client( + vertexai=True, project=self._project, location=self._location + ) + return client._api_client def _convert_event_to_json(event: Event) -> Dict[str, Any]: @@ -366,19 +396,3 @@ def _from_api_event(api_event: Dict[str, Any]) -> Event: ) return event - - -def _parse_reasoning_engine_id(app_name: str): - if app_name.isdigit(): - return app_name - - pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' - match = re.fullmatch(pattern, app_name) - - if not bool(match): - raise ValueError( - f'App name {app_name} is not valid. It should either be the full' - ' ReasoningEngine resource name, or the reasoning engine id.' - ) - - return match.groups()[-1] diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 92f6a29dd..6a9e0b46a 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -245,8 +245,14 @@ async def async_request( raise ValueError(f'Unsupported http method: {http_method}') -def mock_vertex_ai_session_service(): +def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None): """Creates a mock Vertex AI Session service for testing.""" + if agent_engine_id: + return VertexAiSessionService( + project='test-project', + location='test-location', + agent_engine_id=agent_engine_id, + ) return VertexAiSessionService( project='test-project', location='test-location' ) @@ -265,7 +271,7 @@ def mock_get_api_client(): '2': (MOCK_EVENT_JSON_2, 'my_token'), } with mock.patch( - 'google.adk.sessions.vertex_ai_session_service._get_api_client', + 'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client', return_value=api_client, ): yield @@ -273,8 +279,12 @@ def mock_get_api_client(): @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') -async def test_get_empty_session(): - session_service = mock_vertex_ai_session_service() +@pytest.mark.parametrize('agent_engine_id', [None, '123']) +async def test_get_empty_session(agent_engine_id): + if agent_engine_id: + session_service = mock_vertex_ai_session_service(agent_engine_id) + else: + session_service = mock_vertex_ai_session_service() with pytest.raises(ValueError) as excinfo: await session_service.get_session( app_name='123', user_id='user', session_id='0' From b08bdbcd7f03fb3e2b4a7ea3456eed7335ed9e0e Mon Sep 17 00:00:00 2001 From: "Wei Sun (Jack)" Date: Wed, 11 Jun 2025 22:26:45 -0700 Subject: [PATCH 03/61] chore: Fixes the sample of example_tool The `Example` objects should be strong typed. PiperOrigin-RevId: 770476132 --- contributing/samples/hello_world_ma/agent.py | 79 +++++++++++--------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/contributing/samples/hello_world_ma/agent.py b/contributing/samples/hello_world_ma/agent.py index 98e79a397..a6bf78a9e 100755 --- a/contributing/samples/hello_world_ma/agent.py +++ b/contributing/samples/hello_world_ma/agent.py @@ -15,6 +15,7 @@ import random from google.adk.agents import Agent +from google.adk.examples.example import Example from google.adk.tools.example_tool import ExampleTool from google.genai import types @@ -66,43 +67,47 @@ def check_prime(nums: list[int]) -> str: ) -example_tool = ExampleTool([ - { - "input": { - "role": "user", - "parts": [{"text": "Roll a 6-sided die."}], - }, - "output": [ - {"role": "model", "parts": [{"text": "I rolled a 4 for you."}]} - ], - }, - { - "input": { - "role": "user", - "parts": [{"text": "Is 7 a prime number?"}], - }, - "output": [{ - "role": "model", - "parts": [{"text": "Yes, 7 is a prime number."}], - }], - }, - { - "input": { - "role": "user", - "parts": [{"text": "Roll a 10-sided die and check if it's prime."}], - }, - "output": [ - { - "role": "model", - "parts": [{"text": "I rolled an 8 for you."}], - }, - { - "role": "model", - "parts": [{"text": "8 is not a prime number."}], - }, - ], - }, -]) +example_tool = ExampleTool( + examples=[ + Example( + input=types.UserContent( + parts=[types.Part(text="Roll a 6-sided die.")] + ), + output=[ + types.ModelContent( + parts=[types.Part(text="I rolled a 4 for you.")] + ) + ], + ), + Example( + input=types.UserContent( + parts=[types.Part(text="Is 7 a prime number?")] + ), + output=[ + types.ModelContent( + parts=[types.Part(text="Yes, 7 is a prime number.")] + ) + ], + ), + Example( + input=types.UserContent( + parts=[ + types.Part( + text="Roll a 10-sided die and check if it's prime." + ) + ] + ), + output=[ + types.ModelContent( + parts=[types.Part(text="I rolled an 8 for you.")] + ), + types.ModelContent( + parts=[types.Part(text="8 is not a prime number.")] + ), + ], + ), + ] +) prime_agent = Agent( name="prime_agent", From bbceb4f2e89f720533b99cf356c532024a120dc4 Mon Sep 17 00:00:00 2001 From: Koichi Shiraishi Date: Wed, 11 Jun 2025 22:32:31 -0700 Subject: [PATCH 04/61] fix: remove unnecessary double quote on Claude docstring Merge https://github.com/google/adk-python/pull/1266 Subject says it all. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/1266 from zchee:fix-docstring b5ceacc82b45398e9421c0f9a1a4d6352d12e21a PiperOrigin-RevId: 770478088 --- src/google/adk/models/anthropic_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 96b95ac5a..a3a0e0962 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -201,7 +201,7 @@ def function_declaration_to_tool_param( class Claude(BaseLlm): - """ "Integration with Claude models served from Vertex AI. + """Integration with Claude models served from Vertex AI. Attributes: model: The name of the Claude model. From 1551bd4f4d7042fffb497d9308b05f92d45d818f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Jun 2025 23:02:20 -0700 Subject: [PATCH 05/61] feat: Re-factor some eval sets manager logic, and implement GcsEvalSetsManager to handle storage of eval sets on GCS Eval sets will be stored as json files under `gs://{bucket_name}/{app_name}/evals/eval_sets/` PiperOrigin-RevId: 770487129 --- .../evaluation/_eval_sets_manager_utils.py | 108 ++++ .../adk/evaluation/gcs_eval_sets_manager.py | 196 +++++++ .../adk/evaluation/local_eval_sets_manager.py | 105 +--- .../evaluation/test_gcs_eval_sets_manager.py | 518 ++++++++++++++++++ .../test_local_eval_sets_manager.py | 87 ++- 5 files changed, 916 insertions(+), 98 deletions(-) create mode 100644 src/google/adk/evaluation/_eval_sets_manager_utils.py create mode 100644 src/google/adk/evaluation/gcs_eval_sets_manager.py create mode 100644 tests/unittests/evaluation/test_gcs_eval_sets_manager.py diff --git a/src/google/adk/evaluation/_eval_sets_manager_utils.py b/src/google/adk/evaluation/_eval_sets_manager_utils.py new file mode 100644 index 000000000..b7e12dd37 --- /dev/null +++ b/src/google/adk/evaluation/_eval_sets_manager_utils.py @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional + +from ..errors.not_found_error import NotFoundError +from .eval_case import EvalCase +from .eval_set import EvalSet +from .eval_sets_manager import EvalSetsManager + +logger = logging.getLogger("google_adk." + __name__) + + +def get_eval_set_from_app_and_id( + eval_sets_manager: EvalSetsManager, app_name: str, eval_set_id: str +) -> EvalSet: + """Returns an EvalSet if found, otherwise raises NotFoundError.""" + eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id) + if not eval_set: + raise NotFoundError(f"Eval set `{eval_set_id}` not found.") + return eval_set + + +def get_eval_case_from_eval_set( + eval_set: EvalSet, eval_case_id: str +) -> Optional[EvalCase]: + """Returns an EvalCase if found, otherwise None.""" + eval_case_to_find = None + + # Look up the eval case by eval_case_id + for eval_case in eval_set.eval_cases: + if eval_case.eval_id == eval_case_id: + eval_case_to_find = eval_case + break + + return eval_case_to_find + + +def add_eval_case_to_eval_set( + eval_set: EvalSet, eval_case: EvalCase +) -> EvalSet: + """Adds an eval case to an eval set and returns the updated eval set.""" + eval_case_id = eval_case.eval_id + + if [x for x in eval_set.eval_cases if x.eval_id == eval_case_id]: + raise ValueError( + f"Eval id `{eval_case_id}` already exists in `{eval_set.eval_set_id}`" + " eval set.", + ) + + eval_set.eval_cases.append(eval_case) + return eval_set + + +def update_eval_case_in_eval_set( + eval_set: EvalSet, updated_eval_case: EvalCase +) -> EvalSet: + """Updates an eval case in an eval set and returns the updated eval set.""" + # Find the eval case to be updated. + eval_case_id = updated_eval_case.eval_id + eval_case_to_update = get_eval_case_from_eval_set(eval_set, eval_case_id) + + if not eval_case_to_update: + raise NotFoundError( + f"Eval case `{eval_case_id}` not found in eval set" + f" `{eval_set.eval_set_id}`." + ) + + # Remove the existing eval case and add the updated eval case. + eval_set.eval_cases.remove(eval_case_to_update) + eval_set.eval_cases.append(updated_eval_case) + return eval_set + + +def delete_eval_case_from_eval_set( + eval_set: EvalSet, eval_case_id: str +) -> EvalSet: + """Deletes an eval case from an eval set and returns the updated eval set.""" + # Find the eval case to be deleted. + eval_case_to_delete = get_eval_case_from_eval_set(eval_set, eval_case_id) + + if not eval_case_to_delete: + raise NotFoundError( + f"Eval case `{eval_case_id}` not found in eval set" + f" `{eval_set.eval_set_id}`." + ) + + # Remove the existing eval case. + logger.info( + "EvalCase`%s` was found in the eval set. It will be removed permanently.", + eval_case_id, + ) + eval_set.eval_cases.remove(eval_case_to_delete) + return eval_set diff --git a/src/google/adk/evaluation/gcs_eval_sets_manager.py b/src/google/adk/evaluation/gcs_eval_sets_manager.py new file mode 100644 index 000000000..fe5d8c9b5 --- /dev/null +++ b/src/google/adk/evaluation/gcs_eval_sets_manager.py @@ -0,0 +1,196 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import re +import time +from typing import Optional + +from google.cloud import exceptions as cloud_exceptions +from google.cloud import storage +from typing_extensions import override + +from ._eval_sets_manager_utils import add_eval_case_to_eval_set +from ._eval_sets_manager_utils import delete_eval_case_from_eval_set +from ._eval_sets_manager_utils import get_eval_case_from_eval_set +from ._eval_sets_manager_utils import get_eval_set_from_app_and_id +from ._eval_sets_manager_utils import update_eval_case_in_eval_set +from .eval_case import EvalCase +from .eval_set import EvalSet +from .eval_sets_manager import EvalSetsManager + +logger = logging.getLogger("google_adk." + __name__) + +_EVAL_SETS_DIR = "evals/eval_sets" +_EVAL_SET_FILE_EXTENSION = ".evalset.json" + + +class GcsEvalSetsManager(EvalSetsManager): + """An EvalSetsManager that stores eval sets in a GCS bucket.""" + + def __init__(self, bucket_name: str, **kwargs): + """Initializes the GcsEvalSetsManager. + + Args: + bucket_name: The name of the bucket to use. + **kwargs: Keyword arguments to pass to the Google Cloud Storage client. + """ + self.bucket_name = bucket_name + self.storage_client = storage.Client(**kwargs) + self.bucket = self.storage_client.bucket(self.bucket_name) + # Check if the bucket exists. + if not self.bucket.exists(): + raise ValueError( + f"Bucket `{self.bucket_name}` does not exist. Please create it " + "before using the GcsEvalSetsManager." + ) + + def _get_eval_sets_dir(self, app_name: str) -> str: + return f"{app_name}/{_EVAL_SETS_DIR}" + + def _get_eval_set_blob_name(self, app_name: str, eval_set_id: str) -> str: + eval_sets_dir = self._get_eval_sets_dir(app_name) + return f"{eval_sets_dir}/{eval_set_id}{_EVAL_SET_FILE_EXTENSION}" + + def _validate_id(self, id_name: str, id_value: str): + pattern = r"^[a-zA-Z0-9_]+$" + if not bool(re.fullmatch(pattern, id_value)): + raise ValueError( + f"Invalid {id_name}. {id_name} should have the `{pattern}` format", + ) + + def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet): + """Writes an EvalSet to GCS.""" + blob = self.bucket.blob(blob_name) + blob.upload_from_string( + eval_set.model_dump_json(indent=2), + content_type="application/json", + ) + + def _save_eval_set(self, app_name: str, eval_set_id: str, eval_set: EvalSet): + eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id) + self._write_eval_set_to_blob(eval_set_blob_name, eval_set) + + @override + def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]: + """Returns an EvalSet identified by an app_name and eval_set_id.""" + eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id) + blob = self.bucket.blob(eval_set_blob_name) + if not blob.exists(): + return None + eval_set_data = blob.download_as_text() + return EvalSet.model_validate_json(eval_set_data) + + @override + def create_eval_set(self, app_name: str, eval_set_id: str): + """Creates an empty EvalSet and saves it to GCS.""" + self._validate_id(id_name="Eval Set Id", id_value=eval_set_id) + new_eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id) + if self.bucket.blob(new_eval_set_blob_name).exists(): + raise ValueError( + f"Eval set `{eval_set_id}` already exists for app `{app_name}`." + ) + logger.info("Creating eval set blob: `%s`", new_eval_set_blob_name) + new_eval_set = EvalSet( + eval_set_id=eval_set_id, + name=eval_set_id, + eval_cases=[], + creation_timestamp=time.time(), + ) + self._write_eval_set_to_blob(new_eval_set_blob_name, new_eval_set) + + @override + def list_eval_sets(self, app_name: str) -> list[str]: + """Returns a list of EvalSet ids that belong to the given app_name.""" + eval_sets_dir = self._get_eval_sets_dir(app_name) + eval_sets = [] + try: + for blob in self.bucket.list_blobs(prefix=eval_sets_dir): + if not blob.name.endswith(_EVAL_SET_FILE_EXTENSION): + continue + eval_set_id = blob.name.split("/")[-1].removesuffix( + _EVAL_SET_FILE_EXTENSION + ) + eval_sets.append(eval_set_id) + return sorted(eval_sets) + except cloud_exceptions.NotFound as e: + raise ValueError( + f"App `{app_name}` not found in GCS bucket `{self.bucket_name}`." + ) from e + + @override + def get_eval_case( + self, app_name: str, eval_set_id: str, eval_case_id: str + ) -> Optional[EvalCase]: + """Returns an EvalCase identified by an app_name, eval_set_id and eval_case_id.""" + eval_set = self.get_eval_set(app_name, eval_set_id) + if not eval_set: + return None + return get_eval_case_from_eval_set(eval_set, eval_case_id) + + @override + def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: EvalCase): + """Adds the given EvalCase to an existing EvalSet. + + Args: + app_name: The name of the app. + eval_set_id: The id of the eval set containing the eval case to update. + eval_case: The EvalCase to add. + + Raises: + NotFoundError: If the eval set is not found. + ValueError: If the eval case already exists in the eval set. + """ + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = add_eval_case_to_eval_set(eval_set, eval_case) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) + + @override + def update_eval_case( + self, app_name: str, eval_set_id: str, updated_eval_case: EvalCase + ): + """Updates an existing EvalCase. + + Args: + app_name: The name of the app. + eval_set_id: The id of the eval set containing the eval case to update. + updated_eval_case: The updated EvalCase. Overwrites the existing EvalCase + using the eval_id field. + + Raises: + NotFoundError: If the eval set or the eval case is not found. + """ + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = update_eval_case_in_eval_set(eval_set, updated_eval_case) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) + + @override + def delete_eval_case( + self, app_name: str, eval_set_id: str, eval_case_id: str + ): + """Deletes the EvalCase with the given eval_case_id from the given EvalSet. + + Args: + app_name: The name of the app. + eval_set_id: The id of the eval set containing the eval case to delete. + eval_case_id: The id of the eval case to delete. + + Raises: + NotFoundError: If the eval set or the eval case to delete is not found. + """ + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = delete_eval_case_from_eval_set(eval_set, eval_case_id) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) diff --git a/src/google/adk/evaluation/local_eval_sets_manager.py b/src/google/adk/evaluation/local_eval_sets_manager.py index e01ecd357..0e93b9201 100644 --- a/src/google/adk/evaluation/local_eval_sets_manager.py +++ b/src/google/adk/evaluation/local_eval_sets_manager.py @@ -27,7 +27,11 @@ from pydantic import ValidationError from typing_extensions import override -from ..errors.not_found_error import NotFoundError +from ._eval_sets_manager_utils import add_eval_case_to_eval_set +from ._eval_sets_manager_utils import delete_eval_case_from_eval_set +from ._eval_sets_manager_utils import get_eval_case_from_eval_set +from ._eval_sets_manager_utils import get_eval_set_from_app_and_id +from ._eval_sets_manager_utils import update_eval_case_in_eval_set from .eval_case import EvalCase from .eval_case import IntermediateData from .eval_case import Invocation @@ -218,7 +222,7 @@ def create_eval_set(self, app_name: str, eval_set_id: str): eval_cases=[], creation_timestamp=time.time(), ) - self._write_eval_set(new_eval_set_path, new_eval_set) + self._write_eval_set_to_path(new_eval_set_path, new_eval_set) @override def list_eval_sets(self, app_name: str) -> list[str]: @@ -233,51 +237,27 @@ def list_eval_sets(self, app_name: str) -> list[str]: return sorted(eval_sets) - @override - def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: EvalCase): - """Adds the given EvalCase to an existing EvalSet identified by app_name and eval_set_id. - - Raises: - NotFoundError: If the eval set is not found. - """ - eval_case_id = eval_case.eval_id - self._validate_id(id_name="Eval Case Id", id_value=eval_case_id) - - eval_set = self.get_eval_set(app_name, eval_set_id) - - if not eval_set: - raise NotFoundError(f"Eval set `{eval_set_id}` not found.") - - if [x for x in eval_set.eval_cases if x.eval_id == eval_case_id]: - raise ValueError( - f"Eval id `{eval_case_id}` already exists in `{eval_set_id}`" - " eval set.", - ) - - eval_set.eval_cases.append(eval_case) - - eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) - self._write_eval_set(eval_set_file_path, eval_set) - @override def get_eval_case( self, app_name: str, eval_set_id: str, eval_case_id: str ) -> Optional[EvalCase]: """Returns an EvalCase if found, otherwise None.""" eval_set = self.get_eval_set(app_name, eval_set_id) - if not eval_set: return None + return get_eval_case_from_eval_set(eval_set, eval_case_id) - eval_case_to_find = None + @override + def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: EvalCase): + """Adds the given EvalCase to an existing EvalSet identified by app_name and eval_set_id. - # Look up the eval case by eval_case_id - for eval_case in eval_set.eval_cases: - if eval_case.eval_id == eval_case_id: - eval_case_to_find = eval_case - break + Raises: + NotFoundError: If the eval set is not found. + """ + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = add_eval_case_to_eval_set(eval_set, eval_case) - return eval_case_to_find + self._save_eval_set(app_name, eval_set_id, updated_eval_set) @override def update_eval_case( @@ -288,28 +268,9 @@ def update_eval_case( Raises: NotFoundError: If the eval set or the eval case is not found. """ - eval_case_id = updated_eval_case.eval_id - - # Find the eval case to be updated. - eval_case_to_update = self.get_eval_case( - app_name, eval_set_id, eval_case_id - ) - - if eval_case_to_update: - # Remove the eval case from the existing eval set. - eval_set = self.get_eval_set(app_name, eval_set_id) - eval_set.eval_cases.remove(eval_case_to_update) - - # Add the updated eval case to the existing eval set. - eval_set.eval_cases.append(updated_eval_case) - - # Persit the eval set. - eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) - self._write_eval_set(eval_set_file_path, eval_set) - else: - raise NotFoundError( - f"Eval Set `{eval_set_id}` or Eval id `{eval_case_id}` not found.", - ) + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = update_eval_case_in_eval_set(eval_set, updated_eval_case) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) @override def delete_eval_case( @@ -320,25 +281,9 @@ def delete_eval_case( Raises: NotFoundError: If the eval set or the eval case to delete is not found. """ - # Find the eval case that needs to be deleted. - eval_case_to_remove = self.get_eval_case( - app_name, eval_set_id, eval_case_id - ) - - if eval_case_to_remove: - logger.info( - "EvalCase`%s` was found in the eval set. It will be removed" - " permanently.", - eval_case_id, - ) - eval_set = self.get_eval_set(app_name, eval_set_id) - eval_set.eval_cases.remove(eval_case_to_remove) - eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) - self._write_eval_set(eval_set_file_path, eval_set) - else: - raise NotFoundError( - f"Eval Set `{eval_set_id}` or Eval id `{eval_case_id}` not found.", - ) + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = delete_eval_case_from_eval_set(eval_set, eval_case_id) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: return os.path.join( @@ -354,6 +299,10 @@ def _validate_id(self, id_name: str, id_value: str): f"Invalid {id_name}. {id_name} should have the `{pattern}` format", ) - def _write_eval_set(self, eval_set_path: str, eval_set: EvalSet): + def _write_eval_set_to_path(self, eval_set_path: str, eval_set: EvalSet): with open(eval_set_path, "w") as f: f.write(eval_set.model_dump_json(indent=2)) + + def _save_eval_set(self, app_name: str, eval_set_id: str, eval_set: EvalSet): + eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) + self._write_eval_set_to_path(eval_set_file_path, eval_set) diff --git a/tests/unittests/evaluation/test_gcs_eval_sets_manager.py b/tests/unittests/evaluation/test_gcs_eval_sets_manager.py new file mode 100644 index 000000000..9b670a941 --- /dev/null +++ b/tests/unittests/evaluation/test_gcs_eval_sets_manager.py @@ -0,0 +1,518 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from typing import Union + +from google.adk.errors.not_found_error import NotFoundError +from google.adk.evaluation.eval_case import EvalCase +from google.adk.evaluation.eval_set import EvalSet +from google.adk.evaluation.gcs_eval_sets_manager import _EVAL_SET_FILE_EXTENSION +from google.adk.evaluation.gcs_eval_sets_manager import GcsEvalSetsManager +import pytest + + +class MockBlob: + """Mocks a GCS Blob object. + + This class provides mock implementations for a few common GCS Blob methods, + allowing the user to test code that interacts with GCS without actually + connecting to a real bucket. + """ + + def __init__(self, name: str) -> None: + """Initializes a MockBlob. + + Args: + name: The name of the blob. + """ + self.name = name + self.content: Optional[bytes] = None + self.content_type: Optional[str] = None + self._exists: bool = False + + def upload_from_string( + self, data: Union[str, bytes], content_type: Optional[str] = None + ) -> None: + """Mocks uploading data to the blob (from a string or bytes). + + Args: + data: The data to upload (string or bytes). + content_type: The content type of the data (optional). + """ + if isinstance(data, str): + self.content = data.encode("utf-8") + elif isinstance(data, bytes): + self.content = data + else: + raise TypeError("data must be str or bytes") + + if content_type: + self.content_type = content_type + self._exists = True + + def download_as_text(self) -> str: + """Mocks downloading the blob's content as text. + + Returns: + str: The content of the blob as text. + + Raises: + Exception: If the blob doesn't exist (hasn't been uploaded to). + """ + if self.content is None: + return b"" + return self.content + + def delete(self) -> None: + """Mocks deleting a blob.""" + self.content = None + self.content_type = None + self._exists = False + + def exists(self) -> bool: + """Mocks checking if the blob exists.""" + return self._exists + + +class MockBucket: + """Mocks a GCS Bucket object.""" + + def __init__(self, name: str) -> None: + """Initializes a MockBucket. + + Args: + name: The name of the bucket. + """ + self.name = name + self.blobs: dict[str, MockBlob] = {} + + def blob(self, blob_name: str) -> MockBlob: + """Mocks getting a Blob object (doesn't create it in storage). + + Args: + blob_name: The name of the blob. + + Returns: + A MockBlob instance. + """ + if blob_name not in self.blobs: + self.blobs[blob_name] = MockBlob(blob_name) + return self.blobs[blob_name] + + def list_blobs(self, prefix: Optional[str] = None) -> list[MockBlob]: + """Mocks listing blobs in a bucket, optionally with a prefix.""" + if prefix: + return [ + blob for name, blob in self.blobs.items() if name.startswith(prefix) + ] + return list(self.blobs.values()) + + def exists(self) -> bool: + """Mocks checking if the bucket exists.""" + return True + + +class MockClient: + """Mocks the GCS Client.""" + + def __init__(self) -> None: + """Initializes MockClient.""" + self.buckets: dict[str, MockBucket] = {} + + def bucket(self, bucket_name: str) -> MockBucket: + """Mocks getting a Bucket object.""" + if bucket_name not in self.buckets: + self.buckets[bucket_name] = MockBucket(bucket_name) + return self.buckets[bucket_name] + + +class TestGcsEvalSetsManager: + """Tests for GcsEvalSetsManager.""" + + @pytest.fixture + def gcs_eval_sets_manager(self, mocker): + mock_storage_client = MockClient() + bucket_name = "test_bucket" + mock_bucket = MockBucket(bucket_name) + mocker.patch.object(mock_storage_client, "bucket", return_value=mock_bucket) + mocker.patch( + "google.cloud.storage.Client", return_value=mock_storage_client + ) + return GcsEvalSetsManager(bucket_name=bucket_name) + + def test_gcs_eval_sets_manager_get_eval_set_success( + self, gcs_eval_sets_manager + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mock_bucket = gcs_eval_sets_manager.bucket + mock_blob = mock_bucket.blob( + f"{app_name}/evals/eval_sets/{eval_set_id}{_EVAL_SET_FILE_EXTENSION}" + ) + mock_blob.upload_from_string(mock_eval_set.model_dump_json()) + + eval_set = gcs_eval_sets_manager.get_eval_set(app_name, eval_set_id) + + assert eval_set == mock_eval_set + + def test_gcs_eval_sets_manager_get_eval_set_not_found( + self, gcs_eval_sets_manager + ): + app_name = "test_app" + eval_set_id = "test_eval_set_not_exist" + eval_set = gcs_eval_sets_manager.get_eval_set(app_name, eval_set_id) + + assert eval_set is None + + def test_gcs_eval_sets_manager_create_eval_set_success( + self, gcs_eval_sets_manager, mocker + ): + mocked_time = 12345678 + mocker.patch("time.time", return_value=mocked_time) + app_name = "test_app" + eval_set_id = "test_eval_set" + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, + "_write_eval_set_to_blob", + ) + eval_set_blob_name = gcs_eval_sets_manager._get_eval_set_blob_name( + app_name, eval_set_id + ) + + gcs_eval_sets_manager.create_eval_set(app_name, eval_set_id) + + mock_write_eval_set_to_blob.assert_called_once_with( + eval_set_blob_name, + EvalSet( + eval_set_id=eval_set_id, + name=eval_set_id, + eval_cases=[], + creation_timestamp=mocked_time, + ), + ) + + def test_gcs_eval_sets_manager_create_eval_set_invalid_id( + self, gcs_eval_sets_manager + ): + app_name = "test_app" + eval_set_id = "invalid-id" + + with pytest.raises(ValueError, match="Invalid Eval Set Id"): + gcs_eval_sets_manager.create_eval_set(app_name, eval_set_id) + + def test_gcs_eval_sets_manager_list_eval_sets_success( + self, gcs_eval_sets_manager + ): + app_name = "test_app" + mock_blob_1 = MockBlob( + f"test_app/evals/eval_sets/eval_set_1{_EVAL_SET_FILE_EXTENSION}" + ) + mock_blob_2 = MockBlob( + f"test_app/evals/eval_sets/eval_set_2{_EVAL_SET_FILE_EXTENSION}" + ) + mock_blob_3 = MockBlob("test_app/evals/eval_sets/not_an_eval_set.txt") + mock_bucket = gcs_eval_sets_manager.bucket + mock_bucket.blobs = { + mock_blob_1.name: mock_blob_1, + mock_blob_2.name: mock_blob_2, + mock_blob_3.name: mock_blob_3, + } + + eval_sets = gcs_eval_sets_manager.list_eval_sets(app_name) + + assert eval_sets == ["eval_set_1", "eval_set_2"] + + def test_gcs_eval_sets_manager_add_eval_case_success( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + eval_set_blob_name = gcs_eval_sets_manager._get_eval_set_blob_name( + app_name, eval_set_id + ) + + gcs_eval_sets_manager.add_eval_case(app_name, eval_set_id, mock_eval_case) + + assert len(mock_eval_set.eval_cases) == 1 + assert mock_eval_set.eval_cases[0] == mock_eval_case + mock_write_eval_set_to_blob.assert_called_once_with( + eval_set_blob_name, mock_eval_set + ) + + def test_gcs_eval_sets_manager_add_eval_case_eval_set_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=None + ) + + with pytest.raises( + NotFoundError, match="Eval set `test_eval_set` not found." + ): + gcs_eval_sets_manager.add_eval_case(app_name, eval_set_id, mock_eval_case) + + def test_gcs_eval_sets_manager_add_eval_case_eval_case_id_exists( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet( + eval_set_id=eval_set_id, eval_cases=[mock_eval_case] + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + + with pytest.raises( + ValueError, + match=( + f"Eval id `{eval_case_id}` already exists in `{eval_set_id}` eval" + " set." + ), + ): + gcs_eval_sets_manager.add_eval_case(app_name, eval_set_id, mock_eval_case) + + def test_gcs_eval_sets_manager_get_eval_case_success( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet( + eval_set_id=eval_set_id, eval_cases=[mock_eval_case] + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + + eval_case = gcs_eval_sets_manager.get_eval_case( + app_name, eval_set_id, eval_case_id + ) + + assert eval_case == mock_eval_case + + def test_gcs_eval_sets_manager_get_eval_case_eval_set_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=None + ) + + eval_case = gcs_eval_sets_manager.get_eval_case( + app_name, eval_set_id, eval_case_id + ) + + assert eval_case is None + + def test_gcs_eval_sets_manager_get_eval_case_eval_case_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + + eval_case = gcs_eval_sets_manager.get_eval_case( + app_name, eval_set_id, eval_case_id + ) + + assert eval_case is None + + def test_gcs_eval_sets_manager_update_eval_case_success( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase( + eval_id=eval_case_id, conversation=[], creation_timestamp=456 + ) + updated_eval_case = EvalCase( + eval_id=eval_case_id, conversation=[], creation_timestamp=123 + ) + mock_eval_set = EvalSet( + eval_set_id=eval_set_id, eval_cases=[mock_eval_case] + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_case", return_value=mock_eval_case + ) + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + eval_set_blob_name = gcs_eval_sets_manager._get_eval_set_blob_name( + app_name, eval_set_id + ) + + gcs_eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + + assert len(mock_eval_set.eval_cases) == 1 + assert mock_eval_set.eval_cases[0] == updated_eval_case + mock_write_eval_set_to_blob.assert_called_once_with( + eval_set_blob_name, + EvalSet(eval_set_id=eval_set_id, eval_cases=[updated_eval_case]), + ) + + def test_gcs_eval_sets_manager_update_eval_case_eval_set_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + updated_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_case", return_value=None + ) + + with pytest.raises( + NotFoundError, + match=f"Eval set `{eval_set_id}` not found.", + ): + gcs_eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + + def test_gcs_eval_sets_manager_update_eval_case_eval_case_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + updated_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + + with pytest.raises( + NotFoundError, + match=( + f"Eval case `{eval_case_id}` not found in eval set `{eval_set_id}`." + ), + ): + gcs_eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + + def test_gcs_eval_sets_manager_delete_eval_case_success( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet( + eval_set_id=eval_set_id, eval_cases=[mock_eval_case] + ) + mock_bucket = gcs_eval_sets_manager.bucket + mock_blob = mock_bucket.blob( + f"{app_name}/evals/eval_sets/{eval_set_id}{_EVAL_SET_FILE_EXTENSION}" + ) + mock_blob.upload_from_string(mock_eval_set.model_dump_json()) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_case", return_value=mock_eval_case + ) + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + eval_set_blob_name = gcs_eval_sets_manager._get_eval_set_blob_name( + app_name, eval_set_id + ) + + gcs_eval_sets_manager.delete_eval_case(app_name, eval_set_id, eval_case_id) + + assert len(mock_eval_set.eval_cases) == 0 + mock_write_eval_set_to_blob.assert_called_once_with( + eval_set_blob_name, + EvalSet(eval_set_id=eval_set_id, eval_cases=[]), + ) + + def test_gcs_eval_sets_manager_delete_eval_case_eval_set_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + + with pytest.raises( + NotFoundError, + match=f"Eval set `{eval_set_id}` not found.", + ): + gcs_eval_sets_manager.delete_eval_case( + app_name, eval_set_id, eval_case_id + ) + mock_write_eval_set_to_blob.assert_not_called() + + def test_gcs_eval_sets_manager_delete_eval_case_eval_case_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_case", return_value=None + ) + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + + with pytest.raises( + NotFoundError, + match=( + f"Eval case `{eval_case_id}` not found in eval set `{eval_set_id}`." + ), + ): + gcs_eval_sets_manager.delete_eval_case( + app_name, eval_set_id, eval_case_id + ) + mock_write_eval_set_to_blob.assert_not_called() diff --git a/tests/unittests/evaluation/test_local_eval_sets_manager.py b/tests/unittests/evaluation/test_local_eval_sets_manager.py index 2b919fa83..de00659e2 100644 --- a/tests/unittests/evaluation/test_local_eval_sets_manager.py +++ b/tests/unittests/evaluation/test_local_eval_sets_manager.py @@ -361,8 +361,8 @@ def test_local_eval_sets_manager_create_eval_set_success( app_name = "test_app" eval_set_id = "test_eval_set" mocker.patch("os.path.exists", return_value=False) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) eval_set_file_path = os.path.join( local_eval_sets_manager._agents_dir, @@ -371,7 +371,7 @@ def test_local_eval_sets_manager_create_eval_set_success( ) local_eval_sets_manager.create_eval_set(app_name, eval_set_id) - mock_write_eval_set.assert_called_once_with( + mock_write_eval_set_to_path.assert_called_once_with( eval_set_file_path, EvalSet( eval_set_id=eval_set_id, @@ -420,8 +420,8 @@ def test_local_eval_sets_manager_add_eval_case_success( "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_set", return_value=mock_eval_set, ) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) local_eval_sets_manager.add_eval_case(app_name, eval_set_id, mock_eval_case) @@ -434,7 +434,7 @@ def test_local_eval_sets_manager_add_eval_case_success( eval_set_id + _EVAL_SET_FILE_EXTENSION, ) mock_eval_set.eval_cases.append(mock_eval_case) - mock_write_eval_set.assert_called_once_with( + mock_write_eval_set_to_path.assert_called_once_with( expected_eval_set_file_path, mock_eval_set ) @@ -568,8 +568,8 @@ def test_local_eval_sets_manager_update_eval_case_success( "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", return_value=mock_eval_case, ) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) local_eval_sets_manager.update_eval_case( @@ -583,12 +583,12 @@ def test_local_eval_sets_manager_update_eval_case_success( app_name, eval_set_id + _EVAL_SET_FILE_EXTENSION, ) - mock_write_eval_set.assert_called_once_with( + mock_write_eval_set_to_path.assert_called_once_with( expected_eval_set_file_path, EvalSet(eval_set_id=eval_set_id, eval_cases=[updated_eval_case]), ) - def test_local_eval_sets_manager_update_eval_case_eval_case_not_found( + def test_local_eval_sets_manager_update_eval_case_eval_set_not_found( self, local_eval_sets_manager, mocker ): app_name = "test_app" @@ -601,10 +601,34 @@ def test_local_eval_sets_manager_update_eval_case_eval_case_not_found( return_value=None, ) + with pytest.raises( + NotFoundError, + match=f"Eval set `{eval_set_id}` not found.", + ): + local_eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + + def test_local_eval_sets_manager_update_eval_case_eval_case_not_found( + self, local_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + updated_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_set", + return_value=mock_eval_set, + ) + mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", + return_value=None, + ) with pytest.raises( NotFoundError, match=( - f"Eval Set `{eval_set_id}` or Eval id `{eval_case_id}` not found." + f"Eval case `{eval_case_id}` not found in eval set `{eval_set_id}`." ), ): local_eval_sets_manager.update_eval_case( @@ -630,8 +654,8 @@ def test_local_eval_sets_manager_delete_eval_case_success( "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", return_value=mock_eval_case, ) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) local_eval_sets_manager.delete_eval_case( @@ -644,12 +668,12 @@ def test_local_eval_sets_manager_delete_eval_case_success( app_name, eval_set_id + _EVAL_SET_FILE_EXTENSION, ) - mock_write_eval_set.assert_called_once_with( + mock_write_eval_set_to_path.assert_called_once_with( expected_eval_set_file_path, EvalSet(eval_set_id=eval_set_id, eval_cases=[]), ) - def test_local_eval_sets_manager_delete_eval_case_eval_case_not_found( + def test_local_eval_sets_manager_delete_eval_case_eval_set_not_found( self, local_eval_sets_manager, mocker ): app_name = "test_app" @@ -660,18 +684,41 @@ def test_local_eval_sets_manager_delete_eval_case_eval_case_not_found( "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", return_value=None, ) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) + with pytest.raises( + NotFoundError, + match=f"Eval set `{eval_set_id}` not found.", + ): + local_eval_sets_manager.delete_eval_case( + app_name, eval_set_id, eval_case_id + ) + + mock_write_eval_set_to_path.assert_not_called() + + def test_local_eval_sets_manager_delete_eval_case_eval_case_not_found( + self, local_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_set", + return_value=mock_eval_set, + ) + mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", + return_value=None, + ) with pytest.raises( NotFoundError, match=( - f"Eval Set `{eval_set_id}` or Eval id `{eval_case_id}` not found." + f"Eval case `{eval_case_id}` not found in eval set `{eval_set_id}`." ), ): local_eval_sets_manager.delete_eval_case( app_name, eval_set_id, eval_case_id ) - - mock_write_eval_set.assert_not_called() From 0a5cf45a75aca7b0322136b65ca5504a0c3c7362 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Jun 2025 23:41:27 -0700 Subject: [PATCH 06/61] feat: Implement GcsEvalSetResultsManager to handle storage of eval sets on GCS, and refactor eval set results manager Eval results will be stored as json files under `gs://{bucket_name}/{app_name}/evals/eval_history/` PiperOrigin-RevId: 770499242 --- .../_eval_set_results_manager_utils.py | 44 ++++ src/google/adk/evaluation/eval_result.py | 19 +- .../evaluation/eval_set_results_manager.py | 7 +- .../gcs_eval_set_results_manager.py | 121 +++++++++++ .../local_eval_set_results_manager.py | 24 +-- tests/unittests/evaluation/mock_gcs_utils.py | 117 +++++++++++ .../test_gcs_eval_set_results_manager.py | 191 ++++++++++++++++++ .../evaluation/test_gcs_eval_sets_manager.py | 120 +---------- .../test_local_eval_set_results_manager.py | 15 +- 9 files changed, 503 insertions(+), 155 deletions(-) create mode 100644 src/google/adk/evaluation/_eval_set_results_manager_utils.py create mode 100644 src/google/adk/evaluation/gcs_eval_set_results_manager.py create mode 100644 tests/unittests/evaluation/mock_gcs_utils.py create mode 100644 tests/unittests/evaluation/test_gcs_eval_set_results_manager.py diff --git a/src/google/adk/evaluation/_eval_set_results_manager_utils.py b/src/google/adk/evaluation/_eval_set_results_manager_utils.py new file mode 100644 index 000000000..8505e68d1 --- /dev/null +++ b/src/google/adk/evaluation/_eval_set_results_manager_utils.py @@ -0,0 +1,44 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time + +from .eval_result import EvalCaseResult +from .eval_result import EvalSetResult + + +def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str: + """Sanitizes the eval set result name.""" + return eval_set_result_name.replace("/", "_") + + +def create_eval_set_result( + app_name: str, + eval_set_id: str, + eval_case_results: list[EvalCaseResult], +) -> EvalSetResult: + """Creates a new EvalSetResult given eval_case_results.""" + timestamp = time.time() + eval_set_result_id = f"{app_name}_{eval_set_id}_{timestamp}" + eval_set_result_name = _sanitize_eval_set_result_name(eval_set_result_id) + eval_set_result = EvalSetResult( + eval_set_result_id=eval_set_result_id, + eval_set_result_name=eval_set_result_name, + eval_set_id=eval_set_id, + eval_case_results=eval_case_results, + creation_timestamp=timestamp, + ) + return eval_set_result diff --git a/src/google/adk/evaluation/eval_result.py b/src/google/adk/evaluation/eval_result.py index 8f87a14b4..96e8d3c98 100644 --- a/src/google/adk/evaluation/eval_result.py +++ b/src/google/adk/evaluation/eval_result.py @@ -36,8 +36,9 @@ class EvalCaseResult(BaseModel): populate_by_name=True, ) - eval_set_file: str = Field( + eval_set_file: Optional[str] = Field( deprecated=True, + default=None, description="This field is deprecated, use eval_set_id instead.", ) eval_set_id: str = "" @@ -49,11 +50,15 @@ class EvalCaseResult(BaseModel): final_eval_status: EvalStatus """Final eval status for this eval case.""" - eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( - deprecated=True, - description=( - "This field is deprecated, use overall_eval_metric_results instead." - ), + eval_metric_results: Optional[list[tuple[EvalMetric, EvalMetricResult]]] = ( + Field( + deprecated=True, + default=None, + description=( + "This field is deprecated, use overall_eval_metric_results" + " instead." + ), + ) ) overall_eval_metric_results: list[EvalMetricResult] @@ -80,7 +85,7 @@ class EvalSetResult(BaseModel): populate_by_name=True, ) eval_set_result_id: str - eval_set_result_name: str + eval_set_result_name: Optional[str] = None eval_set_id: str eval_case_results: list[EvalCaseResult] = Field(default_factory=list) creation_timestamp: float = 0.0 diff --git a/src/google/adk/evaluation/eval_set_results_manager.py b/src/google/adk/evaluation/eval_set_results_manager.py index 5a300ed14..588e823ba 100644 --- a/src/google/adk/evaluation/eval_set_results_manager.py +++ b/src/google/adk/evaluation/eval_set_results_manager.py @@ -16,6 +16,7 @@ from abc import ABC from abc import abstractmethod +from typing import Optional from .eval_result import EvalCaseResult from .eval_result import EvalSetResult @@ -38,7 +39,11 @@ def save_eval_set_result( def get_eval_set_result( self, app_name: str, eval_set_result_id: str ) -> EvalSetResult: - """Returns an EvalSetResult identified by app_name and eval_set_result_id.""" + """Returns the EvalSetResult from app_name and eval_set_result_id. + + Raises: + NotFoundError: If the EvalSetResult is not found. + """ raise NotImplementedError() @abstractmethod diff --git a/src/google/adk/evaluation/gcs_eval_set_results_manager.py b/src/google/adk/evaluation/gcs_eval_set_results_manager.py new file mode 100644 index 000000000..860d932ff --- /dev/null +++ b/src/google/adk/evaluation/gcs_eval_set_results_manager.py @@ -0,0 +1,121 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging + +from google.cloud import exceptions as cloud_exceptions +from google.cloud import storage +from typing_extensions import override + +from ..errors.not_found_error import NotFoundError +from ._eval_set_results_manager_utils import create_eval_set_result +from .eval_result import EvalCaseResult +from .eval_result import EvalSetResult +from .eval_set_results_manager import EvalSetResultsManager + +logger = logging.getLogger("google_adk." + __name__) + +_EVAL_HISTORY_DIR = "evals/eval_history" +_EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json" + + +class GcsEvalSetResultsManager(EvalSetResultsManager): + """An EvalSetResultsManager that stores eval results in a GCS bucket.""" + + def __init__(self, bucket_name: str, **kwargs): + """Initializes the GcsEvalSetsManager. + + Args: + bucket_name: The name of the bucket to use. + **kwargs: Keyword arguments to pass to the Google Cloud Storage client. + """ + self.bucket_name = bucket_name + self.storage_client = storage.Client(**kwargs) + self.bucket = self.storage_client.bucket(self.bucket_name) + # Check if the bucket exists. + if not self.bucket.exists(): + raise ValueError( + f"Bucket `{self.bucket_name}` does not exist. Please create it before" + " using the GcsEvalSetsManager." + ) + + def _get_eval_history_dir(self, app_name: str) -> str: + return f"{app_name}/{_EVAL_HISTORY_DIR}" + + def _get_eval_set_result_blob_name( + self, app_name: str, eval_set_result_id: str + ) -> str: + eval_history_dir = self._get_eval_history_dir(app_name) + return f"{eval_history_dir}/{eval_set_result_id}{_EVAL_SET_RESULT_FILE_EXTENSION}" + + def _write_eval_set_result( + self, blob_name: str, eval_set_result: EvalSetResult + ): + """Writes an EvalSetResult to GCS.""" + blob = self.bucket.blob(blob_name) + blob.upload_from_string( + eval_set_result.model_dump_json(indent=2), + content_type="application/json", + ) + + @override + def save_eval_set_result( + self, + app_name: str, + eval_set_id: str, + eval_case_results: list[EvalCaseResult], + ) -> None: + """Creates and saves a new EvalSetResult given eval_case_results.""" + eval_set_result = create_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + + eval_set_result_blob_name = self._get_eval_set_result_blob_name( + app_name, eval_set_result.eval_set_result_id + ) + logger.info("Writing eval result to blob: %s", eval_set_result_blob_name) + self._write_eval_set_result(eval_set_result_blob_name, eval_set_result) + + @override + def get_eval_set_result( + self, app_name: str, eval_set_result_id: str + ) -> EvalSetResult: + """Returns an EvalSetResult from app_name and eval_set_result_id.""" + eval_set_result_blob_name = self._get_eval_set_result_blob_name( + app_name, eval_set_result_id + ) + blob = self.bucket.blob(eval_set_result_blob_name) + if not blob.exists(): + raise NotFoundError(f"Eval set result `{eval_set_result_id}` not found.") + eval_set_result_data = blob.download_as_text() + return EvalSetResult.model_validate_json(eval_set_result_data) + + @override + def list_eval_set_results(self, app_name: str) -> list[str]: + """Returns the eval result ids that belong to the given app_name.""" + eval_history_dir = self._get_eval_history_dir(app_name) + eval_set_results = [] + try: + for blob in self.bucket.list_blobs(prefix=eval_history_dir): + eval_set_result_id = blob.name.split("/")[-1].removesuffix( + _EVAL_SET_RESULT_FILE_EXTENSION + ) + eval_set_results.append(eval_set_result_id) + return sorted(eval_set_results) + except cloud_exceptions.NotFound as e: + raise ValueError( + f"App `{app_name}` not found in GCS bucket `{self.bucket_name}`." + ) from e diff --git a/src/google/adk/evaluation/local_eval_set_results_manager.py b/src/google/adk/evaluation/local_eval_set_results_manager.py index 598af7f96..3a66f888c 100644 --- a/src/google/adk/evaluation/local_eval_set_results_manager.py +++ b/src/google/adk/evaluation/local_eval_set_results_manager.py @@ -17,10 +17,11 @@ import json import logging import os -import time from typing_extensions import override +from ..errors.not_found_error import NotFoundError +from ._eval_set_results_manager_utils import create_eval_set_result from .eval_result import EvalCaseResult from .eval_result import EvalSetResult from .eval_set_results_manager import EvalSetResultsManager @@ -31,10 +32,6 @@ _EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json" -def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str: - return eval_set_result_name.replace("/", "_") - - class LocalEvalSetResultsManager(EvalSetResultsManager): """An EvalSetResult manager that stores eval set results locally on disk.""" @@ -49,15 +46,8 @@ def save_eval_set_result( eval_case_results: list[EvalCaseResult], ) -> None: """Creates and saves a new EvalSetResult given eval_case_results.""" - timestamp = time.time() - eval_set_result_id = app_name + "_" + eval_set_id + "_" + str(timestamp) - eval_set_result_name = _sanitize_eval_set_result_name(eval_set_result_id) - eval_set_result = EvalSetResult( - eval_set_result_id=eval_set_result_id, - eval_set_result_name=eval_set_result_name, - eval_set_id=eval_set_id, - eval_case_results=eval_case_results, - creation_timestamp=timestamp, + eval_set_result = create_eval_set_result( + app_name, eval_set_id, eval_case_results ) # Write eval result file, with eval_set_result_name. app_eval_history_dir = self._get_eval_history_dir(app_name) @@ -67,7 +57,7 @@ def save_eval_set_result( eval_set_result_json = eval_set_result.model_dump_json() eval_set_result_file_path = os.path.join( app_eval_history_dir, - eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION, + eval_set_result.eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION, ) logger.info("Writing eval result to file: %s", eval_set_result_file_path) with open(eval_set_result_file_path, "w") as f: @@ -87,9 +77,7 @@ def get_eval_set_result( + _EVAL_SET_RESULT_FILE_EXTENSION ) if not os.path.exists(maybe_eval_result_file_path): - raise ValueError( - f"Eval set result `{eval_set_result_id}` does not exist." - ) + raise NotFoundError(f"Eval set result `{eval_set_result_id}` not found.") with open(maybe_eval_result_file_path, "r") as file: eval_result_data = json.load(file) return EvalSetResult.model_validate_json(eval_result_data) diff --git a/tests/unittests/evaluation/mock_gcs_utils.py b/tests/unittests/evaluation/mock_gcs_utils.py new file mode 100644 index 000000000..d9ea008c3 --- /dev/null +++ b/tests/unittests/evaluation/mock_gcs_utils.py @@ -0,0 +1,117 @@ +from typing import Optional +from typing import Union + + +class MockBlob: + """Mocks a GCS Blob object. + + This class provides mock implementations for a few common GCS Blob methods, + allowing the user to test code that interacts with GCS without actually + connecting to a real bucket. + """ + + def __init__(self, name: str) -> None: + """Initializes a MockBlob. + + Args: + name: The name of the blob. + """ + self.name = name + self.content: Optional[bytes] = None + self.content_type: Optional[str] = None + self._exists: bool = False + + def upload_from_string( + self, data: Union[str, bytes], content_type: Optional[str] = None + ) -> None: + """Mocks uploading data to the blob (from a string or bytes). + + Args: + data: The data to upload (string or bytes). + content_type: The content type of the data (optional). + """ + if isinstance(data, str): + self.content = data.encode("utf-8") + elif isinstance(data, bytes): + self.content = data + else: + raise TypeError("data must be str or bytes") + + if content_type: + self.content_type = content_type + self._exists = True + + def download_as_text(self) -> str: + """Mocks downloading the blob's content as text. + + Returns: + str: The content of the blob as text. + + Raises: + Exception: If the blob doesn't exist (hasn't been uploaded to). + """ + if self.content is None: + return b"" + return self.content + + def delete(self) -> None: + """Mocks deleting a blob.""" + self.content = None + self.content_type = None + self._exists = False + + def exists(self) -> bool: + """Mocks checking if the blob exists.""" + return self._exists + + +class MockBucket: + """Mocks a GCS Bucket object.""" + + def __init__(self, name: str) -> None: + """Initializes a MockBucket. + + Args: + name: The name of the bucket. + """ + self.name = name + self.blobs: dict[str, MockBlob] = {} + + def blob(self, blob_name: str) -> MockBlob: + """Mocks getting a Blob object (doesn't create it in storage). + + Args: + blob_name: The name of the blob. + + Returns: + A MockBlob instance. + """ + if blob_name not in self.blobs: + self.blobs[blob_name] = MockBlob(blob_name) + return self.blobs[blob_name] + + def list_blobs(self, prefix: Optional[str] = None) -> list[MockBlob]: + """Mocks listing blobs in a bucket, optionally with a prefix.""" + if prefix: + return [ + blob for name, blob in self.blobs.items() if name.startswith(prefix) + ] + return list(self.blobs.values()) + + def exists(self) -> bool: + """Mocks checking if the bucket exists.""" + return True + + +class MockClient: + """Mocks the GCS Client.""" + + def __init__(self) -> None: + """Initializes MockClient.""" + self.buckets: dict[str, MockBucket] = {} + + def bucket(self, bucket_name: str) -> MockBucket: + """Mocks getting a Bucket object.""" + if bucket_name not in self.buckets: + self.buckets[bucket_name] = MockBucket(bucket_name) + return self.buckets[bucket_name] diff --git a/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py b/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py new file mode 100644 index 000000000..7fd0bb97e --- /dev/null +++ b/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py @@ -0,0 +1,191 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.errors.not_found_error import NotFoundError +from google.adk.evaluation._eval_set_results_manager_utils import _sanitize_eval_set_result_name +from google.adk.evaluation._eval_set_results_manager_utils import create_eval_set_result +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_metrics import EvalMetricResult +from google.adk.evaluation.eval_metrics import EvalMetricResultPerInvocation +from google.adk.evaluation.eval_result import EvalCaseResult +from google.adk.evaluation.evaluator import EvalStatus +from google.adk.evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from google.genai import types as genai_types +import pytest + +from .mock_gcs_utils import MockBucket +from .mock_gcs_utils import MockClient + + +def _get_test_eval_case_results(): + # Create mock Invocation objects + actual_invocation_1 = Invocation( + invocation_id="actual_1", + user_content=genai_types.Content( + parts=[genai_types.Part(text="input_1")] + ), + ) + expected_invocation_1 = Invocation( + invocation_id="expected_1", + user_content=genai_types.Content( + parts=[genai_types.Part(text="expected_input_1")] + ), + ) + actual_invocation_2 = Invocation( + invocation_id="actual_2", + user_content=genai_types.Content( + parts=[genai_types.Part(text="input_2")] + ), + ) + expected_invocation_2 = Invocation( + invocation_id="expected_2", + user_content=genai_types.Content( + parts=[genai_types.Part(text="expected_input_2")] + ), + ) + + eval_metric_result_1 = EvalMetricResult( + metric_name="metric", + threshold=0.8, + score=1.0, + eval_status=EvalStatus.PASSED, + ) + eval_metric_result_2 = EvalMetricResult( + metric_name="metric", + threshold=0.8, + score=0.5, + eval_status=EvalStatus.FAILED, + ) + eval_metric_result_per_invocation_1 = EvalMetricResultPerInvocation( + actual_invocation=actual_invocation_1, + expected_invocation=expected_invocation_1, + eval_metric_results=[eval_metric_result_1], + ) + eval_metric_result_per_invocation_2 = EvalMetricResultPerInvocation( + actual_invocation=actual_invocation_2, + expected_invocation=expected_invocation_2, + eval_metric_results=[eval_metric_result_2], + ) + return [ + EvalCaseResult( + eval_set_id="eval_set", + eval_id="eval_case_1", + final_eval_status=EvalStatus.PASSED, + overall_eval_metric_results=[eval_metric_result_1], + eval_metric_result_per_invocation=[ + eval_metric_result_per_invocation_1 + ], + session_id="session_1", + ), + EvalCaseResult( + eval_set_id="eval_set", + eval_id="eval_case_2", + final_eval_status=EvalStatus.FAILED, + overall_eval_metric_results=[eval_metric_result_2], + eval_metric_result_per_invocation=[ + eval_metric_result_per_invocation_2 + ], + session_id="session_2", + ), + ] + + +class TestGcsEvalSetResultsManager: + + @pytest.fixture + def gcs_eval_set_results_manager(self, mocker): + mock_storage_client = MockClient() + bucket_name = "test_bucket" + mock_bucket = MockBucket(bucket_name) + mocker.patch.object(mock_storage_client, "bucket", return_value=mock_bucket) + mocker.patch( + "google.cloud.storage.Client", return_value=mock_storage_client + ) + return GcsEvalSetResultsManager(bucket_name=bucket_name) + + def test_save_eval_set_result(self, gcs_eval_set_results_manager, mocker): + mocker.patch("time.time", return_value=12345678) + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_results = _get_test_eval_case_results() + eval_set_result = create_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + blob_name = gcs_eval_set_results_manager._get_eval_set_result_blob_name( + app_name, eval_set_result.eval_set_result_id + ) + mock_write_eval_set_result = mocker.patch.object( + gcs_eval_set_results_manager, + "_write_eval_set_result", + ) + gcs_eval_set_results_manager.save_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + mock_write_eval_set_result.assert_called_once_with( + blob_name, + eval_set_result, + ) + + def test_get_eval_set_result_not_found( + self, gcs_eval_set_results_manager, mocker + ): + mocker.patch("time.time", return_value=12345678) + app_name = "test_app" + with pytest.raises(NotFoundError) as e: + gcs_eval_set_results_manager.get_eval_set_result( + app_name, "non_existent_id" + ) + + def test_get_eval_set_result(self, gcs_eval_set_results_manager, mocker): + mocker.patch("time.time", return_value=12345678) + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_results = _get_test_eval_case_results() + gcs_eval_set_results_manager.save_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + eval_set_result = create_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + retrieved_eval_set_result = ( + gcs_eval_set_results_manager.get_eval_set_result( + app_name, eval_set_result.eval_set_result_id + ) + ) + assert retrieved_eval_set_result == eval_set_result + + def test_list_eval_set_results(self, gcs_eval_set_results_manager, mocker): + mocker.patch("time.time", return_value=123) + app_name = "test_app" + eval_set_ids = ["test_eval_set_1", "test_eval_set_2", "test_eval_set_3"] + for eval_set_id in eval_set_ids: + eval_case_results = _get_test_eval_case_results() + gcs_eval_set_results_manager.save_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + retrieved_eval_set_result_ids = ( + gcs_eval_set_results_manager.list_eval_set_results(app_name) + ) + assert retrieved_eval_set_result_ids == [ + "test_app_test_eval_set_1_123", + "test_app_test_eval_set_2_123", + "test_app_test_eval_set_3_123", + ] + + def test_list_eval_set_results_empty(self, gcs_eval_set_results_manager): + app_name = "test_app" + retrieved_eval_set_result_ids = ( + gcs_eval_set_results_manager.list_eval_set_results(app_name) + ) + assert retrieved_eval_set_result_ids == [] diff --git a/tests/unittests/evaluation/test_gcs_eval_sets_manager.py b/tests/unittests/evaluation/test_gcs_eval_sets_manager.py index 9b670a941..bb8e3bd3b 100644 --- a/tests/unittests/evaluation/test_gcs_eval_sets_manager.py +++ b/tests/unittests/evaluation/test_gcs_eval_sets_manager.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional -from typing import Union - from google.adk.errors.not_found_error import NotFoundError from google.adk.evaluation.eval_case import EvalCase from google.adk.evaluation.eval_set import EvalSet @@ -22,120 +19,9 @@ from google.adk.evaluation.gcs_eval_sets_manager import GcsEvalSetsManager import pytest - -class MockBlob: - """Mocks a GCS Blob object. - - This class provides mock implementations for a few common GCS Blob methods, - allowing the user to test code that interacts with GCS without actually - connecting to a real bucket. - """ - - def __init__(self, name: str) -> None: - """Initializes a MockBlob. - - Args: - name: The name of the blob. - """ - self.name = name - self.content: Optional[bytes] = None - self.content_type: Optional[str] = None - self._exists: bool = False - - def upload_from_string( - self, data: Union[str, bytes], content_type: Optional[str] = None - ) -> None: - """Mocks uploading data to the blob (from a string or bytes). - - Args: - data: The data to upload (string or bytes). - content_type: The content type of the data (optional). - """ - if isinstance(data, str): - self.content = data.encode("utf-8") - elif isinstance(data, bytes): - self.content = data - else: - raise TypeError("data must be str or bytes") - - if content_type: - self.content_type = content_type - self._exists = True - - def download_as_text(self) -> str: - """Mocks downloading the blob's content as text. - - Returns: - str: The content of the blob as text. - - Raises: - Exception: If the blob doesn't exist (hasn't been uploaded to). - """ - if self.content is None: - return b"" - return self.content - - def delete(self) -> None: - """Mocks deleting a blob.""" - self.content = None - self.content_type = None - self._exists = False - - def exists(self) -> bool: - """Mocks checking if the blob exists.""" - return self._exists - - -class MockBucket: - """Mocks a GCS Bucket object.""" - - def __init__(self, name: str) -> None: - """Initializes a MockBucket. - - Args: - name: The name of the bucket. - """ - self.name = name - self.blobs: dict[str, MockBlob] = {} - - def blob(self, blob_name: str) -> MockBlob: - """Mocks getting a Blob object (doesn't create it in storage). - - Args: - blob_name: The name of the blob. - - Returns: - A MockBlob instance. - """ - if blob_name not in self.blobs: - self.blobs[blob_name] = MockBlob(blob_name) - return self.blobs[blob_name] - - def list_blobs(self, prefix: Optional[str] = None) -> list[MockBlob]: - """Mocks listing blobs in a bucket, optionally with a prefix.""" - if prefix: - return [ - blob for name, blob in self.blobs.items() if name.startswith(prefix) - ] - return list(self.blobs.values()) - - def exists(self) -> bool: - """Mocks checking if the bucket exists.""" - return True - - -class MockClient: - """Mocks the GCS Client.""" - - def __init__(self) -> None: - """Initializes MockClient.""" - self.buckets: dict[str, MockBucket] = {} - - def bucket(self, bucket_name: str) -> MockBucket: - """Mocks getting a Bucket object.""" - if bucket_name not in self.buckets: - self.buckets[bucket_name] = MockBucket(bucket_name) - return self.buckets[bucket_name] +from .mock_gcs_utils import MockBlob +from .mock_gcs_utils import MockBucket +from .mock_gcs_utils import MockClient class TestGcsEvalSetsManager: diff --git a/tests/unittests/evaluation/test_local_eval_set_results_manager.py b/tests/unittests/evaluation/test_local_eval_set_results_manager.py index 038f17abb..3411d9b7a 100644 --- a/tests/unittests/evaluation/test_local_eval_set_results_manager.py +++ b/tests/unittests/evaluation/test_local_eval_set_results_manager.py @@ -21,24 +21,17 @@ import time from unittest.mock import patch +from google.adk.errors.not_found_error import NotFoundError +from google.adk.evaluation._eval_set_results_manager_utils import _sanitize_eval_set_result_name from google.adk.evaluation.eval_result import EvalCaseResult from google.adk.evaluation.eval_result import EvalSetResult from google.adk.evaluation.evaluator import EvalStatus from google.adk.evaluation.local_eval_set_results_manager import _ADK_EVAL_HISTORY_DIR from google.adk.evaluation.local_eval_set_results_manager import _EVAL_SET_RESULT_FILE_EXTENSION -from google.adk.evaluation.local_eval_set_results_manager import _sanitize_eval_set_result_name from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager import pytest -def test_sanitize_eval_set_result_name(): - assert _sanitize_eval_set_result_name("app/name") == "app_name" - assert _sanitize_eval_set_result_name("app_name") == "app_name" - assert _sanitize_eval_set_result_name("app/name/with/slashes") == ( - "app_name_with_slashes" - ) - - class TestLocalEvalSetResultsManager: @pytest.fixture(autouse=True) @@ -115,11 +108,9 @@ def test_get_eval_set_result(self, mock_time): def test_get_eval_set_result_not_found(self, mock_time): mock_time.return_value = self.timestamp - with pytest.raises(ValueError) as e: + with pytest.raises(NotFoundError) as e: self.manager.get_eval_set_result(self.app_name, "non_existent_id") - assert "does not exist" in str(e.value) - @patch("time.time") def test_list_eval_set_results(self, mock_time): mock_time.return_value = self.timestamp From 2ff9b1f639b623165895adfb471aa1920d491006 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Jun 2025 07:52:08 -0700 Subject: [PATCH 07/61] test: Add unit tests for `execute_sql` tool This change introduces unit tests in which the behavior of the tool is asserted for various query types in various write modes through a mocked BigQuery client. PiperOrigin-RevId: 770653117 --- .../bigquery/test_bigquery_query_tool.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index fc592c6c8..35d44ef81 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -16,12 +16,16 @@ import textwrap from typing import Optional +from unittest import mock from google.adk.tools import BaseTool from google.adk.tools.bigquery import BigQueryCredentialsConfig from google.adk.tools.bigquery import BigQueryToolset from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode +from google.adk.tools.bigquery.query_tool import execute_sql +from google.cloud import bigquery +from google.oauth2.credentials import Credentials import pytest @@ -218,3 +222,123 @@ async def test_execute_sql_declaration_write(tool_config): - Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE". - First run "DROP TABLE", followed by "CREATE TABLE". - To insert data into a table, use "INSERT INTO" statement.""") + + +@pytest.mark.parametrize( + ("write_mode",), + [ + pytest.param( + WriteMode.BLOCKED, + id="blocked", + ), + pytest.param( + WriteMode.ALLOWED, + id="allowed", + ), + ], +) +def test_execute_sql_select_stmt(write_mode): + """Test execute_sql tool for SELECT query when writes are blocked.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + query_result = [{"num": 123}] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=write_mode) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + # Simulate the result of query_and_wait API + bq_client.query_and_wait.return_value = query_result + + # Test the tool + result = execute_sql(project, query, credentials, tool_config) + assert result == {"status": "SUCCESS", "rows": query_result} + + +@pytest.mark.parametrize( + ("query", "statement_type"), + [ + pytest.param( + "CREATE TABLE my_dataset.my_table AS SELECT 123 AS num", + "CREATE_AS_SELECT", + id="create-as-select", + ), + pytest.param( + "DROP TABLE my_dataset.my_table", + "DROP_TABLE", + id="drop-table", + ), + ], +) +def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): + """Test execute_sql tool for SELECT query when writes are blocked.""" + project = "my_project" + query_result = [] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + # Simulate the result of query_and_wait API + bq_client.query_and_wait.return_value = query_result + + # Test the tool + result = execute_sql(project, query, credentials, tool_config) + assert result == {"status": "SUCCESS", "rows": query_result} + + +@pytest.mark.parametrize( + ("query", "statement_type"), + [ + pytest.param( + "CREATE TABLE my_dataset.my_table AS SELECT 123 AS num", + "CREATE_AS_SELECT", + id="create-as-select", + ), + pytest.param( + "DROP TABLE my_dataset.my_table", + "DROP_TABLE", + id="drop-table", + ), + ], +) +def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): + """Test execute_sql tool for SELECT query when writes are blocked.""" + project = "my_project" + query_result = [] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=WriteMode.BLOCKED) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + # Simulate the result of query_and_wait API + bq_client.query_and_wait.return_value = query_result + + # Test the tool + result = execute_sql(project, query, credentials, tool_config) + assert result == { + "status": "ERROR", + "error_details": "Read-only mode only supports SELECT statements.", + } From d22920bd7f827461afd649601326b0c58aea6716 Mon Sep 17 00:00:00 2001 From: GenkiNoguchi <40311508+ammmr@users.noreply.github.com> Date: Thu, 12 Jun 2025 14:32:29 -0700 Subject: [PATCH 08/61] feat: support realtime input config Merge https://github.com/google/adk-python/pull/981 issue: https://github.com/google/adk-python/issues/982 This pull request introduces a new configuration option, `realtime_input_config`, to the `RunConfig` class. **Reason for this change:** Currently, there is no direct way to configure real-time audio input behaviors, such as Voice Activity Detection (VAD), for live agents through the `RunConfig`. The Gemini API documentation (specifically [Configure automatic VAD](https://ai.google.dev/gemini-api/docs/live#configure-automatic-vad)) outlines parameters for VAD that users may want to customize. This change enables users to pass these real-time input configurations, providing more granular control over the audio input for live agents. **Changes made:** - Added a new optional field `realtime_input_config: Optional[types.RealtimeInputConfig]` to the `RunConfig` class. - The docstring for `realtime_input_config` has been added to explain its purpose. **Example Usage (Conceptual):** While the specific structure of `types.RealtimeInputConfig` would define the exact parameters, a user might configure it like this: ```python # (Assuming types.RealtimeInputConfig and types.VadConfig are defined elsewhere) # import your_project.types as types run_config = RunConfig( # ... other configurations ... realtime_input_config=types.RealtimeInputConfig( automatic_activity_detection =types.AutomaticActivityDetection( # VAD specific parameters like sensitivity, endpoint_duration_millis etc. # based on https://ai.google.dev/gemini-api/docs/live#configure-automatic-vad ) # Potentially other real-time input settings could be added here in the future ) ) COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/981 from ammmr:patch-add-realtime-input-config b2e17fbf5742d264029ad49bf632422b5c5b1e0a PiperOrigin-RevId: 770797640 --- src/google/adk/agents/run_config.py | 5 + src/google/adk/flows/llm_flows/basic.py | 3 + .../adk/models/gemini_llm_connection.py | 2 + .../llm_flows/test_base_llm_flow_realtime.py | 201 ++++++++++++++++++ .../models/test_gemini_llm_connection.py | 111 ++++++++++ tests/unittests/testing_utils.py | 11 +- 6 files changed, 330 insertions(+), 3 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py create mode 100644 tests/unittests/models/test_gemini_llm_connection.py diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index 566fe8606..5679f04e9 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from enum import Enum import logging import sys @@ -68,6 +70,9 @@ class RunConfig(BaseModel): input_audio_transcription: Optional[types.AudioTranscriptionConfig] = None """Input transcription for live agents with audio input from user.""" + realtime_input_config: Optional[types.RealtimeInputConfig] = None + """Realtime input config for live agents with audio input from user.""" + max_llm_calls: int = 500 """ A limit on the total number of llm calls for a given run. diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index d48c8cd20..7efadd97e 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -65,6 +65,9 @@ async def run_async( llm_request.live_connect_config.input_audio_transcription = ( invocation_context.run_config.input_audio_transcription ) + llm_request.live_connect_config.realtime_input_config = ( + invocation_context.run_config.realtime_input_config + ) # TODO: handle tool append here, instead of in BaseTool.process_llm_request. diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 400aab3d6..36aea3b20 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging from typing import AsyncGenerator diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py b/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py new file mode 100644 index 000000000..f3eefb186 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py @@ -0,0 +1,201 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.agents import Agent +from google.adk.agents.live_request_queue import LiveRequest +from google.adk.agents.live_request_queue import LiveRequestQueue +from google.adk.agents.run_config import RunConfig +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.models.llm_request import LlmRequest +from google.genai import types +import pytest + +from ... import testing_utils + + +class TestBaseLlmFlow(BaseLlmFlow): + """Test implementation of BaseLlmFlow for testing purposes.""" + + pass + + +@pytest.fixture +def test_blob(): + """Test blob for audio data.""" + return types.Blob(data=b'\x00\xFF\x00\xFF', mime_type='audio/pcm') + + +@pytest.fixture +def mock_llm_connection(): + """Mock LLM connection for testing.""" + connection = mock.AsyncMock() + connection.send_realtime = mock.AsyncMock() + return connection + + +@pytest.mark.asyncio +async def test_send_to_model_with_disabled_vad(test_blob, mock_llm_connection): + """Test _send_to_model with automatic_activity_detection.disabled=True.""" + # Create LlmRequest with disabled VAD + realtime_input_config = types.RealtimeInputConfig( + automatic_activity_detection=types.AutomaticActivityDetection( + disabled=True + ) + ) + + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, + user_content='', + run_config=RunConfig(realtime_input_config=realtime_input_config), + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send a blob to the queue + live_request = LiveRequest(blob=test_blob) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + mock_llm_connection.send_realtime.assert_called_once_with(test_blob) + + +@pytest.mark.asyncio +async def test_send_to_model_with_enabled_vad(test_blob, mock_llm_connection): + """Test _send_to_model with automatic_activity_detection.disabled=False. + + Custom VAD activity signal is not supported so we should still disable it. + """ + # Create LlmRequest with enabled VAD + realtime_input_config = types.RealtimeInputConfig( + automatic_activity_detection=types.AutomaticActivityDetection( + disabled=False + ) + ) + + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send a blob to the queue + live_request = LiveRequest(blob=test_blob) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + mock_llm_connection.send_realtime.assert_called_once_with(test_blob) + + +@pytest.mark.asyncio +async def test_send_to_model_without_realtime_config( + test_blob, mock_llm_connection +): + """Test _send_to_model without realtime_input_config (default behavior).""" + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send a blob to the queue + live_request = LiveRequest(blob=test_blob) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + mock_llm_connection.send_realtime.assert_called_once_with(test_blob) + + +@pytest.mark.asyncio +async def test_send_to_model_with_none_automatic_activity_detection( + test_blob, mock_llm_connection +): + """Test _send_to_model with automatic_activity_detection=None.""" + # Create LlmRequest with None automatic_activity_detection + realtime_input_config = types.RealtimeInputConfig( + automatic_activity_detection=None + ) + + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, + user_content='', + run_config=RunConfig(realtime_input_config=realtime_input_config), + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send a blob to the queue + live_request = LiveRequest(blob=test_blob) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + mock_llm_connection.send_realtime.assert_called_once_with(test_blob) + + +@pytest.mark.asyncio +async def test_send_to_model_with_text_content(mock_llm_connection): + """Test _send_to_model with text content (not blob).""" + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send text content to the queue + content = types.Content( + role='user', parts=[types.Part.from_text(text='Hello')] + ) + live_request = LiveRequest(content=content) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + # Verify send_content was called instead of send_realtime + mock_llm_connection.send_content.assert_called_once_with(content) + mock_llm_connection.send_realtime.assert_not_called() diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py new file mode 100644 index 000000000..232711503 --- /dev/null +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -0,0 +1,111 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.models.gemini_llm_connection import GeminiLlmConnection +from google.genai import types +import pytest + + +@pytest.fixture +def mock_gemini_session(): + """Mock Gemini session for testing.""" + return mock.AsyncMock() + + +@pytest.fixture +def gemini_connection(mock_gemini_session): + """GeminiLlmConnection instance with mocked session.""" + return GeminiLlmConnection(mock_gemini_session) + + +@pytest.fixture +def test_blob(): + """Test blob for audio data.""" + return types.Blob(data=b'\x00\xFF\x00\xFF', mime_type='audio/pcm') + + +@pytest.mark.asyncio +async def test_send_realtime_default_behavior( + gemini_connection, mock_gemini_session, test_blob +): + """Test send_realtime with default automatic_activity_detection value (True).""" + await gemini_connection.send_realtime(test_blob) + + # Should call send once + mock_gemini_session.send.assert_called_once_with(input=test_blob.model_dump()) + + +@pytest.mark.asyncio +async def test_send_history(gemini_connection, mock_gemini_session): + """Test send_history method.""" + history = [ + types.Content(role='user', parts=[types.Part.from_text(text='Hello')]), + types.Content( + role='model', parts=[types.Part.from_text(text='Hi there!')] + ), + ] + + await gemini_connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + assert 'input' in call_args + assert call_args['input'].turns == history + assert call_args['input'].turn_complete is False # Last message is from model + + +@pytest.mark.asyncio +async def test_send_content_text(gemini_connection, mock_gemini_session): + """Test send_content with text content.""" + content = types.Content( + role='user', parts=[types.Part.from_text(text='Hello')] + ) + + await gemini_connection.send_content(content) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + assert 'input' in call_args + assert call_args['input'].turns == [content] + assert call_args['input'].turn_complete is True + + +@pytest.mark.asyncio +async def test_send_content_function_response( + gemini_connection, mock_gemini_session +): + """Test send_content with function response.""" + function_response = types.FunctionResponse( + name='test_function', response={'result': 'success'} + ) + content = types.Content( + role='user', parts=[types.Part(function_response=function_response)] + ) + + await gemini_connection.send_content(content) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + assert 'input' in call_args + assert call_args['input'].function_responses == [function_response] + + +@pytest.mark.asyncio +async def test_close(gemini_connection, mock_gemini_session): + """Test close method.""" + await gemini_connection.close() + + mock_gemini_session.close.assert_called_once() diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index 1a8ed5233..b1d5ff822 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -56,7 +56,9 @@ def __init__(self, parts: list[types.Part]): super().__init__(role='model', parts=parts) -async def create_invocation_context(agent: Agent, user_content: str = ''): +async def create_invocation_context( + agent: Agent, user_content: str = '', run_config: RunConfig = None +): invocation_id = 'test_id' artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() @@ -73,7 +75,7 @@ async def create_invocation_context(agent: Agent, user_content: str = ''): user_content=types.Content( role='user', parts=[types.Part.from_text(text=user_content)] ), - run_config=RunConfig(), + run_config=run_config or RunConfig(), ) if user_content: append_user_content( @@ -205,13 +207,16 @@ async def run_async(self, new_message: types.ContentUnion) -> list[Event]: events.append(event) return events - def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: + def run_live( + self, live_request_queue: LiveRequestQueue, run_config: RunConfig = None + ) -> list[Event]: collected_responses = [] async def consume_responses(session: Session): run_res = self.runner.run_live( session=session, live_request_queue=live_request_queue, + run_config=run_config or RunConfig(), ) async for response in run_res: From 4ccda99e8ec7aa715399b4b83c3f101c299a95e8 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 12 Jun 2025 16:19:42 -0700 Subject: [PATCH 09/61] fix: Merge custom http options with adk specific http options in model api request PiperOrigin-RevId: 770836112 --- src/google/adk/models/google_llm.py | 47 ++-- tests/unittests/models/test_google_llm.py | 249 ++++++++++++++++++++++ 2 files changed, 282 insertions(+), 14 deletions(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index fd1d8e582..bff2b675c 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -95,6 +95,13 @@ async def generate_content_async( ) logger.info(_build_request_log(llm_request)) + # add tracking headers to custom headers given it will override the headers + # set in the api client constructor + if llm_request.config and llm_request.config.http_options: + if not llm_request.config.http_options.headers: + llm_request.config.http_options.headers = {} + llm_request.config.http_options.headers.update(self._tracking_headers) + if stream: responses = await self.api_client.aio.models.generate_content_stream( model=llm_request.model, @@ -201,24 +208,21 @@ def _tracking_headers(self) -> dict[str, str]: return tracking_headers @cached_property - def _live_api_client(self) -> Client: + def _live_api_version(self) -> str: if self._api_backend == GoogleLLMVariant.VERTEX_AI: # use beta version for vertex api - api_version = 'v1beta1' - # use default api version for vertex - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers, api_version=api_version - ) - ) + return 'v1beta1' else: # use v1alpha for using API KEY from Google AI Studio - api_version = 'v1alpha' - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers, api_version=api_version - ) - ) + return 'v1alpha' + + @cached_property + def _live_api_client(self) -> Client: + return Client( + http_options=types.HttpOptions( + headers=self._tracking_headers, api_version=self._live_api_version + ) + ) @contextlib.asynccontextmanager async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: @@ -230,6 +234,21 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: Yields: BaseLlmConnection, the connection to the Gemini model. """ + # add tracking headers to custom headers and set api_version given + # the customized http options will override the one set in the api client + # constructor + if ( + llm_request.live_connect_config + and llm_request.live_connect_config.http_options + ): + if not llm_request.live_connect_config.http_options.headers: + llm_request.live_connect_config.http_options.headers = {} + llm_request.live_connect_config.http_options.headers.update( + self._tracking_headers + ) + llm_request.live_connect_config.http_options.api_version = ( + self._live_api_version + ) llm_request.live_connect_config.system_instruction = types.Content( role='system', diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 9278bee54..fb8540bb2 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -341,6 +341,255 @@ async def __aexit__(self, *args): assert connection is mock_connection +@pytest.mark.asyncio +async def test_generate_content_async_with_custom_headers( + gemini_llm, llm_request, generate_content_response +): + """Test that tracking headers are updated when custom headers are provided.""" + # Add custom headers to the request config + custom_headers = {"custom-header": "custom-value"} + for key in gemini_llm._tracking_headers: + custom_headers[key] = "custom " + gemini_llm._tracking_headers[key] + llm_request.config.http_options = types.HttpOptions(headers=custom_headers) + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + # Create a mock coroutine that returns the generate_content_response + async def mock_coro(): + return generate_content_response + + mock_client.aio.models.generate_content.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=False + ) + ] + + # Verify that the config passed to generate_content contains merged headers + mock_client.aio.models.generate_content.assert_called_once() + call_args = mock_client.aio.models.generate_content.call_args + config_arg = call_args.kwargs["config"] + + for key, value in config_arg.http_options.headers.items(): + if key in gemini_llm._tracking_headers: + assert value == gemini_llm._tracking_headers[key] + else: + assert value == custom_headers[key] + + assert len(responses) == 1 + assert isinstance(responses[0], LlmResponse) + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_custom_headers( + gemini_llm, llm_request +): + """Test that tracking headers are updated when custom headers are provided in streaming mode.""" + # Add custom headers to the request config + custom_headers = {"custom-header": "custom-value"} + llm_request.config.http_options = types.HttpOptions(headers=custom_headers) + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + # Create mock stream responses + class MockAsyncIterator: + + def __init__(self, seq): + self.iter = iter(seq) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration: + raise StopAsyncIteration + + mock_responses = [ + types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + role="model", parts=[Part.from_text(text="Hello")] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + ] + + async def mock_coro(): + return MockAsyncIterator(mock_responses) + + mock_client.aio.models.generate_content_stream.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=True + ) + ] + + # Verify that the config passed to generate_content_stream contains merged headers + mock_client.aio.models.generate_content_stream.assert_called_once() + call_args = mock_client.aio.models.generate_content_stream.call_args + config_arg = call_args.kwargs["config"] + + expected_headers = custom_headers.copy() + expected_headers.update(gemini_llm._tracking_headers) + assert config_arg.http_options.headers == expected_headers + + assert len(responses) == 2 + + +@pytest.mark.asyncio +async def test_generate_content_async_without_custom_headers( + gemini_llm, llm_request, generate_content_response +): + """Test that tracking headers are not modified when no custom headers exist.""" + # Ensure no http_options exist initially + llm_request.config.http_options = None + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + + async def mock_coro(): + return generate_content_response + + mock_client.aio.models.generate_content.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=False + ) + ] + + # Verify that the config passed to generate_content has no http_options + mock_client.aio.models.generate_content.assert_called_once() + call_args = mock_client.aio.models.generate_content.call_args + config_arg = call_args.kwargs["config"] + assert config_arg.http_options is None + + assert len(responses) == 1 + + +def test_live_api_version_vertex_ai(gemini_llm): + """Test that _live_api_version returns 'v1beta1' for Vertex AI backend.""" + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + ): + assert gemini_llm._live_api_version == "v1beta1" + + +def test_live_api_version_gemini_api(gemini_llm): + """Test that _live_api_version returns 'v1alpha' for Gemini API backend.""" + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.GEMINI_API + ): + assert gemini_llm._live_api_version == "v1alpha" + + +def test_live_api_client_properties(gemini_llm): + """Test that _live_api_client is properly configured with tracking headers and API version.""" + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + ): + client = gemini_llm._live_api_client + + # Verify that the client has the correct headers and API version + http_options = client._api_client._http_options + assert http_options.api_version == "v1beta1" + + # Check that tracking headers are included + tracking_headers = gemini_llm._tracking_headers + for key, value in tracking_headers.items(): + assert key in http_options.headers + assert value in http_options.headers[key] + + +@pytest.mark.asyncio +async def test_connect_with_custom_headers(gemini_llm, llm_request): + """Test that connect method updates tracking headers and API version when custom headers are provided.""" + # Setup request with live connect config and custom headers + custom_headers = {"custom-live-header": "live-value"} + llm_request.live_connect_config = types.LiveConnectConfig( + http_options=types.HttpOptions(headers=custom_headers) + ) + + mock_live_session = mock.AsyncMock() + + # Mock the _live_api_client to return a mock client + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + # Create a mock context manager + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + async with gemini_llm.connect(llm_request) as connection: + # Verify that the connect method was called with the right config + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify that tracking headers were merged with custom headers + expected_headers = custom_headers.copy() + expected_headers.update(gemini_llm._tracking_headers) + assert config_arg.http_options.headers == expected_headers + + # Verify that API version was set + assert config_arg.http_options.api_version == gemini_llm._live_api_version + + # Verify that system instruction and tools were set + assert config_arg.system_instruction is not None + assert config_arg.tools == llm_request.config.tools + + # Verify connection is properly wrapped + assert isinstance(connection, GeminiLlmConnection) + + +@pytest.mark.asyncio +async def test_connect_without_custom_headers(gemini_llm, llm_request): + """Test that connect method works properly when no custom headers are provided.""" + # Setup request with live connect config but no custom headers + llm_request.live_connect_config = types.LiveConnectConfig() + + mock_live_session = mock.AsyncMock() + + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + async with gemini_llm.connect(llm_request) as connection: + # Verify that the connect method was called with the right config + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify that http_options remains None since no custom headers were provided + assert config_arg.http_options is None + + # Verify that system instruction and tools were still set + assert config_arg.system_instruction is not None + assert config_arg.tools == llm_request.config.tools + + assert isinstance(connection, GeminiLlmConnection) + + @pytest.mark.parametrize( ( "api_backend, " From 29e4ca9152096e19b66d7d84aab121a8515e7d17 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 12 Jun 2025 17:06:21 -0700 Subject: [PATCH 10/61] chore: Add empty a2a package PiperOrigin-RevId: 770851838 --- src/google/adk/a2a/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 src/google/adk/a2a/__init__.py diff --git a/src/google/adk/a2a/__init__.py b/src/google/adk/a2a/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/a2a/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From b2fc7740b363a4e33ec99c7377f396f5cee40b5a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Jun 2025 17:29:25 -0700 Subject: [PATCH 11/61] fix: Support project-based gemini model path to use google_search_tool PiperOrigin-RevId: 770858301 --- src/google/adk/tools/vertex_ai_search_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index c370e2a72..5449f5090 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -76,8 +76,8 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if llm_request.model and llm_request.model.startswith('gemini-'): - if llm_request.model.startswith('gemini-1') and llm_request.config.tools: + if llm_request.model and 'gemini-' in llm_request.model: + if 'gemini-1' in llm_request.model and llm_request.config.tools: raise ValueError( 'Vertex AI search tool can not be used with other tools in Gemini' ' 1.x.' From 313f1b0913eacbb0e4a6dde6b2834f744bcad2ac Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Jun 2025 19:20:28 -0700 Subject: [PATCH 12/61] chore: Add all missing direct deps to pyproject.toml PiperOrigin-RevId: 770887371 --- pyproject.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 20c0524a1..0a973f947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start + "anyio>=4.9.0;python_version>='3.10'", # For MCP Session Manager "authlib>=1.5.1", # For RestAPI Tool "click>=8.1.8", # For CLI tools "fastapi>=0.115.0", # FastAPI framework @@ -40,12 +41,16 @@ dependencies = [ "opentelemetry-exporter-gcp-trace>=1.9.0", "opentelemetry-sdk>=1.31.0", "pydantic>=2.0, <3.0.0", # For data validation/models + "python-dateutil>=2.9.0.post0", # For Vertext AI Session Service "python-dotenv>=1.0.0", # To manage environment variables "PyYAML>=6.0.2", # For APIHubToolset. + "requests>=2.32.4", "sqlalchemy>=2.0", # SQL database ORM - "tzlocal>=5.3", # Time zone utilities + "starlette>=0.46.2", # For FastAPI CLI "typing-extensions>=4.5, <5", + "tzlocal>=5.3", # Time zone utilities "uvicorn>=0.34.0", # ASGI server for FastAPI + "websockets>=15.0.1", # For BaseLlmFlow # go/keep-sorted end ] dynamic = ["version"] From c5b063f1ff2c8895f4c9f2c1b1ce34b40c94c7c9 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 12 Jun 2025 19:39:40 -0700 Subject: [PATCH 13/61] feat: Add an Oauth2 credential fetcher to exchange and refresh oauth2 token PiperOrigin-RevId: 770892243 --- src/google/adk/auth/auth_credential.py | 5 + .../adk/auth/oauth2_credential_fetcher.py | 169 +++++++ .../auth/test_oauth2_credential_fetcher.py | 441 ++++++++++++++++++ 3 files changed, 615 insertions(+) create mode 100644 src/google/adk/auth/oauth2_credential_fetcher.py create mode 100644 tests/unittests/auth/test_oauth2_credential_fetcher.py diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index db6fa9767..1009a50dd 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from enum import Enum from typing import Any from typing import Dict @@ -75,6 +77,8 @@ class OAuth2Auth(BaseModelWithConfig): auth_code: Optional[str] = None access_token: Optional[str] = None refresh_token: Optional[str] = None + expires_at: Optional[int] = None + expires_in: Optional[int] = None class ServiceAccountCredential(BaseModelWithConfig): @@ -226,3 +230,4 @@ class AuthCredential(BaseModelWithConfig): http: Optional[HttpAuth] = None service_account: Optional[ServiceAccount] = None oauth2: Optional[OAuth2Auth] = None + google_oauth2_json: Optional[str] = None diff --git a/src/google/adk/auth/oauth2_credential_fetcher.py b/src/google/adk/auth/oauth2_credential_fetcher.py new file mode 100644 index 000000000..1a8692417 --- /dev/null +++ b/src/google/adk/auth/oauth2_credential_fetcher.py @@ -0,0 +1,169 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional +from typing import Tuple + +from fastapi.openapi.models import OAuth2 + +from .auth_credential import AuthCredential +from .auth_schemes import AuthScheme +from .auth_schemes import OAuthGrantType +from .auth_schemes import OpenIdConnectWithConfig + +try: + from authlib.integrations.requests_client import OAuth2Session + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + + +logger = logging.getLogger("google_adk." + __name__) + + +class OAuth2CredentialFetcher: + """Exchanges and refreshes an OAuth2 access token.""" + + def __init__( + self, + auth_scheme: AuthScheme, + auth_credential: AuthCredential, + ): + self._auth_scheme = auth_scheme + self._auth_credential = auth_credential + + def _oauth2_session(self) -> Tuple[Optional[OAuth2Session], Optional[str]]: + auth_scheme = self._auth_scheme + auth_credential = self._auth_credential + + if isinstance(auth_scheme, OpenIdConnectWithConfig): + if not hasattr(auth_scheme, "token_endpoint"): + return None, None + token_endpoint = auth_scheme.token_endpoint + scopes = auth_scheme.scopes + elif isinstance(auth_scheme, OAuth2): + if ( + not auth_scheme.flows.authorizationCode + or not auth_scheme.flows.authorizationCode.tokenUrl + ): + return None, None + token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl + scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) + else: + return None, None + + if ( + not auth_credential + or not auth_credential.oauth2 + or not auth_credential.oauth2.client_id + or not auth_credential.oauth2.client_secret + ): + return None, None + + return ( + OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + scope=" ".join(scopes), + redirect_uri=auth_credential.oauth2.redirect_uri, + state=auth_credential.oauth2.state, + ), + token_endpoint, + ) + + def _update_credential(self, tokens: OAuth2Token) -> None: + self._auth_credential.oauth2.access_token = tokens.get("access_token") + self._auth_credential.oauth2.refresh_token = tokens.get("refresh_token") + self._auth_credential.oauth2.expires_at = ( + int(tokens.get("expires_at")) if tokens.get("expires_at") else None + ) + self._auth_credential.oauth2.expires_in = ( + int(tokens.get("expires_in")) if tokens.get("expires_in") else None + ) + + def exchange(self) -> AuthCredential: + """Exchange an oauth token from the authorization response. + + Returns: + An AuthCredential object containing the access token. + """ + if not AUTHLIB_AVIALABLE: + return self._auth_credential + + if ( + self._auth_credential.oauth2 + and self._auth_credential.oauth2.access_token + ): + return self._auth_credential + + client, token_endpoint = self._oauth2_session() + if not client: + logger.warning("Could not create OAuth2 session for token exchange") + return self._auth_credential + + try: + tokens = client.fetch_token( + token_endpoint, + authorization_response=self._auth_credential.oauth2.auth_response_uri, + code=self._auth_credential.oauth2.auth_code, + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + ) + self._update_credential(tokens) + logger.info("Successfully exchanged OAuth2 tokens") + except Exception as e: + logger.error("Failed to exchange OAuth2 tokens: %s", e) + # Return original credential on failure + return self._auth_credential + + return self._auth_credential + + def refresh(self) -> AuthCredential: + """Refresh an oauth token. + + Returns: + An AuthCredential object containing the refreshed access token. + """ + if not AUTHLIB_AVIALABLE: + return self._auth_credential + credential = self._auth_credential + if not credential.oauth2: + return credential + + if OAuth2Token({ + "expires_at": credential.oauth2.expires_at, + "expires_in": credential.oauth2.expires_in, + }).is_expired(): + client, token_endpoint = self._oauth2_session() + if not client: + logger.warning("Could not create OAuth2 session for token refresh") + return credential + + try: + tokens = client.refresh_token( + url=token_endpoint, + refresh_token=credential.oauth2.refresh_token, + ) + self._update_credential(tokens) + logger.info("Successfully refreshed OAuth2 tokens") + except Exception as e: + logger.error("Failed to refresh OAuth2 tokens: %s", e) + # Return original credential on failure + return credential + + return self._auth_credential diff --git a/tests/unittests/auth/test_oauth2_credential_fetcher.py b/tests/unittests/auth/test_oauth2_credential_fetcher.py new file mode 100644 index 000000000..0b9b5a3c1 --- /dev/null +++ b/tests/unittests/auth/test_oauth2_credential_fetcher.py @@ -0,0 +1,441 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock +from unittest.mock import patch + +from authlib.oauth2.rfc6749 import OAuth2Token +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.oauth2_credential_fetcher import OAuth2CredentialFetcher + + +class TestOAuth2CredentialFetcher: + """Test suite for OAuth2CredentialFetcher.""" + + def test_init(self): + """Test OAuth2CredentialFetcher initialization.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + fetcher = OAuth2CredentialFetcher(scheme, credential) + assert fetcher._auth_scheme == scheme + assert fetcher._auth_credential == credential + + def test_oauth2_session_openid_connect(self): + """Test _oauth2_session with OpenID Connect scheme.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ), + ) + + fetcher = OAuth2CredentialFetcher(scheme, credential) + client, token_endpoint = fetcher._oauth2_session() + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + + def test_oauth2_session_oauth2_scheme(self): + """Test _oauth2_session with OAuth2 scheme.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + scheme = OAuth2(type_="oauth2", flows=flows) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + fetcher = OAuth2CredentialFetcher(scheme, credential) + client, token_endpoint = fetcher._oauth2_session() + + assert client is not None + assert token_endpoint == "https://example.com/token" + + def test_oauth2_session_invalid_scheme(self): + """Test _oauth2_session with invalid scheme.""" + scheme = Mock() # Invalid scheme type + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + fetcher = OAuth2CredentialFetcher(scheme, credential) + client, token_endpoint = fetcher._oauth2_session() + + assert client is None + assert token_endpoint is None + + def test_oauth2_session_missing_credentials(self): + """Test _oauth2_session with missing credentials.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret + ), + ) + + fetcher = OAuth2CredentialFetcher(scheme, credential) + client, token_endpoint = fetcher._oauth2_session() + + assert client is None + assert token_endpoint is None + + def test_update_credential(self): + """Test _update_credential method.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + fetcher = OAuth2CredentialFetcher(scheme, credential) + tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + + fetcher._update_credential(tokens) + + assert credential.oauth2.access_token == "new_access_token" + assert credential.oauth2.refresh_token == "new_refresh_token" + assert credential.oauth2.expires_at == int(time.time()) + 3600 + assert credential.oauth2.expires_in == 3600 + + def test_exchange_with_existing_token(self): + """Test exchange method when access token already exists.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + ), + ) + + fetcher = OAuth2CredentialFetcher(scheme, credential) + result = fetcher.exchange() + + assert result == credential + assert result.oauth2.access_token == "existing_token" + + @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") + def test_exchange_success(self, mock_oauth2_session): + """Test successful token exchange.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri=( + "https://example.com/callback?code=auth_code&state=test_state" + ), + ), + ) + + # Mock the OAuth2Session + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = { + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + } + mock_client.fetch_token.return_value = mock_tokens + + fetcher = OAuth2CredentialFetcher(scheme, credential) + result = fetcher.exchange() + + assert result.oauth2.access_token == "new_access_token" + assert result.oauth2.refresh_token == "new_refresh_token" + mock_client.fetch_token.assert_called_once() + + @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") + def test_exchange_with_auth_code(self, mock_oauth2_session): + """Test token exchange with auth code.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_code="test_auth_code", + ), + ) + + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = { + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + } + mock_client.fetch_token.return_value = mock_tokens + + fetcher = OAuth2CredentialFetcher(scheme, credential) + result = fetcher.exchange() + + assert result.oauth2.access_token == "new_access_token" + mock_client.fetch_token.assert_called_once() + + def test_exchange_no_session(self): + """Test exchange when OAuth2Session cannot be created.""" + scheme = Mock() # Invalid scheme + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + ), + ) + + fetcher = OAuth2CredentialFetcher(scheme, credential) + result = fetcher.exchange() + + assert result == credential + assert result.oauth2.access_token is None + + @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") + @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") + def test_refresh_token_not_expired( + self, mock_oauth2_session, mock_oauth2_token + ): + """Test refresh when token is not expired.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="current_token", + refresh_token="refresh_token", + expires_at=int(time.time()) + 3600, + expires_in=3600, + ), + ) + + # Mock token not expired + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = False + mock_oauth2_token.return_value = mock_token_instance + + fetcher = OAuth2CredentialFetcher(scheme, credential) + result = fetcher.refresh() + + assert result == credential + assert result.oauth2.access_token == "current_token" + mock_oauth2_session.assert_not_called() + + @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") + @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") + def test_refresh_token_expired_success( + self, mock_oauth2_session, mock_oauth2_token + ): + """Test successful token refresh when token is expired.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="expired_token", + refresh_token="refresh_token", + expires_at=int(time.time()) - 3600, # Expired + expires_in=3600, + ), + ) + + # Mock token expired + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + # Mock refresh token response + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = { + "access_token": "refreshed_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + } + mock_client.refresh_token.return_value = mock_tokens + + fetcher = OAuth2CredentialFetcher(scheme, credential) + result = fetcher.refresh() + + assert result.oauth2.access_token == "refreshed_access_token" + assert result.oauth2.refresh_token == "new_refresh_token" + mock_client.refresh_token.assert_called_once_with( + url="https://example.com/token", + refresh_token="refresh_token", + ) + + def test_refresh_no_oauth2_credential(self): + """Test refresh when oauth2 credential is missing.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP) # No oauth2 + + fetcher = OAuth2CredentialFetcher(scheme, credential) + result = fetcher.refresh() + + assert result == credential + + @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") + def test_refresh_no_session(self, mock_oauth2_token): + """Test refresh when OAuth2Session cannot be created.""" + scheme = Mock() # Invalid scheme + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="expired_token", + refresh_token="refresh_token", + expires_at=int(time.time()) - 3600, + ), + ) + + # Mock token expired + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + fetcher = OAuth2CredentialFetcher(scheme, credential) + result = fetcher.refresh() + + assert result == credential + assert result.oauth2.access_token == "expired_token" # Unchanged From 177980106b2f7be9a8c0a02f395ff0f85faa0c5a Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 12 Jun 2025 20:07:06 -0700 Subject: [PATCH 14/61] feat: Support refresh access token automatically for rest_api_tool 1. let auth_handler.py to utilize the oauth2 credential fetcher to exchange token 2. restructure tool_auth_handler.py to support refresh token PiperOrigin-RevId: 770901469 --- src/google/adk/auth/auth_handler.py | 82 +++--------------- .../openapi_spec_parser/tool_auth_handler.py | 85 ++++++++++++------- tests/unittests/auth/test_auth_handler.py | 7 +- .../test_tool_auth_handler.py | 69 ++++++++++++++- 4 files changed, 138 insertions(+), 105 deletions(-) diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 1607dcadf..5f80ee3f1 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -16,16 +16,13 @@ from typing import TYPE_CHECKING -from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import SecurityBase from .auth_credential import AuthCredential -from .auth_credential import AuthCredentialTypes -from .auth_credential import OAuth2Auth from .auth_schemes import AuthSchemeType -from .auth_schemes import OAuthGrantType from .auth_schemes import OpenIdConnectWithConfig from .auth_tool import AuthConfig +from .oauth2_credential_fetcher import OAuth2CredentialFetcher if TYPE_CHECKING: from ..sessions.state import State @@ -33,9 +30,9 @@ try: from authlib.integrations.requests_client import OAuth2Session - SUPPORT_TOKEN_EXCHANGE = True + AUTHLIB_AVIALABLE = True except ImportError: - SUPPORT_TOKEN_EXCHANGE = False + AUTHLIB_AVIALABLE = False class AuthHandler: @@ -46,69 +43,9 @@ def __init__(self, auth_config: AuthConfig): def exchange_auth_token( self, ) -> AuthCredential: - """Generates an auth token from the authorization response. - - Returns: - An AuthCredential object containing the access token. - - Raises: - ValueError: If the token endpoint is not configured in the auth - scheme. - AuthCredentialMissingError: If the access token cannot be retrieved - from the token endpoint. - """ - auth_scheme = self.auth_config.auth_scheme - auth_credential = self.auth_config.exchanged_auth_credential - if not SUPPORT_TOKEN_EXCHANGE: - return auth_credential - if isinstance(auth_scheme, OpenIdConnectWithConfig): - if not hasattr(auth_scheme, "token_endpoint"): - return self.auth_config.exchanged_auth_credential - token_endpoint = auth_scheme.token_endpoint - scopes = auth_scheme.scopes - elif isinstance(auth_scheme, OAuth2): - if ( - not auth_scheme.flows.authorizationCode - or not auth_scheme.flows.authorizationCode.tokenUrl - ): - return self.auth_config.exchanged_auth_credential - token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl - scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) - else: - return self.auth_config.exchanged_auth_credential - - if ( - not auth_credential - or not auth_credential.oauth2 - or not auth_credential.oauth2.client_id - or not auth_credential.oauth2.client_secret - or auth_credential.oauth2.access_token - or auth_credential.oauth2.refresh_token - ): - return self.auth_config.exchanged_auth_credential - - client = OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, - scope=" ".join(scopes), - redirect_uri=auth_credential.oauth2.redirect_uri, - state=auth_credential.oauth2.state, - ) - tokens = client.fetch_token( - token_endpoint, - authorization_response=auth_credential.oauth2.auth_response_uri, - code=auth_credential.oauth2.auth_code, - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - ) - - updated_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - access_token=tokens.get("access_token"), - refresh_token=tokens.get("refresh_token"), - ), - ) - return updated_credential + return OAuth2CredentialFetcher( + self.auth_config.auth_scheme, self.auth_config.exchanged_auth_credential + ).exchange() def parse_and_store_auth_response(self, state: State) -> None: @@ -204,6 +141,13 @@ def generate_auth_uri( ValueError: If the authorization endpoint is not configured in the auth scheme. """ + if not AUTHLIB_AVIALABLE: + return ( + self.auth_config.raw_auth_credential.model_copy(deep=True) + if self.auth_config.raw_auth_credential + else None + ) + auth_scheme = self.auth_config.auth_scheme auth_credential = self.auth_config.raw_auth_credential diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index 71e760e30..c36793fdc 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import logging from typing import Literal from typing import Optional -from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from ....auth.auth_credential import AuthCredential @@ -25,6 +25,7 @@ from ....auth.auth_schemes import AuthScheme from ....auth.auth_schemes import AuthSchemeType from ....auth.auth_tool import AuthConfig +from ....auth.oauth2_credential_fetcher import OAuth2CredentialFetcher from ...tool_context import ToolContext from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError @@ -95,10 +96,9 @@ def store_credential( auth_credential: Optional[AuthCredential], ): if self.tool_context: - serializable_credential = jsonable_encoder( - auth_credential, exclude_none=True + self.tool_context.state[key] = auth_credential.model_dump( + exclude_none=True ) - self.tool_context.state[key] = serializable_credential def remove_credential(self, key: str): del self.tool_context.state[key] @@ -146,20 +146,20 @@ def from_tool_context( credential_store, ) - def _handle_existing_credential( + def _get_existing_credential( self, - ) -> Optional[AuthPreparationResult]: + ) -> Optional[AuthCredential]: """Checks for and returns an existing, exchanged credential.""" if self.credential_store: existing_credential = self.credential_store.get_credential( self.auth_scheme, self.auth_credential ) if existing_credential: - return AuthPreparationResult( - state="done", - auth_scheme=self.auth_scheme, - auth_credential=existing_credential, - ) + if existing_credential.oauth2: + existing_credential = OAuth2CredentialFetcher( + self.auth_scheme, existing_credential + ).refresh() + return existing_credential return None def _exchange_credential( @@ -223,6 +223,17 @@ def _get_auth_response(self) -> AuthCredential: ) ) + def _external_exchange_required(self, credential) -> bool: + return ( + credential.auth_type + in ( + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + ) + and not credential.oauth2.access_token + and not credential.google_oauth2_json + ) + def prepare_auth_credentials( self, ) -> AuthPreparationResult: @@ -233,31 +244,41 @@ def prepare_auth_credentials( return AuthPreparationResult(state="done") # Check for existing credential. - existing_result = self._handle_existing_credential() - if existing_result: - return existing_result + existing_credential = self._get_existing_credential() + credential = existing_credential or self.auth_credential # fetch credential from adk framework # Some auth scheme like OAuth2 AuthCode & OpenIDConnect may require # multi-step exchange: # client_id , client_secret -> auth_uri -> auth_code -> access_token - # -> bearer token # adk framework supports exchange access_token already - fetched_credential = self._get_auth_response() or self.auth_credential - - exchanged_credential = self._exchange_credential(fetched_credential) + # for other credential, adk can also get back the credential directly + if not credential or self._external_exchange_required(credential): + credential = self._get_auth_response() + # store fetched credential + if credential: + self._store_credential(credential) + else: + self._request_credential() + return AuthPreparationResult( + state="pending", + auth_scheme=self.auth_scheme, + auth_credential=self.auth_credential, + ) - if exchanged_credential: - self._store_credential(exchanged_credential) - return AuthPreparationResult( - state="done", - auth_scheme=self.auth_scheme, - auth_credential=exchanged_credential, - ) - else: - self._request_credential() - return AuthPreparationResult( - state="pending", - auth_scheme=self.auth_scheme, - auth_credential=self.auth_credential, - ) + # here exchangers are doing two different thing: + # for service account the exchanger is doing actualy token exchange + # while for oauth2 it's actually doing the credentail conversion + # from OAuth2 credential to HTTP credentails for setting credential in + # http header + # TODO cleanup the logic: + # 1. service account token exchanger should happen before we store them in + # the token store + # 2. blow line should only do credential conversion + + exchanged_credential = self._exchange_credential(credential) + return AuthPreparationResult( + state="done", + auth_scheme=self.auth_scheme, + auth_credential=exchanged_credential, + ) diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index c717682ac..402e4d0cd 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -449,7 +449,7 @@ def test_token_exchange_not_supported( ): """Test when token exchange is not supported.""" monkeypatch.setattr( - "google.adk.auth.auth_handler.SUPPORT_TOKEN_EXCHANGE", False + "google.adk.auth.oauth2_credential_fetcher.AUTHLIB_AVIALABLE", False ) handler = AuthHandler(auth_config_with_auth_code) @@ -537,7 +537,10 @@ def test_credentials_with_token( assert result == oauth2_credentials_with_token - @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session) + @patch( + "google.adk.auth.oauth2_credential_fetcher.OAuth2Session", + MockOAuth2Session, + ) def test_successful_token_exchange(self, auth_config_with_auth_code): """Test a successful token exchange.""" handler = AuthHandler(auth_config_with_auth_code) diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py index 0a3b8ccce..8db151fc8 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py @@ -14,6 +14,7 @@ from typing import Optional from unittest.mock import MagicMock +from unittest.mock import patch from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent @@ -147,10 +148,11 @@ def test_openid_connect_with_auth_response( tool_context = create_mock_tool_context() mock_auth_handler = MagicMock() - mock_auth_handler.get_auth_response.return_value = AuthCredential( + returned_credentail = AuthCredential( auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'), ) + mock_auth_handler.get_auth_response.return_value = returned_credentail mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler' monkeypatch.setattr( mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler @@ -172,7 +174,7 @@ def test_openid_connect_with_auth_response( stored_credential = credential_store.get_credential( openid_connect_scheme, openid_connect_credential ) - assert stored_credential == result.auth_credential + assert stored_credential == returned_credentail mock_auth_handler.get_auth_response.assert_called_once() @@ -199,3 +201,66 @@ def test_openid_connect_existing_token( result = handler.prepare_auth_credentials() assert result.state == 'done' assert result.auth_credential == existing_credential + + +@patch( + 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialFetcher' +) +def test_openid_connect_existing_oauth2_token_refresh( + mock_oauth2_fetcher, openid_connect_scheme, openid_connect_credential +): + """Test that OAuth2 tokens are refreshed when existing credentials are found.""" + # Create existing OAuth2 credential + existing_credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id='test_client_id', + client_secret='test_client_secret', + access_token='existing_token', + refresh_token='refresh_token', + ), + ) + + # Mock the refreshed credential + refreshed_credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id='test_client_id', + client_secret='test_client_secret', + access_token='refreshed_token', + refresh_token='new_refresh_token', + ), + ) + + # Setup mock OAuth2CredentialFetcher + mock_fetcher_instance = MagicMock() + mock_fetcher_instance.refresh.return_value = refreshed_credential + mock_oauth2_fetcher.return_value = mock_fetcher_instance + + tool_context = create_mock_tool_context() + credential_store = ToolContextCredentialStore(tool_context=tool_context) + + # Store the existing credential + key = credential_store.get_credential_key( + openid_connect_scheme, openid_connect_credential + ) + credential_store.store_credential(key, existing_credential) + + handler = ToolAuthHandler( + tool_context, + openid_connect_scheme, + openid_connect_credential, + credential_store=credential_store, + ) + + result = handler.prepare_auth_credentials() + + # Verify OAuth2CredentialFetcher was called for refresh + mock_oauth2_fetcher.assert_called_once_with( + openid_connect_scheme, existing_credential + ) + mock_fetcher_instance.refresh.assert_called_once() + + assert result.state == 'done' + # The result should contain the refreshed credential after exchange + assert result.auth_credential is not None From dbdeb49090311757b12ba00ed5071abb53acabc2 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 12 Jun 2025 20:29:36 -0700 Subject: [PATCH 15/61] chore: Add a2a-sdk to pyproject.toml PiperOrigin-RevId: 770908052 --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 0a973f947..8ece4db81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,12 @@ dev = [ # go/keep-sorted end ] +a2a = [ + # go/keep-sorted start + "a2a-sdk>=0.2.7;python_version>='3.10'" + # go/keep-sorted end +] + eval = [ # go/keep-sorted start "google-cloud-aiplatform[evaluation]>=1.87.0", From 40b15ad278c709bc3c0dd0748332c0934dc8530f Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 12 Jun 2025 23:10:24 -0700 Subject: [PATCH 16/61] refactor: enhance mcp tool session management 1. remove unnecessary cached session instance in mcp toolset 2. move session reinitialization logic from mcp tool and mcp toolset to mcp session manager 3. add lock for the code block of session creation to avoid race conditions PiperOrigin-RevId: 770949529 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 206 ++++++++++-------- src/google/adk/tools/mcp_tool/mcp_tool.py | 8 +- src/google/adk/tools/mcp_tool/mcp_toolset.py | 23 +- 3 files changed, 120 insertions(+), 117 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 3a07a6fe1..5bc06e398 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio from contextlib import AsyncExitStack from datetime import timedelta import functools @@ -34,7 +35,6 @@ from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client except ImportError as e: - import sys if sys.version_info < (3, 10): raise ImportError( @@ -105,30 +105,29 @@ class StreamableHTTPConnectionParams(BaseModel): terminate_on_close: bool = True -def retry_on_closed_resource(async_reinit_func_name: str): +def retry_on_closed_resource(session_manager_field_name: str): """Decorator to automatically reinitialize session and retry action. When MCP session was closed, the decorator will automatically recreate the session and retry the action with the same parameters. Note: - 1. async_reinit_func_name is the name of the class member function that - reinitializes the MCP session. - 2. Both the decorated function and the async_reinit_func_name must be async - functions. + 1. session_manager_field_name is the name of the class member field that + contains the MCPSessionManager instance. + 2. The session manager must have a reinitialize_session() async method. Usage: class MCPTool: - ... - async def create_session(self): - self.session = ... + def __init__(self): + self._mcp_session_manager = MCPSessionManager(...) - @retry_on_closed_resource('create_session') + @retry_on_closed_resource('_mcp_session_manager') async def use_session(self): - await self.session.call_tool() + session = await self._mcp_session_manager.create_session() + await session.call_tool() Args: - async_reinit_func_name: The name of the async function to recreate session. + session_manager_field_name: The name of the session manager field. Returns: The decorated function. @@ -141,15 +140,21 @@ async def wrapper(self, *args, **kwargs): return await func(self, *args, **kwargs) except anyio.ClosedResourceError as close_err: try: - if hasattr(self, async_reinit_func_name) and callable( - getattr(self, async_reinit_func_name) - ): - async_init_fn = getattr(self, async_reinit_func_name) - await async_init_fn() + if hasattr(self, session_manager_field_name): + session_manager = getattr(self, session_manager_field_name) + if hasattr(session_manager, 'reinitialize_session') and callable( + getattr(session_manager, 'reinitialize_session') + ): + await session_manager.reinitialize_session() + else: + raise ValueError( + f'Session manager {session_manager_field_name} does not have' + ' reinitialize_session method.' + ) from close_err else: raise ValueError( - f'Function {async_reinit_func_name} does not exist in decorated' - ' class. Please check the function name in' + f'Session manager field {session_manager_field_name} does not' + ' exist in decorated class. Please check the field name in' ' retry_on_closed_resource decorator.' ) from close_err except Exception as reinit_err: @@ -207,6 +212,8 @@ def __init__( # Each session manager maintains its own exit stack for proper cleanup self._exit_stack: Optional[AsyncExitStack] = None self._session: Optional[ClientSession] = None + # Lock to prevent race conditions in session creation + self._session_lock = asyncio.Lock() async def create_session(self) -> ClientSession: """Creates and initializes an MCP client session. @@ -214,83 +221,102 @@ async def create_session(self) -> ClientSession: Returns: ClientSession: The initialized MCP client session. """ + # Fast path: if session already exists, return it without acquiring lock if self._session is not None: return self._session - # Create a new exit stack for this session - self._exit_stack = AsyncExitStack() - - try: - if isinstance(self._connection_params, StdioConnectionParams): - client = stdio_client( - server=self._connection_params.server_params, - errlog=self._errlog, - ) - elif isinstance(self._connection_params, SseConnectionParams): - client = sse_client( - url=self._connection_params.url, - headers=self._connection_params.headers, - timeout=self._connection_params.timeout, - sse_read_timeout=self._connection_params.sse_read_timeout, - ) - elif isinstance(self._connection_params, StreamableHTTPConnectionParams): - client = streamablehttp_client( - url=self._connection_params.url, - headers=self._connection_params.headers, - timeout=timedelta(seconds=self._connection_params.timeout), - sse_read_timeout=timedelta( - seconds=self._connection_params.sse_read_timeout - ), - terminate_on_close=self._connection_params.terminate_on_close, - ) - else: - raise ValueError( - 'Unable to initialize connection. Connection should be' - ' StdioServerParameters or SseServerParams, but got' - f' {self._connection_params}' - ) - - transports = await self._exit_stack.enter_async_context(client) - # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams - # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. - if isinstance(self._connection_params, StdioConnectionParams): - session = await self._exit_stack.enter_async_context( - ClientSession( - *transports[:2], - read_timeout_seconds=timedelta( - seconds=self._connection_params.timeout - ), - ) - ) - else: - session = await self._exit_stack.enter_async_context( - ClientSession(*transports[:2]) - ) - await session.initialize() - - self._session = session - return session - - except Exception: - # If session creation fails, clean up the exit stack - if self._exit_stack: - await self._exit_stack.aclose() - self._exit_stack = None - raise + # Use async lock to prevent race conditions + async with self._session_lock: + # Double-check: session might have been created while waiting for lock + if self._session is not None: + return self._session + + # Create a new exit stack for this session + self._exit_stack = AsyncExitStack() + + try: + if isinstance(self._connection_params, StdioConnectionParams): + client = stdio_client( + server=self._connection_params.server_params, + errlog=self._errlog, + ) + elif isinstance(self._connection_params, SseConnectionParams): + client = sse_client( + url=self._connection_params.url, + headers=self._connection_params.headers, + timeout=self._connection_params.timeout, + sse_read_timeout=self._connection_params.sse_read_timeout, + ) + elif isinstance( + self._connection_params, StreamableHTTPConnectionParams + ): + client = streamablehttp_client( + url=self._connection_params.url, + headers=self._connection_params.headers, + timeout=timedelta(seconds=self._connection_params.timeout), + sse_read_timeout=timedelta( + seconds=self._connection_params.sse_read_timeout + ), + terminate_on_close=self._connection_params.terminate_on_close, + ) + else: + raise ValueError( + 'Unable to initialize connection. Connection should be' + ' StdioServerParameters or SseServerParams, but got' + f' {self._connection_params}' + ) + + transports = await self._exit_stack.enter_async_context(client) + # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams + # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. + if isinstance(self._connection_params, StdioConnectionParams): + session = await self._exit_stack.enter_async_context( + ClientSession( + *transports[:2], + read_timeout_seconds=timedelta( + seconds=self._connection_params.timeout + ), + ) + ) + else: + session = await self._exit_stack.enter_async_context( + ClientSession(*transports[:2]) + ) + await session.initialize() + + self._session = session + return session + + except Exception: + # If session creation fails, clean up the exit stack + if self._exit_stack: + await self._exit_stack.aclose() + self._exit_stack = None + raise async def close(self): """Closes the session and cleans up resources.""" - if self._exit_stack: - try: - await self._exit_stack.aclose() - except Exception as e: - # Log the error but don't re-raise to avoid blocking shutdown - print( - f'Warning: Error during MCP session cleanup: {e}', file=self._errlog - ) - finally: - self._exit_stack = None - self._session = None + if not self._exit_stack: + return + async with self._session_lock: + if self._exit_stack: + try: + await self._exit_stack.aclose() + except Exception as e: + # Log the error but don't re-raise to avoid blocking shutdown + print( + f'Warning: Error during MCP session cleanup: {e}', + file=self._errlog, + ) + finally: + self._exit_stack = None + self._session = None + + async def reinitialize_session(self): + """Reinitializes the session when connection is lost.""" + # Close the old session and create a new one + await self.close() + await self.create_session() SseServerParams = SseConnectionParams diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 463202b18..fac710375 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -105,7 +105,7 @@ def _get_declaration(self) -> FunctionDeclaration: ) return function_decl - @retry_on_closed_resource("_reinitialize_session") + @retry_on_closed_resource("_mcp_session_manager") async def run_async(self, *, args, tool_context: ToolContext): """Runs the tool asynchronously. @@ -122,9 +122,3 @@ async def run_async(self, *, args, tool_context: ToolContext): # TODO(cheliu): Support passing tool context to MCP Server. response = await session.call_tool(self.name, arguments=args) return response - - async def _reinitialize_session(self): - """Reinitializes the session when connection is lost.""" - # Close the old session and create a new one - await self._mcp_session_manager.close() - await self._mcp_session_manager.create_session() diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 8076752b4..f55693e86 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -28,10 +28,8 @@ from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_closed_resource from .mcp_session_manager import SseConnectionParams -from .mcp_session_manager import SseServerParams from .mcp_session_manager import StdioConnectionParams from .mcp_session_manager import StreamableHTTPConnectionParams -from .mcp_session_manager import StreamableHTTPServerParams # Attempt to import MCP Tool from the MCP library, and hints user to upgrade # their Python version to 3.10 if it fails. @@ -127,9 +125,7 @@ def __init__( errlog=self._errlog, ) - self._session = None - - @retry_on_closed_resource("_reinitialize_session") + @retry_on_closed_resource("_mcp_session_manager") async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None, @@ -144,11 +140,10 @@ async def get_tools( List[BaseTool]: A list of tools available under the specified context. """ # Get session from session manager - if not self._session: - self._session = await self._mcp_session_manager.create_session() + session = await self._mcp_session_manager.create_session() # Fetch available tools from the MCP server - tools_response: ListToolsResult = await self._session.list_tools() + tools_response: ListToolsResult = await session.list_tools() # Apply filtering based on context and tool_filter tools = [] @@ -162,14 +157,6 @@ async def get_tools( tools.append(mcp_tool) return tools - async def _reinitialize_session(self): - """Reinitializes the session when connection is lost.""" - # Close the old session and clear cache - await self._mcp_session_manager.close() - self._session = await self._mcp_session_manager.create_session() - - # Tools will be reloaded on next get_tools call - async def close(self) -> None: """Performs cleanup and releases resources held by the toolset. @@ -182,7 +169,3 @@ async def close(self) -> None: except Exception as e: # Log the error but don't re-raise to avoid blocking shutdown print(f"Warning: Error during MCPToolset cleanup: {e}", file=self._errlog) - finally: - # Clear cached tools - self._tools_cache = None - self._tools_loaded = False From cb5597096902525d791c7c9e85b3319a93c5456e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 13 Jun 2025 10:09:53 -0700 Subject: [PATCH 17/61] refactor: Simplify agent_tool.py PiperOrigin-RevId: 771135917 --- src/google/adk/tools/agent_tool.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 9c4de3660..d1137b58e 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -97,17 +97,6 @@ async def run_async( if isinstance(self.agent, LlmAgent) and self.agent.input_schema: input_value = self.agent.input_schema.model_validate(args) - else: - input_value = args['request'] - - if isinstance(self.agent, LlmAgent) and self.agent.input_schema: - if isinstance(input_value, dict): - input_value = self.agent.input_schema.model_validate(input_value) - if not isinstance(input_value, self.agent.input_schema): - raise ValueError( - f'Input value {input_value} is not of type' - f' `{self.agent.input_schema}`.' - ) content = types.Content( role='user', parts=[ @@ -119,7 +108,7 @@ async def run_async( else: content = types.Content( role='user', - parts=[types.Part.from_text(text=input_value)], + parts=[types.Part.from_text(text=args['request'])], ) runner = Runner( app_name=self.agent.name, @@ -145,15 +134,11 @@ async def run_async( if not last_event or not last_event.content or not last_event.content.parts: return '' + merged_text = '\n'.join(p.text for p in last_event.content.parts if p.text) if isinstance(self.agent, LlmAgent) and self.agent.output_schema: - merged_text = '\n'.join( - [p.text for p in last_event.content.parts if p.text] - ) tool_result = self.agent.output_schema.model_validate_json( merged_text ).model_dump(exclude_none=True) else: - tool_result = '\n'.join( - [p.text for p in last_event.content.parts if p.text] - ) + tool_result = merged_text return tool_result From 8e285874da7f5188ea228eb4d7262dbb33b1ae6f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 13 Jun 2025 10:36:34 -0700 Subject: [PATCH 18/61] feat: Add integration tests for litellm with and without turn on add_function_to_prompt Add experiments for https://github.com/google/adk-python/issues/1273 PiperOrigin-RevId: 771145715 --- .../models/test_litellm_no_function.py | 65 +++++++++++ .../models/test_litellm_with_function.py | 104 ++++++++++++++++++ 2 files changed, 169 insertions(+) create mode 100644 tests/integration/models/test_litellm_no_function.py create mode 100644 tests/integration/models/test_litellm_with_function.py diff --git a/tests/integration/models/test_litellm_no_function.py b/tests/integration/models/test_litellm_no_function.py new file mode 100644 index 000000000..e662384ce --- /dev/null +++ b/tests/integration/models/test_litellm_no_function.py @@ -0,0 +1,65 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.models import LlmRequest +from google.adk.models import LlmResponse +from google.adk.models.lite_llm import LiteLlm +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + +_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" + + +_SYSTEM_PROMPT = """You are a helpful assistant.""" + + +@pytest.fixture +def oss_llm(): + return LiteLlm(model=_TEST_MODEL_NAME) + + +@pytest.fixture +def llm_request(): + return LlmRequest( + model=_TEST_MODEL_NAME, + contents=[Content(role="user", parts=[Part.from_text(text="hello")])], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction=_SYSTEM_PROMPT, + ), + ) + + +@pytest.mark.asyncio +async def test_generate_content_async(oss_llm, llm_request): + async for response in oss_llm.generate_content_async(llm_request): + assert isinstance(response, LlmResponse) + assert response.content.parts[0].text + + +# Note that, this test disabled streaming because streaming is not supported +# properly in the current test model for now. +@pytest.mark.asyncio +async def test_generate_content_async_stream(oss_llm, llm_request): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request, stream=False + ) + ] + part = responses[0].content.parts[0] + assert len(part.text) > 0 diff --git a/tests/integration/models/test_litellm_with_function.py b/tests/integration/models/test_litellm_with_function.py new file mode 100644 index 000000000..a2ceb540a --- /dev/null +++ b/tests/integration/models/test_litellm_with_function.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.models import LlmRequest +from google.adk.models import LlmResponse +from google.adk.models.lite_llm import LiteLlm +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import litellm +import pytest + +litellm.add_function_to_prompt = True + +_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" + + +_SYSTEM_PROMPT = """ +You are a helpful assistant, and call tools optionally. +If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs. +""" + + +_FUNCTIONS = [{ + "name": "get_weather", + "description": "Get the weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city, e.g. San Francisco", + }, + }, + "required": ["city"], + }, +}] + + +def get_weather(city: str) -> str: + """Simulates a web search. Use it get information on weather. + + Args: + city: A string containing the location to get weather information for. + + Returns: + A string with the simulated weather information for the queried city. + """ + if "sf" in city.lower() or "san francisco" in city.lower(): + return "It's 70 degrees and foggy." + return "It's 80 degrees and sunny." + + +@pytest.fixture +def oss_llm_with_function(): + return LiteLlm(model=_TEST_MODEL_NAME, functions=_FUNCTIONS) + + +@pytest.fixture +def llm_request(): + return LlmRequest( + model=_TEST_MODEL_NAME, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="What is the weather in San Francisco?") + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction=_SYSTEM_PROMPT, + ), + ) + + +# Note that, this test disabled streaming because streaming is not supported +# properly in the current test model for now. +@pytest.mark.asyncio +async def test_generate_content_asyn_with_function( + oss_llm_with_function, llm_request +): + responses = [ + resp + async for resp in oss_llm_with_function.generate_content_async( + llm_request, stream=False + ) + ] + function_call = responses[0].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" From 131957c531899be6d362c1f9a229274437dfe4b0 Mon Sep 17 00:00:00 2001 From: Selcuk Gun Date: Fri, 13 Jun 2025 10:55:57 -0700 Subject: [PATCH 19/61] chore: Triaging agent improvements & github workflow * modified list issues to only return unlabelled open issues * added github workflow to run on schedule and issue open/reopen * interactive/workflow modes * readme document PiperOrigin-RevId: 771152306 --- .github/workflows/triage.yml | 43 +++++ .../samples/adk_triaging_agent/README.md | 67 +++++++ .../samples/adk_triaging_agent/__init__.py | 15 ++ .../samples/adk_triaging_agent/agent.py | 115 +++++++----- .../samples/adk_triaging_agent/main.py | 164 ++++++++++++++++++ 5 files changed, 360 insertions(+), 44 deletions(-) create mode 100644 .github/workflows/triage.yml create mode 100644 contributing/samples/adk_triaging_agent/README.md create mode 100755 contributing/samples/adk_triaging_agent/__init__.py create mode 100644 contributing/samples/adk_triaging_agent/main.py diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml new file mode 100644 index 000000000..2e258857f --- /dev/null +++ b/.github/workflows/triage.yml @@ -0,0 +1,43 @@ +name: ADK Issue Triaging Agent + +on: + issues: + types: [opened, reopened] + schedule: + - cron: '0 */6 * * *' # every 6h + +jobs: + agent-triage-issues: + runs-on: ubuntu-latest + permissions: + issues: write + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install requests google-adk + + - name: Run Triaging Script + env: + GITHUB_TOKEN: ${{ secrets.ADK_TRIAGE_AGENT }} + GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} + GOOGLE_GENAI_USE_VERTEXAI: 0 + OWNER: 'google' + REPO: 'adk-python' + INTERACTIVE: 0 + EVENT_NAME: ${{ github.event_name }} # 'issues', 'schedule', etc. + ISSUE_NUMBER: ${{ github.event.issue.number }} + ISSUE_TITLE: ${{ github.event.issue.title }} + ISSUE_BODY: ${{ github.event.issue.body }} + ISSUE_COUNT_TO_PROCESS: '3' # Process 3 issues at a time on schedule + run: python contributing/samples/adk_triaging_agent/main.py diff --git a/contributing/samples/adk_triaging_agent/README.md b/contributing/samples/adk_triaging_agent/README.md new file mode 100644 index 000000000..be4071b61 --- /dev/null +++ b/contributing/samples/adk_triaging_agent/README.md @@ -0,0 +1,67 @@ +# ADK Issue Triaging Assistant + +The ADK Issue Triaging Assistant is a Python-based agent designed to help manage and triage GitHub issues for the `google/adk-python` repository. It uses a large language model to analyze new and unlabelled issues, recommend appropriate labels based on a predefined set of rules, and apply them. + +This agent can be operated in two distinct modes: an interactive mode for local use or as a fully automated GitHub Actions workflow. + +--- + +## Interactive Mode + +This mode allows you to run the agent locally to review its recommendations in real-time before any changes are made to your repository's issues. + +### Features +* **Web Interface**: The agent's interactive mode can be rendered in a web browser using the ADK's `adk web` command. +* **User Approval**: In interactive mode, the agent is instructed to ask for your confirmation before applying a label to a GitHub issue. + +### Running in Interactive Mode +To run the agent in interactive mode, first set the required environment variables. Then, execute the following command in your terminal: + +```bash +adk web +``` +This will start a local server and provide a URL to access the agent's web interface in your browser. + +--- + +## GitHub Workflow Mode + +For automated, hands-off issue triaging, the agent can be integrated directly into your repository's CI/CD pipeline using a GitHub Actions workflow. + +### Workflow Triggers +The GitHub workflow is configured to run on specific triggers: + +1. **Issue Events**: The workflow executes automatically whenever a new issue is `opened` or an existing one is `reopened`. + +2. **Scheduled Runs**: The workflow also runs on a recurring schedule (every 6 hours) to process any unlabelled issues that may have been missed. + +### Automated Labeling +When running as part of the GitHub workflow, the agent operates non-interactively. It identifies the best label and applies it directly without requiring user approval. This behavior is configured by setting the `INTERACTIVE` environment variable to `0` in the workflow file. + +### Workflow Configuration +The workflow is defined in a YAML file (`.github/workflows/triage.yml`). This file contains the steps to check out the code, set up the Python environment, install dependencies, and run the triaging script with the necessary environment variables and secrets. + +--- + +## Setup and Configuration + +Whether running in interactive or workflow mode, the agent requires the following setup. + +### Dependencies +The agent requires the following Python libraries. + +```bash +pip install --upgrade pip +pip install google-adk requests +``` + +### Environment Variables +The following environment variables are required for the agent to connect to the necessary services. + +* `GITHUB_TOKEN`: **(Required)** A GitHub Personal Access Token with `issues:write` permissions. Needed for both interactive and workflow modes. +* `GOOGLE_API_KEY`: **(Required)** Your API key for the Gemini API. Needed for both interactive and workflow modes. +* `OWNER`: The GitHub organization or username that owns the repository (e.g., `google`). Needed for both modes. +* `REPO`: The name of the GitHub repository (e.g., `adk-python`). Needed for both modes. +* `INTERACTIVE`: Controls the agent's interaction mode. For the automated workflow, this is set to `0`. For interactive mode, it should be set to `1` or left unset. + +For local execution in interactive mode, you can place these variables in a `.env` file in the project's root directory. For the GitHub workflow, they should be configured as repository secrets. \ No newline at end of file diff --git a/contributing/samples/adk_triaging_agent/__init__.py b/contributing/samples/adk_triaging_agent/__init__.py new file mode 100755 index 000000000..c48963cdc --- /dev/null +++ b/contributing/samples/adk_triaging_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/adk_triaging_agent/agent.py b/contributing/samples/adk_triaging_agent/agent.py index 2720e5b46..ecf574572 100644 --- a/contributing/samples/adk_triaging_agent/agent.py +++ b/contributing/samples/adk_triaging_agent/agent.py @@ -13,61 +13,73 @@ # limitations under the License. import os -import random -import time from google.adk import Agent -from google.adk.tools.tool_context import ToolContext -from google.genai import types import requests -# Read the PAT from the environment variable -GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") # Ensure you've set this in your shell +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") if not GITHUB_TOKEN: raise ValueError("GITHUB_TOKEN environment variable not set") -# Repository information -OWNER = "google" -REPO = "adk-python" +OWNER = os.getenv("OWNER", "google") +REPO = os.getenv("REPO", "adk-python") +BOT_LABEL = os.getenv("BOT_LABEL", "bot_triaged") -# Base URL for the GitHub API BASE_URL = "https://api.github.com" -# Headers including the Authorization header headers = { "Authorization": f"token {GITHUB_TOKEN}", "Accept": "application/vnd.github.v3+json", } +ALLOWED_LABELS = [ + "documentation", + "services", + "question", + "tools", + "eval", + "live", + "models", + "tracing", + "core", + "web", +] -def list_issues(per_page: int): + +def is_interactive(): + return os.environ.get("INTERACTIVE", "1").lower() in ["true", "1"] + + +def list_issues(issue_count: int): """ Generator to list all issues for the repository by handling pagination. Args: - per_page: number of pages to return per page. + issue_count: number of issues to return """ - state = "open" - # only process the 1st page for testing for now - page = 1 - results = [] - url = ( # :contentReference[oaicite:16]{index=16} - f"{BASE_URL}/repos/{OWNER}/{REPO}/issues" - ) - # Warning: let's only handle max 10 issues at a time to avoid bad results - params = {"state": state, "per_page": per_page, "page": page} - response = requests.get(url, headers=headers, params=params) - response.raise_for_status() # :contentReference[oaicite:17]{index=17} - issues = response.json() + query = f"repo:{OWNER}/{REPO} is:open is:issue no:label" + + unlabelled_issues = [] + url = f"{BASE_URL}/search/issues" + + params = { + "q": query, + "sort": "created", + "order": "desc", + "per_page": issue_count, + "page": 1, + } + response = requests.get(url, headers=headers, params=params, timeout=60) + response.raise_for_status() + json_response = response.json() + issues = json_response.get("items", None) if not issues: return [] for issue in issues: - # Skip pull requests (issues API returns PRs as well) - if "pull_request" in issue: - continue - results.append(issue) - return results + if not issue.get("labels", None) or len(issue["labels"]) == 0: + unlabelled_issues.append(issue) + return unlabelled_issues def add_label_to_issue(issue_number: str, label: str): @@ -78,41 +90,56 @@ def add_label_to_issue(issue_number: str, label: str): issue_number: issue number of the Github issue, in string foramt. label: label to assign """ + print(f"Attempting to add label '{label}' to issue #{issue_number}") + if label not in ALLOWED_LABELS: + error_message = ( + f"Error: Label '{label}' is not an allowed label. Will not apply." + ) + print(error_message) + return {"status": "error", "message": error_message, "applied_label": None} + url = f"{BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}/labels" - payload = [label] - response = requests.post(url, headers=headers, json=payload) + payload = [label, BOT_LABEL] + response = requests.post(url, headers=headers, json=payload, timeout=60) response.raise_for_status() return response.json() +approval_instruction = ( + "Only label them when the user approves the labeling!" + if is_interactive() + else ( + "Do not ask for user approval for labeling! If you can't find a" + " appropriate labels for the issue, do not label it." + ) +) + root_agent = Agent( model="gemini-2.5-pro-preview-05-06", name="adk_triaging_assistant", description="Triage ADK issues.", - instruction=""" - You are a Github adk-python repo triaging bot. You will help get issues, and label them. + instruction=f""" + You are a Github adk-python repo triaging bot. You will help get issues, and recommend a label. + IMPORTANT: {approval_instruction} Here are the rules for labeling: - If the user is asking about documentation-related questions, label it with "documentation". - If it's about session, memory services, label it with "services" - - If it's about UI/web, label it with "question" + - If it's about UI/web, label it with "web" + - If the user is asking about a question, label it with "question" - If it's related to tools, label it with "tools" - If it's about agent evalaution, then label it with "eval". - If it's about streaming/live, label it with "live". - If it's about model support(non-Gemini, like Litellm, Ollama, OpenAI models), label it with "models". - If it's about tracing, label it with "tracing". - If it's agent orchestration, agent definition, label it with "core". - - If you can't find a appropriate labels for the issue, return the issues to user to decide. + - If you can't find a appropriate labels for the issue, follow the previous instruction that starts with "IMPORTANT:". + + Present the followings in an easy to read format highlighting issue number and your label. + - the issue summary in a few sentence + - your label recommendation and justification """, tools=[ list_issues, add_label_to_issue, ], - generate_content_config=types.GenerateContentConfig( - safety_settings=[ - types.SafetySetting( # avoid false alarm about rolling dice. - category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=types.HarmBlockThreshold.OFF, - ), - ] - ), ) diff --git a/contributing/samples/adk_triaging_agent/main.py b/contributing/samples/adk_triaging_agent/main.py new file mode 100644 index 000000000..a749b26fc --- /dev/null +++ b/contributing/samples/adk_triaging_agent/main.py @@ -0,0 +1,164 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import time + +import agent +from dotenv import load_dotenv +from google.adk.agents.run_config import RunConfig +from google.adk.runners import InMemoryRunner +from google.adk.sessions import Session +from google.genai import types +import requests + +load_dotenv(override=True) + +OWNER = os.getenv("OWNER", "google") +REPO = os.getenv("REPO", "adk-python") +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +BASE_URL = "https://api.github.com" +headers = { + "Authorization": f"token {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +} + +if not GITHUB_TOKEN: + print( + "Warning: GITHUB_TOKEN environment variable not set. API calls might" + " fail." + ) + + +async def fetch_specific_issue_details(issue_number: int): + """Fetches details for a single issue if it's unlabelled.""" + if not GITHUB_TOKEN: + print("Cannot fetch issue details: GITHUB_TOKEN is not set.") + return None + + url = f"{BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}" + print(f"Fetching details for specific issue: {url}") + try: + response = requests.get(url, headers=headers, timeout=60) + response.raise_for_status() + issue_data = response.json() + if not issue_data.get("labels") or len(issue_data["labels"]) == 0: + print(f"Issue #{issue_number} is unlabelled. Proceeding.") + return { + "number": issue_data["number"], + "title": issue_data["title"], + "body": issue_data.get("body", ""), + } + else: + print(f"Issue #{issue_number} is already labelled. Skipping.") + return None + except requests.exceptions.RequestException as e: + print(f"Error fetching issue #{issue_number}: {e}") + if hasattr(e, "response") and e.response is not None: + print(f"Response content: {e.response.text}") + return None + + +async def main(): + app_name = "triage_app" + user_id_1 = "triage_user" + runner = InMemoryRunner( + agent=agent.root_agent, + app_name=app_name, + ) + session_11 = await runner.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_agent_prompt(session: Session, prompt_text: str): + content = types.Content( + role="user", parts=[types.Part.from_text(text=prompt_text)] + ) + print(f"\n>>>> Agent Prompt: {prompt_text}") + final_agent_response_parts = [] + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=False), + ): + if event.content.parts and event.content.parts[0].text: + print(f"** {event.author} (ADK): {event.content.parts[0].text}") + if event.author == agent.root_agent.name: + final_agent_response_parts.append(event.content.parts[0].text) + print(f"<<<< Agent Final Output: {''.join(final_agent_response_parts)}\n") + + event_name = os.getenv("EVENT_NAME") + issue_number_str = os.getenv("ISSUE_NUMBER") + + if event_name == "issues" and issue_number_str: + print(f"EVENT: Processing specific issue due to '{event_name}' event.") + try: + issue_number = int(issue_number_str) + specific_issue = await fetch_specific_issue_details(issue_number) + + if specific_issue: + prompt = ( + f"A new GitHub issue #{specific_issue['number']} has been opened or" + f" reopened. Title: \"{specific_issue['title']}\"\nBody:" + f" \"{specific_issue['body']}\"\n\nBased on the rules, recommend an" + " appropriate label and its justification." + " Then, use the 'add_label_to_issue' tool to apply the label " + "directly to this issue." + f" The issue number is {specific_issue['number']}." + ) + await run_agent_prompt(session_11, prompt) + else: + print( + f"No unlabelled issue details found for #{issue_number} or an error" + " occurred. Skipping agent interaction." + ) + + except ValueError: + print(f"Error: Invalid ISSUE_NUMBER received: {issue_number_str}") + + else: + print(f"EVENT: Processing batch of issues (event: {event_name}).") + issue_count_str = os.getenv("ISSUE_COUNT_TO_PROCESS", "3") + try: + num_issues_to_process = int(issue_count_str) + except ValueError: + print(f"Warning: Invalid ISSUE_COUNT_TO_PROCESS. Defaulting to 3.") + num_issues_to_process = 3 + + prompt = ( + f"List the first {num_issues_to_process} unlabelled open issues from" + f" the {OWNER}/{REPO} repository. For each issue, provide a summary," + " recommend a label with justification, and then use the" + " 'add_label_to_issue' tool to apply the recommended label directly." + ) + await run_agent_prompt(session_11, prompt) + + +if __name__ == "__main__": + start_time = time.time() + print( + "Script start time:", + time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(start_time)), + ) + print("------------------------------------") + asyncio.run(main()) + end_time = time.time() + print("------------------------------------") + print( + "Script end time:", + time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(end_time)), + ) + print("Total script execution time:", f"{end_time - start_time:.2f} seconds") From d129fd636bde64b0b1aae8049950ffbc2f45790c Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 13 Jun 2025 14:52:45 -0700 Subject: [PATCH 20/61] chore: Update the comments of MCPTool PiperOrigin-RevId: 771234262 --- src/google/adk/tools/mcp_tool/mcp_tool.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index fac710375..6553bb2c0 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -49,7 +49,7 @@ class MCPTool(BaseTool): - """Turns a MCP Tool into a Vertex Agent Framework Tool. + """Turns an MCP Tool into an ADK Tool. Internally, the tool initializes from a MCP Tool, and uses the MCP Session to call the tool. @@ -63,9 +63,9 @@ def __init__( auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, ): - """Initializes a MCPTool. + """Initializes an MCPTool. - This tool wraps a MCP Tool interface and uses a session manager to + This tool wraps an MCP Tool interface and uses a session manager to communicate with the MCP server. Args: @@ -111,7 +111,7 @@ async def run_async(self, *, args, tool_context: ToolContext): Args: args: The arguments as a dict to pass to the tool. - tool_context: The tool context from upper level ADK agent. + tool_context: The tool context of the current invocation. Returns: Any: The response from the tool. @@ -119,6 +119,5 @@ async def run_async(self, *, args, tool_context: ToolContext): # Get the session from the session manager session = await self._mcp_session_manager.create_session() - # TODO(cheliu): Support passing tool context to MCP Server. response = await session.call_tool(self.name, arguments=args) return response From 233fd2024346abd7f89a16c444de0cf26da5c1a1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 13 Jun 2025 17:25:40 -0700 Subject: [PATCH 21/61] feat: Add import session API in the fast API PiperOrigin-RevId: 771279971 --- src/google/adk/cli/fast_api.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 875c3cb7b..4512174c5 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -412,12 +412,19 @@ async def create_session( app_name: str, user_id: str, state: Optional[dict[str, Any]] = None, + events: Optional[list[Event]] = None, ) -> Session: logger.info("New session created") - return await session_service.create_session( + session = await session_service.create_session( app_name=app_name, user_id=user_id, state=state ) + if events: + for event in events: + await session_service.append_event(session=session, event=event) + + return session + def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str: return os.path.join( agents_dir, From 89321062464fcd06a57588e019ec296ad769f17b Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 13 Jun 2025 21:16:45 -0700 Subject: [PATCH 22/61] chore: Raise error when using features decorated by working_in_progress decorator set environment variable ADK_ALLOW_WIP_FEATURES=true can bypass it. working_in_progress features are not working. ADK users are not supposed to set this environment variable. PiperOrigin-RevId: 771333335 --- src/google/adk/utils/feature_decorator.py | 55 +++++- .../unittests/utils/test_feature_decorator.py | 186 +++++++++++++++++- 2 files changed, 227 insertions(+), 14 deletions(-) diff --git a/src/google/adk/utils/feature_decorator.py b/src/google/adk/utils/feature_decorator.py index 301637319..e3be0a652 100644 --- a/src/google/adk/utils/feature_decorator.py +++ b/src/google/adk/utils/feature_decorator.py @@ -15,29 +15,52 @@ from __future__ import annotations import functools +import os from typing import Callable from typing import cast +from typing import Optional from typing import TypeVar from typing import Union import warnings +from dotenv import load_dotenv + T = TypeVar("T", bound=Union[Callable, type]) def _make_feature_decorator( - *, label: str, default_message: str + *, + label: str, + default_message: str, + block_usage: bool = False, + bypass_env_var: Optional[str] = None, ) -> Callable[[str], Callable[[T], T]]: def decorator_factory(message: str = default_message) -> Callable[[T], T]: def decorator(obj: T) -> T: obj_name = getattr(obj, "__name__", type(obj).__name__) - warn_msg = f"[{label.upper()}] {obj_name}: {message}" + msg = f"[{label.upper()}] {obj_name}: {message}" if isinstance(obj, type): # decorating a class orig_init = obj.__init__ @functools.wraps(orig_init) def new_init(self, *args, **kwargs): - warnings.warn(warn_msg, category=UserWarning, stacklevel=2) + # Load .env file if dotenv is available + load_dotenv() + + # Check if usage should be bypassed via environment variable at call time + should_bypass = ( + bypass_env_var is not None + and os.environ.get(bypass_env_var, "").lower() == "true" + ) + + if should_bypass: + # Bypass completely - no warning, no error + pass + elif block_usage: + raise RuntimeError(msg) + else: + warnings.warn(msg, category=UserWarning, stacklevel=2) return orig_init(self, *args, **kwargs) obj.__init__ = new_init # type: ignore[attr-defined] @@ -47,7 +70,22 @@ def new_init(self, *args, **kwargs): @functools.wraps(obj) def wrapper(*args, **kwargs): - warnings.warn(warn_msg, category=UserWarning, stacklevel=2) + # Load .env file if dotenv is available + load_dotenv() + + # Check if usage should be bypassed via environment variable at call time + should_bypass = ( + bypass_env_var is not None + and os.environ.get(bypass_env_var, "").lower() == "true" + ) + + if should_bypass: + # Bypass completely - no warning, no error + pass + elif block_usage: + raise RuntimeError(msg) + else: + warnings.warn(msg, category=UserWarning, stacklevel=2) return obj(*args, **kwargs) return cast(T, wrapper) @@ -65,11 +103,18 @@ def wrapper(*args, **kwargs): working_in_progress = _make_feature_decorator( label="WIP", default_message=( - "This feature is a work in progress and may be incomplete or unstable." + "This feature is a work in progress and is not working completely. ADK" + " users are not supposed to use it." ), + block_usage=True, + bypass_env_var="ADK_ALLOW_WIP_FEATURES", ) """Mark a class or function as a work in progress. +By default, decorated functions/classes will raise RuntimeError when used. +Set ADK_ALLOW_WIP_FEATURES=true environment variable to bypass this restriction. +ADK users are not supposed to set this environment variable. + Sample usage: ``` diff --git a/tests/unittests/utils/test_feature_decorator.py b/tests/unittests/utils/test_feature_decorator.py index aa03fc746..eb700ea6c 100644 --- a/tests/unittests/utils/test_feature_decorator.py +++ b/tests/unittests/utils/test_feature_decorator.py @@ -1,3 +1,5 @@ +import os +import tempfile import warnings from google.adk.utils.feature_decorator import experimental @@ -11,25 +13,176 @@ def run(self): return "running" +@working_in_progress("function not ready") +def wip_function(): + return "executing" + + @experimental("api may have breaking change in the future.") def experimental_fn(): return "executing" -def test_working_in_progress_class_warns(): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") +@experimental("class may change") +class ExperimentalClass: + + def run(self): + return "running experimental" + +def test_working_in_progress_class_raises_error(): + """Test that WIP class raises RuntimeError by default.""" + # Ensure environment variable is not set + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + try: feature = IncompleteFeature() + assert False, "Expected RuntimeError to be raised" + except RuntimeError as e: + assert "[WIP] IncompleteFeature:" in str(e) + assert "don't use yet" in str(e) - assert feature.run() == "running" - assert len(w) == 1 - assert issubclass(w[0].category, UserWarning) - assert "[WIP] IncompleteFeature:" in str(w[0].message) - assert "don't use yet" in str(w[0].message) +def test_working_in_progress_function_raises_error(): + """Test that WIP function raises RuntimeError by default.""" + # Ensure environment variable is not set + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + try: + result = wip_function() + assert False, "Expected RuntimeError to be raised" + except RuntimeError as e: + assert "[WIP] wip_function:" in str(e) + assert "function not ready" in str(e) + + +def test_working_in_progress_class_bypassed_with_env_var(): + """Test that WIP class works without warnings when env var is set.""" + # Set the bypass environment variable + os.environ["ADK_ALLOW_WIP_FEATURES"] = "true" + + try: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + feature = IncompleteFeature() + result = feature.run() + + assert result == "running" + # Should have no warnings when bypassed + assert len(w) == 0 + finally: + # Clean up environment variable + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_working_in_progress_function_bypassed_with_env_var(): + """Test that WIP function works without warnings when env var is set.""" + # Set the bypass environment variable + os.environ["ADK_ALLOW_WIP_FEATURES"] = "true" + + try: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + result = wip_function() + + assert result == "executing" + # Should have no warnings when bypassed + assert len(w) == 0 + finally: + # Clean up environment variable + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_working_in_progress_env_var_case_insensitive(): + """Test that WIP bypass works with different case values.""" + test_cases = ["true", "True", "TRUE", "tRuE"] + + for case in test_cases: + os.environ["ADK_ALLOW_WIP_FEATURES"] = case + + try: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") -def test_experimental_method_warns(): + result = wip_function() + + assert result == "executing" + assert len(w) == 0 + finally: + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_working_in_progress_env_var_false_values(): + """Test that WIP still raises errors with false-like env var values.""" + false_values = ["false", "False", "FALSE", "0", "", "anything_else"] + + for false_val in false_values: + os.environ["ADK_ALLOW_WIP_FEATURES"] = false_val + + try: + result = wip_function() + assert False, f"Expected RuntimeError with env var '{false_val}'" + except RuntimeError as e: + assert "[WIP] wip_function:" in str(e) + finally: + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_working_in_progress_loads_from_dotenv_file(): + """Test that WIP decorator can load environment variables from .env file.""" + # Skip test if dotenv is not available + try: + from dotenv import load_dotenv + except ImportError: + import pytest + + pytest.skip("python-dotenv not available") + + # Ensure environment variable is not set in os.environ + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + # Create a temporary .env file in current directory + dotenv_path = ".env.test" + + try: + # Write the env file + with open(dotenv_path, "w") as f: + f.write("ADK_ALLOW_WIP_FEATURES=true\n") + + # Load the environment variables from the file + load_dotenv(dotenv_path) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # This should work because the .env file contains ADK_ALLOW_WIP_FEATURES=true + result = wip_function() + + assert result == "executing" + # Should have no warnings when bypassed via .env file + assert len(w) == 0 + + finally: + # Clean up + try: + os.unlink(dotenv_path) + except FileNotFoundError: + pass + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_experimental_function_warns(): + """Test that experimental function shows warnings (unchanged behavior).""" with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -40,3 +193,18 @@ def test_experimental_method_warns(): assert issubclass(w[0].category, UserWarning) assert "[EXPERIMENTAL] experimental_fn:" in str(w[0].message) assert "breaking change in the future" in str(w[0].message) + + +def test_experimental_class_warns(): + """Test that experimental class shows warnings (unchanged behavior).""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + exp_class = ExperimentalClass() + result = exp_class.run() + + assert result == "running experimental" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] ExperimentalClass:" in str(w[0].message) + assert "class may change" in str(w[0].message) From b51a1f45fdca0b0a3cd250608afb42a7a30d430c Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 13 Jun 2025 23:17:38 -0700 Subject: [PATCH 23/61] chore: support @working_in_progress and @expremental without calling parameters PiperOrigin-RevId: 771357728 --- src/google/adk/utils/feature_decorator.py | 155 +++++++++++------- .../unittests/utils/test_feature_decorator.py | 91 ++++++++++ 2 files changed, 184 insertions(+), 62 deletions(-) diff --git a/src/google/adk/utils/feature_decorator.py b/src/google/adk/utils/feature_decorator.py index e3be0a652..d597063ae 100644 --- a/src/google/adk/utils/feature_decorator.py +++ b/src/google/adk/utils/feature_decorator.py @@ -34,70 +34,90 @@ def _make_feature_decorator( default_message: str, block_usage: bool = False, bypass_env_var: Optional[str] = None, -) -> Callable[[str], Callable[[T], T]]: - def decorator_factory(message: str = default_message) -> Callable[[T], T]: - def decorator(obj: T) -> T: - obj_name = getattr(obj, "__name__", type(obj).__name__) - msg = f"[{label.upper()}] {obj_name}: {message}" - - if isinstance(obj, type): # decorating a class - orig_init = obj.__init__ - - @functools.wraps(orig_init) - def new_init(self, *args, **kwargs): - # Load .env file if dotenv is available - load_dotenv() - - # Check if usage should be bypassed via environment variable at call time - should_bypass = ( - bypass_env_var is not None - and os.environ.get(bypass_env_var, "").lower() == "true" - ) - - if should_bypass: - # Bypass completely - no warning, no error - pass - elif block_usage: - raise RuntimeError(msg) - else: - warnings.warn(msg, category=UserWarning, stacklevel=2) - return orig_init(self, *args, **kwargs) - - obj.__init__ = new_init # type: ignore[attr-defined] - return cast(T, obj) - - elif callable(obj): # decorating a function or method - - @functools.wraps(obj) - def wrapper(*args, **kwargs): - # Load .env file if dotenv is available - load_dotenv() - - # Check if usage should be bypassed via environment variable at call time - should_bypass = ( - bypass_env_var is not None - and os.environ.get(bypass_env_var, "").lower() == "true" - ) - - if should_bypass: - # Bypass completely - no warning, no error - pass - elif block_usage: - raise RuntimeError(msg) - else: - warnings.warn(msg, category=UserWarning, stacklevel=2) - return obj(*args, **kwargs) - - return cast(T, wrapper) - - else: - raise TypeError( - f"@{label} can only be applied to classes or callable objects" +) -> Callable: + def decorator_factory(message_or_obj=None): + # Case 1: Used as @decorator without parentheses + # message_or_obj is the decorated class/function + if message_or_obj is not None and ( + isinstance(message_or_obj, type) or callable(message_or_obj) + ): + return _create_decorator( + default_message, label, block_usage, bypass_env_var + )(message_or_obj) + + # Case 2: Used as @decorator() with or without message + # message_or_obj is either None or a string message + message = ( + message_or_obj if isinstance(message_or_obj, str) else default_message + ) + return _create_decorator(message, label, block_usage, bypass_env_var) + + return decorator_factory + + +def _create_decorator( + message: str, label: str, block_usage: bool, bypass_env_var: Optional[str] +) -> Callable[[T], T]: + def decorator(obj: T) -> T: + obj_name = getattr(obj, "__name__", type(obj).__name__) + msg = f"[{label.upper()}] {obj_name}: {message}" + + if isinstance(obj, type): # decorating a class + orig_init = obj.__init__ + + @functools.wraps(orig_init) + def new_init(self, *args, **kwargs): + # Load .env file if dotenv is available + load_dotenv() + + # Check if usage should be bypassed via environment variable at call time + should_bypass = ( + bypass_env_var is not None + and os.environ.get(bypass_env_var, "").lower() == "true" ) - return decorator + if should_bypass: + # Bypass completely - no warning, no error + pass + elif block_usage: + raise RuntimeError(msg) + else: + warnings.warn(msg, category=UserWarning, stacklevel=2) + return orig_init(self, *args, **kwargs) + + obj.__init__ = new_init # type: ignore[attr-defined] + return cast(T, obj) + + elif callable(obj): # decorating a function or method + + @functools.wraps(obj) + def wrapper(*args, **kwargs): + # Load .env file if dotenv is available + load_dotenv() + + # Check if usage should be bypassed via environment variable at call time + should_bypass = ( + bypass_env_var is not None + and os.environ.get(bypass_env_var, "").lower() == "true" + ) - return decorator_factory + if should_bypass: + # Bypass completely - no warning, no error + pass + elif block_usage: + raise RuntimeError(msg) + else: + warnings.warn(msg, category=UserWarning, stacklevel=2) + return obj(*args, **kwargs) + + return cast(T, wrapper) + + else: + raise TypeError( + f"@{label} can only be applied to classes or callable objects" + ) + + return decorator working_in_progress = _make_feature_decorator( @@ -137,8 +157,19 @@ def my_wip_function(): Sample usage: ``` -@experimental("This API may have breaking change in the future.") +# Use with default message +@experimental class ExperimentalClass: pass + +# Use with custom message +@experimental("This API may have breaking change in the future.") +class CustomExperimentalClass: + pass + +# Use with empty parentheses (same as default message) +@experimental() +def experimental_function(): + pass ``` """ diff --git a/tests/unittests/utils/test_feature_decorator.py b/tests/unittests/utils/test_feature_decorator.py index eb700ea6c..e2f16446a 100644 --- a/tests/unittests/utils/test_feature_decorator.py +++ b/tests/unittests/utils/test_feature_decorator.py @@ -30,6 +30,31 @@ def run(self): return "running experimental" +# Test classes/functions for new usage patterns +@experimental +class ExperimentalClassNoParens: + + def run(self): + return "running experimental without parens" + + +@experimental() +class ExperimentalClassEmptyParens: + + def run(self): + return "running experimental with empty parens" + + +@experimental +def experimental_fn_no_parens(): + return "executing without parens" + + +@experimental() +def experimental_fn_empty_parens(): + return "executing with empty parens" + + def test_working_in_progress_class_raises_error(): """Test that WIP class raises RuntimeError by default.""" # Ensure environment variable is not set @@ -208,3 +233,69 @@ def test_experimental_class_warns(): assert issubclass(w[0].category, UserWarning) assert "[EXPERIMENTAL] ExperimentalClass:" in str(w[0].message) assert "class may change" in str(w[0].message) + + +def test_experimental_class_no_parens_warns(): + """Test that experimental class without parentheses shows default warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + exp_class = ExperimentalClassNoParens() + result = exp_class.run() + + assert result == "running experimental without parens" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] ExperimentalClassNoParens:" in str(w[0].message) + assert "This feature is experimental and may change or be removed" in str( + w[0].message + ) + + +def test_experimental_class_empty_parens_warns(): + """Test that experimental class with empty parentheses shows default warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + exp_class = ExperimentalClassEmptyParens() + result = exp_class.run() + + assert result == "running experimental with empty parens" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] ExperimentalClassEmptyParens:" in str(w[0].message) + assert "This feature is experimental and may change or be removed" in str( + w[0].message + ) + + +def test_experimental_function_no_parens_warns(): + """Test that experimental function without parentheses shows default warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + result = experimental_fn_no_parens() + + assert result == "executing without parens" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] experimental_fn_no_parens:" in str(w[0].message) + assert "This feature is experimental and may change or be removed" in str( + w[0].message + ) + + +def test_experimental_function_empty_parens_warns(): + """Test that experimental function with empty parentheses shows default warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + result = experimental_fn_empty_parens() + + assert result == "executing with empty parens" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] experimental_fn_empty_parens:" in str(w[0].message) + assert "This feature is experimental and may change or be removed" in str( + w[0].message + ) From 8ebf229c4787a5e3f559ea418a11f2ed407637e0 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 13 Jun 2025 23:23:45 -0700 Subject: [PATCH 24/61] chore: Add base credential service interface (WIP) PiperOrigin-RevId: 771358480 --- .../samples/oauth_calendar_agent/agent.py | 120 ++++++------------ src/google/adk/auth/auth_handler.py | 4 +- src/google/adk/auth/auth_tool.py | 24 +++- .../adk/auth/credential_service/__init__.py | 13 ++ .../base_credential_service.py | 75 +++++++++++ tests/unittests/auth/test_auth_config.py | 26 +++- tests/unittests/auth/test_auth_handler.py | 6 +- 7 files changed, 176 insertions(+), 92 deletions(-) create mode 100644 src/google/adk/auth/credential_service/__init__.py create mode 100644 src/google/adk/auth/credential_service/base_credential_service.py diff --git a/contributing/samples/oauth_calendar_agent/agent.py b/contributing/samples/oauth_calendar_agent/agent.py index a1b1dea87..9d56d3ff8 100644 --- a/contributing/samples/oauth_calendar_agent/agent.py +++ b/contributing/samples/oauth_calendar_agent/agent.py @@ -27,6 +27,8 @@ from google.adk.auth import AuthCredentialTypes from google.adk.auth import OAuth2Auth from google.adk.tools import ToolContext +from google.adk.tools.authenticated_tool.base_authenticated_tool import AuthenticatedFunctionTool +from google.adk.tools.authenticated_tool.credentials_store import ToolContextCredentialsStore from google.adk.tools.google_api_tool import CalendarToolset from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials @@ -56,6 +58,7 @@ def list_calendar_events( end_time: str, limit: int, tool_context: ToolContext, + credential: AuthCredential, ) -> list[dict]: """Search for calendar events. @@ -80,84 +83,11 @@ def list_calendar_events( Returns: list[dict]: A list of events that match the search criteria. """ - creds = None - - # Check if the tokes were already in the session state, which means the user - # has already gone through the OAuth flow and successfully authenticated and - # authorized the tool to access their calendar. - if "calendar_tool_tokens" in tool_context.state: - creds = Credentials.from_authorized_user_info( - tool_context.state["calendar_tool_tokens"], SCOPES - ) - if not creds or not creds.valid: - # If the access token is expired, refresh it with the refresh token. - if creds and creds.expired and creds.refresh_token: - creds.refresh(Request()) - else: - auth_scheme = OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://accounts.google.com/o/oauth2/auth", - tokenUrl="https://oauth2.googleapis.com/token", - scopes={ - "https://www.googleapis.com/auth/calendar": ( - "See, edit, share, and permanently delete all the" - " calendars you can access using Google Calendar" - ) - }, - ) - ) - ) - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id=oauth_client_id, client_secret=oauth_client_secret - ), - ) - # If the user has not gone through the OAuth flow before, or the refresh - # token also expired, we need to ask users to go through the OAuth flow. - # First we check whether the user has just gone through the OAuth flow and - # Oauth response is just passed back. - auth_response = tool_context.get_auth_response( - AuthConfig( - auth_scheme=auth_scheme, raw_auth_credential=auth_credential - ) - ) - if auth_response: - # ADK exchanged the access token already for us - access_token = auth_response.oauth2.access_token - refresh_token = auth_response.oauth2.refresh_token - - creds = Credentials( - token=access_token, - refresh_token=refresh_token, - token_uri=auth_scheme.flows.authorizationCode.tokenUrl, - client_id=oauth_client_id, - client_secret=oauth_client_secret, - scopes=list(auth_scheme.flows.authorizationCode.scopes.keys()), - ) - else: - # If there are no auth response which means the user has not gone - # through the OAuth flow yet, we need to ask users to go through the - # OAuth flow. - tool_context.request_credential( - AuthConfig( - auth_scheme=auth_scheme, - raw_auth_credential=auth_credential, - ) - ) - # The return value is optional and could be any dict object. It will be - # wrapped in a dict with key as 'result' and value as the return value - # if the object returned is not a dict. This response will be passed - # to LLM to generate a user friendly message. e.g. LLM will tell user: - # "I need your authorization to access your calendar. Please authorize - # me so I can check your meetings for today." - return "Need User Authorization to access their calendar." - # We store the access token and refresh token in the session state for the - # next runs. This is just an example. On production, a tool should store - # those credentials in some secure store or properly encrypt it before store - # it in the session state. - tool_context.state["calendar_tool_tokens"] = json.loads(creds.to_json()) + + creds = Credentials( + token=credential.oauth2.access_token, + refresh_token=credential.oauth2.refresh_token, + ) service = build("calendar", "v3", credentials=creds) events_result = ( @@ -208,6 +138,38 @@ def update_time(callback_context: CallbackContext): Currnet time: {_time} """, - tools=[list_calendar_events, calendar_toolset], + tools=[ + AuthenticatedFunctionTool( + func=list_calendar_events, + auth_config=AuthConfig( + auth_scheme=OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl=( + "https://accounts.google.com/o/oauth2/auth" + ), + tokenUrl="https://oauth2.googleapis.com/token", + scopes={ + "https://www.googleapis.com/auth/calendar": ( + "See, edit, share, and permanently delete" + " all the calendars you can access using" + " Google Calendar" + ) + }, + ) + ) + ), + raw_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=oauth_client_id, + client_secret=oauth_client_secret, + ), + ), + ), + credential_store=ToolContextCredentialsStore(), + ), + calendar_toolset, + ], before_agent_callback=update_time, ) diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 5f80ee3f1..3e13cbac2 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -49,7 +49,7 @@ def exchange_auth_token( def parse_and_store_auth_response(self, state: State) -> None: - credential_key = "temp:" + self.auth_config.get_credential_key() + credential_key = "temp:" + self.auth_config.credential_key state[credential_key] = self.auth_config.exchanged_auth_credential if not isinstance( @@ -67,7 +67,7 @@ def _validate(self) -> None: raise ValueError("auth_scheme is empty.") def get_auth_response(self, state: State) -> AuthCredential: - credential_key = "temp:" + self.auth_config.get_credential_key() + credential_key = "temp:" + self.auth_config.credential_key return state.get(credential_key, None) def generate_auth_request(self) -> AuthConfig: diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py index a1a19ab87..53c571d42 100644 --- a/src/google/adk/auth/auth_tool.py +++ b/src/google/adk/auth/auth_tool.py @@ -14,6 +14,10 @@ from __future__ import annotations +from typing import Optional + +from typing_extensions import deprecated + from .auth_credential import AuthCredential from .auth_credential import BaseModelWithConfig from .auth_schemes import AuthScheme @@ -45,11 +49,23 @@ class AuthConfig(BaseModelWithConfig): this field to guide the user through the OAuth2 flow and fill auth response in this field""" + credential_key: Optional[str] = None + """A user specified key used to load and save this credential in a credential + service. + """ + + def __init__(self, **data): + super().__init__(**data) + if self.credential_key: + return + self.credential_key = self.get_credential_key() + + @deprecated("This method is deprecated. Use credential_key instead.") def get_credential_key(self): - """Generates a hash key based on auth_scheme and raw_auth_credential. This - hash key can be used to store / retrieve exchanged_auth_credential in a - credentials store. + """Builds a hash key based on auth_scheme and raw_auth_credential used to + save / load this credential to / from a credentials service. """ + auth_scheme = self.auth_scheme if auth_scheme.model_extra: @@ -62,7 +78,7 @@ def get_credential_key(self): ) auth_credential = self.raw_auth_credential - if auth_credential.model_extra: + if auth_credential and auth_credential.model_extra: auth_credential = auth_credential.model_copy(deep=True) auth_credential.model_extra.clear() credential_name = ( diff --git a/src/google/adk/auth/credential_service/__init__.py b/src/google/adk/auth/credential_service/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/auth/credential_service/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/auth/credential_service/base_credential_service.py b/src/google/adk/auth/credential_service/base_credential_service.py new file mode 100644 index 000000000..7416ccc65 --- /dev/null +++ b/src/google/adk/auth/credential_service/base_credential_service.py @@ -0,0 +1,75 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +from typing import Optional + +from ...tools.tool_context import ToolContext +from ...utils.feature_decorator import working_in_progress +from ..auth_credential import AuthCredential +from ..auth_tool import AuthConfig + + +@working_in_progress("Implementation are in progress. Don't use it for now.") +class BaseCredentialService(ABC): + """Abstract class for Service that loads / saves tool credentials from / to + the backend credential store.""" + + @abstractmethod + async def load_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> Optional[AuthCredential]: + """ + Loads the credential by auth config and current tool context from the + backend credential store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to load the credential. + + tool_context: The context of the current invocation when the tool is + trying to load the credential. + + Returns: + Optional[AuthCredential]: the credential saved in the store. + + """ + + @abstractmethod + async def save_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> None: + """ + Saves the exchanged_auth_credential in auth config to the backend credential + store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to save the credential. + + tool_context: The context of the current invocation when the tool is + trying to save the credential. + + Returns: + None + """ diff --git a/tests/unittests/auth/test_auth_config.py b/tests/unittests/auth/test_auth_config.py index 7a20cfc63..a398ef321 100644 --- a/tests/unittests/auth/test_auth_config.py +++ b/tests/unittests/auth/test_auth_config.py @@ -68,10 +68,28 @@ def auth_config(oauth2_auth_scheme, oauth2_credentials): ) -def test_get_credential_key(auth_config): +@pytest.fixture +def auth_config_with_key(oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + credential_key="test_key", + ) + + +def test_custom_credential_key(auth_config_with_key): + """Test using custom credential key.""" + + key = auth_config_with_key.credential_key + assert key == "test_key" + + +def test_credential_key(auth_config): """Test generating a unique credential key.""" - key = auth_config.get_credential_key() + key = auth_config.credential_key assert key.startswith("adk_oauth2_") assert "_oauth2_" in key @@ -80,8 +98,8 @@ def test_get_credential_key_with_extras(auth_config): """Test generating a key when model_extra exists.""" # Add model_extra to test cleanup - original_key = auth_config.get_credential_key() - key = auth_config.get_credential_key() + original_key = auth_config.credential_key + key = auth_config.credential_key auth_config.auth_scheme.model_extra["extra_field"] = "value" auth_config.raw_auth_credential.model_extra["extra_field"] = "value" diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index 402e4d0cd..aaed35a19 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -387,7 +387,7 @@ def test_get_auth_response_exists( state = MockState() # Store a credential in the state - credential_key = auth_config.get_credential_key() + credential_key = auth_config.credential_key state["temp:" + credential_key] = oauth2_credentials_with_auth_uri result = handler.get_auth_response(state) @@ -418,7 +418,7 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged): handler.parse_and_store_auth_response(state) - credential_key = auth_config.get_credential_key() + credential_key = auth_config.credential_key assert ( state["temp:" + credential_key] == auth_config.exchanged_auth_credential ) @@ -436,7 +436,7 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): handler.parse_and_store_auth_response(state) - credential_key = auth_config_with_exchanged.get_credential_key() + credential_key = auth_config_with_exchanged.credential_key assert state["temp:" + credential_key] == mock_exchange_token.return_value assert mock_exchange_token.called From a19d617ed8cf1c505fe6a86cca2a33904fefde0a Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 13 Jun 2025 23:31:54 -0700 Subject: [PATCH 25/61] chore: Add experimental decorator to BigQuery tools PiperOrigin-RevId: 771359886 --- src/google/adk/tools/bigquery/bigquery_credentials.py | 2 ++ src/google/adk/tools/bigquery/bigquery_tool.py | 3 +++ src/google/adk/tools/bigquery/bigquery_toolset.py | 2 ++ 3 files changed, 7 insertions(+) diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index 27c3a786c..0a99136c4 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -31,12 +31,14 @@ from ...auth.auth_credential import AuthCredentialTypes from ...auth.auth_credential import OAuth2Auth from ...auth.auth_tool import AuthConfig +from ...utils.feature_decorator import experimental from ..tool_context import ToolContext BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] +@experimental class BigQueryCredentialsConfig(BaseModel): """Configuration for Google API tools. (Experimental)""" diff --git a/src/google/adk/tools/bigquery/bigquery_tool.py b/src/google/adk/tools/bigquery/bigquery_tool.py index 553012208..182734188 100644 --- a/src/google/adk/tools/bigquery/bigquery_tool.py +++ b/src/google/adk/tools/bigquery/bigquery_tool.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import inspect from typing import Any @@ -21,6 +22,7 @@ from google.oauth2.credentials import Credentials from typing_extensions import override +from ...utils.feature_decorator import experimental from ..function_tool import FunctionTool from ..tool_context import ToolContext from .bigquery_credentials import BigQueryCredentialsConfig @@ -28,6 +30,7 @@ from .config import BigQueryToolConfig +@experimental class BigQueryTool(FunctionTool): """GoogleApiTool class for tools that call Google APIs. diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 5543d103f..313cf4990 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -26,11 +26,13 @@ from ...tools.base_tool import BaseTool from ...tools.base_toolset import BaseToolset from ...tools.base_toolset import ToolPredicate +from ...utils.feature_decorator import experimental from .bigquery_credentials import BigQueryCredentialsConfig from .bigquery_tool import BigQueryTool from .config import BigQueryToolConfig +@experimental class BigQueryToolset(BaseToolset): """BigQuery Toolset contains tools for interacting with BigQuery data and metadata.""" From d1bda9d946581461df6065b620929a1588b1f64b Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Sat, 14 Jun 2025 12:55:27 -0700 Subject: [PATCH 26/61] chore: Allow working_in_progress feature for unittests PiperOrigin-RevId: 771500394 --- tests/unittests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index ad204005e..2b93226db 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -23,6 +23,7 @@ 'GOOGLE_API_KEY': 'fake_google_api_key', 'GOOGLE_CLOUD_PROJECT': 'fake_google_cloud_project', 'GOOGLE_CLOUD_LOCATION': 'fake_google_cloud_location', + 'ADK_ALLOW_WIP_FEATURES': 'true', } ENV_SETUPS = { From a4d432a9e62a5e759e94168463795cd0e7d7ab72 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Sat, 14 Jun 2025 13:39:14 -0700 Subject: [PATCH 27/61] chore: Add Service Account Credential Exchanger (Experimental) PiperOrigin-RevId: 771507089 --- .../service_account_credential_exchanger.py | 92 +++++ ...st_service_account_credential_exchanger.py | 341 ++++++++++++++++++ 2 files changed, 433 insertions(+) create mode 100644 src/google/adk/auth/service_account_credential_exchanger.py create mode 100644 tests/unittests/auth/test_service_account_credential_exchanger.py diff --git a/src/google/adk/auth/service_account_credential_exchanger.py b/src/google/adk/auth/service_account_credential_exchanger.py new file mode 100644 index 000000000..644501ee6 --- /dev/null +++ b/src/google/adk/auth/service_account_credential_exchanger.py @@ -0,0 +1,92 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential fetcher for Google Service Account.""" + +from __future__ import annotations + +import google.auth +from google.auth.transport.requests import Request +from google.oauth2 import service_account + +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_credential import AuthCredentialTypes +from .auth_credential import HttpAuth +from .auth_credential import HttpCredentials + + +@experimental +class ServiceAccountCredentialExchanger: + """Exchanges Google Service Account credentials for an access token. + + Uses the default service credential if `use_default_credential = True`. + Otherwise, uses the service account credential provided in the auth + credential. + """ + + def __init__(self, credential: AuthCredential): + if credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: + raise ValueError("Credential is not a service account credential.") + self._credential = credential + + def exchange(self) -> AuthCredential: + """Exchanges the service account auth credential for an access token. + + If the AuthCredential contains a service account credential, it will be used + to exchange for an access token. Otherwise, if use_default_credential is True, + the default application credential will be used for exchanging an access token. + + Returns: + An AuthCredential in HTTP Bearer format, containing the access token. + + Raises: + ValueError: If service account credentials are missing or invalid. + Exception: If credential exchange or refresh fails. + """ + if ( + self._credential is None + or self._credential.service_account is None + or ( + self._credential.service_account.service_account_credential is None + and not self._credential.service_account.use_default_credential + ) + ): + raise ValueError( + "Service account credentials are missing. Please provide them, or set" + " `use_default_credential = True` to use application default" + " credential in a hosted service like Google Cloud Run." + ) + + try: + if self._credential.service_account.use_default_credential: + credentials, _ = google.auth.default() + else: + config = self._credential.service_account + credentials = service_account.Credentials.from_service_account_info( + config.service_account_credential.model_dump(), scopes=config.scopes + ) + + # Refresh credentials to ensure we have a valid access token + credentials.refresh(Request()) + + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token=credentials.token), + ), + ) + except Exception as e: + raise ValueError(f"Failed to exchange service account token: {e}") from e diff --git a/tests/unittests/auth/test_service_account_credential_exchanger.py b/tests/unittests/auth/test_service_account_credential_exchanger.py new file mode 100644 index 000000000..a5c668436 --- /dev/null +++ b/tests/unittests/auth/test_service_account_credential_exchanger.py @@ -0,0 +1,341 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the ServiceAccountCredentialExchanger.""" + +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_credential import ServiceAccountCredential +from google.adk.auth.service_account_credential_exchanger import ServiceAccountCredentialExchanger +import pytest + + +class TestServiceAccountCredentialExchanger: + """Test cases for ServiceAccountCredentialExchanger.""" + + def test_init_valid_credential(self): + """Test successful initialization with valid service account credential.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=ServiceAccountCredential( + type_="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE" + " KEY-----" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="12345", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + universe_domain="googleapis.com", + ), + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + assert exchanger._credential == credential + + def test_init_invalid_credential_type(self): + """Test initialization with invalid credential type raises ValueError.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key="test-key", + ) + + with pytest.raises( + ValueError, match="Credential is not a service account credential" + ): + ServiceAccountCredentialExchanger(credential) + + @patch( + "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) + @patch("google.adk.auth.service_account_credential_exchanger.Request") + def test_exchange_with_explicit_credentials_success( + self, mock_request_class, mock_from_service_account_info + ): + """Test successful exchange with explicit service account credentials.""" + # Setup mocks + mock_request = MagicMock() + mock_request_class.return_value = mock_request + + mock_credentials = MagicMock() + mock_credentials.token = "mock_access_token" + mock_from_service_account_info.return_value = mock_credentials + + # Create test credential + service_account_cred = ServiceAccountCredential( + type_="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE KEY-----" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="12345", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + universe_domain="googleapis.com", + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=service_account_cred, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + result = exchanger.exchange() + + # Verify the result + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_access_token" + + # Verify mocks were called correctly + mock_from_service_account_info.assert_called_once_with( + service_account_cred.model_dump(), + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + mock_credentials.refresh.assert_called_once_with(mock_request) + + @patch( + "google.adk.auth.service_account_credential_exchanger.google.auth.default" + ) + @patch("google.adk.auth.service_account_credential_exchanger.Request") + def test_exchange_with_default_credentials_success( + self, mock_request_class, mock_google_auth_default + ): + """Test successful exchange with default application credentials.""" + # Setup mocks + mock_request = MagicMock() + mock_request_class.return_value = mock_request + + mock_credentials = MagicMock() + mock_credentials.token = "default_access_token" + mock_google_auth_default.return_value = (mock_credentials, "test-project") + + # Create test credential with use_default_credential=True + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + result = exchanger.exchange() + + # Verify the result + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "default_access_token" + + # Verify mocks were called correctly + mock_google_auth_default.assert_called_once() + mock_credentials.refresh.assert_called_once_with(mock_request) + + def test_exchange_missing_service_account(self): + """Test exchange fails when service_account is None.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=None, + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Service account credentials are missing" + ): + exchanger.exchange() + + def test_exchange_missing_credentials_and_not_default(self): + """Test exchange fails when credentials are missing and use_default_credential is False.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=None, + use_default_credential=False, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Service account credentials are missing" + ): + exchanger.exchange() + + @patch( + "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) + def test_exchange_credential_creation_failure( + self, mock_from_service_account_info + ): + """Test exchange handles credential creation failure gracefully.""" + # Setup mock to raise exception + mock_from_service_account_info.side_effect = Exception( + "Invalid private key" + ) + + # Create test credential + service_account_cred = ServiceAccountCredential( + type_="service_account", + project_id="test-project", + private_key_id="key-id", + private_key="invalid-key", + client_email="test@test-project.iam.gserviceaccount.com", + client_id="12345", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + universe_domain="googleapis.com", + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=service_account_cred, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Failed to exchange service account token" + ): + exchanger.exchange() + + @patch( + "google.adk.auth.service_account_credential_exchanger.google.auth.default" + ) + def test_exchange_default_credential_failure(self, mock_google_auth_default): + """Test exchange handles default credential failure gracefully.""" + # Setup mock to raise exception + mock_google_auth_default.side_effect = Exception( + "No default credentials found" + ) + + # Create test credential with use_default_credential=True + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Failed to exchange service account token" + ): + exchanger.exchange() + + @patch( + "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) + @patch("google.adk.auth.service_account_credential_exchanger.Request") + def test_exchange_refresh_failure( + self, mock_request_class, mock_from_service_account_info + ): + """Test exchange handles credential refresh failure gracefully.""" + # Setup mocks + mock_request = MagicMock() + mock_request_class.return_value = mock_request + + mock_credentials = MagicMock() + mock_credentials.refresh.side_effect = Exception( + "Network error during refresh" + ) + mock_from_service_account_info.return_value = mock_credentials + + # Create test credential + service_account_cred = ServiceAccountCredential( + type_="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE KEY-----" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="12345", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + universe_domain="googleapis.com", + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=service_account_cred, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + + with pytest.raises( + ValueError, match="Failed to exchange service account token" + ): + exchanger.exchange() + + def test_exchange_none_credential_in_constructor(self): + """Test that passing None credential raises appropriate error during construction.""" + # This test verifies behavior when _credential is None, though this shouldn't + # happen in normal usage due to constructor validation + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + exchanger = ServiceAccountCredentialExchanger(credential) + # Manually set to None to test the validation logic + exchanger._credential = None + + with pytest.raises( + ValueError, match="Service account credentials are missing" + ): + exchanger.exchange() From 675faefc670b5cd41991939fe0fc604df331111a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Jun 2025 08:34:53 -0700 Subject: [PATCH 28/61] feat: Allow data_store_specs pass into ADK VAIS built-in tool PiperOrigin-RevId: 772039465 --- src/google/adk/tools/vertex_ai_search_tool.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index 5449f5090..b00cd0329 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -39,6 +39,9 @@ def __init__( self, *, data_store_id: Optional[str] = None, + data_store_specs: Optional[ + list[types.VertexAISearchDataStoreSpec] + ] = None, search_engine_id: Optional[str] = None, filter: Optional[str] = None, max_results: Optional[int] = None, @@ -49,6 +52,8 @@ def __init__( data_store_id: The Vertex AI search data store resource ID in the format of "projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}". + data_store_specs: Specifications that define the specific DataStores to be + searched. It should only be set if engine is used. search_engine_id: The Vertex AI search engine resource ID in the format of "projects/{project}/locations/{location}/collections/{collection}/engines/{engine}". @@ -64,7 +69,12 @@ def __init__( raise ValueError( 'Either data_store_id or search_engine_id must be specified.' ) + if data_store_specs is not None and search_engine_id is None: + raise ValueError( + 'search_engine_id must be specified if data_store_specs is specified.' + ) self.data_store_id = data_store_id + self.data_store_specs = data_store_specs self.search_engine_id = search_engine_id self.filter = filter self.max_results = max_results @@ -89,6 +99,7 @@ async def process_llm_request( retrieval=types.Retrieval( vertex_ai_search=types.VertexAISearch( datastore=self.data_store_id, + data_store_specs=self.data_store_specs, engine=self.search_engine_id, filter=self.filter, max_results=self.max_results, From badbcbd7a464e6b323cf3164d2bcd4e27cbc057f Mon Sep 17 00:00:00 2001 From: SimonWei <119845914+simonwei97@users.noreply.github.com> Date: Tue, 17 Jun 2025 00:40:41 +0800 Subject: [PATCH 29/61] fix: agent generate config err (#1305) * fix: agent generate config err * fix: resovle comment --------- Co-authored-by: Hangfei Lin Co-authored-by: genquan9 <49327371+genquan9@users.noreply.github.com> --- src/google/adk/models/lite_llm.py | 65 +++++++++++++++++++++----- tests/unittests/models/test_litellm.py | 33 ++++++++++++- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index ed54faecf..c954711ad 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,6 +23,7 @@ from typing import Dict from typing import Generator from typing import Iterable +from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -481,16 +482,22 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> tuple[Iterable[Message], Iterable[dict]]: - """Converts an LlmRequest to litellm inputs. +) -> Tuple[ + List[Message], + Optional[List[Dict]], + Optional[types.SchemaUnion], + Optional[Dict], +]: + """Converts an LlmRequest to litellm inputs and extracts generation params. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary and response format). + The litellm inputs (message list, tool dictionary, response format, and generation params). """ - messages = [] + # 1. Construct messages + messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -507,7 +514,8 @@ def _get_completion_inputs( ), ) - tools = None + # 2. Convert tool declarations + tools: Optional[List[Dict]] = None if ( llm_request.config and llm_request.config.tools @@ -518,12 +526,39 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - response_format = None + # 3. Handle response format + response_format: Optional[types.SchemaUnion] = ( + llm_request.config.response_schema if llm_request.config else None + ) + + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: + config_dict = llm_request.config.model_dump(exclude_none=True) + # Generate LiteLlm parameters here, + # Following https://docs.litellm.ai/docs/completion/input. + generation_params = {} + param_mapping = { + "max_output_tokens": "max_completion_tokens", + "stop_sequences": "stop", + } + for key in ( + "temperature", + "max_output_tokens", + "top_p", + "top_k", + "stop_sequences", + "presence_penalty", + "frequency_penalty", + ): + if key in config_dict: + mapped_key = param_mapping.get(key, key) + generation_params[mapped_key] = config_dict[key] - if llm_request.config.response_schema: - response_format = llm_request.config.response_schema + if not generation_params: + generation_params = None - return messages, tools, response_format + return messages, tools, response_format, generation_params def _build_function_declaration_log( @@ -660,7 +695,9 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format = _get_completion_inputs(llm_request) + messages, tools, response_format, generation_params = ( + _get_completion_inputs(llm_request) + ) completion_args = { "model": self.model, @@ -668,7 +705,13 @@ async def generate_content_async( "tools": tools, "response_format": response_format, } - completion_args.update(self._additional_args) + + # Merge additional arguments and generation parameters safely + if hasattr(self, "_additional_args") and self._additional_args: + completion_args.update(self._additional_args) + + if generation_params: + completion_args.update(generation_params) if stream: text = "" diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index f316e83ae..e600ee7f0 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,7 +13,6 @@ # limitations under the License. -import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -1430,3 +1429,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +def test_get_completion_inputs_generation_params(): + # Test that generation_params are extracted and mapped correctly + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]), + ], + config=types.GenerateContentConfig( + temperature=0.33, + max_output_tokens=123, + top_p=0.88, + top_k=7, + stop_sequences=["foo", "bar"], + presence_penalty=0.1, + frequency_penalty=0.2, + ), + ) + from google.adk.models.lite_llm import _get_completion_inputs + + _, _, _, generation_params = _get_completion_inputs(req) + assert generation_params["temperature"] == 0.33 + assert generation_params["max_completion_tokens"] == 123 + assert generation_params["top_p"] == 0.88 + assert generation_params["top_k"] == 7 + assert generation_params["stop"] == ["foo", "bar"] + assert generation_params["presence_penalty"] == 0.1 + assert generation_params["frequency_penalty"] == 0.2 + # Should not include max_output_tokens + assert "max_output_tokens" not in generation_params + assert "stop_sequences" not in generation_params From 1cfc555e70a45bdd23d0741176b3f17300f4b4ab Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Mon, 16 Jun 2025 10:19:37 -0700 Subject: [PATCH 30/61] ADK changes PiperOrigin-RevId: 772078053 --- src/google/adk/models/lite_llm.py | 65 +++++--------------------- tests/unittests/models/test_litellm.py | 33 +------------ 2 files changed, 12 insertions(+), 86 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index c954711ad..ed54faecf 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,7 +23,6 @@ from typing import Dict from typing import Generator from typing import Iterable -from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -482,22 +481,16 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> Tuple[ - List[Message], - Optional[List[Dict]], - Optional[types.SchemaUnion], - Optional[Dict], -]: - """Converts an LlmRequest to litellm inputs and extracts generation params. +) -> tuple[Iterable[Message], Iterable[dict]]: + """Converts an LlmRequest to litellm inputs. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary, response format, and generation params). + The litellm inputs (message list, tool dictionary and response format). """ - # 1. Construct messages - messages: List[Message] = [] + messages = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -514,8 +507,7 @@ def _get_completion_inputs( ), ) - # 2. Convert tool declarations - tools: Optional[List[Dict]] = None + tools = None if ( llm_request.config and llm_request.config.tools @@ -526,39 +518,12 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - # 3. Handle response format - response_format: Optional[types.SchemaUnion] = ( - llm_request.config.response_schema if llm_request.config else None - ) - - # 4. Extract generation parameters - generation_params: Optional[Dict] = None - if llm_request.config: - config_dict = llm_request.config.model_dump(exclude_none=True) - # Generate LiteLlm parameters here, - # Following https://docs.litellm.ai/docs/completion/input. - generation_params = {} - param_mapping = { - "max_output_tokens": "max_completion_tokens", - "stop_sequences": "stop", - } - for key in ( - "temperature", - "max_output_tokens", - "top_p", - "top_k", - "stop_sequences", - "presence_penalty", - "frequency_penalty", - ): - if key in config_dict: - mapped_key = param_mapping.get(key, key) - generation_params[mapped_key] = config_dict[key] + response_format = None - if not generation_params: - generation_params = None + if llm_request.config.response_schema: + response_format = llm_request.config.response_schema - return messages, tools, response_format, generation_params + return messages, tools, response_format def _build_function_declaration_log( @@ -695,9 +660,7 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format, generation_params = ( - _get_completion_inputs(llm_request) - ) + messages, tools, response_format = _get_completion_inputs(llm_request) completion_args = { "model": self.model, @@ -705,13 +668,7 @@ async def generate_content_async( "tools": tools, "response_format": response_format, } - - # Merge additional arguments and generation parameters safely - if hasattr(self, "_additional_args") and self._additional_args: - completion_args.update(self._additional_args) - - if generation_params: - completion_args.update(generation_params) + completion_args.update(self._additional_args) if stream: text = "" diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index e600ee7f0..f316e83ae 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,6 +13,7 @@ # limitations under the License. +import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -1429,35 +1430,3 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} - - -@pytest.mark.asyncio -def test_get_completion_inputs_generation_params(): - # Test that generation_params are extracted and mapped correctly - req = LlmRequest( - contents=[ - types.Content(role="user", parts=[types.Part.from_text(text="hi")]), - ], - config=types.GenerateContentConfig( - temperature=0.33, - max_output_tokens=123, - top_p=0.88, - top_k=7, - stop_sequences=["foo", "bar"], - presence_penalty=0.1, - frequency_penalty=0.2, - ), - ) - from google.adk.models.lite_llm import _get_completion_inputs - - _, _, _, generation_params = _get_completion_inputs(req) - assert generation_params["temperature"] == 0.33 - assert generation_params["max_completion_tokens"] == 123 - assert generation_params["top_p"] == 0.88 - assert generation_params["top_k"] == 7 - assert generation_params["stop"] == ["foo", "bar"] - assert generation_params["presence_penalty"] == 0.1 - assert generation_params["frequency_penalty"] == 0.2 - # Should not include max_output_tokens - assert "max_output_tokens" not in generation_params - assert "stop_sequences" not in generation_params From 8201f9aebd62ab4cf1ab36e08e475c9aba3ffb57 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Mon, 16 Jun 2025 12:06:30 -0700 Subject: [PATCH 31/61] chore: Added live-streaming sample agent Also added a readme. PiperOrigin-RevId: 772120698 --- .../live_bidi_streaming_agent/__init__.py | 15 +++ .../live_bidi_streaming_agent/agent.py | 104 ++++++++++++++++++ .../live_bidi_streaming_agent/readme.md | 37 +++++++ 3 files changed, 156 insertions(+) create mode 100755 contributing/samples/live_bidi_streaming_agent/__init__.py create mode 100755 contributing/samples/live_bidi_streaming_agent/agent.py create mode 100644 contributing/samples/live_bidi_streaming_agent/readme.md diff --git a/contributing/samples/live_bidi_streaming_agent/__init__.py b/contributing/samples/live_bidi_streaming_agent/__init__.py new file mode 100755 index 000000000..c48963cdc --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/live_bidi_streaming_agent/agent.py b/contributing/samples/live_bidi_streaming_agent/agent.py new file mode 100755 index 000000000..2896bd70f --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/agent.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk import Agent +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +def roll_die(sides: int, tool_context: ToolContext) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if not 'rolls' in tool_context.state: + tool_context.state['rolls'] = [] + + tool_context.state['rolls'] = tool_context.state['rolls'] + [result] + return result + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + 'No prime numbers found.' + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model='gemini-2.0-flash-live-preview-04-09', # for Vertex project + # model='gemini-2.0-flash-live-001', # for AI studio key + name='hello_world_agent', + description=( + 'hello world agent that can roll a dice of 8 sides and check prime' + ' numbers.' + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) diff --git a/contributing/samples/live_bidi_streaming_agent/readme.md b/contributing/samples/live_bidi_streaming_agent/readme.md new file mode 100644 index 000000000..6a9258f3e --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/readme.md @@ -0,0 +1,37 @@ +# Simplistic Live (Bidi-Streaming) Agent +This project provides a basic example of a live, bidirectional streaming agent +designed for testing and experimentation. + +You can see full documentation [here](https://google.github.io/adk-docs/streaming/). + +## Getting Started + +Follow these steps to get the agent up and running: + +1. **Start the ADK Web Server** + Open your terminal, navigate to the root directory that contains the + `live_bidi_streaming_agent` folder, and execute the following command: + ```bash + adk web + ``` + +2. **Access the ADK Web UI** + Once the server is running, open your web browser and navigate to the URL + provided in the terminal (it will typically be `http://localhost:8000`). + +3. **Select the Agent** + In the top-left corner of the ADK Web UI, use the dropdown menu to select + this agent. + +4. **Start Streaming** + Click on either the **Audio** or **Video** icon located near the chat input + box to begin the streaming session. + +5. **Interact with the Agent** + You can now begin talking to the agent, and it will respond in real-time. + +## Usage Notes + +* You only need to click the **Audio** or **Video** button once to initiate the + stream. The current version does not support stopping and restarting the stream + by clicking the button again during a session. From fe1d5aa439cc56b89d248a52556c0a9b4cbd15e4 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Mon, 16 Jun 2025 14:35:16 -0700 Subject: [PATCH 32/61] feat: add enable_affective_dialog and proactivity to run_config and llm_request PiperOrigin-RevId: 772175206 --- src/google/adk/agents/run_config.py | 6 ++++++ src/google/adk/flows/llm_flows/basic.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index 5679f04e9..c9a50a0ae 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -73,6 +73,12 @@ class RunConfig(BaseModel): realtime_input_config: Optional[types.RealtimeInputConfig] = None """Realtime input config for live agents with audio input from user.""" + enable_affective_dialog: Optional[bool] = None + """If enabled, the model will detect emotions and adapt its responses accordingly.""" + + proactivity: Optional[types.ProactivityConfig] = None + """Configures the proactivity of the model. This allows the model to respond proactively to the input and to ignore irrelevant input.""" + max_llm_calls: int = 500 """ A limit on the total number of llm calls for a given run. diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 7efadd97e..ee5c83da1 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -68,6 +68,12 @@ async def run_async( llm_request.live_connect_config.realtime_input_config = ( invocation_context.run_config.realtime_input_config ) + llm_request.live_connect_config.enable_affective_dialog = ( + invocation_context.run_config.enable_affective_dialog + ) + llm_request.live_connect_config.proactivity = ( + invocation_context.run_config.proactivity + ) # TODO: handle tool append here, instead of in BaseTool.process_llm_request. From fef87784297b806914de307f48c51d83f977298f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Jun 2025 16:19:47 -0700 Subject: [PATCH 33/61] fix: liteLLM test failures Fix liteLLM test failures for function call responses. PiperOrigin-RevId: 772212629 --- tests/unittests/models/test_litellm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index f316e83ae..8b43cc48b 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1194,11 +1194,11 @@ async def test_generate_content_async_stream( assert responses[2].content.role == "model" assert responses[2].content.parts[0].text == "two:" assert responses[3].content.role == "model" - assert responses[3].content.parts[0].function_call.name == "test_function" - assert responses[3].content.parts[0].function_call.args == { + assert responses[3].content.parts[-1].function_call.name == "test_function" + assert responses[3].content.parts[-1].function_call.args == { "test_arg": "test_value" } - assert responses[3].content.parts[0].function_call.id == "test_tool_call_id" + assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" mock_completion.assert_called_once() _, kwargs = mock_completion.call_args @@ -1257,11 +1257,11 @@ async def test_generate_content_async_stream_with_usage_metadata( assert responses[2].content.role == "model" assert responses[2].content.parts[0].text == "two:" assert responses[3].content.role == "model" - assert responses[3].content.parts[0].function_call.name == "test_function" - assert responses[3].content.parts[0].function_call.args == { + assert responses[3].content.parts[-1].function_call.name == "test_function" + assert responses[3].content.parts[-1].function_call.args == { "test_arg": "test_value" } - assert responses[3].content.parts[0].function_call.id == "test_tool_call_id" + assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" assert responses[3].usage_metadata.prompt_token_count == 10 assert responses[3].usage_metadata.candidates_token_count == 5 From 31b81a342d3438b1efb7557e362b9288810033d5 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 16:21:32 -0700 Subject: [PATCH 34/61] chore: Update streamable http mcp example agent PiperOrigin-RevId: 772213323 --- contributing/samples/mcp_streamablehttp_agent/README.md | 3 +-- contributing/samples/mcp_streamablehttp_agent/agent.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/contributing/samples/mcp_streamablehttp_agent/README.md b/contributing/samples/mcp_streamablehttp_agent/README.md index 1c211dd71..547a0788d 100644 --- a/contributing/samples/mcp_streamablehttp_agent/README.md +++ b/contributing/samples/mcp_streamablehttp_agent/README.md @@ -1,8 +1,7 @@ -This agent connects to a local MCP server via sse. +This agent connects to a local MCP server via Streamable HTTP. To run this agent, start the local MCP server first by : ```bash uv run filesystem_server.py ``` - diff --git a/contributing/samples/mcp_streamablehttp_agent/agent.py b/contributing/samples/mcp_streamablehttp_agent/agent.py index 61d59e051..f165c4c1b 100644 --- a/contributing/samples/mcp_streamablehttp_agent/agent.py +++ b/contributing/samples/mcp_streamablehttp_agent/agent.py @@ -18,7 +18,6 @@ from google.adk.agents.llm_agent import LlmAgent from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPServerParams from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset -from google.adk.tools.mcp_tool.mcp_toolset import SseServerParams _allowed_path = os.path.dirname(os.path.abspath(__file__)) From 4bda24517163a52dc227b525f26d3d83ce36f1ec Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 16:25:56 -0700 Subject: [PATCH 35/61] chore: fix oauth_calendar_agent example PiperOrigin-RevId: 772214855 --- .../samples/oauth_calendar_agent/agent.py | 120 ++++++++++++------ 1 file changed, 79 insertions(+), 41 deletions(-) diff --git a/contributing/samples/oauth_calendar_agent/agent.py b/contributing/samples/oauth_calendar_agent/agent.py index 9d56d3ff8..a1b1dea87 100644 --- a/contributing/samples/oauth_calendar_agent/agent.py +++ b/contributing/samples/oauth_calendar_agent/agent.py @@ -27,8 +27,6 @@ from google.adk.auth import AuthCredentialTypes from google.adk.auth import OAuth2Auth from google.adk.tools import ToolContext -from google.adk.tools.authenticated_tool.base_authenticated_tool import AuthenticatedFunctionTool -from google.adk.tools.authenticated_tool.credentials_store import ToolContextCredentialsStore from google.adk.tools.google_api_tool import CalendarToolset from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials @@ -58,7 +56,6 @@ def list_calendar_events( end_time: str, limit: int, tool_context: ToolContext, - credential: AuthCredential, ) -> list[dict]: """Search for calendar events. @@ -83,11 +80,84 @@ def list_calendar_events( Returns: list[dict]: A list of events that match the search criteria. """ - - creds = Credentials( - token=credential.oauth2.access_token, - refresh_token=credential.oauth2.refresh_token, - ) + creds = None + + # Check if the tokes were already in the session state, which means the user + # has already gone through the OAuth flow and successfully authenticated and + # authorized the tool to access their calendar. + if "calendar_tool_tokens" in tool_context.state: + creds = Credentials.from_authorized_user_info( + tool_context.state["calendar_tool_tokens"], SCOPES + ) + if not creds or not creds.valid: + # If the access token is expired, refresh it with the refresh token. + if creds and creds.expired and creds.refresh_token: + creds.refresh(Request()) + else: + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://accounts.google.com/o/oauth2/auth", + tokenUrl="https://oauth2.googleapis.com/token", + scopes={ + "https://www.googleapis.com/auth/calendar": ( + "See, edit, share, and permanently delete all the" + " calendars you can access using Google Calendar" + ) + }, + ) + ) + ) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=oauth_client_id, client_secret=oauth_client_secret + ), + ) + # If the user has not gone through the OAuth flow before, or the refresh + # token also expired, we need to ask users to go through the OAuth flow. + # First we check whether the user has just gone through the OAuth flow and + # Oauth response is just passed back. + auth_response = tool_context.get_auth_response( + AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential + ) + ) + if auth_response: + # ADK exchanged the access token already for us + access_token = auth_response.oauth2.access_token + refresh_token = auth_response.oauth2.refresh_token + + creds = Credentials( + token=access_token, + refresh_token=refresh_token, + token_uri=auth_scheme.flows.authorizationCode.tokenUrl, + client_id=oauth_client_id, + client_secret=oauth_client_secret, + scopes=list(auth_scheme.flows.authorizationCode.scopes.keys()), + ) + else: + # If there are no auth response which means the user has not gone + # through the OAuth flow yet, we need to ask users to go through the + # OAuth flow. + tool_context.request_credential( + AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=auth_credential, + ) + ) + # The return value is optional and could be any dict object. It will be + # wrapped in a dict with key as 'result' and value as the return value + # if the object returned is not a dict. This response will be passed + # to LLM to generate a user friendly message. e.g. LLM will tell user: + # "I need your authorization to access your calendar. Please authorize + # me so I can check your meetings for today." + return "Need User Authorization to access their calendar." + # We store the access token and refresh token in the session state for the + # next runs. This is just an example. On production, a tool should store + # those credentials in some secure store or properly encrypt it before store + # it in the session state. + tool_context.state["calendar_tool_tokens"] = json.loads(creds.to_json()) service = build("calendar", "v3", credentials=creds) events_result = ( @@ -138,38 +208,6 @@ def update_time(callback_context: CallbackContext): Currnet time: {_time} """, - tools=[ - AuthenticatedFunctionTool( - func=list_calendar_events, - auth_config=AuthConfig( - auth_scheme=OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl=( - "https://accounts.google.com/o/oauth2/auth" - ), - tokenUrl="https://oauth2.googleapis.com/token", - scopes={ - "https://www.googleapis.com/auth/calendar": ( - "See, edit, share, and permanently delete" - " all the calendars you can access using" - " Google Calendar" - ) - }, - ) - ) - ), - raw_auth_credential=AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id=oauth_client_id, - client_secret=oauth_client_secret, - ), - ), - ), - credential_store=ToolContextCredentialsStore(), - ), - calendar_toolset, - ], + tools=[list_calendar_events, calendar_toolset], before_agent_callback=update_time, ) From aafa80bd85a49fb1c1a255ac797587cffd3fa567 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Jun 2025 16:36:34 -0700 Subject: [PATCH 36/61] fix: stream in litellm + adk and add corresponding integration tests Fixes https://github.com/google/adk-python/issues/1368 PiperOrigin-RevId: 772218385 --- src/google/adk/models/lite_llm.py | 3 +- .../models/test_litellm_no_function.py | 109 +++++++++++++++++- .../models/test_litellm_with_function.py | 25 ++-- 3 files changed, 125 insertions(+), 12 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index ed54faecf..dce5ed7c4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -739,11 +739,12 @@ async def generate_content_async( _message_to_generate_content_response( ChatCompletionAssistantMessage( role="assistant", - content="", + content=text, tool_calls=tool_calls, ) ) ) + text = "" function_calls.clear() elif finish_reason == "stop" and text: aggregated_llm_response = _message_to_generate_content_response( diff --git a/tests/integration/models/test_litellm_no_function.py b/tests/integration/models/test_litellm_no_function.py index e662384ce..ff5d3bb82 100644 --- a/tests/integration/models/test_litellm_no_function.py +++ b/tests/integration/models/test_litellm_no_function.py @@ -20,12 +20,26 @@ from google.genai.types import Part import pytest -_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """You are a helpful assistant.""" +def get_weather(city: str) -> str: + """Simulates a web search. Use it get information on weather. + + Args: + city: A string containing the location to get weather information for. + + Returns: + A string with the simulated weather information for the queried city. + """ + if "sf" in city.lower() or "san francisco" in city.lower(): + return "It's 70 degrees and foggy." + return "It's 80 degrees and sunny." + + @pytest.fixture def oss_llm(): return LiteLlm(model=_TEST_MODEL_NAME) @@ -44,6 +58,48 @@ def llm_request(): ) +@pytest.fixture +def llm_request_with_tools(): + return LlmRequest( + model=_TEST_MODEL_NAME, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="What is the weather in San Francisco?") + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction=_SYSTEM_PROMPT, + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="get_weather", + description="Get the weather in a given location", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "city": types.Schema( + type=types.Type.STRING, + description=( + "The city to get the weather for." + ), + ), + }, + required=["city"], + ), + ) + ] + ) + ], + ), + ) + + @pytest.mark.asyncio async def test_generate_content_async(oss_llm, llm_request): async for response in oss_llm.generate_content_async(llm_request): @@ -51,10 +107,8 @@ async def test_generate_content_async(oss_llm, llm_request): assert response.content.parts[0].text -# Note that, this test disabled streaming because streaming is not supported -# properly in the current test model for now. @pytest.mark.asyncio -async def test_generate_content_async_stream(oss_llm, llm_request): +async def test_generate_content_async(oss_llm, llm_request): responses = [ resp async for resp in oss_llm.generate_content_async( @@ -63,3 +117,50 @@ async def test_generate_content_async_stream(oss_llm, llm_request): ] part = responses[0].content.parts[0] assert len(part.text) > 0 + + +@pytest.mark.asyncio +async def test_generate_content_async_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=False + ) + ] + function_call = responses[0].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_async_stream(oss_llm, llm_request): + responses = [ + resp + async for resp in oss_llm.generate_content_async(llm_request, stream=True) + ] + text = "" + for i in range(len(responses) - 1): + assert responses[i].partial is True + assert responses[i].content.parts[0].text + text += responses[i].content.parts[0].text + + # Last message should be accumulated text + assert responses[-1].content.parts[0].text == text + assert not responses[-1].partial + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" diff --git a/tests/integration/models/test_litellm_with_function.py b/tests/integration/models/test_litellm_with_function.py index a2ceb540a..799c55e5c 100644 --- a/tests/integration/models/test_litellm_with_function.py +++ b/tests/integration/models/test_litellm_with_function.py @@ -13,7 +13,6 @@ # limitations under the License. from google.adk.models import LlmRequest -from google.adk.models import LlmResponse from google.adk.models.lite_llm import LiteLlm from google.genai import types from google.genai.types import Content @@ -23,12 +22,11 @@ litellm.add_function_to_prompt = True -_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" - +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """ You are a helpful assistant, and call tools optionally. -If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs. +If call tools, the tool format should be in json body, and the tool argument values should be parsed from users inputs. """ @@ -40,7 +38,7 @@ "properties": { "city": { "type": "string", - "description": "The city, e.g. San Francisco", + "description": "The city to get the weather for.", }, }, "required": ["city"], @@ -87,8 +85,6 @@ def llm_request(): ) -# Note that, this test disabled streaming because streaming is not supported -# properly in the current test model for now. @pytest.mark.asyncio async def test_generate_content_asyn_with_function( oss_llm_with_function, llm_request @@ -102,3 +98,18 @@ async def test_generate_content_asyn_with_function( function_call = responses[0].content.parts[0].function_call assert function_call.name == "get_weather" assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_asyn_stream_with_function( + oss_llm_with_function, llm_request +): + responses = [ + resp + async for resp in oss_llm_with_function.generate_content_async( + llm_request, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" From e384fa4ad76114fa942a3be8bd51bbeb5225e00e Mon Sep 17 00:00:00 2001 From: Liang Wu Date: Mon, 16 Jun 2025 16:57:05 -0700 Subject: [PATCH 37/61] chore: fix previously skipped isort issue PiperOrigin-RevId: 772224853 --- tests/integration/models/test_litellm_no_function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/models/test_litellm_no_function.py b/tests/integration/models/test_litellm_no_function.py index ff5d3bb82..05072b899 100644 --- a/tests/integration/models/test_litellm_no_function.py +++ b/tests/integration/models/test_litellm_no_function.py @@ -20,7 +20,6 @@ from google.genai.types import Part import pytest - _TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """You are a helpful assistant.""" From a6b1baa61b5dbf4168a035609077094307171135 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 17:04:27 -0700 Subject: [PATCH 38/61] chore: Add base credential exchanger (Experimental) PiperOrigin-RevId: 772227201 --- src/google/adk/auth/exchanger/__init__.py | 25 ++++++++++ .../exchanger/base_credential_exchanger.py | 49 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 src/google/adk/auth/exchanger/__init__.py create mode 100644 src/google/adk/auth/exchanger/base_credential_exchanger.py diff --git a/src/google/adk/auth/exchanger/__init__.py b/src/google/adk/auth/exchanger/__init__.py new file mode 100644 index 000000000..ce5c464c4 --- /dev/null +++ b/src/google/adk/auth/exchanger/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential exchanger module.""" + +from .base_credential_exchanger import BaseCredentialExchanger +from .credential_exchanger_registry import CredentialExchangerRegistry +from .service_account_credential_exchanger import ServiceAccountCredentialExchanger + +__all__ = [ + "BaseCredentialExchanger", + "CredentialExchangerRegistry", + "ServiceAccountCredentialExchanger", +] diff --git a/src/google/adk/auth/exchanger/base_credential_exchanger.py b/src/google/adk/auth/exchanger/base_credential_exchanger.py new file mode 100644 index 000000000..1d7417cd0 --- /dev/null +++ b/src/google/adk/auth/exchanger/base_credential_exchanger.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base credential exchanger interface.""" + +from __future__ import annotations + +import abc +from typing import Optional + +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_schemes import AuthScheme + + +@experimental +class BaseCredentialExchanger(abc.ABC): + """Base interface for credential exchangers.""" + + @abc.abstractmethod + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Exchange credential if needed. + + Args: + auth_credential: The credential to exchange. + auth_scheme: The authentication scheme (optional, some exchangers don't need it). + + Returns: + The exchanged credential. + + Raises: + ValueError: If credential exchange fails. + """ + pass From 28dfcd25128e4cab34764abd1451f15529c4626d Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 17:36:12 -0700 Subject: [PATCH 39/61] chore: Add experimental decorator to Oauth2 credential fethcer PiperOrigin-RevId: 772236406 --- src/google/adk/auth/oauth2_credential_fetcher.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/google/adk/auth/oauth2_credential_fetcher.py b/src/google/adk/auth/oauth2_credential_fetcher.py index 1a8692417..cbed70762 100644 --- a/src/google/adk/auth/oauth2_credential_fetcher.py +++ b/src/google/adk/auth/oauth2_credential_fetcher.py @@ -20,6 +20,7 @@ from fastapi.openapi.models import OAuth2 +from ..utils.feature_decorator import experimental from .auth_credential import AuthCredential from .auth_schemes import AuthScheme from .auth_schemes import OAuthGrantType @@ -37,8 +38,9 @@ logger = logging.getLogger("google_adk." + __name__) +@experimental class OAuth2CredentialFetcher: - """Exchanges and refreshes an OAuth2 access token.""" + """Exchanges and refreshes an OAuth2 access token. (Experimental)""" def __init__( self, From e2a81365ec18cb4ed1a5422513992ccc21962937 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 18:21:03 -0700 Subject: [PATCH 40/61] chore: Add a base credential refresher interface PiperOrigin-RevId: 772248299 --- .../exchanger/base_credential_exchanger.py | 12 ++- src/google/adk/auth/refresher/__init__.py | 21 ++++++ .../refresher/base_credential_refresher.py | 74 +++++++++++++++++++ 3 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 src/google/adk/auth/refresher/__init__.py create mode 100644 src/google/adk/auth/refresher/base_credential_refresher.py diff --git a/src/google/adk/auth/exchanger/base_credential_exchanger.py b/src/google/adk/auth/exchanger/base_credential_exchanger.py index 1d7417cd0..b09adb80a 100644 --- a/src/google/adk/auth/exchanger/base_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/base_credential_exchanger.py @@ -24,9 +24,17 @@ from ..auth_schemes import AuthScheme +class CredentialExchangError(Exception): + """Base exception for credential exchange errors.""" + + @experimental class BaseCredentialExchanger(abc.ABC): - """Base interface for credential exchangers.""" + """Base interface for credential exchangers. + + Credential exchangers are responsible for exchanging credentials from + one format or scheme to another. + """ @abc.abstractmethod async def exchange( @@ -44,6 +52,6 @@ async def exchange( The exchanged credential. Raises: - ValueError: If credential exchange fails. + CredentialExchangError: If credential exchange fails. """ pass diff --git a/src/google/adk/auth/refresher/__init__.py b/src/google/adk/auth/refresher/__init__.py new file mode 100644 index 000000000..27d7245dc --- /dev/null +++ b/src/google/adk/auth/refresher/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential refresher module.""" + +from .base_credential_refresher import BaseCredentialRefresher + +__all__ = [ + "BaseCredentialRefresher", +] diff --git a/src/google/adk/auth/refresher/base_credential_refresher.py b/src/google/adk/auth/refresher/base_credential_refresher.py new file mode 100644 index 000000000..230b07d09 --- /dev/null +++ b/src/google/adk/auth/refresher/base_credential_refresher.py @@ -0,0 +1,74 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base credential refresher interface.""" + +from __future__ import annotations + +import abc +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.utils.feature_decorator import experimental + + +class CredentialRefresherError(Exception): + """Base exception for credential refresh errors.""" + + +@experimental +class BaseCredentialRefresher(abc.ABC): + """Base interface for credential refreshers. + + Credential refreshers are responsible for checking if a credential is expired + or needs to be refreshed, and for refreshing it if necessary. + """ + + @abc.abstractmethod + async def is_refresh_needed( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> bool: + """Checks if a credential needs to be refreshed. + + Args: + auth_credential: The credential to check. + auth_scheme: The authentication scheme (optional, some refreshers don't need it). + + Returns: + True if the credential needs to be refreshed, False otherwise. + """ + pass + + @abc.abstractmethod + async def refresh( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Refreshes a credential if needed. + + Args: + auth_credential: The credential to refresh. + auth_scheme: The authentication scheme (optional, some refreshers don't need it). + + Returns: + The refreshed credential. + + Raises: + CredentialRefresherError: If credential refresh fails. + """ + pass From 476805d5b9e6d598ca8bb71488a4923c162cfdbc Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 18:33:41 -0700 Subject: [PATCH 41/61] chore: Add a2a extra dependency for github UT workflows PiperOrigin-RevId: 772251530 --- .github/workflows/python-unit-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index a504fde0d..0d77402f9 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -43,7 +43,7 @@ jobs: run: | uv venv .venv source .venv/bin/activate - uv sync --extra test --extra eval + uv sync --extra test --extra eval --extra a2a - name: Run unit tests with pytest run: | From 94caccc148833c135b9b60af3c4c54986b10406c Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 19:02:31 -0700 Subject: [PATCH 42/61] refactor: Extract util method from OAuth2 credential fetcher for reuse Context: we'd like to separate fetcher into exchanger and refresher later. This cl help to extract the common utility that will be used by both exchanger and refresher. PiperOrigin-RevId: 772257995 --- .../adk/auth/oauth2_credential_fetcher.py | 59 +--- src/google/adk/auth/oauth2_credential_util.py | 107 ++++++ tests/unittests/auth/test_auth_handler.py | 2 +- .../auth/test_oauth2_credential_fetcher.py | 332 +----------------- .../auth/test_oauth2_credential_util.py | 147 ++++++++ 5 files changed, 284 insertions(+), 363 deletions(-) create mode 100644 src/google/adk/auth/oauth2_credential_util.py create mode 100644 tests/unittests/auth/test_oauth2_credential_util.py diff --git a/src/google/adk/auth/oauth2_credential_fetcher.py b/src/google/adk/auth/oauth2_credential_fetcher.py index cbed70762..c9e838b25 100644 --- a/src/google/adk/auth/oauth2_credential_fetcher.py +++ b/src/google/adk/auth/oauth2_credential_fetcher.py @@ -15,19 +15,15 @@ from __future__ import annotations import logging -from typing import Optional -from typing import Tuple - -from fastapi.openapi.models import OAuth2 from ..utils.feature_decorator import experimental from .auth_credential import AuthCredential from .auth_schemes import AuthScheme from .auth_schemes import OAuthGrantType -from .auth_schemes import OpenIdConnectWithConfig +from .oauth2_credential_util import create_oauth2_session +from .oauth2_credential_util import update_credential_with_tokens try: - from authlib.integrations.requests_client import OAuth2Session from authlib.oauth2.rfc6749 import OAuth2Token AUTHLIB_AVIALABLE = True @@ -50,45 +46,6 @@ def __init__( self._auth_scheme = auth_scheme self._auth_credential = auth_credential - def _oauth2_session(self) -> Tuple[Optional[OAuth2Session], Optional[str]]: - auth_scheme = self._auth_scheme - auth_credential = self._auth_credential - - if isinstance(auth_scheme, OpenIdConnectWithConfig): - if not hasattr(auth_scheme, "token_endpoint"): - return None, None - token_endpoint = auth_scheme.token_endpoint - scopes = auth_scheme.scopes - elif isinstance(auth_scheme, OAuth2): - if ( - not auth_scheme.flows.authorizationCode - or not auth_scheme.flows.authorizationCode.tokenUrl - ): - return None, None - token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl - scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) - else: - return None, None - - if ( - not auth_credential - or not auth_credential.oauth2 - or not auth_credential.oauth2.client_id - or not auth_credential.oauth2.client_secret - ): - return None, None - - return ( - OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, - scope=" ".join(scopes), - redirect_uri=auth_credential.oauth2.redirect_uri, - state=auth_credential.oauth2.state, - ), - token_endpoint, - ) - def _update_credential(self, tokens: OAuth2Token) -> None: self._auth_credential.oauth2.access_token = tokens.get("access_token") self._auth_credential.oauth2.refresh_token = tokens.get("refresh_token") @@ -114,7 +71,9 @@ def exchange(self) -> AuthCredential: ): return self._auth_credential - client, token_endpoint = self._oauth2_session() + client, token_endpoint = create_oauth2_session( + self._auth_scheme, self._auth_credential + ) if not client: logger.warning("Could not create OAuth2 session for token exchange") return self._auth_credential @@ -126,7 +85,7 @@ def exchange(self) -> AuthCredential: code=self._auth_credential.oauth2.auth_code, grant_type=OAuthGrantType.AUTHORIZATION_CODE, ) - self._update_credential(tokens) + update_credential_with_tokens(self._auth_credential, tokens) logger.info("Successfully exchanged OAuth2 tokens") except Exception as e: logger.error("Failed to exchange OAuth2 tokens: %s", e) @@ -151,7 +110,9 @@ def refresh(self) -> AuthCredential: "expires_at": credential.oauth2.expires_at, "expires_in": credential.oauth2.expires_in, }).is_expired(): - client, token_endpoint = self._oauth2_session() + client, token_endpoint = create_oauth2_session( + self._auth_scheme, self._auth_credential + ) if not client: logger.warning("Could not create OAuth2 session for token refresh") return credential @@ -161,7 +122,7 @@ def refresh(self) -> AuthCredential: url=token_endpoint, refresh_token=credential.oauth2.refresh_token, ) - self._update_credential(tokens) + update_credential_with_tokens(self._auth_credential, tokens) logger.info("Successfully refreshed OAuth2 tokens") except Exception as e: logger.error("Failed to refresh OAuth2 tokens: %s", e) diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py new file mode 100644 index 000000000..51ed4d29f --- /dev/null +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional +from typing import Tuple + +from fastapi.openapi.models import OAuth2 + +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_schemes import AuthScheme +from .auth_schemes import OpenIdConnectWithConfig + +try: + from authlib.integrations.requests_client import OAuth2Session + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +def create_oauth2_session( + auth_scheme: AuthScheme, + auth_credential: AuthCredential, +) -> Tuple[Optional[OAuth2Session], Optional[str]]: + """Create an OAuth2 session for token operations. + + Args: + auth_scheme: The authentication scheme configuration. + auth_credential: The authentication credential. + + Returns: + Tuple of (OAuth2Session, token_endpoint) or (None, None) if cannot create session. + """ + if isinstance(auth_scheme, OpenIdConnectWithConfig): + if not hasattr(auth_scheme, "token_endpoint"): + return None, None + token_endpoint = auth_scheme.token_endpoint + scopes = auth_scheme.scopes + elif isinstance(auth_scheme, OAuth2): + if ( + not auth_scheme.flows.authorizationCode + or not auth_scheme.flows.authorizationCode.tokenUrl + ): + return None, None + token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl + scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) + else: + return None, None + + if ( + not auth_credential + or not auth_credential.oauth2 + or not auth_credential.oauth2.client_id + or not auth_credential.oauth2.client_secret + ): + return None, None + + return ( + OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + scope=" ".join(scopes), + redirect_uri=auth_credential.oauth2.redirect_uri, + state=auth_credential.oauth2.state, + ), + token_endpoint, + ) + + +@experimental +def update_credential_with_tokens( + auth_credential: AuthCredential, tokens: OAuth2Token +) -> None: + """Update the credential with new tokens. + + Args: + auth_credential: The authentication credential to update. + tokens: The OAuth2Token object containing new token information. + """ + auth_credential.oauth2.access_token = tokens.get("access_token") + auth_credential.oauth2.refresh_token = tokens.get("refresh_token") + auth_credential.oauth2.expires_at = ( + int(tokens.get("expires_at")) if tokens.get("expires_at") else None + ) + auth_credential.oauth2.expires_in = ( + int(tokens.get("expires_in")) if tokens.get("expires_in") else None + ) diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index aaed35a19..2bfc7d4c9 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -538,7 +538,7 @@ def test_credentials_with_token( assert result == oauth2_credentials_with_token @patch( - "google.adk.auth.oauth2_credential_fetcher.OAuth2Session", + "google.adk.auth.oauth2_credential_util.OAuth2Session", MockOAuth2Session, ) def test_successful_token_exchange(self, auth_config_with_auth_code): diff --git a/tests/unittests/auth/test_oauth2_credential_fetcher.py b/tests/unittests/auth/test_oauth2_credential_fetcher.py index 0b9b5a3c1..aba6a9923 100644 --- a/tests/unittests/auth/test_oauth2_credential_fetcher.py +++ b/tests/unittests/auth/test_oauth2_credential_fetcher.py @@ -14,7 +14,6 @@ import time from unittest.mock import Mock -from unittest.mock import patch from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import OAuth2 @@ -24,38 +23,15 @@ from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_schemes import OpenIdConnectWithConfig -from google.adk.auth.oauth2_credential_fetcher import OAuth2CredentialFetcher +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens -class TestOAuth2CredentialFetcher: - """Test suite for OAuth2CredentialFetcher.""" +class TestOAuth2CredentialUtil: + """Test suite for OAuth2 credential utility functions.""" - def test_init(self): - """Test OAuth2CredentialFetcher initialization.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - assert fetcher._auth_scheme == scheme - assert fetcher._auth_credential == credential - - def test_oauth2_session_openid_connect(self): - """Test _oauth2_session with OpenID Connect scheme.""" + def test_create_oauth2_session_openid_connect(self): + """Test create_oauth2_session with OpenID Connect scheme.""" scheme = OpenIdConnectWithConfig( type_="openIdConnect", openId_connect_url=( @@ -75,16 +51,15 @@ def test_oauth2_session_openid_connect(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is not None assert token_endpoint == "https://example.com/token" assert client.client_id == "test_client_id" assert client.client_secret == "test_client_secret" - def test_oauth2_session_oauth2_scheme(self): - """Test _oauth2_session with OAuth2 scheme.""" + def test_create_oauth2_session_oauth2_scheme(self): + """Test create_oauth2_session with OAuth2 scheme.""" flows = OAuthFlows( authorizationCode=OAuthFlowAuthorizationCode( authorizationUrl="https://example.com/auth", @@ -102,14 +77,13 @@ def test_oauth2_session_oauth2_scheme(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is not None assert token_endpoint == "https://example.com/token" - def test_oauth2_session_invalid_scheme(self): - """Test _oauth2_session with invalid scheme.""" + def test_create_oauth2_session_invalid_scheme(self): + """Test create_oauth2_session with invalid scheme.""" scheme = Mock() # Invalid scheme type credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, @@ -119,14 +93,13 @@ def test_oauth2_session_invalid_scheme(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is None assert token_endpoint is None - def test_oauth2_session_missing_credentials(self): - """Test _oauth2_session with missing credentials.""" + def test_create_oauth2_session_missing_credentials(self): + """Test create_oauth2_session with missing credentials.""" scheme = OpenIdConnectWithConfig( type_="openIdConnect", openId_connect_url=( @@ -144,23 +117,13 @@ def test_oauth2_session_missing_credentials(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is None assert token_endpoint is None - def test_update_credential(self): - """Test _update_credential method.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) + def test_update_credential_with_tokens(self): + """Test update_credential_with_tokens function.""" credential = AuthCredential( auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, oauth2=OAuth2Auth( @@ -169,7 +132,6 @@ def test_update_credential(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) tokens = OAuth2Token({ "access_token": "new_access_token", "refresh_token": "new_refresh_token", @@ -177,265 +139,9 @@ def test_update_credential(self): "expires_in": 3600, }) - fetcher._update_credential(tokens) + update_credential_with_tokens(credential, tokens) assert credential.oauth2.access_token == "new_access_token" assert credential.oauth2.refresh_token == "new_refresh_token" assert credential.oauth2.expires_at == int(time.time()) + 3600 assert credential.oauth2.expires_in == 3600 - - def test_exchange_with_existing_token(self): - """Test exchange method when access token already exists.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="existing_token", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result == credential - assert result.oauth2.access_token == "existing_token" - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_exchange_success(self, mock_oauth2_session): - """Test successful token exchange.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_response_uri=( - "https://example.com/callback?code=auth_code&state=test_state" - ), - ), - ) - - # Mock the OAuth2Session - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - } - mock_client.fetch_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result.oauth2.access_token == "new_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" - mock_client.fetch_token.assert_called_once() - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_exchange_with_auth_code(self, mock_oauth2_session): - """Test token exchange with auth code.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_code="test_auth_code", - ), - ) - - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - } - mock_client.fetch_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result.oauth2.access_token == "new_access_token" - mock_client.fetch_token.assert_called_once() - - def test_exchange_no_session(self): - """Test exchange when OAuth2Session cannot be created.""" - scheme = Mock() # Invalid scheme - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_response_uri="https://example.com/callback?code=auth_code", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result == credential - assert result.oauth2.access_token is None - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_refresh_token_not_expired( - self, mock_oauth2_session, mock_oauth2_token - ): - """Test refresh when token is not expired.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="current_token", - refresh_token="refresh_token", - expires_at=int(time.time()) + 3600, - expires_in=3600, - ), - ) - - # Mock token not expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = False - mock_oauth2_token.return_value = mock_token_instance - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - assert result.oauth2.access_token == "current_token" - mock_oauth2_session.assert_not_called() - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_refresh_token_expired_success( - self, mock_oauth2_session, mock_oauth2_token - ): - """Test successful token refresh when token is expired.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="expired_token", - refresh_token="refresh_token", - expires_at=int(time.time()) - 3600, # Expired - expires_in=3600, - ), - ) - - # Mock token expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = True - mock_oauth2_token.return_value = mock_token_instance - - # Mock refresh token response - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "refreshed_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - } - mock_client.refresh_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result.oauth2.access_token == "refreshed_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" - mock_client.refresh_token.assert_called_once_with( - url="https://example.com/token", - refresh_token="refresh_token", - ) - - def test_refresh_no_oauth2_credential(self): - """Test refresh when oauth2 credential is missing.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP) # No oauth2 - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - def test_refresh_no_session(self, mock_oauth2_token): - """Test refresh when OAuth2Session cannot be created.""" - scheme = Mock() # Invalid scheme - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="expired_token", - refresh_token="refresh_token", - expires_at=int(time.time()) - 3600, - ), - ) - - # Mock token expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = True - mock_oauth2_token.return_value = mock_token_instance - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - assert result.oauth2.access_token == "expired_token" # Unchanged diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py new file mode 100644 index 000000000..aba6a9923 --- /dev/null +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock + +from authlib.oauth2.rfc6749 import OAuth2Token +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens + + +class TestOAuth2CredentialUtil: + """Test suite for OAuth2 credential utility functions.""" + + def test_create_oauth2_session_openid_connect(self): + """Test create_oauth2_session with OpenID Connect scheme.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + + def test_create_oauth2_session_oauth2_scheme(self): + """Test create_oauth2_session with OAuth2 scheme.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + scheme = OAuth2(type_="oauth2", flows=flows) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + + def test_create_oauth2_session_invalid_scheme(self): + """Test create_oauth2_session with invalid scheme.""" + scheme = Mock() # Invalid scheme type + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_create_oauth2_session_missing_credentials(self): + """Test create_oauth2_session with missing credentials.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_update_credential_with_tokens(self): + """Test update_credential_with_tokens function.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + + update_credential_with_tokens(credential, tokens) + + assert credential.oauth2.access_token == "new_access_token" + assert credential.oauth2.refresh_token == "new_refresh_token" + assert credential.oauth2.expires_at == int(time.time()) + 3600 + assert credential.oauth2.expires_in == 3600 From c755cf23c555a2173b0eafd774cc0cc027b5f3da Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 19:53:08 -0700 Subject: [PATCH 43/61] chore: Ignore a2a ut tests for python 3.9 given a2a-sdk only supports 3.10+ PiperOrigin-RevId: 772270172 --- .github/workflows/python-unit-tests.yml | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 0d77402f9..d4af7b13a 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -48,6 +48,13 @@ jobs: - name: Run unit tests with pytest run: | source .venv/bin/activate - pytest tests/unittests \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + if [[ "${{ matrix.python-version }}" == "3.9" ]]; then + pytest tests/unittests \ + --ignore=tests/unittests/a2a \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + else + pytest tests/unittests \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + fi \ No newline at end of file From e1812797ad499a2503275e41d28b07338ca951f9 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Jun 2025 20:34:37 -0700 Subject: [PATCH 44/61] chore: Add A2A Part converter (WIP) PiperOrigin-RevId: 772282116 --- src/google/adk/a2a/converters/__init__.py | 13 + .../adk/a2a/converters/part_converter.py | 166 +++++++ tests/unittests/a2a/__init__.py | 13 + tests/unittests/a2a/converters/__init__.py | 13 + .../a2a/converters/test_part_converter.py | 443 ++++++++++++++++++ 5 files changed, 648 insertions(+) create mode 100644 src/google/adk/a2a/converters/__init__.py create mode 100644 src/google/adk/a2a/converters/part_converter.py create mode 100644 tests/unittests/a2a/__init__.py create mode 100644 tests/unittests/a2a/converters/__init__.py create mode 100644 tests/unittests/a2a/converters/test_part_converter.py diff --git a/src/google/adk/a2a/converters/__init__.py b/src/google/adk/a2a/converters/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/a2a/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py new file mode 100644 index 000000000..1c51fd7c1 --- /dev/null +++ b/src/google/adk/a2a/converters/part_converter.py @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +module containing utilities for conversion betwen A2A Part and Google GenAI Part +""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from a2a import types as a2a_types +from google.genai import types as genai_types + +from ...utils.feature_decorator import working_in_progress + +logger = logging.getLogger('google_adk.' + __name__) + +A2A_DATA_PART_METADATA_TYPE_KEY = 'type' +A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = 'function_call' +A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response' + + +@working_in_progress +def convert_a2a_part_to_genai_part( + a2a_part: a2a_types.Part, +) -> Optional[genai_types.Part]: + """Convert an A2A Part to a Google GenAI Part.""" + part = a2a_part.root + if isinstance(part, a2a_types.TextPart): + return genai_types.Part(text=part.text) + + if isinstance(part, a2a_types.FilePart): + if isinstance(part.file, a2a_types.FileWithUri): + return genai_types.Part( + file_data=genai_types.FileData( + file_uri=part.file.uri, mime_type=part.file.mimeType + ) + ) + + elif isinstance(part.file, a2a_types.FileWithBytes): + return genai_types.Part( + inline_data=genai_types.Blob( + data=part.file.bytes.encode('utf-8'), mime_type=part.file.mimeType + ) + ) + else: + logger.warning( + 'Cannot convert unsupported file type: %s for A2A part: %s', + type(part.file), + a2a_part, + ) + return None + + if isinstance(part, a2a_types.DataPart): + # Conver the Data Part to funcall and function reponse. + # This is mainly for converting human in the loop and auth request and + # response. + # TODO once A2A defined how to suervice such information, migrate below + # logic accordinlgy + if part.metadata and A2A_DATA_PART_METADATA_TYPE_KEY in part.metadata: + if ( + part.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ): + return genai_types.Part( + function_call=genai_types.FunctionCall.model_validate( + part.data, by_alias=True + ) + ) + if ( + part.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ): + return genai_types.Part( + function_response=genai_types.FunctionResponse.model_validate( + part.data, by_alias=True + ) + ) + return genai_types.Part(text=json.dumps(part.data)) + + logger.warning( + 'Cannot convert unsupported part type: %s for A2A part: %s', + type(part), + a2a_part, + ) + return None + + +@working_in_progress +def convert_genai_part_to_a2a_part( + part: genai_types.Part, +) -> Optional[a2a_types.Part]: + """Convert a Google GenAI Part to an A2A Part.""" + if part.text: + return a2a_types.TextPart(text=part.text) + + if part.file_data: + return a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=part.file_data.file_uri, + mimeType=part.file_data.mime_type, + ) + ) + + if part.inline_data: + return a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=part.inline_data.data, + mimeType=part.inline_data.mime_type, + ) + ) + ) + + # Conver the funcall and function reponse to A2A DataPart. + # This is mainly for converting human in the loop and auth request and + # response. + # TODO once A2A defined how to suervice such information, migrate below + # logic accordinlgy + if part.function_call: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.function_call.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + }, + ) + ) + + if part.function_response: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.function_response.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + }, + ) + ) + + logger.warning( + 'Cannot convert unsupported part for Google GenAI part: %s', + part, + ) + return None diff --git a/tests/unittests/a2a/__init__.py b/tests/unittests/a2a/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/converters/__init__.py b/tests/unittests/a2a/converters/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py new file mode 100644 index 000000000..5ad6cd62d --- /dev/null +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -0,0 +1,443 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import Mock +from unittest.mock import patch + +from a2a import types as a2a_types +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part +from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part +from google.genai import types as genai_types +import pytest + + +class TestConvertA2aPartToGenaiPart: + """Test cases for convert_a2a_part_to_genai_part function.""" + + def test_convert_text_part(self): + """Test conversion of A2A TextPart to GenAI Part.""" + # Arrange + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="Hello, world!")) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == "Hello, world!" + + def test_convert_file_part_with_uri(self): + """Test conversion of A2A FilePart with URI to GenAI Part.""" + # Arrange + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri="gs://bucket/file.txt", mimeType="text/plain" + ) + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.file_data is not None + assert result.file_data.file_uri == "gs://bucket/file.txt" + assert result.file_data.mime_type == "text/plain" + + def test_convert_file_part_with_bytes(self): + """Test conversion of A2A FilePart with bytes to GenAI Part.""" + # Arrange + test_bytes = b"test file content" + # Note: A2A FileWithBytes converts bytes to string automatically + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=test_bytes, mimeType="text/plain" + ) + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.inline_data is not None + # Source code now properly converts A2A string back to bytes for GenAI Blob + assert result.inline_data.data == test_bytes + assert result.inline_data.mime_type == "text/plain" + + def test_convert_data_part_function_call(self): + """Test conversion of A2A DataPart with function call metadata.""" + # Arrange + function_call_data = { + "name": "test_function", + "args": {"param1": "value1", "param2": 42}, + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=function_call_data, + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.function_call is not None + assert result.function_call.name == "test_function" + assert result.function_call.args == {"param1": "value1", "param2": 42} + + def test_convert_data_part_function_response(self): + """Test conversion of A2A DataPart with function response metadata.""" + # Arrange + function_response_data = { + "name": "test_function", + "response": {"result": "success", "data": [1, 2, 3]}, + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=function_response_data, + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.function_response is not None + assert result.function_response.name == "test_function" + assert result.function_response.response == { + "result": "success", + "data": [1, 2, 3], + } + + def test_convert_data_part_without_special_metadata(self): + """Test conversion of A2A DataPart without special metadata to text.""" + # Arrange + data = {"key": "value", "number": 123} + a2a_part = a2a_types.Part( + root=a2a_types.DataPart(data=data, metadata={"other": "metadata"}) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == json.dumps(data) + + def test_convert_data_part_no_metadata(self): + """Test conversion of A2A DataPart with no metadata to text.""" + # Arrange + data = {"key": "value", "array": [1, 2, 3]} + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data)) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == json.dumps(data) + + def test_convert_unsupported_file_type(self): + """Test handling of unsupported file types.""" + + # Arrange - Create a mock unsupported file type + class UnsupportedFileType: + pass + + # Create a part manually since FilePart validation might reject it + mock_file_part = Mock() + mock_file_part.file = UnsupportedFileType() + a2a_part = Mock() + a2a_part.root = mock_file_part + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + def test_convert_unsupported_part_type(self): + """Test handling of unsupported part types.""" + + # Arrange - Create a mock unsupported part type + class UnsupportedPartType: + pass + + mock_part = Mock() + mock_part.root = UnsupportedPartType() + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_a2a_part_to_genai_part(mock_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + +class TestConvertGenaiPartToA2aPart: + """Test cases for convert_genai_part_to_a2a_part function.""" + + def test_convert_text_part(self): + """Test conversion of GenAI text Part to A2A Part.""" + # Arrange + genai_part = genai_types.Part(text="Hello, world!") + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.TextPart) + assert result.text == "Hello, world!" + + def test_convert_file_data_part(self): + """Test conversion of GenAI file_data Part to A2A Part.""" + # Arrange + genai_part = genai_types.Part( + file_data=genai_types.FileData( + file_uri="gs://bucket/file.txt", mime_type="text/plain" + ) + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.FilePart) + assert isinstance(result.file, a2a_types.FileWithUri) + assert result.file.uri == "gs://bucket/file.txt" + assert result.file.mimeType == "text/plain" + + def test_convert_inline_data_part(self): + """Test conversion of GenAI inline_data Part to A2A Part.""" + # Arrange + test_bytes = b"test file content" + genai_part = genai_types.Part( + inline_data=genai_types.Blob(data=test_bytes, mime_type="text/plain") + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithBytes) + # A2A FileWithBytes stores bytes as strings + assert result.root.file.bytes == test_bytes.decode("utf-8") + assert result.root.file.mimeType == "text/plain" + + def test_convert_function_call_part(self): + """Test conversion of GenAI function_call Part to A2A Part.""" + # Arrange + function_call = genai_types.FunctionCall( + name="test_function", args={"param1": "value1", "param2": 42} + ) + genai_part = genai_types.Part(function_call=function_call) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = function_call.model_dump(by_alias=True, exclude_none=True) + assert result.root.data == expected_data + assert ( + result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + + def test_convert_function_response_part(self): + """Test conversion of GenAI function_response Part to A2A Part.""" + # Arrange + function_response = genai_types.FunctionResponse( + name="test_function", response={"result": "success", "data": [1, 2, 3]} + ) + genai_part = genai_types.Part(function_response=function_response) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = function_response.model_dump( + by_alias=True, exclude_none=True + ) + assert result.root.data == expected_data + assert ( + result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + + def test_convert_unsupported_part(self): + """Test handling of unsupported GenAI Part types.""" + # Arrange - Create a GenAI Part with no recognized fields + genai_part = genai_types.Part() + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + +class TestRoundTripConversions: + """Test cases for round-trip conversions to ensure consistency.""" + + def test_text_part_round_trip(self): + """Test round-trip conversion for text parts.""" + # Arrange + original_text = "Hello, world!" + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text=original_text)) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.TextPart) + assert result_a2a_part.text == original_text + + def test_file_uri_round_trip(self): + """Test round-trip conversion for file parts with URI.""" + # Arrange + original_uri = "gs://bucket/file.txt" + original_mime_type = "text/plain" + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=original_uri, mimeType=original_mime_type + ) + ) + ) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.FilePart) + assert isinstance(result_a2a_part.file, a2a_types.FileWithUri) + assert result_a2a_part.file.uri == original_uri + assert result_a2a_part.file.mimeType == original_mime_type + + +class TestEdgeCases: + """Test cases for edge cases and error conditions.""" + + def test_empty_text_part(self): + """Test conversion of empty text part.""" + # Arrange + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="")) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == "" + + def test_none_input_a2a_to_genai(self): + """Test handling of None input for A2A to GenAI conversion.""" + # This test depends on how the function handles None input + # If it should raise an exception, we test for that + with pytest.raises(AttributeError): + convert_a2a_part_to_genai_part(None) + + def test_none_input_genai_to_a2a(self): + """Test handling of None input for GenAI to A2A conversion.""" + # This test depends on how the function handles None input + # If it should raise an exception, we test for that + with pytest.raises(AttributeError): + convert_genai_part_to_a2a_part(None) + + def test_data_part_with_complex_data(self): + """Test conversion of DataPart with complex nested data.""" + # Arrange + complex_data = { + "nested": { + "array": [1, 2, {"inner": "value"}], + "boolean": True, + "null_value": None, + }, + "unicode": "Hello 世界 🌍", + } + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=complex_data)) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == json.dumps(complex_data) + + def test_data_part_with_empty_metadata(self): + """Test conversion of DataPart with empty metadata dict.""" + # Arrange + data = {"key": "value"} + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data, metadata={})) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == json.dumps(data) From 694b71256c631d44bb4c4488279ea91d82f43e26 Mon Sep 17 00:00:00 2001 From: SimonWei <119845914+simonwei97@users.noreply.github.com> Date: Tue, 17 Jun 2025 23:48:00 +0800 Subject: [PATCH 45/61] fix: agent generate config error (#1450) --- src/google/adk/models/lite_llm.py | 60 +++++++++++++++++++++----- tests/unittests/models/test_litellm.py | 33 +++++++++++++- 2 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index dce5ed7c4..e34299f6f 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,6 +23,7 @@ from typing import Dict from typing import Generator from typing import Iterable +from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -481,16 +482,22 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> tuple[Iterable[Message], Iterable[dict]]: - """Converts an LlmRequest to litellm inputs. +) -> Tuple[ + List[Message], + Optional[List[dict]], + Optional[types.SchemaUnion], + Optional[Dict], +]: + """Converts an LlmRequest to litellm inputs and extracts generation params. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary and response format). + The litellm inputs (message list, tool dictionary, response format and generation params). """ - messages = [] + # 1. Construct messages + messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -507,7 +514,8 @@ def _get_completion_inputs( ), ) - tools = None + # 2. Convert tool declarations + tools: Optional[List[Dict]] = None if ( llm_request.config and llm_request.config.tools @@ -518,12 +526,39 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - response_format = None - - if llm_request.config.response_schema: + # 3. Handle response format + response_format: Optional[types.SchemaUnion] = None + if llm_request.config and llm_request.config.response_schema: response_format = llm_request.config.response_schema - return messages, tools, response_format + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: + config_dict = llm_request.config.model_dump(exclude_none=True) + # Generate LiteLlm parameters here, + # Following https://docs.litellm.ai/docs/completion/input. + generation_params = {} + param_mapping = { + "max_output_tokens": "max_completion_tokens", + "stop_sequences": "stop", + } + for key in ( + "temperature", + "max_output_tokens", + "top_p", + "top_k", + "stop_sequences", + "presence_penalty", + "frequency_penalty", + ): + if key in config_dict: + mapped_key = param_mapping.get(key, key) + generation_params[mapped_key] = config_dict[key] + + if not generation_params: + generation_params = None + + return messages, tools, response_format, generation_params def _build_function_declaration_log( @@ -660,7 +695,9 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format = _get_completion_inputs(llm_request) + messages, tools, response_format, generation_params = ( + _get_completion_inputs(llm_request) + ) completion_args = { "model": self.model, @@ -670,6 +707,9 @@ async def generate_content_async( } completion_args.update(self._additional_args) + if generation_params: + completion_args.update(generation_params) + if stream: text = "" # Track function calls by index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 8b43cc48b..0125872fd 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,7 +13,6 @@ # limitations under the License. -import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -1430,3 +1429,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +def test_get_completion_inputs_generation_params(): + # Test that generation_params are extracted and mapped correctly + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]), + ], + config=types.GenerateContentConfig( + temperature=0.33, + max_output_tokens=123, + top_p=0.88, + top_k=7, + stop_sequences=["foo", "bar"], + presence_penalty=0.1, + frequency_penalty=0.2, + ), + ) + from google.adk.models.lite_llm import _get_completion_inputs + + _, _, _, generation_params = _get_completion_inputs(req) + assert generation_params["temperature"] == 0.33 + assert generation_params["max_completion_tokens"] == 123 + assert generation_params["top_p"] == 0.88 + assert generation_params["top_k"] == 7 + assert generation_params["stop"] == ["foo", "bar"] + assert generation_params["presence_penalty"] == 0.1 + assert generation_params["frequency_penalty"] == 0.2 + # Should not include max_output_tokens + assert "max_output_tokens" not in generation_params + assert "stop_sequences" not in generation_params From 1ae176ad2fa2b691714ac979aec21f1cf7d35e45 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Jun 2025 10:30:58 -0700 Subject: [PATCH 46/61] fix: update conversion between Celsius and Fahrenheit #non-breaking The correct conversion from 25 degrees Celsius is 77 degrees Fahrenheit. The previous value of 41 was wrong. PiperOrigin-RevId: 772528757 --- contributing/samples/quickstart/agent.py | 2 +- src/google/adk/models/lite_llm.py | 60 ++++-------------------- tests/unittests/models/test_litellm.py | 33 +------------ 3 files changed, 12 insertions(+), 83 deletions(-) diff --git a/contributing/samples/quickstart/agent.py b/contributing/samples/quickstart/agent.py index fdd6b7f9d..b251069ad 100644 --- a/contributing/samples/quickstart/agent.py +++ b/contributing/samples/quickstart/agent.py @@ -29,7 +29,7 @@ def get_weather(city: str) -> dict: "status": "success", "report": ( "The weather in New York is sunny with a temperature of 25 degrees" - " Celsius (41 degrees Fahrenheit)." + " Celsius (77 degrees Fahrenheit)." ), } else: diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index e34299f6f..dce5ed7c4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,7 +23,6 @@ from typing import Dict from typing import Generator from typing import Iterable -from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -482,22 +481,16 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> Tuple[ - List[Message], - Optional[List[dict]], - Optional[types.SchemaUnion], - Optional[Dict], -]: - """Converts an LlmRequest to litellm inputs and extracts generation params. +) -> tuple[Iterable[Message], Iterable[dict]]: + """Converts an LlmRequest to litellm inputs. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary, response format and generation params). + The litellm inputs (message list, tool dictionary and response format). """ - # 1. Construct messages - messages: List[Message] = [] + messages = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -514,8 +507,7 @@ def _get_completion_inputs( ), ) - # 2. Convert tool declarations - tools: Optional[List[Dict]] = None + tools = None if ( llm_request.config and llm_request.config.tools @@ -526,39 +518,12 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - # 3. Handle response format - response_format: Optional[types.SchemaUnion] = None - if llm_request.config and llm_request.config.response_schema: - response_format = llm_request.config.response_schema - - # 4. Extract generation parameters - generation_params: Optional[Dict] = None - if llm_request.config: - config_dict = llm_request.config.model_dump(exclude_none=True) - # Generate LiteLlm parameters here, - # Following https://docs.litellm.ai/docs/completion/input. - generation_params = {} - param_mapping = { - "max_output_tokens": "max_completion_tokens", - "stop_sequences": "stop", - } - for key in ( - "temperature", - "max_output_tokens", - "top_p", - "top_k", - "stop_sequences", - "presence_penalty", - "frequency_penalty", - ): - if key in config_dict: - mapped_key = param_mapping.get(key, key) - generation_params[mapped_key] = config_dict[key] + response_format = None - if not generation_params: - generation_params = None + if llm_request.config.response_schema: + response_format = llm_request.config.response_schema - return messages, tools, response_format, generation_params + return messages, tools, response_format def _build_function_declaration_log( @@ -695,9 +660,7 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format, generation_params = ( - _get_completion_inputs(llm_request) - ) + messages, tools, response_format = _get_completion_inputs(llm_request) completion_args = { "model": self.model, @@ -707,9 +670,6 @@ async def generate_content_async( } completion_args.update(self._additional_args) - if generation_params: - completion_args.update(generation_params) - if stream: text = "" # Track function calls by index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 0125872fd..8b43cc48b 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,6 +13,7 @@ # limitations under the License. +import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -1429,35 +1430,3 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} - - -@pytest.mark.asyncio -def test_get_completion_inputs_generation_params(): - # Test that generation_params are extracted and mapped correctly - req = LlmRequest( - contents=[ - types.Content(role="user", parts=[types.Part.from_text(text="hi")]), - ], - config=types.GenerateContentConfig( - temperature=0.33, - max_output_tokens=123, - top_p=0.88, - top_k=7, - stop_sequences=["foo", "bar"], - presence_penalty=0.1, - frequency_penalty=0.2, - ), - ) - from google.adk.models.lite_llm import _get_completion_inputs - - _, _, _, generation_params = _get_completion_inputs(req) - assert generation_params["temperature"] == 0.33 - assert generation_params["max_completion_tokens"] == 123 - assert generation_params["top_p"] == 0.88 - assert generation_params["top_k"] == 7 - assert generation_params["stop"] == ["foo", "bar"] - assert generation_params["presence_penalty"] == 0.1 - assert generation_params["frequency_penalty"] == 0.2 - # Should not include max_output_tokens - assert "max_output_tokens" not in generation_params - assert "stop_sequences" not in generation_params From c04adaade118be242fcf110e24e96253ac6550ab Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 10:45:10 -0700 Subject: [PATCH 47/61] chore: Add in memory credential service (Experimental) PiperOrigin-RevId: 772534962 --- .../in_memory_credential_service.py | 64 ++++ .../test_in_memory_credential_service.py | 323 ++++++++++++++++++ 2 files changed, 387 insertions(+) create mode 100644 src/google/adk/auth/credential_service/in_memory_credential_service.py create mode 100644 tests/unittests/auth/credential_service/test_in_memory_credential_service.py diff --git a/src/google/adk/auth/credential_service/in_memory_credential_service.py b/src/google/adk/auth/credential_service/in_memory_credential_service.py new file mode 100644 index 000000000..f6f51b35a --- /dev/null +++ b/src/google/adk/auth/credential_service/in_memory_credential_service.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from typing_extensions import override + +from ...tools.tool_context import ToolContext +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_tool import AuthConfig +from .base_credential_service import BaseCredentialService + + +@experimental +class InMemoryCredentialService(BaseCredentialService): + """Class for in memory implementation of credential service(Experimental)""" + + def __init__(self): + super().__init__() + self._credentials = {} + + @override + async def load_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> Optional[AuthCredential]: + credential_bucket = self._get_bucket_for_current_context(tool_context) + return credential_bucket.get(auth_config.credential_key) + + @override + async def save_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> None: + credential_bucket = self._get_bucket_for_current_context(tool_context) + credential_bucket[auth_config.credential_key] = ( + auth_config.exchanged_auth_credential + ) + + def _get_bucket_for_current_context(self, tool_context: ToolContext) -> str: + app_name = tool_context._invocation_context.app_name + user_id = tool_context._invocation_context.user_id + + if app_name not in self._credentials: + self._credentials[app_name] = {} + if user_id not in self._credentials[app_name]: + self._credentials[app_name][user_id] = {} + return self._credentials[app_name][user_id] diff --git a/tests/unittests/auth/credential_service/test_in_memory_credential_service.py b/tests/unittests/auth/credential_service/test_in_memory_credential_service.py new file mode 100644 index 000000000..9312f72a3 --- /dev/null +++ b/tests/unittests/auth/credential_service/test_in_memory_credential_service.py @@ -0,0 +1,323 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from google.adk.tools.tool_context import ToolContext +import pytest + + +class TestInMemoryCredentialService: + """Tests for the InMemoryCredentialService class.""" + + @pytest.fixture + def credential_service(self): + """Create an InMemoryCredentialService instance for testing.""" + return InMemoryCredentialService() + + @pytest.fixture + def oauth2_auth_scheme(self): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + @pytest.fixture + def oauth2_credentials(self): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + @pytest.fixture + def auth_config(self, oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + exchanged_credential = oauth2_credentials.model_copy(deep=True) + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged_credential, + ) + + @pytest.fixture + def tool_context(self): + """Create a mock ToolContext for testing.""" + mock_context = Mock(spec=ToolContext) + mock_invocation_context = Mock() + mock_invocation_context.app_name = "test_app" + mock_invocation_context.user_id = "test_user" + mock_context._invocation_context = mock_invocation_context + return mock_context + + @pytest.fixture + def another_tool_context(self): + """Create another mock ToolContext with different app/user for testing isolation.""" + mock_context = Mock(spec=ToolContext) + mock_invocation_context = Mock() + mock_invocation_context.app_name = "another_app" + mock_invocation_context.user_id = "another_user" + mock_context._invocation_context = mock_invocation_context + return mock_context + + def test_init(self, credential_service): + """Test that the service initializes with an empty store.""" + assert isinstance(credential_service._credentials, dict) + assert len(credential_service._credentials) == 0 + + @pytest.mark.asyncio + async def test_load_credential_not_found( + self, credential_service, auth_config, tool_context + ): + """Test loading a credential that doesn't exist returns None.""" + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_save_and_load_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving and then loading a credential.""" + # Save the credential + await credential_service.save_credential(auth_config, tool_context) + + # Load the credential + result = await credential_service.load_credential(auth_config, tool_context) + + # Verify the credential was saved and loaded correctly + assert result is not None + assert result == auth_config.exchanged_auth_credential + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.client_id == "mock_client_id" + + @pytest.mark.asyncio + async def test_save_credential_updates_existing( + self, credential_service, auth_config, tool_context, oauth2_credentials + ): + """Test that saving a credential updates an existing one.""" + # Save initial credential + await credential_service.save_credential(auth_config, tool_context) + + # Create a new credential and update the auth_config + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="updated_client_id", + client_secret="updated_client_secret", + redirect_uri="https://updated.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + + # Save the updated credential + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify the credential was updated + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "updated_client_id" + assert result.oauth2.client_secret == "updated_client_secret" + + @pytest.mark.asyncio + async def test_credentials_isolated_by_context( + self, credential_service, auth_config, tool_context, another_tool_context + ): + """Test that credentials are isolated between different app/user contexts.""" + # Save credential in first context + await credential_service.save_credential(auth_config, tool_context) + + # Try to load from another context + result = await credential_service.load_credential( + auth_config, another_tool_context + ) + assert result is None + + # Verify original context still has the credential + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + + @pytest.mark.asyncio + async def test_multiple_credentials_same_context( + self, credential_service, tool_context, oauth2_auth_scheme + ): + """Test storing multiple credentials in the same context with different keys.""" + # Create two different auth configs with different credential keys + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client1", + client_secret="secret1", + redirect_uri="https://example1.com/callback", + ), + ) + + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client2", + client_secret="secret2", + redirect_uri="https://example2.com/callback", + ), + ) + + auth_config1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred1, + exchanged_auth_credential=cred1, + credential_key="key1", + ) + + auth_config2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred2, + exchanged_auth_credential=cred2, + credential_key="key2", + ) + + # Save both credentials + await credential_service.save_credential(auth_config1, tool_context) + await credential_service.save_credential(auth_config2, tool_context) + + # Load and verify both credentials + result1 = await credential_service.load_credential( + auth_config1, tool_context + ) + result2 = await credential_service.load_credential( + auth_config2, tool_context + ) + + assert result1 is not None + assert result2 is not None + assert result1.oauth2.client_id == "client1" + assert result2.oauth2.client_id == "client2" + + def test_get_bucket_for_current_context_creates_nested_structure( + self, credential_service, tool_context + ): + """Test that _get_bucket_for_current_context creates the proper nested structure.""" + storage = credential_service._get_bucket_for_current_context(tool_context) + + # Verify the nested structure was created + assert "test_app" in credential_service._credentials + assert "test_user" in credential_service._credentials["test_app"] + assert isinstance(storage, dict) + assert storage is credential_service._credentials["test_app"]["test_user"] + + def test_get_bucket_for_current_context_reuses_existing( + self, credential_service, tool_context + ): + """Test that _get_bucket_for_current_context reuses existing structure.""" + # Create initial structure + storage1 = credential_service._get_bucket_for_current_context(tool_context) + storage1["test_key"] = "test_value" + + # Get storage again + storage2 = credential_service._get_bucket_for_current_context(tool_context) + + # Verify it's the same storage instance + assert storage1 is storage2 + assert storage2["test_key"] == "test_value" + + def test_get_storage_different_apps( + self, credential_service, tool_context, another_tool_context + ): + """Test that different apps get different storage instances.""" + storage1 = credential_service._get_bucket_for_current_context(tool_context) + storage2 = credential_service._get_bucket_for_current_context( + another_tool_context + ) + + # Verify they are different storage instances + assert storage1 is not storage2 + + # Verify the structure + assert "test_app" in credential_service._credentials + assert "another_app" in credential_service._credentials + assert "test_user" in credential_service._credentials["test_app"] + assert "another_user" in credential_service._credentials["another_app"] + + @pytest.mark.asyncio + async def test_same_user_different_apps( + self, credential_service, auth_config + ): + """Test that the same user in different apps get isolated storage.""" + # Create two contexts with same user but different apps + context1 = Mock(spec=ToolContext) + mock_invocation_context1 = Mock() + mock_invocation_context1.app_name = "app1" + mock_invocation_context1.user_id = "same_user" + context1._invocation_context = mock_invocation_context1 + + context2 = Mock(spec=ToolContext) + mock_invocation_context2 = Mock() + mock_invocation_context2.app_name = "app2" + mock_invocation_context2.user_id = "same_user" + context2._invocation_context = mock_invocation_context2 + + # Save credential in app1 + await credential_service.save_credential(auth_config, context1) + + # Try to load from app2 (should not find it) + result = await credential_service.load_credential(auth_config, context2) + assert result is None + + # Verify app1 still has the credential + result = await credential_service.load_credential(auth_config, context1) + assert result is not None + + @pytest.mark.asyncio + async def test_same_app_different_users( + self, credential_service, auth_config + ): + """Test that different users in the same app get isolated storage.""" + # Create two contexts with same app but different users + context1 = Mock(spec=ToolContext) + mock_invocation_context1 = Mock() + mock_invocation_context1.app_name = "same_app" + mock_invocation_context1.user_id = "user1" + context1._invocation_context = mock_invocation_context1 + + context2 = Mock(spec=ToolContext) + mock_invocation_context2 = Mock() + mock_invocation_context2.app_name = "same_app" + mock_invocation_context2.user_id = "user2" + context2._invocation_context = mock_invocation_context2 + + # Save credential for user1 + await credential_service.save_credential(auth_config, context1) + + # Try to load for user2 (should not find it) + result = await credential_service.load_credential(auth_config, context2) + assert result is None + + # Verify user1 still has the credential + result = await credential_service.load_credential(auth_config, context1) + assert result is not None From 6d174eba305a51fcf2122c0fd481378752d690ef Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Jun 2025 18:03:45 -0700 Subject: [PATCH 48/61] fix: Set explicit project in the BigQuery client This change sets an explicit project id in the BigQuery client from the conversation context. Without this the client was trying to set a project from the environment's application default credentials and running into issues where application default credentials is not available. PiperOrigin-RevId: 772695883 --- src/google/adk/tools/bigquery/client.py | 6 +- .../adk/tools/bigquery/metadata_tool.py | 16 ++- src/google/adk/tools/bigquery/query_tool.py | 4 +- .../tools/bigquery/test_bigquery_client.py | 125 ++++++++++++++++++ .../bigquery/test_bigquery_metadata_tool.py | 122 +++++++++++++++++ .../bigquery/test_bigquery_query_tool.py | 58 ++++++-- .../tools/bigquery/test_bigquery_toolset.py | 4 +- 7 files changed, 315 insertions(+), 20 deletions(-) create mode 100644 tests/unittests/tools/bigquery/test_bigquery_client.py create mode 100644 tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index d72761b2d..ea1bebc7a 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -21,13 +21,15 @@ USER_AGENT = "adk-bigquery-tool" -def get_bigquery_client(*, credentials: Credentials) -> bigquery.Client: +def get_bigquery_client( + *, project: str, credentials: Credentials +) -> bigquery.Client: """Get a BigQuery client.""" client_info = google.api_core.client_info.ClientInfo(user_agent=USER_AGENT) bigquery_client = bigquery.Client( - credentials=credentials, client_info=client_info + project=project, credentials=credentials, client_info=client_info ) return bigquery_client diff --git a/src/google/adk/tools/bigquery/metadata_tool.py b/src/google/adk/tools/bigquery/metadata_tool.py index 6e279d59e..4f5400611 100644 --- a/src/google/adk/tools/bigquery/metadata_tool.py +++ b/src/google/adk/tools/bigquery/metadata_tool.py @@ -42,7 +42,9 @@ def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]: 'bbc_news'] """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) datasets = [] for dataset in bq_client.list_datasets(project_id): @@ -106,7 +108,9 @@ def get_dataset_info( } """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) dataset = bq_client.get_dataset( bigquery.DatasetReference(project_id, dataset_id) ) @@ -137,7 +141,9 @@ def list_table_ids( 'local_data_for_better_health_county_data'] """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) tables = [] for table in bq_client.list_tables( @@ -251,7 +257,9 @@ def get_table_info( } """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) return bq_client.get_table( bigquery.TableReference( bigquery.DatasetReference(project_id, dataset_id), table_id diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 80b56aad3..d3a94fda7 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -72,7 +72,9 @@ def execute_sql( """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) if not config or config.write_mode == WriteMode.BLOCKED: query_job = bq_client.query( query, diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py new file mode 100644 index 000000000..612dddd6e --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -0,0 +1,125 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock + +from google.adk.tools.bigquery.client import get_bigquery_client +from google.auth.exceptions import DefaultCredentialsError +from google.oauth2.credentials import Credentials +import pytest + + +def test_bigquery_client_project(): + """Test BigQuery client project.""" + # Trigger the BigQuery client creation + client = get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the client has the desired project set + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_explicit(): + """Test BigQuery client creation does not invoke default auth.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_with_default_auth(): + """Test BigQuery client creation invokes default auth to set the project.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate credentials + mock_creds = mock.create_autospec(Credentials, instance=True) + + # Simulate output of the default auth + mock_default_auth.return_value = (mock_creds, "test-gcp-project") + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project=None, + credentials=mock_creds, + ) + + # Verify that default auth was called once to set the client project + mock_default_auth.assert_called_once() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_with_env(): + """Test BigQuery client creation sets the project from environment variable.""" + # Let's simulate the project set in environment variables + with mock.patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True + ): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project=None, + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_user_agent(): + """Test BigQuery client user agent.""" + with mock.patch( + "google.cloud.bigquery.client.Connection", autospec=True + ) as mock_connection: + # Trigger the BigQuery client creation + get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the tracking user agent was set + client_info_arg = mock_connection.call_args[1].get("client_info") + assert client_info_arg is not None + assert client_info_arg.user_agent == "adk-bigquery-tool" diff --git a/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py new file mode 100644 index 000000000..14ecea558 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py @@ -0,0 +1,122 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock + +from google.adk.tools.bigquery import metadata_tool +from google.auth.exceptions import DefaultCredentialsError +from google.cloud import bigquery +from google.oauth2.credentials import Credentials +import pytest + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.list_datasets", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_list_dataset_ids(mock_default_auth, mock_list_datasets): + """Test list_dataset_ids tool invocation.""" + project = "my_project_id" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_list_datasets.return_value = [ + bigquery.DatasetReference(project, "dataset1"), + bigquery.DatasetReference(project, "dataset2"), + ] + result = metadata_tool.list_dataset_ids(project, mock_credentials) + assert result == ["dataset1", "dataset2"] + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.get_dataset", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_get_dataset_info(mock_default_auth, mock_get_dataset): + """Test get_dataset_info tool invocation.""" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_get_dataset.return_value = mock.create_autospec( + Credentials, instance=True + ) + result = metadata_tool.get_dataset_info( + "my_project_id", "my_dataset_id", mock_credentials + ) + assert result != { + "status": "ERROR", + "error_details": "Your default credentials were not found", + } + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.list_tables", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_list_table_ids(mock_default_auth, mock_list_tables): + """Test list_table_ids tool invocation.""" + project = "my_project_id" + dataset = "my_dataset_id" + dataset_ref = bigquery.DatasetReference(project, dataset) + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_list_tables.return_value = [ + bigquery.TableReference(dataset_ref, "table1"), + bigquery.TableReference(dataset_ref, "table2"), + ] + result = metadata_tool.list_table_ids(project, dataset, mock_credentials) + assert result == ["table1", "table2"] + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.get_table", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_get_table_info(mock_default_auth, mock_get_table): + """Test get_table_info tool invocation.""" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_get_table.return_value = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_table_info( + "my_project_id", "my_dataset_id", "my_table_id", mock_credentials + ) + assert result != { + "status": "ERROR", + "error_details": "Your default credentials were not found", + } + mock_default_auth.assert_not_called() diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index 35d44ef81..3cb8c3c4a 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import os import textwrap from typing import Optional from unittest import mock @@ -24,6 +25,7 @@ from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode from google.adk.tools.bigquery.query_tool import execute_sql +from google.auth.exceptions import DefaultCredentialsError from google.cloud import bigquery from google.oauth2.credentials import Credentials import pytest @@ -227,14 +229,8 @@ async def test_execute_sql_declaration_write(tool_config): @pytest.mark.parametrize( ("write_mode",), [ - pytest.param( - WriteMode.BLOCKED, - id="blocked", - ), - pytest.param( - WriteMode.ALLOWED, - id="allowed", - ), + pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.ALLOWED, id="allowed"), ], ) def test_execute_sql_select_stmt(write_mode): @@ -279,7 +275,7 @@ def test_execute_sql_select_stmt(write_mode): ], ) def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): - """Test execute_sql tool for SELECT query when writes are blocked.""" + """Test execute_sql tool for non-SELECT query when writes are blocked.""" project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) @@ -318,7 +314,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): ], ) def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): - """Test execute_sql tool for SELECT query when writes are blocked.""" + """Test execute_sql tool for non-SELECT query when writes are blocked.""" project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) @@ -342,3 +338,45 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): "status": "ERROR", "error_details": "Read-only mode only supports SELECT statements.", } + + +@pytest.mark.parametrize( + ("write_mode",), + [ + pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.ALLOWED, id="allowed"), + ], +) +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True) +@mock.patch("google.cloud.bigquery.Client.query", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_execute_sql_no_default_auth( + mock_default_auth, mock_query, mock_query_and_wait, write_mode +): + """Test execute_sql tool invocation does not involve calling default auth.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + query_result = [{"num": 123}] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=write_mode) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + mock_query.return_value = query_job + + # Simulate the result of query_and_wait API + mock_query_and_wait.return_value = query_result + + # Test the tool worked without invoking default auth + result = execute_sql(project, query, credentials, tool_config) + assert result == {"status": "SUCCESS", "rows": query_result} + mock_default_auth.assert_not_called() diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index ea9990b9f..4129dc512 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -96,9 +96,7 @@ async def test_bigquery_toolset_tools_selective(selected_tools): ], ) @pytest.mark.asyncio -async def test_bigquery_toolset_unknown_tool_raises( - selected_tools, returned_tools -): +async def test_bigquery_toolset_unknown_tool(selected_tools, returned_tools): """Test BigQuery toolset with filter. This test verifies the behavior of the BigQuery toolset when filter is From 5f89a469ec6a9bad5ab8625e71d6b4d54046e2cd Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 18:08:05 -0700 Subject: [PATCH 49/61] chore: Add credential service to runner and invocation context PiperOrigin-RevId: 772697298 --- src/google/adk/agents/invocation_context.py | 2 ++ .../credential_service/base_credential_service.py | 4 ++-- src/google/adk/cli/cli.py | 10 ++++++++++ src/google/adk/cli/fast_api.py | 5 +++++ src/google/adk/runners.py | 7 ++++++- tests/unittests/cli/utils/test_cli.py | 11 ++++++++--- 6 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index f70371535..765f22a2c 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -22,6 +22,7 @@ from pydantic import ConfigDict from ..artifacts.base_artifact_service import BaseArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService from ..memory.base_memory_service import BaseMemoryService from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session @@ -115,6 +116,7 @@ class InvocationContext(BaseModel): artifact_service: Optional[BaseArtifactService] = None session_service: BaseSessionService memory_service: Optional[BaseMemoryService] = None + credential_service: Optional[BaseCredentialService] = None invocation_id: str """The id of this invocation context. Readonly.""" diff --git a/src/google/adk/auth/credential_service/base_credential_service.py b/src/google/adk/auth/credential_service/base_credential_service.py index 7416ccc65..fc6cd500d 100644 --- a/src/google/adk/auth/credential_service/base_credential_service.py +++ b/src/google/adk/auth/credential_service/base_credential_service.py @@ -19,12 +19,12 @@ from typing import Optional from ...tools.tool_context import ToolContext -from ...utils.feature_decorator import working_in_progress +from ...utils.feature_decorator import experimental from ..auth_credential import AuthCredential from ..auth_tool import AuthConfig -@working_in_progress("Implementation are in progress. Don't use it for now.") +@experimental class BaseCredentialService(ABC): """Abstract class for Service that loads / saves tool credentials from / to the backend credential store.""" diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index aceb3fcce..79d0bfe65 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -24,6 +24,8 @@ from ..agents.llm_agent import LlmAgent from ..artifacts import BaseArtifactService from ..artifacts import InMemoryArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService +from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService @@ -43,6 +45,7 @@ async def run_input_file( root_agent: LlmAgent, artifact_service: BaseArtifactService, session_service: BaseSessionService, + credential_service: BaseCredentialService, input_path: str, ) -> Session: runner = Runner( @@ -50,6 +53,7 @@ async def run_input_file( agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, ) with open(input_path, 'r', encoding='utf-8') as f: input_file = InputFile.model_validate_json(f.read()) @@ -75,12 +79,14 @@ async def run_interactively( artifact_service: BaseArtifactService, session: Session, session_service: BaseSessionService, + credential_service: BaseCredentialService, ) -> None: runner = Runner( app_name=session.app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, ) while True: query = input('[user]: ') @@ -125,6 +131,7 @@ async def run_cli( artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() + credential_service = InMemoryCredentialService() user_id = 'test_user' session = await session_service.create_session( @@ -141,6 +148,7 @@ async def run_cli( root_agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, input_path=input_file, ) elif saved_session_file: @@ -163,6 +171,7 @@ async def run_cli( artifact_service, session, session_service, + credential_service, ) else: click.echo(f'Running agent {root_agent.name}, type exit to exit.') @@ -171,6 +180,7 @@ async def run_cli( artifact_service, session, session_service, + credential_service, ) if save_session: diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4512174c5..46e008655 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -57,6 +57,7 @@ from ..agents.run_config import StreamingMode from ..artifacts.gcs_artifact_service import GcsArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService +from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..errors.not_found_error import NotFoundError from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import SessionInput @@ -305,6 +306,9 @@ async def internal_lifespan(app: FastAPI): else: artifact_service = InMemoryArtifactService() + # Build the Credential service + credential_service = InMemoryCredentialService() + # initialize Agent Loader agent_loader = AgentLoader(agents_dir) @@ -929,6 +933,7 @@ async def _get_runner_async(app_name: str) -> Runner: artifact_service=artifact_service, session_service=session_service, memory_service=memory_service, + credential_service=credential_service, ) runner_dict[app_name] = runner return runner diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index c4fcdfb9e..01412a2b3 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -17,7 +17,6 @@ import asyncio import logging import queue -import threading from typing import AsyncGenerator from typing import Generator from typing import Optional @@ -34,6 +33,7 @@ from .agents.run_config import RunConfig from .artifacts.base_artifact_service import BaseArtifactService from .artifacts.in_memory_artifact_service import InMemoryArtifactService +from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .events.event import Event from .memory.base_memory_service import BaseMemoryService @@ -73,6 +73,8 @@ class Runner: """The session service for the runner.""" memory_service: Optional[BaseMemoryService] = None """The memory service for the runner.""" + credential_service: Optional[BaseCredentialService] = None + """The credential service for the runner.""" def __init__( self, @@ -82,6 +84,7 @@ def __init__( artifact_service: Optional[BaseArtifactService] = None, session_service: BaseSessionService, memory_service: Optional[BaseMemoryService] = None, + credential_service: Optional[BaseCredentialService] = None, ): """Initializes the Runner. @@ -97,6 +100,7 @@ def __init__( self.artifact_service = artifact_service self.session_service = session_service self.memory_service = memory_service + self.credential_service = credential_service def run( self, @@ -418,6 +422,7 @@ def _new_invocation_context( artifact_service=self.artifact_service, session_service=self.session_service, memory_service=self.memory_service, + credential_service=self.credential_service, invocation_id=invocation_id, agent=self.agent, session=session, diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 1721885f3..2139a8c20 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -129,6 +129,7 @@ def _echo(msg: str) -> None: artifact_service = cli.InMemoryArtifactService() session_service = cli.InMemorySessionService() + credential_service = cli.InMemoryCredentialService() dummy_root = types.SimpleNamespace(name="root") session = await cli.run_input_file( @@ -137,6 +138,7 @@ def _echo(msg: str) -> None: root_agent=dummy_root, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, input_path=str(input_path), ) @@ -199,9 +201,10 @@ async def test_run_interactively_whitespace_and_exit( ) -> None: """run_interactively should skip blank input, echo once, then exit.""" # make a session that belongs to dummy agent - svc = cli.InMemorySessionService() - sess = await svc.create_session(app_name="dummy", user_id="u") + session_service = cli.InMemorySessionService() + sess = await session_service.create_session(app_name="dummy", user_id="u") artifact_service = cli.InMemoryArtifactService() + credential_service = cli.InMemoryCredentialService() root_agent = types.SimpleNamespace(name="root") # fake user input: blank -> 'hello' -> 'exit' @@ -212,7 +215,9 @@ async def test_run_interactively_whitespace_and_exit( echoed: list[str] = [] monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg)) - await cli.run_interactively(root_agent, artifact_service, sess, svc) + await cli.run_interactively( + root_agent, artifact_service, sess, session_service, credential_service + ) # verify: assistant echoed once with 'echo:hello' assert any("echo:hello" in m for m in echoed) From f9fa7841df81bcfc38d11a3d059c3d02f8ec3794 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Jun 2025 18:30:51 -0700 Subject: [PATCH 50/61] chore: add google-adk/{version} to bigquery user agent PiperOrigin-RevId: 772703504 --- src/google/adk/tools/bigquery/client.py | 4 +++- tests/unittests/tools/bigquery/test_bigquery_client.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index ea1bebc7a..23f1befc5 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -18,7 +18,9 @@ from google.cloud import bigquery from google.oauth2.credentials import Credentials -USER_AGENT = "adk-bigquery-tool" +from ... import version + +USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" def get_bigquery_client( diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py index 612dddd6e..e8b373416 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_client.py +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -14,6 +14,7 @@ from __future__ import annotations import os +import re from unittest import mock from google.adk.tools.bigquery.client import get_bigquery_client @@ -122,4 +123,7 @@ def test_bigquery_client_user_agent(): # Verify that the tracking user agent was set client_info_arg = mock_connection.call_args[1].get("client_info") assert client_info_arg is not None - assert client_info_arg.user_agent == "adk-bigquery-tool" + assert re.search( + r"adk-bigquery-tool google-adk/([0-9A-Za-z._\-+/]+)", + client_info_arg.user_agent, + ) From 0a9625317a7a511cae39fd566625e98dfab24486 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 18:55:11 -0700 Subject: [PATCH 51/61] refactor: Adapt service account credential exchanger to base credential exchanger interface PiperOrigin-RevId: 772710438 --- src/google/adk/auth/exchanger/__init__.py | 2 - .../service_account_credential_exchanger.py | 70 +++--- tests/unittests/auth/exchanger/__init__.py | 15 ++ ...st_service_account_credential_exchanger.py | 202 +++++++++++++----- 4 files changed, 203 insertions(+), 86 deletions(-) rename src/google/adk/auth/{ => exchanger}/service_account_credential_exchanger.py (57%) create mode 100644 tests/unittests/auth/exchanger/__init__.py rename tests/unittests/auth/{ => exchanger}/test_service_account_credential_exchanger.py (61%) diff --git a/src/google/adk/auth/exchanger/__init__.py b/src/google/adk/auth/exchanger/__init__.py index ce5c464c4..4226ae715 100644 --- a/src/google/adk/auth/exchanger/__init__.py +++ b/src/google/adk/auth/exchanger/__init__.py @@ -15,11 +15,9 @@ """Credential exchanger module.""" from .base_credential_exchanger import BaseCredentialExchanger -from .credential_exchanger_registry import CredentialExchangerRegistry from .service_account_credential_exchanger import ServiceAccountCredentialExchanger __all__ = [ "BaseCredentialExchanger", - "CredentialExchangerRegistry", "ServiceAccountCredentialExchanger", ] diff --git a/src/google/adk/auth/service_account_credential_exchanger.py b/src/google/adk/auth/exchanger/service_account_credential_exchanger.py similarity index 57% rename from src/google/adk/auth/service_account_credential_exchanger.py rename to src/google/adk/auth/exchanger/service_account_credential_exchanger.py index 644501ee6..415081ca5 100644 --- a/src/google/adk/auth/service_account_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/service_account_credential_exchanger.py @@ -16,19 +16,22 @@ from __future__ import annotations +from typing import Optional + import google.auth from google.auth.transport.requests import Request from google.oauth2 import service_account +from typing_extensions import override -from ..utils.feature_decorator import experimental -from .auth_credential import AuthCredential -from .auth_credential import AuthCredentialTypes -from .auth_credential import HttpAuth -from .auth_credential import HttpCredentials +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_credential import AuthCredentialTypes +from ..auth_schemes import AuthScheme +from .base_credential_exchanger import BaseCredentialExchanger @experimental -class ServiceAccountCredentialExchanger: +class ServiceAccountCredentialExchanger(BaseCredentialExchanger): """Exchanges Google Service Account credentials for an access token. Uses the default service credential if `use_default_credential = True`. @@ -36,44 +39,56 @@ class ServiceAccountCredentialExchanger: credential. """ - def __init__(self, credential: AuthCredential): - if credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: - raise ValueError("Credential is not a service account credential.") - self._credential = credential - - def exchange(self) -> AuthCredential: + @override + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: """Exchanges the service account auth credential for an access token. If the AuthCredential contains a service account credential, it will be used to exchange for an access token. Otherwise, if use_default_credential is True, the default application credential will be used for exchanging an access token. + Args: + auth_scheme: The authentication scheme. + auth_credential: The credential to exchange. + Returns: - An AuthCredential in HTTP Bearer format, containing the access token. + An AuthCredential in OAUTH2 format, containing the exchanged credential JSON. Raises: ValueError: If service account credentials are missing or invalid. Exception: If credential exchange or refresh fails. """ + if auth_credential is None: + raise ValueError("Credential cannot be None.") + + if auth_credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: + raise ValueError("Credential is not a service account credential.") + + if auth_credential.service_account is None: + raise ValueError( + "Service account credentials are missing. Please provide them." + ) + if ( - self._credential is None - or self._credential.service_account is None - or ( - self._credential.service_account.service_account_credential is None - and not self._credential.service_account.use_default_credential - ) + auth_credential.service_account.service_account_credential is None + and not auth_credential.service_account.use_default_credential ): raise ValueError( - "Service account credentials are missing. Please provide them, or set" - " `use_default_credential = True` to use application default" - " credential in a hosted service like Google Cloud Run." + "Service account credentials are invalid. Please set the" + " service_account_credential field or set `use_default_credential =" + " True` to use application default credential in a hosted service" + " like Google Cloud Run." ) try: - if self._credential.service_account.use_default_credential: + if auth_credential.service_account.use_default_credential: credentials, _ = google.auth.default() else: - config = self._credential.service_account + config = auth_credential.service_account credentials = service_account.Credentials.from_service_account_info( config.service_account_credential.model_dump(), scopes=config.scopes ) @@ -82,11 +97,8 @@ def exchange(self) -> AuthCredential: credentials.refresh(Request()) return AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", - credentials=HttpCredentials(token=credentials.token), - ), + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json=credentials.to_json(), ) except Exception as e: raise ValueError(f"Failed to exchange service account token: {e}") from e diff --git a/tests/unittests/auth/exchanger/__init__.py b/tests/unittests/auth/exchanger/__init__.py new file mode 100644 index 000000000..5fb8a262b --- /dev/null +++ b/tests/unittests/auth/exchanger/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for credential exchanger.""" diff --git a/tests/unittests/auth/test_service_account_credential_exchanger.py b/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py similarity index 61% rename from tests/unittests/auth/test_service_account_credential_exchanger.py rename to tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py index a5c668436..195e143d3 100644 --- a/tests/unittests/auth/test_service_account_credential_exchanger.py +++ b/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py @@ -17,19 +17,20 @@ from unittest.mock import MagicMock from unittest.mock import patch +from fastapi.openapi.models import HTTPBearer from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import ServiceAccount from google.adk.auth.auth_credential import ServiceAccountCredential -from google.adk.auth.service_account_credential_exchanger import ServiceAccountCredentialExchanger +from google.adk.auth.exchanger.service_account_credential_exchanger import ServiceAccountCredentialExchanger import pytest class TestServiceAccountCredentialExchanger: """Test cases for ServiceAccountCredentialExchanger.""" - def test_init_valid_credential(self): - """Test successful initialization with valid service account credential.""" + def test_exchange_with_valid_credential(self): + """Test successful exchange with valid service account credential.""" credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( @@ -55,26 +56,36 @@ def test_init_valid_credential(self): ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - assert exchanger._credential == credential + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() - def test_init_invalid_credential_type(self): - """Test initialization with invalid credential type raises ValueError.""" + # This should not raise an exception + assert exchanger is not None + + @pytest.mark.asyncio + async def test_exchange_invalid_credential_type(self): + """Test exchange with invalid credential type raises ValueError.""" credential = AuthCredential( auth_type=AuthCredentialTypes.API_KEY, api_key="test-key", ) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + with pytest.raises( ValueError, match="Credential is not a service account credential" ): - ServiceAccountCredentialExchanger(credential) + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) @patch( - "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" ) - @patch("google.adk.auth.service_account_credential_exchanger.Request") - def test_exchange_with_explicit_credentials_success( + async def test_exchange_with_explicit_credentials_success( self, mock_request_class, mock_from_service_account_info ): """Test successful exchange with explicit service account credentials.""" @@ -84,6 +95,9 @@ def test_exchange_with_explicit_credentials_success( mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" + mock_credentials.to_json.return_value = ( + '{"token": "mock_access_token", "type": "authorized_user"}' + ) mock_from_service_account_info.return_value = mock_credentials # Create test credential @@ -113,13 +127,20 @@ def test_exchange_with_explicit_credentials_success( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - result = exchanger.exchange() + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + result = await exchanger.exchange(credential, auth_scheme) # Verify the result - assert result.auth_type == AuthCredentialTypes.HTTP - assert result.http.scheme == "bearer" - assert result.http.credentials.token == "mock_access_token" + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.google_oauth2_json is not None + # Verify that google_oauth2_json contains the token + import json + + exchanged_creds = json.loads(result.google_oauth2_json) + assert exchanged_creds.get( + "token" + ) == "mock_access_token" or "mock_access_token" in str(exchanged_creds) # Verify mocks were called correctly mock_from_service_account_info.assert_called_once_with( @@ -128,11 +149,14 @@ def test_exchange_with_explicit_credentials_success( ) mock_credentials.refresh.assert_called_once_with(mock_request) + @pytest.mark.asyncio @patch( - "google.adk.auth.service_account_credential_exchanger.google.auth.default" + "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" ) - @patch("google.adk.auth.service_account_credential_exchanger.Request") - def test_exchange_with_default_credentials_success( + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" + ) + async def test_exchange_with_default_credentials_success( self, mock_request_class, mock_google_auth_default ): """Test successful exchange with default application credentials.""" @@ -142,6 +166,9 @@ def test_exchange_with_default_credentials_success( mock_credentials = MagicMock() mock_credentials.token = "default_access_token" + mock_credentials.to_json.return_value = ( + '{"token": "default_access_token", "type": "authorized_user"}' + ) mock_google_auth_default.return_value = (mock_credentials, "test-project") # Create test credential with use_default_credential=True @@ -153,33 +180,45 @@ def test_exchange_with_default_credentials_success( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - result = exchanger.exchange() + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + result = await exchanger.exchange(credential, auth_scheme) # Verify the result - assert result.auth_type == AuthCredentialTypes.HTTP - assert result.http.scheme == "bearer" - assert result.http.credentials.token == "default_access_token" + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.google_oauth2_json is not None + # Verify that google_oauth2_json contains the token + import json + + exchanged_creds = json.loads(result.google_oauth2_json) + assert exchanged_creds.get( + "token" + ) == "default_access_token" or "default_access_token" in str( + exchanged_creds + ) # Verify mocks were called correctly mock_google_auth_default.assert_called_once() mock_credentials.refresh.assert_called_once_with(mock_request) - def test_exchange_missing_service_account(self): + @pytest.mark.asyncio + async def test_exchange_missing_service_account(self): """Test exchange fails when service_account is None.""" credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=None, ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Service account credentials are missing" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) - def test_exchange_missing_credentials_and_not_default(self): + @pytest.mark.asyncio + async def test_exchange_missing_credentials_and_not_default(self): """Test exchange fails when credentials are missing and use_default_credential is False.""" credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, @@ -190,17 +229,19 @@ def test_exchange_missing_credentials_and_not_default(self): ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( - ValueError, match="Service account credentials are missing" + ValueError, match="Service account credentials are invalid" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio @patch( - "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" ) - def test_exchange_credential_creation_failure( + async def test_exchange_credential_creation_failure( self, mock_from_service_account_info ): """Test exchange handles credential creation failure gracefully.""" @@ -234,17 +275,21 @@ def test_exchange_credential_creation_failure( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Failed to exchange service account token" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio @patch( - "google.adk.auth.service_account_credential_exchanger.google.auth.default" + "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" ) - def test_exchange_default_credential_failure(self, mock_google_auth_default): + async def test_exchange_default_credential_failure( + self, mock_google_auth_default + ): """Test exchange handles default credential failure gracefully.""" # Setup mock to raise exception mock_google_auth_default.side_effect = Exception( @@ -260,18 +305,22 @@ def test_exchange_default_credential_failure(self, mock_google_auth_default): ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Failed to exchange service account token" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + @pytest.mark.asyncio + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + ) @patch( - "google.adk.auth.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" ) - @patch("google.adk.auth.service_account_credential_exchanger.Request") - def test_exchange_refresh_failure( + async def test_exchange_refresh_failure( self, mock_request_class, mock_from_service_account_info ): """Test exchange handles credential refresh failure gracefully.""" @@ -312,30 +361,73 @@ def test_exchange_refresh_failure( ), ) - exchanger = ServiceAccountCredentialExchanger(credential) + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() with pytest.raises( ValueError, match="Failed to exchange service account token" ): - exchanger.exchange() + await exchanger.exchange(credential, auth_scheme) + + @pytest.mark.asyncio + async def test_exchange_none_credential_in_constructor(self): + """Test that passing None credential raises appropriate error during exchange.""" + # This test verifies behavior when credential is None + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + + with pytest.raises(ValueError, match="Credential cannot be None"): + await exchanger.exchange(None, auth_scheme) - def test_exchange_none_credential_in_constructor(self): - """Test that passing None credential raises appropriate error during construction.""" - # This test verifies behavior when _credential is None, though this shouldn't - # happen in normal usage due to constructor validation + @pytest.mark.asyncio + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" + ) + @patch( + "google.adk.auth.exchanger.service_account_credential_exchanger.Request" + ) + async def test_exchange_with_service_account_no_explicit_credentials( + self, mock_request_class, mock_google_auth_default + ): + """Test exchange with service account that has no explicit credentials uses default.""" + # Setup mocks + mock_request = MagicMock() + mock_request_class.return_value = mock_request + + mock_credentials = MagicMock() + mock_credentials.token = "default_access_token" + mock_credentials.to_json.return_value = ( + '{"token": "default_access_token", "type": "authorized_user"}' + ) + mock_google_auth_default.return_value = (mock_credentials, "test-project") + + # Create test credential with no explicit credentials but use_default_credential=True credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( + service_account_credential=None, use_default_credential=True, scopes=["https://www.googleapis.com/auth/cloud-platform"], ), ) - exchanger = ServiceAccountCredentialExchanger(credential) - # Manually set to None to test the validation logic - exchanger._credential = None + auth_scheme = HTTPBearer() + exchanger = ServiceAccountCredentialExchanger() + result = await exchanger.exchange(credential, auth_scheme) - with pytest.raises( - ValueError, match="Service account credentials are missing" - ): - exchanger.exchange() + # Verify the result + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.google_oauth2_json is not None + # Verify that google_oauth2_json contains the token + import json + + exchanged_creds = json.loads(result.google_oauth2_json) + assert exchanged_creds.get( + "token" + ) == "default_access_token" or "default_access_token" in str( + exchanged_creds + ) + + # Verify mocks were called correctly + mock_google_auth_default.assert_called_once() + mock_credentials.refresh.assert_called_once_with(mock_request) From 55201cb6a1d59674e9aea1d25da37c6edbb7e0c7 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 19:08:45 -0700 Subject: [PATCH 52/61] chore: Add credential exchanger registry (Experimentals) PiperOrigin-RevId: 772713412 --- .../credential_exchanger_registry.py | 58 +++++ .../test_credential_exchanger_registry.py | 242 ++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 src/google/adk/auth/exchanger/credential_exchanger_registry.py create mode 100644 tests/unittests/auth/exchanger/test_credential_exchanger_registry.py diff --git a/src/google/adk/auth/exchanger/credential_exchanger_registry.py b/src/google/adk/auth/exchanger/credential_exchanger_registry.py new file mode 100644 index 000000000..5af7f3c1a --- /dev/null +++ b/src/google/adk/auth/exchanger/credential_exchanger_registry.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential exchanger registry.""" + +from __future__ import annotations + +from typing import Dict +from typing import Optional + +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredentialTypes +from .base_credential_exchanger import BaseCredentialExchanger + + +@experimental +class CredentialExchangerRegistry: + """Registry for credential exchanger instances.""" + + def __init__(self): + self._exchangers: Dict[AuthCredentialTypes, BaseCredentialExchanger] = {} + + def register( + self, + credential_type: AuthCredentialTypes, + exchanger_instance: BaseCredentialExchanger, + ) -> None: + """Register an exchanger instance for a credential type. + + Args: + credential_type: The credential type to register for. + exchanger_instance: The exchanger instance to register. + """ + self._exchangers[credential_type] = exchanger_instance + + def get_exchanger( + self, credential_type: AuthCredentialTypes + ) -> Optional[BaseCredentialExchanger]: + """Get the exchanger instance for a credential type. + + Args: + credential_type: The credential type to get exchanger for. + + Returns: + The exchanger instance if registered, None otherwise. + """ + return self._exchangers.get(credential_type) diff --git a/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py b/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py new file mode 100644 index 000000000..66b858232 --- /dev/null +++ b/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py @@ -0,0 +1,242 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the CredentialExchangerRegistry.""" + +from typing import Optional +from unittest.mock import MagicMock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.exchanger.base_credential_exchanger import BaseCredentialExchanger +from google.adk.auth.exchanger.credential_exchanger_registry import CredentialExchangerRegistry +import pytest + + +class MockCredentialExchanger(BaseCredentialExchanger): + """Mock credential exchanger for testing.""" + + def __init__(self, exchange_result: Optional[AuthCredential] = None): + self.exchange_result = exchange_result or AuthCredential( + auth_type=AuthCredentialTypes.HTTP + ) + + def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Mock exchange method.""" + return self.exchange_result + + +class TestCredentialExchangerRegistry: + """Test cases for CredentialExchangerRegistry.""" + + def test_initialization(self): + """Test that the registry initializes with an empty exchangers dictionary.""" + registry = CredentialExchangerRegistry() + + # Access the private attribute for testing + assert hasattr(registry, '_exchangers') + assert isinstance(registry._exchangers, dict) + assert len(registry._exchangers) == 0 + + def test_register_single_exchanger(self): + """Test registering a single exchanger.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + # Verify the exchanger was registered + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert retrieved_exchanger is mock_exchanger + + def test_register_multiple_exchangers(self): + """Test registering multiple exchangers for different credential types.""" + registry = CredentialExchangerRegistry() + + api_key_exchanger = MockCredentialExchanger() + oauth2_exchanger = MockCredentialExchanger() + service_account_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.API_KEY, api_key_exchanger) + registry.register(AuthCredentialTypes.OAUTH2, oauth2_exchanger) + registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, service_account_exchanger + ) + + # Verify all exchangers were registered correctly + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is api_key_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.OAUTH2) is oauth2_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.SERVICE_ACCOUNT) + is service_account_exchanger + ) + + def test_register_overwrites_existing_exchanger(self): + """Test that registering an exchanger for an existing type overwrites the previous one.""" + registry = CredentialExchangerRegistry() + + first_exchanger = MockCredentialExchanger() + second_exchanger = MockCredentialExchanger() + + # Register first exchanger + registry.register(AuthCredentialTypes.API_KEY, first_exchanger) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is first_exchanger + ) + + # Register second exchanger for the same type + registry.register(AuthCredentialTypes.API_KEY, second_exchanger) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is second_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) + is not first_exchanger + ) + + def test_get_exchanger_returns_correct_instance(self): + """Test that get_exchanger returns the correct exchanger instance.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.HTTP, mock_exchanger) + + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.HTTP) + assert retrieved_exchanger is mock_exchanger + assert isinstance(retrieved_exchanger, BaseCredentialExchanger) + + def test_get_exchanger_nonexistent_type_returns_none(self): + """Test that get_exchanger returns None for non-existent credential types.""" + registry = CredentialExchangerRegistry() + + # Try to get an exchanger that was never registered + result = registry.get_exchanger(AuthCredentialTypes.OAUTH2) + assert result is None + + def test_get_exchanger_after_registration_and_removal(self): + """Test behavior when an exchanger is registered and then the registry is cleared indirectly.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + # Register exchanger + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + assert registry.get_exchanger(AuthCredentialTypes.API_KEY) is mock_exchanger + + # Clear the internal dictionary (simulating some edge case) + registry._exchangers.clear() + assert registry.get_exchanger(AuthCredentialTypes.API_KEY) is None + + def test_register_with_all_credential_types(self): + """Test registering exchangers for all available credential types.""" + registry = CredentialExchangerRegistry() + + exchangers = {} + credential_types = [ + AuthCredentialTypes.API_KEY, + AuthCredentialTypes.HTTP, + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + AuthCredentialTypes.SERVICE_ACCOUNT, + ] + + # Register an exchanger for each credential type + for cred_type in credential_types: + exchanger = MockCredentialExchanger() + exchangers[cred_type] = exchanger + registry.register(cred_type, exchanger) + + # Verify all exchangers can be retrieved + for cred_type in credential_types: + retrieved_exchanger = registry.get_exchanger(cred_type) + assert retrieved_exchanger is exchangers[cred_type] + + def test_register_with_mock_exchanger_using_magicmock(self): + """Test registering with a MagicMock exchanger.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MagicMock(spec=BaseCredentialExchanger) + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert retrieved_exchanger is mock_exchanger + + def test_registry_isolation(self): + """Test that different registry instances are isolated from each other.""" + registry1 = CredentialExchangerRegistry() + registry2 = CredentialExchangerRegistry() + + exchanger1 = MockCredentialExchanger() + exchanger2 = MockCredentialExchanger() + + # Register different exchangers in different registry instances + registry1.register(AuthCredentialTypes.API_KEY, exchanger1) + registry2.register(AuthCredentialTypes.API_KEY, exchanger2) + + # Verify isolation + assert registry1.get_exchanger(AuthCredentialTypes.API_KEY) is exchanger1 + assert registry2.get_exchanger(AuthCredentialTypes.API_KEY) is exchanger2 + assert ( + registry1.get_exchanger(AuthCredentialTypes.API_KEY) is not exchanger2 + ) + assert ( + registry2.get_exchanger(AuthCredentialTypes.API_KEY) is not exchanger1 + ) + + def test_exchanger_functionality_through_registry(self): + """Test that exchangers registered in the registry function correctly.""" + registry = CredentialExchangerRegistry() + + # Create a mock exchanger with specific return value + expected_result = AuthCredential(auth_type=AuthCredentialTypes.HTTP) + mock_exchanger = MockCredentialExchanger(exchange_result=expected_result) + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + # Get the exchanger and test its functionality + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + input_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY) + + result = retrieved_exchanger.exchange(input_credential) + assert result is expected_result + + def test_register_none_exchanger(self): + """Test that registering None as an exchanger works (edge case).""" + registry = CredentialExchangerRegistry() + + # This should work but return None when retrieved + registry.register(AuthCredentialTypes.API_KEY, None) + + result = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert result is None + + def test_internal_dictionary_structure(self): + """Test the internal structure of the registry.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.OAUTH2, mock_exchanger) + + # Verify internal dictionary structure + assert AuthCredentialTypes.OAUTH2 in registry._exchangers + assert registry._exchangers[AuthCredentialTypes.OAUTH2] is mock_exchanger + assert len(registry._exchangers) == 1 From a17ebe6ebd7b58fa86a90cceb7650c2b3187933d Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 17 Jun 2025 21:11:56 -0700 Subject: [PATCH 53/61] chore: Add a credential refresher registry PiperOrigin-RevId: 772747251 --- .../credential_refresher_registry.py | 59 ++++++ .../test_credential_refresher_registry.py | 174 ++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 src/google/adk/auth/refresher/credential_refresher_registry.py create mode 100644 tests/unittests/auth/refresher/test_credential_refresher_registry.py diff --git a/src/google/adk/auth/refresher/credential_refresher_registry.py b/src/google/adk/auth/refresher/credential_refresher_registry.py new file mode 100644 index 000000000..90975d66d --- /dev/null +++ b/src/google/adk/auth/refresher/credential_refresher_registry.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential refresher registry.""" + +from __future__ import annotations + +from typing import Dict +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.utils.feature_decorator import experimental + +from .base_credential_refresher import BaseCredentialRefresher + + +@experimental +class CredentialRefresherRegistry: + """Registry for credential refresher instances.""" + + def __init__(self): + self._refreshers: Dict[AuthCredentialTypes, BaseCredentialRefresher] = {} + + def register( + self, + credential_type: AuthCredentialTypes, + refresher_instance: BaseCredentialRefresher, + ) -> None: + """Register a refresher instance for a credential type. + + Args: + credential_type: The credential type to register for. + refresher_instance: The refresher instance to register. + """ + self._refreshers[credential_type] = refresher_instance + + def get_refresher( + self, credential_type: AuthCredentialTypes + ) -> Optional[BaseCredentialRefresher]: + """Get the refresher instance for a credential type. + + Args: + credential_type: The credential type to get refresher for. + + Returns: + The refresher instance if registered, None otherwise. + """ + return self._refreshers.get(credential_type) diff --git a/tests/unittests/auth/refresher/test_credential_refresher_registry.py b/tests/unittests/auth/refresher/test_credential_refresher_registry.py new file mode 100644 index 000000000..b00cc4da8 --- /dev/null +++ b/tests/unittests/auth/refresher/test_credential_refresher_registry.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CredentialRefresherRegistry.""" + +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.refresher.base_credential_refresher import BaseCredentialRefresher +from google.adk.auth.refresher.credential_refresher_registry import CredentialRefresherRegistry + + +class TestCredentialRefresherRegistry: + """Tests for the CredentialRefresherRegistry class.""" + + def test_init(self): + """Test that registry initializes with empty refreshers dictionary.""" + registry = CredentialRefresherRegistry() + assert registry._refreshers == {} + + def test_register_refresher(self): + """Test registering a refresher instance for a credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher) + + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher + + def test_register_multiple_refreshers(self): + """Test registering multiple refresher instances for different credential types.""" + registry = CredentialRefresherRegistry() + mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher) + mock_openid_refresher = Mock(spec=BaseCredentialRefresher) + mock_service_account_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher) + registry.register( + AuthCredentialTypes.OPEN_ID_CONNECT, mock_openid_refresher + ) + registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, mock_service_account_refresher + ) + + assert ( + registry._refreshers[AuthCredentialTypes.OAUTH2] + == mock_oauth2_refresher + ) + assert ( + registry._refreshers[AuthCredentialTypes.OPEN_ID_CONNECT] + == mock_openid_refresher + ) + assert ( + registry._refreshers[AuthCredentialTypes.SERVICE_ACCOUNT] + == mock_service_account_refresher + ) + + def test_register_overwrite_existing_refresher(self): + """Test that registering a refresher overwrites an existing one for the same credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher_1 = Mock(spec=BaseCredentialRefresher) + mock_refresher_2 = Mock(spec=BaseCredentialRefresher) + + # Register first refresher + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_1) + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_1 + + # Register second refresher for same credential type + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_2) + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_2 + + def test_get_refresher_existing(self): + """Test getting a refresher instance for a registered credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher) + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result == mock_refresher + + def test_get_refresher_non_existing(self): + """Test getting a refresher instance for a non-registered credential type returns None.""" + registry = CredentialRefresherRegistry() + + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result is None + + def test_get_refresher_after_registration(self): + """Test getting refresher instances for multiple credential types.""" + registry = CredentialRefresherRegistry() + mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher) + mock_api_key_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher) + registry.register(AuthCredentialTypes.API_KEY, mock_api_key_refresher) + + # Get registered refreshers + oauth2_result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + api_key_result = registry.get_refresher(AuthCredentialTypes.API_KEY) + + assert oauth2_result == mock_oauth2_refresher + assert api_key_result == mock_api_key_refresher + + # Get non-registered refresher + http_result = registry.get_refresher(AuthCredentialTypes.HTTP) + assert http_result is None + + def test_register_all_credential_types(self): + """Test registering refreshers for all available credential types.""" + registry = CredentialRefresherRegistry() + + refreshers = {} + for credential_type in AuthCredentialTypes: + mock_refresher = Mock(spec=BaseCredentialRefresher) + refreshers[credential_type] = mock_refresher + registry.register(credential_type, mock_refresher) + + # Verify all refreshers are registered correctly + for credential_type in AuthCredentialTypes: + result = registry.get_refresher(credential_type) + assert result == refreshers[credential_type] + + def test_empty_registry_get_refresher(self): + """Test getting refresher from empty registry returns None for any credential type.""" + registry = CredentialRefresherRegistry() + + for credential_type in AuthCredentialTypes: + result = registry.get_refresher(credential_type) + assert result is None + + def test_registry_independence(self): + """Test that multiple registry instances are independent.""" + registry1 = CredentialRefresherRegistry() + registry2 = CredentialRefresherRegistry() + + mock_refresher1 = Mock(spec=BaseCredentialRefresher) + mock_refresher2 = Mock(spec=BaseCredentialRefresher) + + registry1.register(AuthCredentialTypes.OAUTH2, mock_refresher1) + registry2.register(AuthCredentialTypes.OAUTH2, mock_refresher2) + + # Verify registries are independent + assert ( + registry1.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher1 + ) + assert ( + registry2.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher2 + ) + assert registry1.get_refresher( + AuthCredentialTypes.OAUTH2 + ) != registry2.get_refresher(AuthCredentialTypes.OAUTH2) + + def test_register_with_none_refresher(self): + """Test registering None as a refresher instance.""" + registry = CredentialRefresherRegistry() + + # This should technically work as the registry accepts any value + registry.register(AuthCredentialTypes.OAUTH2, None) + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result is None From 9a207cb832e86dd9fd643220139a0384388cdb6c Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 10:45:54 -0700 Subject: [PATCH 54/61] refactor: Refactor oauth2_credential_exchanger to exchanger and refresher separately PiperOrigin-RevId: 772979993 --- src/google/adk/auth/auth_handler.py | 19 +- src/google/adk/auth/auth_preprocessor.py | 6 +- .../exchanger/oauth2_credential_exchanger.py | 104 ++++++ .../adk/auth/oauth2_credential_fetcher.py | 132 -------- .../refresher/oauth2_credential_refresher.py | 154 +++++++++ .../integration_connector_tool.py | 4 +- .../openapi_spec_parser/rest_api_tool.py | 6 +- .../openapi_spec_parser/tool_auth_handler.py | 16 +- .../test_oauth2_credential_exchanger.py | 220 +++++++++++++ tests/unittests/auth/refresher/__init__.py | 13 + .../test_oauth2_credential_refresher.py | 297 ++++++++++++++++++ tests/unittests/auth/test_auth_handler.py | 72 +++-- .../auth/test_oauth2_credential_fetcher.py | 147 --------- .../test_integration_connector_tool.py | 40 +-- .../openapi_spec_parser/test_rest_api_tool.py | 17 +- .../test_tool_auth_handler.py | 47 +-- 16 files changed, 926 insertions(+), 368 deletions(-) create mode 100644 src/google/adk/auth/exchanger/oauth2_credential_exchanger.py delete mode 100644 src/google/adk/auth/oauth2_credential_fetcher.py create mode 100644 src/google/adk/auth/refresher/oauth2_credential_refresher.py create mode 100644 tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py create mode 100644 tests/unittests/auth/refresher/__init__.py create mode 100644 tests/unittests/auth/refresher/test_oauth2_credential_refresher.py delete mode 100644 tests/unittests/auth/test_oauth2_credential_fetcher.py diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 3e13cbac2..473f31413 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -22,7 +22,7 @@ from .auth_schemes import AuthSchemeType from .auth_schemes import OpenIdConnectWithConfig from .auth_tool import AuthConfig -from .oauth2_credential_fetcher import OAuth2CredentialFetcher +from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger if TYPE_CHECKING: from ..sessions.state import State @@ -36,18 +36,23 @@ class AuthHandler: + """A handler that handles the auth flow in Agent Development Kit to help + orchestrate the credential request and response flow (e.g. OAuth flow) + This class should only be used by Agent Development Kit. + """ def __init__(self, auth_config: AuthConfig): self.auth_config = auth_config - def exchange_auth_token( + async def exchange_auth_token( self, ) -> AuthCredential: - return OAuth2CredentialFetcher( - self.auth_config.auth_scheme, self.auth_config.exchanged_auth_credential - ).exchange() + exchanger = OAuth2CredentialExchanger() + return await exchanger.exchange( + self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme + ) - def parse_and_store_auth_response(self, state: State) -> None: + async def parse_and_store_auth_response(self, state: State) -> None: credential_key = "temp:" + self.auth_config.credential_key @@ -60,7 +65,7 @@ def parse_and_store_auth_response(self, state: State) -> None: ): return - state[credential_key] = self.exchange_auth_token() + state[credential_key] = await self.exchange_auth_token() def _validate(self) -> None: if not self.auth_scheme: diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 0c964ed96..b06774973 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -67,9 +67,9 @@ async def run_async( # function call request_euc_function_call_ids.add(function_call_response.id) auth_config = AuthConfig.model_validate(function_call_response.response) - AuthHandler(auth_config=auth_config).parse_and_store_auth_response( - state=invocation_context.session.state - ) + await AuthHandler( + auth_config=auth_config + ).parse_and_store_auth_response(state=invocation_context.session.state) break if not request_euc_function_call_ids: diff --git a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py new file mode 100644 index 000000000..768457e1a --- /dev/null +++ b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 credential exchanger implementation.""" + +from __future__ import annotations + +import logging +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import OAuthGrantType +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +from google.adk.utils.feature_decorator import experimental +from typing_extensions import override + +from .base_credential_exchanger import BaseCredentialExchanger +from .base_credential_exchanger import CredentialExchangError + +try: + from authlib.integrations.requests_client import OAuth2Session + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class OAuth2CredentialExchanger(BaseCredentialExchanger): + """Exchanges OAuth2 credentials from authorization responses.""" + + @override + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Exchange OAuth2 credential from authorization response. + if credential exchange failed, the original credential will be returned. + + Args: + auth_credential: The OAuth2 credential to exchange. + auth_scheme: The OAuth2 authentication scheme. + + Returns: + The exchanged credential with access token. + + Raises: + CredentialExchangError: If auth_scheme is missing. + """ + if not auth_scheme: + raise CredentialExchangError( + "auth_scheme is required for OAuth2 credential exchange" + ) + + if not AUTHLIB_AVIALABLE: + # If authlib is not available, we cannot exchange the credential. + # We return the original credential without exchange. + # The client using this tool can decide to exchange the credential + # themselves using other lib. + logger.warning( + "authlib is not available, skipping OAuth2 credential exchange." + ) + return auth_credential + + if auth_credential.oauth2 and auth_credential.oauth2.access_token: + return auth_credential + + client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential) + if not client: + logger.warning("Could not create OAuth2 session for token exchange") + return auth_credential + + try: + tokens = client.fetch_token( + token_endpoint, + authorization_response=auth_credential.oauth2.auth_response_uri, + code=auth_credential.oauth2.auth_code, + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + ) + update_credential_with_tokens(auth_credential, tokens) + logger.debug("Successfully exchanged OAuth2 tokens") + except Exception as e: + # TODO reconsider whether we should raise errors in this case + logger.error("Failed to exchange OAuth2 tokens: %s", e) + # Return original credential on failure + return auth_credential + + return auth_credential diff --git a/src/google/adk/auth/oauth2_credential_fetcher.py b/src/google/adk/auth/oauth2_credential_fetcher.py deleted file mode 100644 index c9e838b25..000000000 --- a/src/google/adk/auth/oauth2_credential_fetcher.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import logging - -from ..utils.feature_decorator import experimental -from .auth_credential import AuthCredential -from .auth_schemes import AuthScheme -from .auth_schemes import OAuthGrantType -from .oauth2_credential_util import create_oauth2_session -from .oauth2_credential_util import update_credential_with_tokens - -try: - from authlib.oauth2.rfc6749 import OAuth2Token - - AUTHLIB_AVIALABLE = True -except ImportError: - AUTHLIB_AVIALABLE = False - - -logger = logging.getLogger("google_adk." + __name__) - - -@experimental -class OAuth2CredentialFetcher: - """Exchanges and refreshes an OAuth2 access token. (Experimental)""" - - def __init__( - self, - auth_scheme: AuthScheme, - auth_credential: AuthCredential, - ): - self._auth_scheme = auth_scheme - self._auth_credential = auth_credential - - def _update_credential(self, tokens: OAuth2Token) -> None: - self._auth_credential.oauth2.access_token = tokens.get("access_token") - self._auth_credential.oauth2.refresh_token = tokens.get("refresh_token") - self._auth_credential.oauth2.expires_at = ( - int(tokens.get("expires_at")) if tokens.get("expires_at") else None - ) - self._auth_credential.oauth2.expires_in = ( - int(tokens.get("expires_in")) if tokens.get("expires_in") else None - ) - - def exchange(self) -> AuthCredential: - """Exchange an oauth token from the authorization response. - - Returns: - An AuthCredential object containing the access token. - """ - if not AUTHLIB_AVIALABLE: - return self._auth_credential - - if ( - self._auth_credential.oauth2 - and self._auth_credential.oauth2.access_token - ): - return self._auth_credential - - client, token_endpoint = create_oauth2_session( - self._auth_scheme, self._auth_credential - ) - if not client: - logger.warning("Could not create OAuth2 session for token exchange") - return self._auth_credential - - try: - tokens = client.fetch_token( - token_endpoint, - authorization_response=self._auth_credential.oauth2.auth_response_uri, - code=self._auth_credential.oauth2.auth_code, - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - ) - update_credential_with_tokens(self._auth_credential, tokens) - logger.info("Successfully exchanged OAuth2 tokens") - except Exception as e: - logger.error("Failed to exchange OAuth2 tokens: %s", e) - # Return original credential on failure - return self._auth_credential - - return self._auth_credential - - def refresh(self) -> AuthCredential: - """Refresh an oauth token. - - Returns: - An AuthCredential object containing the refreshed access token. - """ - if not AUTHLIB_AVIALABLE: - return self._auth_credential - credential = self._auth_credential - if not credential.oauth2: - return credential - - if OAuth2Token({ - "expires_at": credential.oauth2.expires_at, - "expires_in": credential.oauth2.expires_in, - }).is_expired(): - client, token_endpoint = create_oauth2_session( - self._auth_scheme, self._auth_credential - ) - if not client: - logger.warning("Could not create OAuth2 session for token refresh") - return credential - - try: - tokens = client.refresh_token( - url=token_endpoint, - refresh_token=credential.oauth2.refresh_token, - ) - update_credential_with_tokens(self._auth_credential, tokens) - logger.info("Successfully refreshed OAuth2 tokens") - except Exception as e: - logger.error("Failed to refresh OAuth2 tokens: %s", e) - # Return original credential on failure - return credential - - return self._auth_credential diff --git a/src/google/adk/auth/refresher/oauth2_credential_refresher.py b/src/google/adk/auth/refresher/oauth2_credential_refresher.py new file mode 100644 index 000000000..2d0a8b670 --- /dev/null +++ b/src/google/adk/auth/refresher/oauth2_credential_refresher.py @@ -0,0 +1,154 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 credential refresher implementation.""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +from google.adk.utils.feature_decorator import experimental +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from typing_extensions import override + +from .base_credential_refresher import BaseCredentialRefresher + +try: + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class OAuth2CredentialRefresher(BaseCredentialRefresher): + """Refreshes OAuth2 credentials including Google OAuth2 JSON credentials.""" + + @override + async def is_refresh_needed( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> bool: + """Check if the OAuth2 credential needs to be refreshed. + + Args: + auth_credential: The OAuth2 credential to check. + auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON). + + Returns: + True if the credential needs to be refreshed, False otherwise. + """ + # Handle Google OAuth2 credentials (from service account exchange) + if auth_credential.google_oauth2_json: + try: + google_credential = Credentials.from_authorized_user_info( + json.loads(auth_credential.google_oauth2_json) + ) + return google_credential.expired and bool( + google_credential.refresh_token + ) + except Exception as e: + logger.warning("Failed to parse Google OAuth2 JSON credential: %s", e) + return False + + # Handle regular OAuth2 credentials + elif auth_credential.oauth2 and auth_scheme: + if not AUTHLIB_AVIALABLE: + return False + + if not auth_credential.oauth2: + return False + + return OAuth2Token({ + "expires_at": auth_credential.oauth2.expires_at, + "expires_in": auth_credential.oauth2.expires_in, + }).is_expired() + + return False + + @override + async def refresh( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Refresh the OAuth2 credential. + If refresh failed, return the original credential. + + Args: + auth_credential: The OAuth2 credential to refresh. + auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON). + + Returns: + The refreshed credential. + + """ + # Handle Google OAuth2 credentials (from service account exchange) + if auth_credential.google_oauth2_json: + try: + google_credential = Credentials.from_authorized_user_info( + json.loads(auth_credential.google_oauth2_json) + ) + if google_credential.expired and google_credential.refresh_token: + google_credential.refresh(Request()) + auth_credential.google_oauth2_json = google_credential.to_json() + logger.info("Successfully refreshed Google OAuth2 JSON credential") + except Exception as e: + # TODO reconsider whether we should raise error when refresh failed. + logger.error("Failed to refresh Google OAuth2 JSON credential: %s", e) + + # Handle regular OAuth2 credentials + elif auth_credential.oauth2 and auth_scheme: + if not AUTHLIB_AVIALABLE: + return auth_credential + + if not auth_credential.oauth2: + return auth_credential + + if OAuth2Token({ + "expires_at": auth_credential.oauth2.expires_at, + "expires_in": auth_credential.oauth2.expires_in, + }).is_expired(): + client, token_endpoint = create_oauth2_session( + auth_scheme, auth_credential + ) + if not client: + logger.warning("Could not create OAuth2 session for token refresh") + return auth_credential + + try: + tokens = client.refresh_token( + url=token_endpoint, + refresh_token=auth_credential.oauth2.refresh_token, + ) + update_credential_with_tokens(auth_credential, tokens) + logger.debug("Successfully refreshed OAuth2 tokens") + except Exception as e: + # TODO reconsider whether we should raise error when refresh failed. + logger.error("Failed to refresh OAuth2 tokens: %s", e) + # Return original credential on failure + return auth_credential + + return auth_credential diff --git a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py index 4e5be5959..5a50a7f0c 100644 --- a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py +++ b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py @@ -150,7 +150,7 @@ async def run_async( tool_auth_handler = ToolAuthHandler.from_tool_context( tool_context, self._auth_scheme, self._auth_credential ) - auth_result = tool_auth_handler.prepare_auth_credentials() + auth_result = await tool_auth_handler.prepare_auth_credentials() if auth_result.state == 'pending': return { @@ -178,7 +178,7 @@ async def run_async( args['operation'] = self._operation args['action'] = self._action logger.info('Running tool: %s with args: %s', self.name, args) - return self._rest_api_tool.call(args=args, tool_context=tool_context) + return await self._rest_api_tool.call(args=args, tool_context=tool_context) def __str__(self): return ( diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 1e451fe0f..dee103932 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -345,9 +345,9 @@ def _prepare_request_params( async def run_async( self, *, args: dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: - return self.call(args=args, tool_context=tool_context) + return await self.call(args=args, tool_context=tool_context) - def call( + async def call( self, *, args: dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: """Executes the REST API call. @@ -364,7 +364,7 @@ def call( tool_auth_handler = ToolAuthHandler.from_tool_context( tool_context, self.auth_scheme, self.auth_credential ) - auth_result = tool_auth_handler.prepare_auth_credentials() + auth_result = await tool_auth_handler.prepare_auth_credentials() auth_state, auth_scheme, auth_credential = ( auth_result.state, auth_result.auth_scheme, diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index c36793fdc..08e535d28 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -25,7 +25,7 @@ from ....auth.auth_schemes import AuthScheme from ....auth.auth_schemes import AuthSchemeType from ....auth.auth_tool import AuthConfig -from ....auth.oauth2_credential_fetcher import OAuth2CredentialFetcher +from ....auth.refresher.oauth2_credential_refresher import OAuth2CredentialRefresher from ...tool_context import ToolContext from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError @@ -146,7 +146,7 @@ def from_tool_context( credential_store, ) - def _get_existing_credential( + async def _get_existing_credential( self, ) -> Optional[AuthCredential]: """Checks for and returns an existing, exchanged credential.""" @@ -156,9 +156,11 @@ def _get_existing_credential( ) if existing_credential: if existing_credential.oauth2: - existing_credential = OAuth2CredentialFetcher( - self.auth_scheme, existing_credential - ).refresh() + refresher = OAuth2CredentialRefresher() + if await refresher.is_refresh_needed(existing_credential): + existing_credential = await refresher.refresh( + existing_credential, self.auth_scheme + ) return existing_credential return None @@ -234,7 +236,7 @@ def _external_exchange_required(self, credential) -> bool: and not credential.google_oauth2_json ) - def prepare_auth_credentials( + async def prepare_auth_credentials( self, ) -> AuthPreparationResult: """Prepares authentication credentials, handling exchange and user interaction.""" @@ -244,7 +246,7 @@ def prepare_auth_credentials( return AuthPreparationResult(state="done") # Check for existing credential. - existing_credential = self._get_existing_credential() + existing_credential = await self._get_existing_credential() credential = existing_credential or self.auth_credential # fetch credential from adk framework diff --git a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py new file mode 100644 index 000000000..ef1dbbbee --- /dev/null +++ b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py @@ -0,0 +1,220 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock +from unittest.mock import patch + +from authlib.oauth2.rfc6749 import OAuth2Token +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.exchanger.base_credential_exchanger import CredentialExchangError +from google.adk.auth.exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger +import pytest + + +class TestOAuth2CredentialExchanger: + """Test suite for OAuth2CredentialExchanger.""" + + @pytest.mark.asyncio + async def test_exchange_with_existing_token(self): + """Test exchange method when access token already exists.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return the same credential since access token already exists + assert result == credential + assert result.oauth2.access_token == "existing_token" + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_success(self, mock_oauth2_session): + """Test successful token exchange.""" + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Verify token exchange was successful + assert result.oauth2.access_token == "new_access_token" + assert result.oauth2.refresh_token == "new_refresh_token" + mock_client.fetch_token.assert_called_once() + + @pytest.mark.asyncio + async def test_exchange_missing_auth_scheme(self): + """Test exchange with missing auth_scheme raises ValueError.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + exchanger = OAuth2CredentialExchanger() + try: + await exchanger.exchange(credential, None) + assert False, "Should have raised ValueError" + except CredentialExchangError as e: + assert "auth_scheme is required" in str(e) + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_no_session(self, mock_oauth2_session): + """Test exchange when OAuth2Session cannot be created.""" + # Mock to return None for create_oauth2_session + mock_oauth2_session.return_value = None + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret to trigger session creation failure + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when session creation fails + assert result == credential + assert result.oauth2.access_token is None + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_fetch_token_failure(self, mock_oauth2_session): + """Test exchange when fetch_token fails.""" + # Setup mock to raise exception during fetch_token + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_client.fetch_token.side_effect = Exception("Token fetch failed") + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when fetch_token fails + assert result == credential + assert result.oauth2.access_token is None + mock_client.fetch_token.assert_called_once() + + @pytest.mark.asyncio + async def test_exchange_authlib_not_available(self): + """Test exchange when authlib is not available.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + + # Mock AUTHLIB_AVIALABLE to False + with patch( + "google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVIALABLE", + False, + ): + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when authlib is not available + assert result == credential + assert result.oauth2.access_token is None diff --git a/tests/unittests/auth/refresher/__init__.py b/tests/unittests/auth/refresher/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/refresher/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py new file mode 100644 index 000000000..b22bf2ccd --- /dev/null +++ b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py @@ -0,0 +1,297 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock +from unittest.mock import patch + +from authlib.oauth2.rfc6749 import OAuth2Token +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.refresher.oauth2_credential_refresher import OAuth2CredentialRefresher +import pytest + + +class TestOAuth2CredentialRefresher: + """Test suite for OAuth2CredentialRefresher.""" + + @patch("google.adk.auth.refresher.oauth2_credential_refresher.OAuth2Token") + @pytest.mark.asyncio + async def test_needs_refresh_token_not_expired(self, mock_oauth2_token): + """Test needs_refresh when token is not expired.""" + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = False + mock_oauth2_token.return_value = mock_token_instance + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + expires_at=int(time.time()) + 3600, + ), + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, scheme) + + assert not needs_refresh + + @patch("google.adk.auth.refresher.oauth2_credential_refresher.OAuth2Token") + @pytest.mark.asyncio + async def test_needs_refresh_token_expired(self, mock_oauth2_token): + """Test needs_refresh when token is expired.""" + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + expires_at=int(time.time()) - 3600, # Expired + ), + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, scheme) + + assert needs_refresh + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @patch("google.adk.auth.oauth2_credential_util.OAuth2Token") + @pytest.mark.asyncio + async def test_refresh_token_expired_success( + self, mock_oauth2_token, mock_oauth2_session + ): + """Test successful token refresh when token is expired.""" + # Setup mock token + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + # Setup mock session + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "refreshed_access_token", + "refresh_token": "refreshed_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.refresh_token.return_value = mock_tokens + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="old_token", + refresh_token="old_refresh_token", + expires_at=int(time.time()) - 3600, # Expired + ), + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, scheme) + + # Verify token refresh was successful + assert result.oauth2.access_token == "refreshed_access_token" + assert result.oauth2.refresh_token == "refreshed_refresh_token" + mock_client.refresh_token.assert_called_once() + + @pytest.mark.asyncio + async def test_refresh_no_oauth2_credential(self): + """Test refresh with no OAuth2 credential returns original.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + # No oauth2 field + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, scheme) + + assert result == credential + + @pytest.mark.asyncio + async def test_needs_refresh_google_oauth2_json_expired(self): + """Test needs_refresh with Google OAuth2 JSON credential that is expired.""" + import json + from unittest.mock import patch + + # Mock Google OAuth2 JSON credential data + google_oauth2_json = json.dumps({ + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "refresh_token": "test_refresh_token", + "type": "authorized_user", + }) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json=google_oauth2_json, + ) + + # Mock the Google Credentials class + with patch( + "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" + ) as mock_credentials: + mock_google_credential = Mock() + mock_google_credential.expired = True + mock_google_credential.refresh_token = "test_refresh_token" + mock_credentials.from_authorized_user_info.return_value = ( + mock_google_credential + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, None) + + assert needs_refresh + + @pytest.mark.asyncio + async def test_needs_refresh_google_oauth2_json_not_expired(self): + """Test needs_refresh with Google OAuth2 JSON credential that is not expired.""" + import json + from unittest.mock import patch + + # Mock Google OAuth2 JSON credential data + google_oauth2_json = json.dumps({ + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "refresh_token": "test_refresh_token", + "type": "authorized_user", + }) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json=google_oauth2_json, + ) + + # Mock the Google Credentials class + with patch( + "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" + ) as mock_credentials: + mock_google_credential = Mock() + mock_google_credential.expired = False + mock_google_credential.refresh_token = "test_refresh_token" + mock_credentials.from_authorized_user_info.return_value = ( + mock_google_credential + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, None) + + assert not needs_refresh + + @pytest.mark.asyncio + async def test_refresh_google_oauth2_json_success(self): + """Test successful refresh of Google OAuth2 JSON credential.""" + import json + from unittest.mock import patch + + # Mock Google OAuth2 JSON credential data + google_oauth2_json = json.dumps({ + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "refresh_token": "test_refresh_token", + "type": "authorized_user", + }) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json=google_oauth2_json, + ) + + # Mock the Google Credentials and Request classes + with patch( + "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" + ) as mock_credentials: + with patch( + "google.adk.auth.refresher.oauth2_credential_refresher.Request" + ) as mock_request: + mock_google_credential = Mock() + mock_google_credential.expired = True + mock_google_credential.refresh_token = "test_refresh_token" + mock_google_credential.to_json.return_value = json.dumps({ + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "refresh_token": "new_refresh_token", + "access_token": "new_access_token", + "type": "authorized_user", + }) + mock_credentials.from_authorized_user_info.return_value = ( + mock_google_credential + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, None) + + mock_google_credential.refresh.assert_called_once() + assert ( + result.google_oauth2_json != google_oauth2_json + ) # Should be updated + + @pytest.mark.asyncio + async def test_needs_refresh_no_oauth2_credential(self): + """Test needs_refresh with no OAuth2 credential returns False.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + # No oauth2 field + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, None) + + assert not needs_refresh diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index 2bfc7d4c9..f0d730d02 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -13,8 +13,11 @@ # limitations under the License. import copy +import time +from unittest.mock import Mock from unittest.mock import patch +from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import APIKey from fastapi.openapi.models import APIKeyIn from fastapi.openapi.models import OAuth2 @@ -405,7 +408,8 @@ def test_get_auth_response_not_exists(self, auth_config): class TestParseAndStoreAuthResponse: """Tests for the parse_and_store_auth_response method.""" - def test_non_oauth_scheme(self, auth_config_with_exchanged): + @pytest.mark.asyncio + async def test_non_oauth_scheme(self, auth_config_with_exchanged): """Test with a non-OAuth auth scheme.""" # Modify the auth scheme type to be non-OAuth auth_config = copy.deepcopy(auth_config_with_exchanged) @@ -416,7 +420,7 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged): handler = AuthHandler(auth_config) state = MockState() - handler.parse_and_store_auth_response(state) + await handler.parse_and_store_auth_response(state) credential_key = auth_config.credential_key assert ( @@ -424,7 +428,10 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged): ) @patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token") - def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): + @pytest.mark.asyncio + async def test_oauth_scheme( + self, mock_exchange_token, auth_config_with_exchanged + ): """Test with an OAuth auth scheme.""" mock_exchange_token.return_value = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, @@ -434,7 +441,7 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): handler = AuthHandler(auth_config_with_exchanged) state = MockState() - handler.parse_and_store_auth_response(state) + await handler.parse_and_store_auth_response(state) credential_key = auth_config_with_exchanged.credential_key assert state["temp:" + credential_key] == mock_exchange_token.return_value @@ -444,20 +451,20 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): class TestExchangeAuthToken: """Tests for the exchange_auth_token method.""" - def test_token_exchange_not_supported( + @pytest.mark.asyncio + async def test_token_exchange_not_supported( self, auth_config_with_auth_code, monkeypatch ): """Test when token exchange is not supported.""" - monkeypatch.setattr( - "google.adk.auth.oauth2_credential_fetcher.AUTHLIB_AVIALABLE", False - ) + monkeypatch.setattr("google.adk.auth.auth_handler.AUTHLIB_AVIALABLE", False) handler = AuthHandler(auth_config_with_auth_code) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == auth_config_with_auth_code.exchanged_auth_credential - def test_openid_missing_token_endpoint( + @pytest.mark.asyncio + async def test_openid_missing_token_endpoint( self, openid_auth_scheme, oauth2_credentials_with_auth_code ): """Test OpenID Connect without a token endpoint.""" @@ -472,11 +479,12 @@ def test_openid_missing_token_endpoint( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_auth_code - def test_oauth2_missing_token_url( + @pytest.mark.asyncio + async def test_oauth2_missing_token_url( self, oauth2_auth_scheme, oauth2_credentials_with_auth_code ): """Test OAuth2 without a token URL.""" @@ -491,11 +499,12 @@ def test_oauth2_missing_token_url( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_auth_code - def test_non_oauth_scheme(self, auth_config_with_auth_code): + @pytest.mark.asyncio + async def test_non_oauth_scheme(self, auth_config_with_auth_code): """Test with a non-OAuth auth scheme.""" # Modify the auth scheme type to be non-OAuth auth_config = copy.deepcopy(auth_config_with_auth_code) @@ -504,11 +513,12 @@ def test_non_oauth_scheme(self, auth_config_with_auth_code): ) handler = AuthHandler(auth_config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == auth_config.exchanged_auth_credential - def test_missing_credentials(self, oauth2_auth_scheme): + @pytest.mark.asyncio + async def test_missing_credentials(self, oauth2_auth_scheme): """Test with missing credentials.""" empty_credential = AuthCredential(auth_type=AuthCredentialTypes.OAUTH2) @@ -518,11 +528,12 @@ def test_missing_credentials(self, oauth2_auth_scheme): ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == empty_credential - def test_credentials_with_token( + @pytest.mark.asyncio + async def test_credentials_with_token( self, auth_config, oauth2_credentials_with_token ): """Test when credentials already have a token.""" @@ -533,18 +544,29 @@ def test_credentials_with_token( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_token - @patch( - "google.adk.auth.oauth2_credential_util.OAuth2Session", - MockOAuth2Session, - ) - def test_successful_token_exchange(self, auth_config_with_auth_code): + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_successful_token_exchange( + self, mock_oauth2_session, auth_config_with_auth_code + ): """Test a successful token exchange.""" + # Setup mock OAuth2Session + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "mock_access_token", + "refresh_token": "mock_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + handler = AuthHandler(auth_config_with_auth_code) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result.oauth2.access_token == "mock_access_token" assert result.oauth2.refresh_token == "mock_refresh_token" diff --git a/tests/unittests/auth/test_oauth2_credential_fetcher.py b/tests/unittests/auth/test_oauth2_credential_fetcher.py deleted file mode 100644 index aba6a9923..000000000 --- a/tests/unittests/auth/test_oauth2_credential_fetcher.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from unittest.mock import Mock - -from authlib.oauth2.rfc6749 import OAuth2Token -from fastapi.openapi.models import OAuth2 -from fastapi.openapi.models import OAuthFlowAuthorizationCode -from fastapi.openapi.models import OAuthFlows -from google.adk.auth.auth_credential import AuthCredential -from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.auth.auth_credential import OAuth2Auth -from google.adk.auth.auth_schemes import OpenIdConnectWithConfig -from google.adk.auth.oauth2_credential_util import create_oauth2_session -from google.adk.auth.oauth2_credential_util import update_credential_with_tokens - - -class TestOAuth2CredentialUtil: - """Test suite for OAuth2 credential utility functions.""" - - def test_create_oauth2_session_openid_connect(self): - """Test create_oauth2_session with OpenID Connect scheme.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - state="test_state", - ), - ) - - client, token_endpoint = create_oauth2_session(scheme, credential) - - assert client is not None - assert token_endpoint == "https://example.com/token" - assert client.client_id == "test_client_id" - assert client.client_secret == "test_client_secret" - - def test_create_oauth2_session_oauth2_scheme(self): - """Test create_oauth2_session with OAuth2 scheme.""" - flows = OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://example.com/auth", - tokenUrl="https://example.com/token", - scopes={"read": "Read access", "write": "Write access"}, - ) - ) - scheme = OAuth2(type_="oauth2", flows=flows) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - ), - ) - - client, token_endpoint = create_oauth2_session(scheme, credential) - - assert client is not None - assert token_endpoint == "https://example.com/token" - - def test_create_oauth2_session_invalid_scheme(self): - """Test create_oauth2_session with invalid scheme.""" - scheme = Mock() # Invalid scheme type - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - ), - ) - - client, token_endpoint = create_oauth2_session(scheme, credential) - - assert client is None - assert token_endpoint is None - - def test_create_oauth2_session_missing_credentials(self): - """Test create_oauth2_session with missing credentials.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - # Missing client_secret - ), - ) - - client, token_endpoint = create_oauth2_session(scheme, credential) - - assert client is None - assert token_endpoint is None - - def test_update_credential_with_tokens(self): - """Test update_credential_with_tokens function.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - ), - ) - - tokens = OAuth2Token({ - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - }) - - update_credential_with_tokens(credential, tokens) - - assert credential.oauth2.access_token == "new_access_token" - assert credential.oauth2.refresh_token == "new_refresh_token" - assert credential.oauth2.expires_at == int(time.time()) + 3600 - assert credential.oauth2.expires_in == 3600 diff --git a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py index cd37a105e..c9b542e51 100644 --- a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py +++ b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py @@ -20,6 +20,7 @@ from google.adk.auth.auth_credential import HttpCredentials from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool +from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import AuthPreparationResult from google.genai.types import FunctionDeclaration from google.genai.types import Schema from google.genai.types import Type @@ -50,7 +51,9 @@ def mock_rest_api_tool(): "required": ["user_id", "page_size", "filter", "connection_name"], } mock_tool._operation_parser = mock_parser - mock_tool.call.return_value = {"status": "success", "data": "mock_data"} + mock_tool.call = mock.AsyncMock( + return_value={"status": "success", "data": "mock_data"} + ) return mock_tool @@ -179,9 +182,6 @@ async def test_run_with_auth_async_none_token( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = mock.MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) # Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None mock_auth_credential_without_token = AuthCredential( auth_type=AuthCredentialTypes.HTTP, @@ -190,8 +190,12 @@ async def test_run_with_auth_async_none_token( credentials=HttpCredentials(token=None), # Token is None ), ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = ( - mock_auth_credential_without_token + mock_tool_auth_handler_instance.prepare_auth_credentials = mock.AsyncMock( + return_value=( + AuthPreparationResult( + state="done", auth_credential=mock_auth_credential_without_token + ) + ) ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance @@ -229,18 +233,18 @@ async def test_run_with_auth_async( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = mock.MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", - credentials=HttpCredentials(token="mocked_token"), - ), + + mock_tool_auth_handler_instance.prepare_auth_credentials = mock.AsyncMock( + return_value=AuthPreparationResult( + state="done", + auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="mocked_token"), + ), + ), + ) ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance result = await integration_tool_with_auth.run_async( diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index 303dda69d..c4cbea7b9 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -14,6 +14,7 @@ import json +from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -194,7 +195,8 @@ def test_get_declaration( @patch( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request" ) - def test_call_success( + @pytest.mark.asyncio + async def test_call_success( self, mock_request, mock_tool_context, @@ -217,7 +219,7 @@ def test_call_success( ) # Call the method - result = tool.call(args={}, tool_context=mock_tool_context) + result = await tool.call(args={}, tool_context=mock_tool_context) # Check the result assert result == {"result": "success"} @@ -225,7 +227,8 @@ def test_call_success( @patch( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request" ) - def test_call_auth_pending( + @pytest.mark.asyncio + async def test_call_auth_pending( self, mock_request, sample_endpoint, @@ -246,12 +249,14 @@ def test_call_auth_pending( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "pending" + mock_prepare_result = MagicMock() + mock_prepare_result.state = "pending" + mock_tool_auth_handler_instance.prepare_auth_credentials = AsyncMock( + return_value=mock_prepare_result ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance - response = tool.call(args={}, tool_context=None) + response = await tool.call(args={}, tool_context=None) assert response == { "pending": True, "message": "Needs your authorization to access your data.", diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py index 8db151fc8..e405ce5b8 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py @@ -116,7 +116,8 @@ def openid_connect_credential(): return credential -def test_openid_connect_no_auth_response( +@pytest.mark.asyncio +async def test_openid_connect_no_auth_response( openid_connect_scheme, openid_connect_credential ): # Setup Mock exchanger @@ -132,12 +133,13 @@ def test_openid_connect_no_auth_response( credential_exchanger=mock_exchanger, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'pending' assert result.auth_credential == openid_connect_credential -def test_openid_connect_with_auth_response( +@pytest.mark.asyncio +async def test_openid_connect_with_auth_response( openid_connect_scheme, openid_connect_credential, monkeypatch ): mock_exchanger = MockOpenIdConnectCredentialExchanger( @@ -166,7 +168,7 @@ def test_openid_connect_with_auth_response( credential_exchanger=mock_exchanger, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'done' assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP assert 'test_access_token' in result.auth_credential.http.credentials.token @@ -178,7 +180,8 @@ def test_openid_connect_with_auth_response( mock_auth_handler.get_auth_response.assert_called_once() -def test_openid_connect_existing_token( +@pytest.mark.asyncio +async def test_openid_connect_existing_token( openid_connect_scheme, openid_connect_credential ): _, existing_credential = token_to_scheme_credential( @@ -198,16 +201,17 @@ def test_openid_connect_existing_token( openid_connect_credential, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'done' assert result.auth_credential == existing_credential @patch( - 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialFetcher' + 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialRefresher' ) -def test_openid_connect_existing_oauth2_token_refresh( - mock_oauth2_fetcher, openid_connect_scheme, openid_connect_credential +@pytest.mark.asyncio +async def test_openid_connect_existing_oauth2_token_refresh( + mock_oauth2_refresher, openid_connect_scheme, openid_connect_credential ): """Test that OAuth2 tokens are refreshed when existing credentials are found.""" # Create existing OAuth2 credential @@ -232,10 +236,13 @@ def test_openid_connect_existing_oauth2_token_refresh( ), ) - # Setup mock OAuth2CredentialFetcher - mock_fetcher_instance = MagicMock() - mock_fetcher_instance.refresh.return_value = refreshed_credential - mock_oauth2_fetcher.return_value = mock_fetcher_instance + # Setup mock OAuth2CredentialRefresher + from unittest.mock import AsyncMock + + mock_refresher_instance = MagicMock() + mock_refresher_instance.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher_instance.refresh = AsyncMock(return_value=refreshed_credential) + mock_oauth2_refresher.return_value = mock_refresher_instance tool_context = create_mock_tool_context() credential_store = ToolContextCredentialStore(tool_context=tool_context) @@ -253,13 +260,17 @@ def test_openid_connect_existing_oauth2_token_refresh( credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() + + # Verify OAuth2CredentialRefresher was called for refresh + mock_oauth2_refresher.assert_called_once() - # Verify OAuth2CredentialFetcher was called for refresh - mock_oauth2_fetcher.assert_called_once_with( - openid_connect_scheme, existing_credential + mock_refresher_instance.is_refresh_needed.assert_called_once_with( + existing_credential + ) + mock_refresher_instance.refresh.assert_called_once_with( + existing_credential, openid_connect_scheme ) - mock_fetcher_instance.refresh.assert_called_once() assert result.state == 'done' # The result should contain the refreshed credential after exchange From 2c739ab5812d24686cee61a0a1ce808b63ceb883 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 11:00:41 -0700 Subject: [PATCH 55/61] chore: Add Credential Manager for managing tools credential (Experimental) PiperOrigin-RevId: 772986051 --- src/google/adk/auth/credential_manager.py | 265 +++++++++ .../unittests/auth/test_credential_manager.py | 559 ++++++++++++++++++ 2 files changed, 824 insertions(+) create mode 100644 src/google/adk/auth/credential_manager.py create mode 100644 tests/unittests/auth/test_credential_manager.py diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py new file mode 100644 index 000000000..7471bdffa --- /dev/null +++ b/src/google/adk/auth/credential_manager.py @@ -0,0 +1,265 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from ..tools.tool_context import ToolContext +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_credential import AuthCredentialTypes +from .auth_schemes import AuthSchemeType +from .auth_tool import AuthConfig +from .exchanger.base_credential_exchanger import BaseCredentialExchanger +from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry +from .refresher.base_credential_refresher import BaseCredentialRefresher +from .refresher.credential_refresher_registry import CredentialRefresherRegistry + + +@experimental +class CredentialManager: + """Manages authentication credentials through a structured workflow. + + The CredentialManager orchestrates the complete lifecycle of authentication + credentials, from initial loading to final preparation for use. It provides + a centralized interface for handling various credential types and authentication + schemes while maintaining proper credential hygiene (refresh, exchange, caching). + + This class is only for use by Agent Development Kit. + + Args: + auth_config: Configuration containing authentication scheme and credentials + + Example: + ```python + auth_config = AuthConfig( + auth_scheme=oauth2_scheme, + raw_auth_credential=service_account_credential + ) + manager = CredentialManager(auth_config) + + # Register custom exchanger if needed + manager.register_credential_exchanger( + AuthCredentialTypes.CUSTOM_TYPE, + CustomCredentialExchanger() + ) + + # Register custom refresher if needed + manager.register_credential_refresher( + AuthCredentialTypes.CUSTOM_TYPE, + CustomCredentialRefresher() + ) + + # Load and prepare credential + credential = await manager.load_auth_credential(tool_context) + ``` + """ + + def __init__( + self, + auth_config: AuthConfig, + ): + self._auth_config = auth_config + self._exchanger_registry = CredentialExchangerRegistry() + self._refresher_registry = CredentialRefresherRegistry() + + # Register default exchangers and refreshers + from .exchanger.service_account_credential_exchanger import ServiceAccountCredentialExchanger + + self._exchanger_registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, ServiceAccountCredentialExchanger() + ) + from .refresher.oauth2_credential_refresher import OAuth2CredentialRefresher + + oauth2_refresher = OAuth2CredentialRefresher() + self._refresher_registry.register( + AuthCredentialTypes.OAUTH2, oauth2_refresher + ) + self._refresher_registry.register( + AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher + ) + + def register_credential_exchanger( + self, + credential_type: AuthCredentialTypes, + exchanger_instance: BaseCredentialExchanger, + ) -> None: + """Register a credential exchanger for a credential type. + + Args: + credential_type: The credential type to register for. + exchanger_instance: The exchanger instance to register. + """ + self._exchanger_registry.register(credential_type, exchanger_instance) + + async def request_credential(self, tool_context: ToolContext) -> None: + tool_context.request_credential(self._auth_config) + + async def get_auth_credential( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load and prepare authentication credential through a structured workflow.""" + + # Step 1: Validate credential configuration + await self._validate_credential() + + # Step 2: Check if credential is already ready (no processing needed) + if self._is_credential_ready(): + return self._auth_config.raw_auth_credential + + # Step 3: Try to load existing processed credential + credential = await self._load_existing_credential(tool_context) + + # Step 4: If no existing credential, load from auth response + # TODO instead of load from auth response, we can store auth response in + # credential service. + was_from_auth_response = False + if not credential: + credential = await self._load_from_auth_response(tool_context) + was_from_auth_response = True + + # Step 5: If still no credential available, return None + if not credential: + return None + + # Step 6: Exchange credential if needed (e.g., service account to access token) + credential, was_exchanged = await self._exchange_credential(credential) + + # Step 7: Refresh credential if expired + if not was_exchanged: + credential, was_refreshed = await self._refresh_credential(credential) + + # Step 8: Save credential if it was modified + if was_from_auth_response or was_exchanged or was_refreshed: + await self._save_credential(tool_context, credential) + + return credential + + async def _load_existing_credential( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load existing credential from credential service or cached exchanged credential.""" + + # Try loading from credential service first + credential = await self._load_from_credential_service(tool_context) + if credential: + return credential + + # Check if we have a cached exchanged credential + if self._auth_config.exchanged_auth_credential: + return self._auth_config.exchanged_auth_credential + + return None + + async def _load_from_credential_service( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load credential from credential service if available.""" + credential_service = tool_context._invocation_context.credential_service + if credential_service: + # Note: This should be made async in a future refactor + # For now, assuming synchronous operation + return await credential_service.load_credential( + self._auth_config, tool_context + ) + return None + + async def _load_from_auth_response( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load credential from auth response in tool context.""" + return tool_context.get_auth_response(self._auth_config) + + async def _exchange_credential( + self, credential: AuthCredential + ) -> tuple[AuthCredential, bool]: + """Exchange credential if needed and return the credential and whether it was exchanged.""" + exchanger = self._exchanger_registry.get_exchanger(credential.auth_type) + if not exchanger: + return credential, False + + exchanged_credential = await exchanger.exchange( + credential, self._auth_config.auth_scheme + ) + return exchanged_credential, True + + async def _refresh_credential( + self, credential: AuthCredential + ) -> tuple[AuthCredential, bool]: + """Refresh credential if expired and return the credential and whether it was refreshed.""" + refresher = self._refresher_registry.get_refresher(credential.auth_type) + if not refresher: + return credential, False + + if await refresher.is_refresh_needed( + credential, self._auth_config.auth_scheme + ): + refreshed_credential = await refresher.refresh( + credential, self._auth_config.auth_scheme + ) + return refreshed_credential, True + + return credential, False + + def _is_credential_ready(self) -> bool: + """Check if credential is ready to use without further processing.""" + raw_credential = self._auth_config.raw_auth_credential + if not raw_credential: + return False + + # Simple credentials that don't need exchange or refresh + return raw_credential.auth_type in ( + AuthCredentialTypes.API_KEY, + AuthCredentialTypes.HTTP, + # Add other simple auth types as needed + ) + + async def _validate_credential(self) -> None: + """Validate credential configuration and raise errors if invalid.""" + if not self._auth_config.raw_auth_credential: + if self._auth_config.auth_scheme.type_ in ( + AuthSchemeType.oauth2, + AuthSchemeType.openIdConnect, + ): + raise ValueError( + "raw_auth_credential is required for auth_scheme type " + f"{self._auth_config.auth_scheme.type_}" + ) + + raw_credential = self._auth_config.raw_auth_credential + if raw_credential: + if ( + raw_credential.auth_type + in ( + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + ) + and not raw_credential.oauth2 + ): + raise ValueError( + "auth_config.raw_credential.oauth2 required for credential type " + f"{raw_credential.auth_type}" + ) + # Additional validation can be added here + + async def _save_credential( + self, tool_context: ToolContext, credential: AuthCredential + ) -> None: + """Save credential to credential service if available.""" + credential_service = tool_context._invocation_context.credential_service + if credential_service: + # Update the exchanged credential in config + self._auth_config.exchanged_auth_credential = credential + await credential_service.save_credential(self._auth_config, tool_context) diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py new file mode 100644 index 000000000..283e865a7 --- /dev/null +++ b/tests/unittests/auth/test_credential_manager.py @@ -0,0 +1,559 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from fastapi.openapi.models import HTTPBearer +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_credential import ServiceAccountCredential +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +import pytest + + +class TestCredentialManager: + """Test suite for CredentialManager.""" + + def test_init(self): + """Test CredentialManager initialization.""" + auth_config = Mock(spec=AuthConfig) + manager = CredentialManager(auth_config) + assert manager._auth_config == auth_config + + @pytest.mark.asyncio + async def test_request_credential(self): + """Test request_credential method.""" + auth_config = Mock(spec=AuthConfig) + tool_context = Mock() + tool_context.request_credential = Mock() + + manager = CredentialManager(auth_config) + await manager.request_credential(tool_context) + + tool_context.request_credential.assert_called_once_with(auth_config) + + @pytest.mark.asyncio + async def test_load_auth_credentials_success(self): + """Test load_auth_credential with successful flow.""" + # Create mocks + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.exchanged_auth_credential = None + + # Mock the credential that will be returned + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.API_KEY + + tool_context = Mock() + + manager = CredentialManager(auth_config) + + # Mock the private methods + manager._validate_credential = AsyncMock() + manager._is_credential_ready = Mock(return_value=False) + manager._load_existing_credential = AsyncMock(return_value=None) + manager._load_from_auth_response = AsyncMock(return_value=mock_credential) + manager._exchange_credential = AsyncMock( + return_value=(mock_credential, False) + ) + manager._refresh_credential = AsyncMock( + return_value=(mock_credential, False) + ) + manager._save_credential = AsyncMock() + + result = await manager.get_auth_credential(tool_context) + + # Verify all methods were called + manager._validate_credential.assert_called_once() + manager._is_credential_ready.assert_called_once() + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) + manager._exchange_credential.assert_called_once_with(mock_credential) + manager._refresh_credential.assert_called_once_with(mock_credential) + manager._save_credential.assert_called_once_with( + tool_context, mock_credential + ) + + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_auth_credentials_no_credential(self): + """Test load_auth_credential when no credential is available.""" + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.exchanged_auth_credential = None + + tool_context = Mock() + + manager = CredentialManager(auth_config) + + # Mock the private methods + manager._validate_credential = AsyncMock() + manager._is_credential_ready = Mock(return_value=False) + manager._load_existing_credential = AsyncMock(return_value=None) + manager._load_from_auth_response = AsyncMock(return_value=None) + manager._exchange_credential = AsyncMock() + manager._refresh_credential = AsyncMock() + manager._save_credential = AsyncMock() + + result = await manager.get_auth_credential(tool_context) + + # Verify methods were called but no credential returned + manager._validate_credential.assert_called_once() + manager._is_credential_ready.assert_called_once() + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) + manager._exchange_credential.assert_not_called() + manager._refresh_credential.assert_not_called() + manager._save_credential.assert_not_called() + + assert result is None + + @pytest.mark.asyncio + async def test_load_existing_credential_already_exchanged(self): + """Test _load_existing_credential when credential is already exchanged.""" + auth_config = Mock(spec=AuthConfig) + mock_credential = Mock(spec=AuthCredential) + auth_config.exchanged_auth_credential = mock_credential + + tool_context = Mock() + + manager = CredentialManager(auth_config) + manager._load_from_credential_service = AsyncMock(return_value=None) + + result = await manager._load_existing_credential(tool_context) + + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_existing_credential_with_credential_service(self): + """Test _load_existing_credential with credential service.""" + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = None + + mock_credential = Mock(spec=AuthCredential) + + tool_context = Mock() + + manager = CredentialManager(auth_config) + manager._load_from_credential_service = AsyncMock( + return_value=mock_credential + ) + + result = await manager._load_existing_credential(tool_context) + + manager._load_from_credential_service.assert_called_once_with(tool_context) + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_from_credential_service_with_service(self): + """Test _load_from_credential_service from tool context when credential service is available.""" + auth_config = Mock(spec=AuthConfig) + + mock_credential = Mock(spec=AuthCredential) + + # Mock credential service + credential_service = Mock() + credential_service.load_credential = AsyncMock(return_value=mock_credential) + + # Mock invocation context + invocation_context = Mock() + invocation_context.credential_service = credential_service + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + result = await manager._load_from_credential_service(tool_context) + + credential_service.load_credential.assert_called_once_with( + auth_config, tool_context + ) + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_from_credential_service_no_service(self): + """Test _load_from_credential_service when no credential service is available.""" + auth_config = Mock(spec=AuthConfig) + + # Mock invocation context with no credential service + invocation_context = Mock() + invocation_context.credential_service = None + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + result = await manager._load_from_credential_service(tool_context) + + assert result is None + + @pytest.mark.asyncio + async def test_save_credential_with_service(self): + """Test _save_credential with credential service.""" + auth_config = Mock(spec=AuthConfig) + mock_credential = Mock(spec=AuthCredential) + + # Mock credential service + credential_service = AsyncMock() + + # Mock invocation context + invocation_context = Mock() + invocation_context.credential_service = credential_service + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + await manager._save_credential(tool_context, mock_credential) + + credential_service.save_credential.assert_called_once_with( + auth_config, tool_context + ) + assert auth_config.exchanged_auth_credential == mock_credential + + @pytest.mark.asyncio + async def test_save_credential_no_service(self): + """Test _save_credential when no credential service is available.""" + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = None + mock_credential = Mock(spec=AuthCredential) + + # Mock invocation context with no credential service + invocation_context = Mock() + invocation_context.credential_service = None + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + await manager._save_credential(tool_context, mock_credential) + + # Should not raise an error, and credential should not be set in auth_config + # when there's no credential service (according to implementation) + assert auth_config.exchanged_auth_credential is None + + @pytest.mark.asyncio + async def test_refresh_credential_oauth2(self): + """Test _refresh_credential with OAuth2 credential.""" + mock_oauth2_auth = Mock(spec=OAuth2Auth) + + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.OAUTH2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = Mock() + + # Mock refresher + mock_refresher = Mock() + mock_refresher.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher.refresh = AsyncMock(return_value=mock_credential) + + auth_config.raw_auth_credential = mock_credential + + manager = CredentialManager(auth_config) + + # Mock the refresher registry to return our mock refresher + with patch.object( + manager._refresher_registry, + "get_refresher", + return_value=mock_refresher, + ): + result, was_refreshed = await manager._refresh_credential(mock_credential) + + mock_refresher.is_refresh_needed.assert_called_once_with( + mock_credential, auth_config.auth_scheme + ) + mock_refresher.refresh.assert_called_once_with( + mock_credential, auth_config.auth_scheme + ) + assert result == mock_credential + assert was_refreshed is True + + @pytest.mark.asyncio + async def test_refresh_credential_no_refresher(self): + """Test _refresh_credential with credential that has no refresher.""" + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + + manager = CredentialManager(auth_config) + + # Mock the refresher registry to return None (no refresher available) + with patch.object( + manager._refresher_registry, + "get_refresher", + return_value=None, + ): + result, was_refreshed = await manager._refresh_credential(mock_credential) + + assert result == mock_credential + assert was_refreshed is False + + @pytest.mark.asyncio + async def test_is_credential_ready_api_key(self): + """Test _is_credential_ready with API key credential.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + + manager = CredentialManager(auth_config) + result = manager._is_credential_ready() + + assert result is True + + @pytest.mark.asyncio + async def test_is_credential_ready_oauth2(self): + """Test _is_credential_ready with OAuth2 credential (needs processing).""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + + manager = CredentialManager(auth_config) + result = manager._is_credential_ready() + + assert result is False + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_oauth2(self): + """Test _validate_credential with no raw credential for OAuth2.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises(ValueError, match="raw_auth_credential is required"): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_openid(self): + """Test _validate_credential with no raw credential for OpenID Connect.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.openIdConnect + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises(ValueError, match="raw_auth_credential is required"): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_other_scheme(self): + """Test _validate_credential with no raw credential for other schemes.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.apiKey + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + await manager._validate_credential() + + # Should return without error for non-OAuth2/OpenID schemes + + @pytest.mark.asyncio + async def test_validate_credential_oauth2_missing_oauth2_field(self): + """Test _validate_credential with OAuth2 credential missing oauth2 field.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.oauth2 + + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 + mock_raw_credential.oauth2 = None + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises( + ValueError, match="auth_config.raw_credential.oauth2 required" + ): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_exchange_credentials_service_account(self): + """Test _exchange_credential with service account credential.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.SERVICE_ACCOUNT + + mock_exchanged_credential = Mock(spec=AuthCredential) + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = Mock() + + manager = CredentialManager(auth_config) + + # Mock the exchanger that gets created during registration + with patch.object( + manager._exchanger_registry, "get_exchanger" + ) as mock_get_exchanger: + mock_exchanger = Mock() + mock_exchanger.exchange = AsyncMock( + return_value=mock_exchanged_credential + ) + mock_get_exchanger.return_value = mock_exchanger + + result, was_exchanged = await manager._exchange_credential( + mock_raw_credential + ) + + assert result == mock_exchanged_credential + assert was_exchanged is True + mock_get_exchanger.assert_called_once_with( + AuthCredentialTypes.SERVICE_ACCOUNT + ) + mock_exchanger.exchange.assert_called_once_with( + mock_raw_credential, auth_config.auth_scheme + ) + + @pytest.mark.asyncio + async def test_exchange_credential_no_exchanger(self): + """Test _exchange_credential with credential that has no exchanger.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + + manager = CredentialManager(auth_config) + + # Mock the exchanger registry to return None (no exchanger available) + with patch.object( + manager._exchanger_registry, "get_exchanger", return_value=None + ): + result, was_exchanged = await manager._exchange_credential( + mock_raw_credential + ) + + assert result == mock_raw_credential + assert was_exchanged is False + + +# Test fixtures +@pytest.fixture +def oauth2_auth_scheme(): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + +@pytest.fixture +def openid_auth_scheme(): + """Create an OpenID Connect auth scheme for testing.""" + return OpenIdConnectWithConfig( + type_="openIdConnect", + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + + +@pytest.fixture +def bearer_auth_scheme(): + """Create a Bearer auth scheme for testing.""" + return HTTPBearer(bearerFormat="JWT") + + +@pytest.fixture +def oauth2_credential(): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + +@pytest.fixture +def service_account_credential(): + """Create service account credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=ServiceAccountCredential( + type="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE" + " KEY-----\n" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="123456789", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + ), + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + +@pytest.fixture +def api_key_credential(): + """Create API key credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key="test-api-key", + ) + + +@pytest.fixture +def http_bearer_credential(): + """Create HTTP Bearer credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="bearer-token"), + ), + ) From dcea7767c67c7edfb694304df32dca10b74c9a71 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 11:15:27 -0700 Subject: [PATCH 56/61] feat: Add Authenticated Tool (Experimental) PiperOrigin-RevId: 772992074 --- .../adk/tools/authenticated_function_tool.py | 107 ++++ .../adk/tools/base_authenticated_tool.py | 107 ++++ .../tools/test_authenticated_function_tool.py | 541 ++++++++++++++++++ .../tools/test_base_authenticated_tool.py | 343 +++++++++++ 4 files changed, 1098 insertions(+) create mode 100644 src/google/adk/tools/authenticated_function_tool.py create mode 100644 src/google/adk/tools/base_authenticated_tool.py create mode 100644 tests/unittests/tools/test_authenticated_function_tool.py create mode 100644 tests/unittests/tools/test_base_authenticated_tool.py diff --git a/src/google/adk/tools/authenticated_function_tool.py b/src/google/adk/tools/authenticated_function_tool.py new file mode 100644 index 000000000..67cc5885f --- /dev/null +++ b/src/google/adk/tools/authenticated_function_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import logging +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .function_tool import FunctionTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class AuthenticatedFunctionTool(FunctionTool): + """A FunctionTool that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + func: Callable[..., Any], + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """Initializes the AuthenticatedFunctionTool. + + Args: + func: The function to be called. + auth_config: The authentication configuration. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__(func=func) + self._ignore_params.append("credential") + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + args_to_call = args.copy() + signature = inspect.signature(self.func) + if "credential" in signature.parameters: + args_to_call["credential"] = credential + return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/base_authenticated_tool.py b/src/google/adk/tools/base_authenticated_tool.py new file mode 100644 index 000000000..4858e4953 --- /dev/null +++ b/src/google/adk/tools/base_authenticated_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +import logging +from typing import Any +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .base_tool import BaseTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class BaseAuthenticatedTool(BaseTool): + """A base tool class that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + name, + description, + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """ + Args: + name: The name of the tool. + description: The description of the tool. + auth_config: The auth configuration of the tool. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__( + name=name, + description=description, + ) + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, + tool_context=tool_context, + credential=credential, + ) + + @abstractmethod + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + pass diff --git a/tests/unittests/tools/test_authenticated_function_tool.py b/tests/unittests/tools/test_authenticated_function_tool.py new file mode 100644 index 000000000..88454032a --- /dev/null +++ b/tests/unittests/tools/test_authenticated_function_tool.py @@ -0,0 +1,541 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.authenticated_function_tool import AuthenticatedFunctionTool +from google.adk.tools.tool_context import ToolContext +import pytest + +# Test functions for different scenarios + + +def sync_function_no_credential(arg1: str, arg2: int) -> str: + """Test sync function without credential parameter.""" + return f"sync_result_{arg1}_{arg2}" + + +async def async_function_no_credential(arg1: str, arg2: int) -> str: + """Test async function without credential parameter.""" + return f"async_result_{arg1}_{arg2}" + + +def sync_function_with_credential(arg1: str, credential: AuthCredential) -> str: + """Test sync function with credential parameter.""" + return f"sync_cred_result_{arg1}_{credential.auth_type.value}" + + +async def async_function_with_credential( + arg1: str, credential: AuthCredential +) -> str: + """Test async function with credential parameter.""" + return f"async_cred_result_{arg1}_{credential.auth_type.value}" + + +def sync_function_with_tool_context( + arg1: str, tool_context: ToolContext +) -> str: + """Test sync function with tool_context parameter.""" + return f"sync_context_result_{arg1}" + + +async def async_function_with_both( + arg1: str, tool_context: ToolContext, credential: AuthCredential +) -> str: + """Test async function with both tool_context and credential parameters.""" + return f"async_both_result_{arg1}_{credential.auth_type.value}" + + +def function_with_optional_args( + arg1: str, arg2: str = "default", credential: AuthCredential = None +) -> str: + """Test function with optional arguments.""" + cred_type = credential.auth_type.value if credential else "none" + return f"optional_result_{arg1}_{arg2}_{cred_type}" + + +class MockCallable: + """Test callable class for testing.""" + + def __init__(self): + self.__name__ = "MockCallable" + self.__doc__ = "Test callable documentation" + + def __call__(self, arg1: str, credential: AuthCredential) -> str: + return f"callable_result_{arg1}_{credential.auth_type.value}" + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + # Create a mock auth_type that returns the expected value + mock_auth_type = Mock() + mock_auth_type.value = "oauth2" + credential.auth_type = mock_auth_type + return credential + + +class TestAuthenticatedFunctionTool: + """Test suite for AuthenticatedFunctionTool.""" + + def test_init_with_sync_function(self): + """Test initialization with synchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, + auth_config=auth_config, + response_for_auth_required="Please authenticate", + ) + + assert tool.name == "sync_function_no_credential" + assert ( + tool.description == "Test sync function without credential parameter." + ) + assert tool.func == sync_function_no_credential + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == "Please authenticate" + assert "credential" in tool._ignore_params + + def test_init_with_async_function(self): + """Test initialization with asynchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=async_function_no_credential, auth_config=auth_config + ) + + assert tool.name == "async_function_no_credential" + assert ( + tool.description == "Test async function without credential parameter." + ) + assert tool.func == async_function_no_credential + assert tool._response_for_auth_required is None + + def test_init_with_callable(self): + """Test initialization with callable object.""" + auth_config = _create_mock_auth_config() + test_callable = MockCallable() + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + + assert tool.name == "MockCallable" + assert tool.description == "Test callable documentation" + assert tool.func == test_callable + + def test_init_no_auth_config(self): + """Test initialization without auth_config.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + + assert tool._credentials_manager is None + + @pytest.mark.asyncio + async def test_run_async_sync_function_no_credential_manager(self): + """Test run_async with sync function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_async_function_no_credential_manager(self): + """Test run_async with async function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=async_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "async_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"sync_cred_result_test_{credential.auth_type.value}" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_async_function_with_credential(self): + """Test run_async with async function that expects credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_cred_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, + auth_config=auth_config, + response_for_auth_required="Custom auth required", + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Custom auth required" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_default_message(self): + """Test run_async when no credential is available with default message.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + + @pytest.mark.asyncio + async def test_run_async_function_without_credential_param(self): + """Test run_async with function that doesn't have credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Credential should not be passed to function since it doesn't have the parameter + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_function_with_tool_context(self): + """Test run_async with function that has tool_context parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_tool_context, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_context_result_test" + + @pytest.mark.asyncio + async def test_run_async_function_with_both_params(self): + """Test run_async with function that has both tool_context and credential parameters.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_both, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_both_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_function_with_optional_credential(self): + """Test run_async with function that has optional credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=function_with_optional_args, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert ( + result == f"optional_result_test_default_{credential.auth_type.value}" + ) + + @pytest.mark.asyncio + async def test_run_async_callable_object(self): + """Test run_async with callable object.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + test_callable = MockCallable() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"callable_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_propagates_function_exception(self): + """Test that run_async propagates exceptions from the wrapped function.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + def failing_function(arg1: str, credential: AuthCredential) -> str: + raise ValueError("Function failed") + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=failing_function, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(ValueError, match="Function failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_missing_required_args(self): + """Test run_async with missing required arguments.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} # Missing arg2 + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Should return error dict indicating missing parameters + assert isinstance(result, dict) + assert "error" in result + assert "arg2" in result["error"] + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_credential_in_ignore_params(self): + """Test that 'credential' is added to ignore_params during initialization.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + assert "credential" in tool._ignore_params + + @pytest.mark.asyncio + async def test_run_async_with_none_credential(self): + """Test run_async when credential is None but function expects it.""" + tool = AuthenticatedFunctionTool(func=function_with_optional_args) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "optional_result_test_default_none" + + def test_signature_inspection(self): + """Test that the tool correctly inspects function signatures.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + signature = inspect.signature(tool.func) + assert "credential" in signature.parameters + assert "arg1" in signature.parameters + + @pytest.mark.asyncio + async def test_args_to_call_modification(self): + """Test that args_to_call is properly modified with credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + # Create a spy function to check what arguments are passed + original_args = {} + + def spy_function(arg1: str, credential: AuthCredential) -> str: + nonlocal original_args + original_args = {"arg1": arg1, "credential": credential} + return "spy_result" + + tool = AuthenticatedFunctionTool(func=spy_function, auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "spy_result" + assert original_args is not None + assert original_args["arg1"] == "test" + assert original_args["credential"] == credential diff --git a/tests/unittests/tools/test_base_authenticated_tool.py b/tests/unittests/tools/test_base_authenticated_tool.py new file mode 100644 index 000000000..55454224d --- /dev/null +++ b/tests/unittests/tools/test_base_authenticated_tool.py @@ -0,0 +1,343 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.base_authenticated_tool import BaseAuthenticatedTool +from google.adk.tools.tool_context import ToolContext +import pytest + + +class _TestAuthenticatedTool(BaseAuthenticatedTool): + """Test implementation of BaseAuthenticatedTool for testing purposes.""" + + def __init__( + self, + name="test_auth_tool", + description="Test authenticated tool", + auth_config=None, + unauthenticated_response=None, + ): + super().__init__( + name=name, + description=description, + auth_config=auth_config, + response_for_auth_required=unauthenticated_response, + ) + self.run_impl_called = False + self.run_impl_result = "test_result" + + async def _run_async_impl(self, *, args, tool_context, credential): + """Test implementation of the abstract method.""" + self.run_impl_called = True + self.last_args = args + self.last_tool_context = tool_context + self.last_credential = credential + return self.run_impl_result + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + credential.auth_type = AuthCredentialTypes.OAUTH2 + return credential + + +class TestBaseAuthenticatedTool: + """Test suite for BaseAuthenticatedTool.""" + + def test_init_with_auth_config(self): + """Test initialization with auth_config.""" + auth_config = _create_mock_auth_config() + unauthenticated_response = {"error": "Not authenticated"} + + tool = _TestAuthenticatedTool( + name="test_tool", + description="Test description", + auth_config=auth_config, + unauthenticated_response=unauthenticated_response, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test description" + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == unauthenticated_response + + def test_init_with_no_auth_config(self): + """Test initialization without auth_config.""" + tool = _TestAuthenticatedTool() + + assert tool.name == "test_auth_tool" + assert tool.description == "Test authenticated tool" + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._credentials_manager is None + + def test_init_with_default_unauthenticated_response(self): + """Test initialization with default unauthenticated response.""" + auth_config = _create_mock_auth_config() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._response_for_auth_required is None + + @pytest.mark.asyncio + async def test_run_async_no_credentials_manager(self): + """Test run_async when no credentials manager is configured.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential is None + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential == credential + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_custom_response(self): + """Test run_async when no credential is available with custom response.""" + auth_config = _create_mock_auth_config() + custom_response = { + "status": "authentication_required", + "message": "Please login", + } + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_string_response(self): + """Test run_async when no credential is available with string response.""" + auth_config = _create_mock_auth_config() + custom_response = "Custom authentication required message" + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + + @pytest.mark.asyncio + async def test_run_async_propagates_impl_exception(self): + """Test that run_async propagates exceptions from _run_async_impl.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + # Make the implementation raise an exception + async def failing_impl(*, args, tool_context, credential): + raise ValueError("Implementation failed") + + tool._run_async_impl = failing_impl + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(ValueError, match="Implementation failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_with_different_args_types(self): + """Test run_async with different argument types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + + # Test with empty args + result = await tool.run_async(args={}, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == {} + + # Test with complex args + complex_args = { + "string_param": "test", + "number_param": 42, + "list_param": [1, 2, 3], + "dict_param": {"nested": "value"}, + } + result = await tool.run_async(args=complex_args, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == complex_args + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_abstract_nature(self): + """Test that BaseAuthenticatedTool cannot be instantiated directly.""" + with pytest.raises(TypeError): + # This should fail because _run_async_impl is abstract + BaseAuthenticatedTool(name="test", description="test") + + @pytest.mark.asyncio + async def test_run_async_return_values(self): + """Test run_async with different return value types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {} + + # Test with None return + tool.run_impl_result = None + result = await tool.run_async(args=args, tool_context=tool_context) + assert result is None + + # Test with dict return + tool.run_impl_result = {"key": "value"} + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == {"key": "value"} + + # Test with list return + tool.run_impl_result = [1, 2, 3] + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == [1, 2, 3] From 18a541c8fa5d9cac2769c1875d5d9dc4f782ca75 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 11:33:09 -0700 Subject: [PATCH 57/61] chore: Ignore mcp_tool ut tests for python 3.9 given mcp sdk only supports 3.10+ PiperOrigin-RevId: 772999037 --- .github/workflows/python-unit-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index d4af7b13a..565ee1dca 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -51,6 +51,7 @@ jobs: if [[ "${{ matrix.python-version }}" == "3.9" ]]; then pytest tests/unittests \ --ignore=tests/unittests/a2a \ + --ignore=tests/unittests/tools/mcp_tool \ --ignore=tests/unittests/artifacts/test_artifact_service.py \ --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py else From 157d9be88d92f22320604832e5a334a6eb81e4af Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 11:43:32 -0700 Subject: [PATCH 58/61] feat: Enable MCP Tool Auth (Experimental) PiperOrigin-RevId: 773002759 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 237 ++++++----- src/google/adk/tools/mcp_tool/mcp_tool.py | 88 ++++- src/google/adk/tools/mcp_tool/mcp_toolset.py | 12 +- tests/unittests/tools/mcp_tool/__init__.py | 13 + .../mcp_tool/test_mcp_session_manager.py | 342 ++++++++++++++++ .../unittests/tools/mcp_tool/test_mcp_tool.py | 373 ++++++++++++++++++ .../tools/mcp_tool/test_mcp_toolset.py | 269 +++++++++++++ 7 files changed, 1231 insertions(+), 103 deletions(-) create mode 100644 tests/unittests/tools/mcp_tool/__init__.py create mode 100644 tests/unittests/tools/mcp_tool/test_mcp_session_manager.py create mode 100644 tests/unittests/tools/mcp_tool/test_mcp_tool.py create mode 100644 tests/unittests/tools/mcp_tool/test_mcp_toolset.py diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 5bc06e398..90b39e6cb 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -18,9 +18,12 @@ from contextlib import AsyncExitStack from datetime import timedelta import functools +import hashlib +import json import logging import sys from typing import Any +from typing import Dict from typing import Optional from typing import TextIO from typing import Union @@ -105,74 +108,39 @@ class StreamableHTTPConnectionParams(BaseModel): terminate_on_close: bool = True -def retry_on_closed_resource(session_manager_field_name: str): - """Decorator to automatically reinitialize session and retry action. +def retry_on_closed_resource(func): + """Decorator to automatically retry action when MCP session is closed. - When MCP session was closed, the decorator will automatically recreate the - session and retry the action with the same parameters. - - Note: - 1. session_manager_field_name is the name of the class member field that - contains the MCPSessionManager instance. - 2. The session manager must have a reinitialize_session() async method. - - Usage: - class MCPTool: - def __init__(self): - self._mcp_session_manager = MCPSessionManager(...) - - @retry_on_closed_resource('_mcp_session_manager') - async def use_session(self): - session = await self._mcp_session_manager.create_session() - await session.call_tool() + When MCP session was closed, the decorator will automatically retry the + action once. The create_session method will handle creating a new session + if the old one was disconnected. Args: - session_manager_field_name: The name of the session manager field. + func: The function to decorate. Returns: The decorated function. """ - def decorator(func): - @functools.wraps(func) # Preserves original function metadata - async def wrapper(self, *args, **kwargs): - try: - return await func(self, *args, **kwargs) - except anyio.ClosedResourceError as close_err: - try: - if hasattr(self, session_manager_field_name): - session_manager = getattr(self, session_manager_field_name) - if hasattr(session_manager, 'reinitialize_session') and callable( - getattr(session_manager, 'reinitialize_session') - ): - await session_manager.reinitialize_session() - else: - raise ValueError( - f'Session manager {session_manager_field_name} does not have' - ' reinitialize_session method.' - ) from close_err - else: - raise ValueError( - f'Session manager field {session_manager_field_name} does not' - ' exist in decorated class. Please check the field name in' - ' retry_on_closed_resource decorator.' - ) from close_err - except Exception as reinit_err: - raise RuntimeError( - f'Error reinitializing: {reinit_err}' - ) from reinit_err - return await func(self, *args, **kwargs) - - return wrapper - - return decorator + @functools.wraps(func) # Preserves original function metadata + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except anyio.ClosedResourceError: + # Simply retry the function - create_session will handle + # detecting and replacing disconnected sessions + logger.info('Retrying %s due to closed resource', func.__name__) + return await func(self, *args, **kwargs) + + return wrapper class MCPSessionManager: """Manages MCP client sessions. This class provides methods for creating and initializing MCP client sessions, - handling different connection parameters (Stdio and SSE). + handling different connection parameters (Stdio and SSE) and supporting + session pooling based on authentication headers. """ def __init__( @@ -209,30 +177,125 @@ def __init__( else: self._connection_params = connection_params self._errlog = errlog - # Each session manager maintains its own exit stack for proper cleanup - self._exit_stack: Optional[AsyncExitStack] = None - self._session: Optional[ClientSession] = None + + # Session pool: maps session keys to (session, exit_stack) tuples + self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} + # Lock to prevent race conditions in session creation self._session_lock = asyncio.Lock() - async def create_session(self) -> ClientSession: + def _generate_session_key( + self, merged_headers: Optional[Dict[str, str]] = None + ) -> str: + """Generates a session key based on connection params and merged headers. + + For StdioConnectionParams, returns a constant key since headers are not + supported. For SSE and StreamableHTTP connections, generates a key based + on the provided merged headers. + + Args: + merged_headers: Already merged headers (base + additional). + + Returns: + A unique session key string. + """ + if isinstance(self._connection_params, StdioConnectionParams): + # For stdio connections, headers are not supported, so use constant key + return 'stdio_session' + + # For SSE and StreamableHTTP connections, use merged headers + if merged_headers: + headers_json = json.dumps(merged_headers, sort_keys=True) + headers_hash = hashlib.md5(headers_json.encode()).hexdigest() + return f'session_{headers_hash}' + else: + return 'session_no_headers' + + def _merge_headers( + self, additional_headers: Optional[Dict[str, str]] = None + ) -> Optional[Dict[str, str]]: + """Merges base connection headers with additional headers. + + Args: + additional_headers: Optional headers to merge with connection headers. + + Returns: + Merged headers dictionary, or None if no headers are provided. + """ + if isinstance(self._connection_params, StdioConnectionParams) or isinstance( + self._connection_params, StdioServerParameters + ): + # Stdio connections don't support headers + return None + + base_headers = {} + if ( + hasattr(self._connection_params, 'headers') + and self._connection_params.headers + ): + base_headers = self._connection_params.headers.copy() + + if additional_headers: + base_headers.update(additional_headers) + + return base_headers + + def _is_session_disconnected(self, session: ClientSession) -> bool: + """Checks if a session is disconnected or closed. + + Args: + session: The ClientSession to check. + + Returns: + True if the session is disconnected, False otherwise. + """ + return session._read_stream._closed or session._write_stream._closed + + async def create_session( + self, headers: Optional[Dict[str, str]] = None + ) -> ClientSession: """Creates and initializes an MCP client session. + This method will check if an existing session for the given headers + is still connected. If it's disconnected, it will be cleaned up and + a new session will be created. + + Args: + headers: Optional headers to include in the session. These will be + merged with any existing connection headers. Only applicable + for SSE and StreamableHTTP connections. + Returns: ClientSession: The initialized MCP client session. """ - # Fast path: if session already exists, return it without acquiring lock - if self._session is not None: - return self._session + # Merge headers once at the beginning + merged_headers = self._merge_headers(headers) + + # Generate session key using merged headers + session_key = self._generate_session_key(merged_headers) # Use async lock to prevent race conditions async with self._session_lock: - # Double-check: session might have been created while waiting for lock - if self._session is not None: - return self._session - - # Create a new exit stack for this session - self._exit_stack = AsyncExitStack() + # Check if we have an existing session + if session_key in self._sessions: + session, exit_stack = self._sessions[session_key] + + # Check if the existing session is still connected + if not self._is_session_disconnected(session): + # Session is still good, return it + return session + else: + # Session is disconnected, clean it up + logger.info('Cleaning up disconnected session: %s', session_key) + try: + await exit_stack.aclose() + except Exception as e: + logger.warning('Error during disconnected session cleanup: %s', e) + finally: + del self._sessions[session_key] + + # Create a new session (either first time or replacing disconnected one) + exit_stack = AsyncExitStack() try: if isinstance(self._connection_params, StdioConnectionParams): @@ -243,7 +306,7 @@ async def create_session(self) -> ClientSession: elif isinstance(self._connection_params, SseConnectionParams): client = sse_client( url=self._connection_params.url, - headers=self._connection_params.headers, + headers=merged_headers, timeout=self._connection_params.timeout, sse_read_timeout=self._connection_params.sse_read_timeout, ) @@ -252,7 +315,7 @@ async def create_session(self) -> ClientSession: ): client = streamablehttp_client( url=self._connection_params.url, - headers=self._connection_params.headers, + headers=merged_headers, timeout=timedelta(seconds=self._connection_params.timeout), sse_read_timeout=timedelta( seconds=self._connection_params.sse_read_timeout @@ -266,11 +329,11 @@ async def create_session(self) -> ClientSession: f' {self._connection_params}' ) - transports = await self._exit_stack.enter_async_context(client) + transports = await exit_stack.enter_async_context(client) # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. if isinstance(self._connection_params, StdioConnectionParams): - session = await self._exit_stack.enter_async_context( + session = await exit_stack.enter_async_context( ClientSession( *transports[:2], read_timeout_seconds=timedelta( @@ -279,44 +342,38 @@ async def create_session(self) -> ClientSession: ) ) else: - session = await self._exit_stack.enter_async_context( + session = await exit_stack.enter_async_context( ClientSession(*transports[:2]) ) await session.initialize() - self._session = session + # Store session and exit stack in the pool + self._sessions[session_key] = (session, exit_stack) + logger.debug('Created new session: %s', session_key) return session except Exception: # If session creation fails, clean up the exit stack - if self._exit_stack: - await self._exit_stack.aclose() - self._exit_stack = None + if exit_stack: + await exit_stack.aclose() raise async def close(self): - """Closes the session and cleans up resources.""" - if not self._exit_stack: - return + """Closes all sessions and cleans up resources.""" async with self._session_lock: - if self._exit_stack: + for session_key in list(self._sessions.keys()): + _, exit_stack = self._sessions[session_key] try: - await self._exit_stack.aclose() + await exit_stack.aclose() except Exception as e: # Log the error but don't re-raise to avoid blocking shutdown print( - f'Warning: Error during MCP session cleanup: {e}', + 'Warning: Error during MCP session cleanup for' + f' {session_key}: {e}', file=self._errlog, ) finally: - self._exit_stack = None - self._session = None - - async def reinitialize_session(self): - """Reinitializes the session when connection is lost.""" - # Close the old session and create a new one - await self.close() - await self.create_session() + del self._sessions[session_key] SseServerParams = SseConnectionParams diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 6553bb2c0..24998c925 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -14,10 +14,13 @@ from __future__ import annotations +import base64 +import json import logging from typing import Optional from google.genai.types import FunctionDeclaration +from google.oauth2.credentials import Credentials from typing_extensions import override from .._gemini_schema_util import _to_gemini_schema @@ -42,13 +45,15 @@ from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme -from ..base_tool import BaseTool +from ...auth.auth_tool import AuthConfig +from ..base_authenticated_tool import BaseAuthenticatedTool +# import from ..tool_context import ToolContext logger = logging.getLogger("google_adk." + __name__) -class MCPTool(BaseTool): +class MCPTool(BaseAuthenticatedTool): """Turns an MCP Tool into an ADK Tool. Internally, the tool initializes from a MCP Tool, and uses the MCP Session to @@ -77,19 +82,17 @@ def __init__( Raises: ValueError: If mcp_tool or mcp_session_manager is None. """ - if mcp_tool is None: - raise ValueError("mcp_tool cannot be None") - if mcp_session_manager is None: - raise ValueError("mcp_session_manager cannot be None") super().__init__( name=mcp_tool.name, description=mcp_tool.description if mcp_tool.description else "", + auth_config=AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential + ) + if auth_scheme + else None, ) self._mcp_tool = mcp_tool self._mcp_session_manager = mcp_session_manager - # TODO(cheliu): Support passing auth to MCP Server. - self._auth_scheme = auth_scheme - self._auth_credential = auth_credential @override def _get_declaration(self) -> FunctionDeclaration: @@ -105,8 +108,11 @@ def _get_declaration(self) -> FunctionDeclaration: ) return function_decl - @retry_on_closed_resource("_mcp_session_manager") - async def run_async(self, *, args, tool_context: ToolContext): + @retry_on_closed_resource + @override + async def _run_async_impl( + self, *, args, tool_context: ToolContext, credential: AuthCredential + ): """Runs the tool asynchronously. Args: @@ -116,8 +122,66 @@ async def run_async(self, *, args, tool_context: ToolContext): Returns: Any: The response from the tool. """ + # Extract headers from credential for session pooling + headers = await self._get_headers(tool_context, credential) + # Get the session from the session manager - session = await self._mcp_session_manager.create_session() + session = await self._mcp_session_manager.create_session(headers=headers) response = await session.call_tool(self.name, arguments=args) return response + + async def _get_headers( + self, tool_context: ToolContext, credential: AuthCredential + ) -> Optional[dict[str, str]]: + headers = None + if credential: + if credential.oauth2: + headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} + elif credential.google_oauth2_json: + google_credential = Credentials.from_authorized_user_info( + json.loads(credential.google_oauth2_json) + ) + headers = {"Authorization": f"Bearer {google_credential.token}"} + elif credential.http: + # Handle HTTP authentication schemes + if ( + credential.http.scheme.lower() == "bearer" + and credential.http.credentials.token + ): + headers = { + "Authorization": f"Bearer {credential.http.credentials.token}" + } + elif credential.http.scheme.lower() == "basic": + # Handle basic auth + if ( + credential.http.credentials.username + and credential.http.credentials.password + ): + + credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}" + encoded_credentials = base64.b64encode( + credentials.encode() + ).decode() + headers = {"Authorization": f"Basic {encoded_credentials}"} + elif credential.http.credentials.token: + # Handle other HTTP schemes with token + headers = { + "Authorization": ( + f"{credential.http.scheme} {credential.http.credentials.token}" + ) + } + elif credential.api_key: + # For API keys, we'll add them as headers since MCP typically uses header-based auth + # The specific header name would depend on the API, using a common default + # TODO Allow user to specify the header name for API keys. + headers = {"X-API-Key": credential.api_key} + elif credential.service_account: + # Service accounts should be exchanged for access tokens before reaching this point + # If we reach here, we can try to use google_oauth2_json or log a warning + logger.warning( + "Service account credentials should be exchanged for access" + " tokens before MCP session creation" + ) + + return headers diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index f55693e86..c01b0cec2 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -22,6 +22,8 @@ from typing import Union from ...agents.readonly_context import ReadonlyContext +from ...auth.auth_credential import AuthCredential +from ...auth.auth_schemes import AuthScheme from ..base_tool import BaseTool from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate @@ -94,6 +96,8 @@ def __init__( ], tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, errlog: TextIO = sys.stderr, + auth_scheme: Optional[AuthScheme] = None, + auth_credential: Optional[AuthCredential] = None, ): """Initializes the MCPToolset. @@ -110,6 +114,8 @@ def __init__( list of tool names to include - A ToolPredicate function for custom filtering logic errlog: TextIO stream for error logging. + auth_scheme: The auth scheme of the tool for tool calling + auth_credential: The auth credential of the tool for tool calling """ super().__init__(tool_filter=tool_filter) @@ -124,8 +130,10 @@ def __init__( connection_params=self._connection_params, errlog=self._errlog, ) + self._auth_scheme = auth_scheme + self._auth_credential = auth_credential - @retry_on_closed_resource("_mcp_session_manager") + @retry_on_closed_resource async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None, @@ -151,6 +159,8 @@ async def get_tools( mcp_tool = MCPTool( mcp_tool=tool, mcp_session_manager=self._mcp_session_manager, + auth_scheme=self._auth_scheme, + auth_credential=self._auth_credential, ) if self._is_tool_selected(mcp_tool, readonly_context): diff --git a/tests/unittests/tools/mcp_tool/__init__.py b/tests/unittests/tools/mcp_tool/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/tools/mcp_tool/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py new file mode 100644 index 000000000..448d41260 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -0,0 +1,342 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +from io import StringIO +import json +import sys +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +import pytest + +# Import real MCP classes +try: + from mcp import StdioServerParameters +except ImportError: + # Create a mock if MCP is not available + class StdioServerParameters: + + def __init__(self, command="test_command", args=None): + self.command = command + self.args = args or [] + + +class MockClientSession: + """Mock ClientSession for testing.""" + + def __init__(self): + self._read_stream = Mock() + self._write_stream = Mock() + self._read_stream._closed = False + self._write_stream._closed = False + self.initialize = AsyncMock() + + +class MockAsyncExitStack: + """Mock AsyncExitStack for testing.""" + + def __init__(self): + self.aclose = AsyncMock() + self.enter_async_context = AsyncMock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class TestMCPSessionManager: + """Test suite for MCPSessionManager class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_stdio_connection_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=5.0 + ) + + def test_init_with_stdio_server_parameters(self): + """Test initialization with StdioServerParameters (deprecated).""" + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.logger" + ) as mock_logger: + manager = MCPSessionManager(self.mock_stdio_params) + + # Should log deprecation warning + mock_logger.warning.assert_called_once() + assert "StdioServerParameters is not recommended" in str( + mock_logger.warning.call_args + ) + + # Should convert to StdioConnectionParams + assert isinstance(manager._connection_params, StdioConnectionParams) + assert manager._connection_params.server_params == self.mock_stdio_params + assert manager._connection_params.timeout == 5 + + def test_init_with_stdio_connection_params(self): + """Test initialization with StdioConnectionParams.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + assert manager._connection_params == self.mock_stdio_connection_params + assert manager._errlog == sys.stderr + assert manager._sessions == {} + + def test_init_with_sse_connection_params(self): + """Test initialization with SseConnectionParams.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", + headers={"Authorization": "Bearer token"}, + timeout=10.0, + ) + manager = MCPSessionManager(sse_params) + + assert manager._connection_params == sse_params + + def test_init_with_streamable_http_params(self): + """Test initialization with StreamableHTTPConnectionParams.""" + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", timeout=15.0 + ) + manager = MCPSessionManager(http_params) + + assert manager._connection_params == http_params + + def test_generate_session_key_stdio(self): + """Test session key generation for stdio connections.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # For stdio, headers should be ignored and return constant key + key1 = manager._generate_session_key({"Authorization": "Bearer token"}) + key2 = manager._generate_session_key(None) + + assert key1 == "stdio_session" + assert key2 == "stdio_session" + assert key1 == key2 + + def test_generate_session_key_sse(self): + """Test session key generation for SSE connections.""" + sse_params = SseConnectionParams(url="https://example.com/mcp") + manager = MCPSessionManager(sse_params) + + headers1 = {"Authorization": "Bearer token1"} + headers2 = {"Authorization": "Bearer token2"} + + key1 = manager._generate_session_key(headers1) + key2 = manager._generate_session_key(headers2) + key3 = manager._generate_session_key(headers1) + + # Different headers should generate different keys + assert key1 != key2 + # Same headers should generate same key + assert key1 == key3 + + # Should be deterministic hash + headers_json = json.dumps(headers1, sort_keys=True) + expected_hash = hashlib.md5(headers_json.encode()).hexdigest() + assert key1 == f"session_{expected_hash}" + + def test_merge_headers_stdio(self): + """Test header merging for stdio connections.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Stdio connections don't support headers + headers = manager._merge_headers({"Authorization": "Bearer token"}) + assert headers is None + + def test_merge_headers_sse(self): + """Test header merging for SSE connections.""" + base_headers = {"Content-Type": "application/json"} + sse_params = SseConnectionParams( + url="https://example.com/mcp", headers=base_headers + ) + manager = MCPSessionManager(sse_params) + + # With additional headers + additional = {"Authorization": "Bearer token"} + merged = manager._merge_headers(additional) + + expected = { + "Content-Type": "application/json", + "Authorization": "Bearer token", + } + assert merged == expected + + def test_is_session_disconnected(self): + """Test session disconnection detection.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock session + session = MockClientSession() + + # Not disconnected + assert not manager._is_session_disconnected(session) + + # Disconnected - read stream closed + session._read_stream._closed = True + assert manager._is_session_disconnected(session) + + @pytest.mark.asyncio + async def test_create_session_stdio_new(self): + """Test creating a new stdio session.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + mock_session = MockClientSession() + mock_exit_stack = MockAsyncExitStack() + + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.stdio_client" + ) as mock_stdio: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack" + ) as mock_exit_stack_class: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.ClientSession" + ) as mock_session_class: + + # Setup mocks + mock_exit_stack_class.return_value = mock_exit_stack + mock_stdio.return_value = AsyncMock() + mock_exit_stack.enter_async_context.side_effect = [ + ("read", "write"), # First call returns transports + mock_session, # Second call returns session + ] + mock_session_class.return_value = mock_session + + # Create session + session = await manager.create_session() + + # Verify session creation + assert session == mock_session + assert len(manager._sessions) == 1 + assert "stdio_session" in manager._sessions + + # Verify session was initialized + mock_session.initialize.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_reuse_existing(self): + """Test reusing an existing connected session.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock existing session + existing_session = MockClientSession() + existing_exit_stack = MockAsyncExitStack() + manager._sessions["stdio_session"] = (existing_session, existing_exit_stack) + + # Session is connected + existing_session._read_stream._closed = False + existing_session._write_stream._closed = False + + session = await manager.create_session() + + # Should reuse existing session + assert session == existing_session + assert len(manager._sessions) == 1 + + # Should not create new session + existing_session.initialize.assert_not_called() + + @pytest.mark.asyncio + async def test_close_success(self): + """Test successful cleanup of all sessions.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add mock sessions + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + + manager._sessions["session1"] = (session1, exit_stack1) + manager._sessions["session2"] = (session2, exit_stack2) + + await manager.close() + + # All sessions should be closed + exit_stack1.aclose.assert_called_once() + exit_stack2.aclose.assert_called_once() + assert len(manager._sessions) == 0 + + @pytest.mark.asyncio + async def test_close_with_errors(self): + """Test cleanup when some sessions fail to close.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add mock sessions + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + exit_stack1.aclose.side_effect = Exception("Close error 1") + + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + + manager._sessions["session1"] = (session1, exit_stack1) + manager._sessions["session2"] = (session2, exit_stack2) + + custom_errlog = StringIO() + manager._errlog = custom_errlog + + # Should not raise exception + await manager.close() + + # Good session should still be closed + exit_stack2.aclose.assert_called_once() + assert len(manager._sessions) == 0 + + # Error should be logged + error_output = custom_errlog.getvalue() + assert "Warning: Error during MCP session cleanup" in error_output + assert "Close error 1" in error_output + + +def test_retry_on_closed_resource_decorator(): + """Test the retry_on_closed_resource decorator.""" + + call_count = 0 + + @retry_on_closed_resource + async def mock_function(self): + nonlocal call_count + call_count += 1 + if call_count == 1: + import anyio + + raise anyio.ClosedResourceError("Resource closed") + return "success" + + @pytest.mark.asyncio + async def test_retry(): + nonlocal call_count + call_count = 0 + + mock_self = Mock() + result = await mock_function(mock_self) + + assert result == "success" + assert call_count == 2 # First call fails, second succeeds + + # Run the test + import asyncio + + asyncio.run(test_retry()) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py new file mode 100644 index 000000000..4d9cffb4d --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -0,0 +1,373 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.tool_context import ToolContext +from google.genai.types import FunctionDeclaration +import pytest + + +# Mock MCP Tool from mcp.types +class MockMCPTool: + """Mock MCP Tool for testing.""" + + def __init__(self, name="test_tool", description="Test tool description"): + self.name = name + self.description = description + self.inputSchema = { + "type": "object", + "properties": { + "param1": {"type": "string", "description": "First parameter"}, + "param2": {"type": "integer", "description": "Second parameter"}, + }, + "required": ["param1"], + } + + +class TestMCPTool: + """Test suite for MCPTool class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_mcp_tool = MockMCPTool() + self.mock_session_manager = Mock(spec=MCPSessionManager) + self.mock_session = AsyncMock() + self.mock_session_manager.create_session = AsyncMock( + return_value=self.mock_session + ) + + def test_init_basic(self): + """Test basic initialization without auth.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test tool description" + assert tool._mcp_tool == self.mock_mcp_tool + assert tool._mcp_session_manager == self.mock_session_manager + + def test_init_with_auth(self): + """Test initialization with authentication.""" + # Create real auth scheme instances instead of mocks + from fastapi.openapi.models import OAuth2 + + auth_scheme = OAuth2(flows={}) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + # The auth config is stored in the parent class _credentials_manager + assert tool._credentials_manager is not None + assert tool._credentials_manager._auth_config.auth_scheme == auth_scheme + assert ( + tool._credentials_manager._auth_config.raw_auth_credential + == auth_credential + ) + + def test_init_with_empty_description(self): + """Test initialization with empty description.""" + mock_tool = MockMCPTool(description=None) + tool = MCPTool( + mcp_tool=mock_tool, + mcp_session_manager=self.mock_session_manager, + ) + + assert tool.description == "" + + def test_get_declaration(self): + """Test function declaration generation.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + declaration = tool._get_declaration() + + assert isinstance(declaration, FunctionDeclaration) + assert declaration.name == "test_tool" + assert declaration.description == "Test tool description" + assert declaration.parameters is not None + + @pytest.mark.asyncio + async def test_run_async_impl_no_auth(self): + """Test running tool without authentication.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Mock the session response + expected_response = {"result": "success"} + self.mock_session.call_tool = AsyncMock(return_value=expected_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=None + ) + + assert result == expected_response + self.mock_session_manager.create_session.assert_called_once_with( + headers=None + ) + # Fix: call_tool uses 'arguments' parameter, not positional args + self.mock_session.call_tool.assert_called_once_with( + "test_tool", arguments=args + ) + + @pytest.mark.asyncio + async def test_run_async_impl_with_oauth2(self): + """Test running tool with OAuth2 authentication.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create OAuth2 credential + oauth2_auth = OAuth2Auth(access_token="test_access_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + # Mock the session response + expected_response = {"result": "success"} + self.mock_session.call_tool = AsyncMock(return_value=expected_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + assert result == expected_response + # Check that headers were passed correctly + self.mock_session_manager.create_session.assert_called_once() + call_args = self.mock_session_manager.create_session.call_args + headers = call_args[1]["headers"] + assert headers == {"Authorization": "Bearer test_access_token"} + + @pytest.mark.asyncio + async def test_get_headers_oauth2(self): + """Test header generation for OAuth2 credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + oauth2_auth = OAuth2Auth(access_token="test_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer test_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_bearer(self): + """Test header generation for HTTP Bearer credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="bearer", credentials=HttpCredentials(token="bearer_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer bearer_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_basic(self): + """Test header generation for HTTP Basic credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="basic", + credentials=HttpCredentials(username="user", password="pass"), + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should create Basic auth header with base64 encoded credentials + import base64 + + expected_encoded = base64.b64encode(b"user:pass").decode() + assert headers == {"Authorization": f"Basic {expected_encoded}"} + + @pytest.mark.asyncio + async def test_get_headers_api_key(self): + """Test header generation for API Key credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"X-API-Key": "my_api_key"} + + @pytest.mark.asyncio + @patch("google.adk.tools.mcp_tool.mcp_tool.json") + @patch("google.adk.tools.mcp_tool.mcp_tool.Credentials") + async def test_get_headers_google_oauth2_json( + self, mock_credentials, mock_json + ): + """Test header generation for Google OAuth2 JSON credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Mock the JSON parsing and Credentials creation + mock_json.loads.return_value = {"token": "google_token"} + mock_google_credential = Mock() + mock_google_credential.token = "google_access_token" + mock_credentials.from_authorized_user_info.return_value = ( + mock_google_credential + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + google_oauth2_json='{"token": "google_token"}', + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer google_access_token"} + mock_json.loads.assert_called_once_with('{"token": "google_token"}') + mock_credentials.from_authorized_user_info.assert_called_once_with( + {"token": "google_token"} + ) + + @pytest.mark.asyncio + async def test_get_headers_no_credential(self): + """Test header generation with no credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, None) + + assert headers is None + + @pytest.mark.asyncio + async def test_get_headers_service_account_no_json(self): + """Test header generation for service account credentials without google_oauth2_json.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create service account credential without google_oauth2_json + service_account = ServiceAccount(scopes=["test"]) + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=service_account, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should return None as no google_oauth2_json is provided + assert headers is None + + @pytest.mark.asyncio + async def test_run_async_impl_retry_decorator(self): + """Test that the retry decorator is applied correctly.""" + # This is more of an integration test to ensure the decorator is present + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Check that the method has the retry decorator + assert hasattr(tool._run_async_impl, "__wrapped__") + + @pytest.mark.asyncio + async def test_get_headers_http_custom_scheme(self): + """Test header generation for custom HTTP scheme.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="custom", credentials=HttpCredentials(token="custom_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "custom custom_token"} + + def test_init_validation(self): + """Test that initialization validates required parameters.""" + # This test ensures that the MCPTool properly handles its dependencies + with pytest.raises(TypeError): + MCPTool() # Missing required parameters + + with pytest.raises(TypeError): + MCPTool(mcp_tool=self.mock_mcp_tool) # Missing session manager diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py new file mode 100644 index 000000000..0ba29b1da --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -0,0 +1,269 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import StringIO +import sys +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +import pytest + +# Import the real MCP classes for proper instantiation +try: + from mcp import StdioServerParameters +except ImportError: + # Create a mock if MCP is not available + class StdioServerParameters: + + def __init__(self, command="test_command", args=None): + self.command = command + self.args = args or [] + + +class MockMCPTool: + """Mock MCP Tool for testing.""" + + def __init__(self, name, description="Test tool description"): + self.name = name + self.description = description + self.inputSchema = { + "type": "object", + "properties": {"param": {"type": "string"}}, + } + + +class MockListToolsResult: + """Mock ListToolsResult for testing.""" + + def __init__(self, tools): + self.tools = tools + + +class TestMCPToolset: + """Test suite for MCPToolset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_session_manager = Mock(spec=MCPSessionManager) + self.mock_session = AsyncMock() + self.mock_session_manager.create_session = AsyncMock( + return_value=self.mock_session + ) + + def test_init_basic(self): + """Test basic initialization with StdioServerParameters.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + + # Note: StdioServerParameters gets converted to StdioConnectionParams internally + assert toolset._errlog == sys.stderr + assert toolset._auth_scheme is None + assert toolset._auth_credential is None + + def test_init_with_stdio_connection_params(self): + """Test initialization with StdioConnectionParams.""" + stdio_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=10.0 + ) + toolset = MCPToolset(connection_params=stdio_params) + + assert toolset._connection_params == stdio_params + + def test_init_with_sse_connection_params(self): + """Test initialization with SseConnectionParams.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", headers={"Authorization": "Bearer token"} + ) + toolset = MCPToolset(connection_params=sse_params) + + assert toolset._connection_params == sse_params + + def test_init_with_streamable_http_params(self): + """Test initialization with StreamableHTTPConnectionParams.""" + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", + headers={"Content-Type": "application/json"}, + ) + toolset = MCPToolset(connection_params=http_params) + + assert toolset._connection_params == http_params + + def test_init_with_tool_filter_list(self): + """Test initialization with tool filter as list.""" + tool_filter = ["tool1", "tool2"] + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=tool_filter + ) + + # The tool filter is stored in the parent BaseToolset class + # We can verify it by checking the filtering behavior in get_tools + assert toolset._is_tool_selected is not None + + def test_init_with_auth(self): + """Test initialization with authentication.""" + # Create real auth scheme instances + from fastapi.openapi.models import OAuth2 + + auth_scheme = OAuth2(flows={}) + from google.adk.auth.auth_credential import OAuth2Auth + + auth_credential = AuthCredential( + auth_type="oauth2", + oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + ) + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + assert toolset._auth_scheme == auth_scheme + assert toolset._auth_credential == auth_credential + + def test_init_missing_connection_params(self): + """Test initialization with missing connection params raises error.""" + with pytest.raises(ValueError, match="Missing connection params"): + MCPToolset(connection_params=None) + + @pytest.mark.asyncio + async def test_get_tools_basic(self): + """Test getting tools without filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("tool1"), + MockMCPTool("tool2"), + MockMCPTool("tool3"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 3 + for tool in tools: + assert isinstance(tool, MCPTool) + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + assert tools[2].name == "tool3" + + @pytest.mark.asyncio + async def test_get_tools_with_list_filter(self): + """Test getting tools with list-based filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("tool1"), + MockMCPTool("tool2"), + MockMCPTool("tool3"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + tool_filter = ["tool1", "tool3"] + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=tool_filter + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool3" + + @pytest.mark.asyncio + async def test_get_tools_with_function_filter(self): + """Test getting tools with function-based filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("read_file"), + MockMCPTool("write_file"), + MockMCPTool("list_directory"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + def file_tools_filter(tool, context): + """Filter for file-related tools only.""" + return "file" in tool.name + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=file_tools_filter + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert tools[0].name == "read_file" + assert tools[1].name == "write_file" + + @pytest.mark.asyncio + async def test_close_success(self): + """Test successful cleanup.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + await toolset.close() + + self.mock_session_manager.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_exception(self): + """Test cleanup when session manager raises exception.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + # Mock close to raise an exception + self.mock_session_manager.close = AsyncMock( + side_effect=Exception("Cleanup error") + ) + + custom_errlog = StringIO() + toolset._errlog = custom_errlog + + # Should not raise exception + await toolset.close() + + # Should log the error + error_output = custom_errlog.getvalue() + assert "Warning: Error during MCPToolset cleanup" in error_output + assert "Cleanup error" in error_output + + @pytest.mark.asyncio + async def test_get_tools_retry_decorator(self): + """Test that get_tools has retry decorator applied.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + + # Check that the method has the retry decorator + assert hasattr(toolset.get_tools, "__wrapped__") From 58e07cae83048d5213d822be5197a96be9ce2950 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Wed, 18 Jun 2025 16:33:56 -0700 Subject: [PATCH 59/61] fix: Fix tracing for live the original code passed in wrong args. now fixed. tested locally. PiperOrigin-RevId: 773108589 --- src/google/adk/flows/llm_flows/functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 2541ac664..2772550c2 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -288,8 +288,7 @@ async def handle_function_calls_live( trace_tool_call( tool=tool, args=function_args, - response_event_id=function_response_event.id, - function_response=function_response, + function_response_event=function_response_event, ) function_response_events.append(function_response_event) From 913d771d6dda4f0b4a5f9c82ab914f3495a92092 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 17:40:12 -0700 Subject: [PATCH 60/61] chore: Raise meaningful errors when importing a2a modules for python 3.9 PiperOrigin-RevId: 773128206 --- src/google/adk/a2a/converters/part_converter.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index 1c51fd7c1..2d94abd7c 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -20,9 +20,20 @@ import json import logging +import sys from typing import Optional -from a2a import types as a2a_types +try: + from a2a import types as a2a_types +except ImportError as e: + if sys.version_info < (3, 10): + raise ImportError( + 'A2A Tool requires Python 3.10 or above. Please upgrade your Python' + ' version.' + ) from e + else: + raise e + from google.genai import types as genai_types from ...utils.feature_decorator import working_in_progress From 9a1115c504427ed8285b5c2053946c11c5d7c0a6 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 18 Jun 2025 18:18:17 -0700 Subject: [PATCH 61/61] chore: Remove service account support given it was not correctly supported. PiperOrigin-RevId: 773137317 --- src/google/adk/auth/auth_credential.py | 1 - src/google/adk/auth/credential_manager.py | 6 +- src/google/adk/auth/exchanger/__init__.py | 2 - .../service_account_credential_exchanger.py | 104 ----- .../refresher/oauth2_credential_refresher.py | 32 +- src/google/adk/tools/mcp_tool/mcp_tool.py | 10 +- .../openapi_spec_parser/tool_auth_handler.py | 1 - ...st_service_account_credential_exchanger.py | 433 ------------------ .../test_oauth2_credential_refresher.py | 118 ----- .../unittests/auth/test_credential_manager.py | 26 +- .../unittests/tools/mcp_tool/test_mcp_tool.py | 42 +- 11 files changed, 15 insertions(+), 760 deletions(-) delete mode 100644 src/google/adk/auth/exchanger/service_account_credential_exchanger.py delete mode 100644 tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 1009a50dd..34d04dde9 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -230,4 +230,3 @@ class AuthCredential(BaseModelWithConfig): http: Optional[HttpAuth] = None service_account: Optional[ServiceAccount] = None oauth2: Optional[OAuth2Auth] = None - google_oauth2_json: Optional[str] = None diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index 7471bdffa..0dbf006ab 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -76,11 +76,7 @@ def __init__( self._refresher_registry = CredentialRefresherRegistry() # Register default exchangers and refreshers - from .exchanger.service_account_credential_exchanger import ServiceAccountCredentialExchanger - - self._exchanger_registry.register( - AuthCredentialTypes.SERVICE_ACCOUNT, ServiceAccountCredentialExchanger() - ) + # TODO: support service account credential exchanger from .refresher.oauth2_credential_refresher import OAuth2CredentialRefresher oauth2_refresher = OAuth2CredentialRefresher() diff --git a/src/google/adk/auth/exchanger/__init__.py b/src/google/adk/auth/exchanger/__init__.py index 4226ae715..3b0fbb246 100644 --- a/src/google/adk/auth/exchanger/__init__.py +++ b/src/google/adk/auth/exchanger/__init__.py @@ -15,9 +15,7 @@ """Credential exchanger module.""" from .base_credential_exchanger import BaseCredentialExchanger -from .service_account_credential_exchanger import ServiceAccountCredentialExchanger __all__ = [ "BaseCredentialExchanger", - "ServiceAccountCredentialExchanger", ] diff --git a/src/google/adk/auth/exchanger/service_account_credential_exchanger.py b/src/google/adk/auth/exchanger/service_account_credential_exchanger.py deleted file mode 100644 index 415081ca5..000000000 --- a/src/google/adk/auth/exchanger/service_account_credential_exchanger.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Credential fetcher for Google Service Account.""" - -from __future__ import annotations - -from typing import Optional - -import google.auth -from google.auth.transport.requests import Request -from google.oauth2 import service_account -from typing_extensions import override - -from ...utils.feature_decorator import experimental -from ..auth_credential import AuthCredential -from ..auth_credential import AuthCredentialTypes -from ..auth_schemes import AuthScheme -from .base_credential_exchanger import BaseCredentialExchanger - - -@experimental -class ServiceAccountCredentialExchanger(BaseCredentialExchanger): - """Exchanges Google Service Account credentials for an access token. - - Uses the default service credential if `use_default_credential = True`. - Otherwise, uses the service account credential provided in the auth - credential. - """ - - @override - async def exchange( - self, - auth_credential: AuthCredential, - auth_scheme: Optional[AuthScheme] = None, - ) -> AuthCredential: - """Exchanges the service account auth credential for an access token. - - If the AuthCredential contains a service account credential, it will be used - to exchange for an access token. Otherwise, if use_default_credential is True, - the default application credential will be used for exchanging an access token. - - Args: - auth_scheme: The authentication scheme. - auth_credential: The credential to exchange. - - Returns: - An AuthCredential in OAUTH2 format, containing the exchanged credential JSON. - - Raises: - ValueError: If service account credentials are missing or invalid. - Exception: If credential exchange or refresh fails. - """ - if auth_credential is None: - raise ValueError("Credential cannot be None.") - - if auth_credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT: - raise ValueError("Credential is not a service account credential.") - - if auth_credential.service_account is None: - raise ValueError( - "Service account credentials are missing. Please provide them." - ) - - if ( - auth_credential.service_account.service_account_credential is None - and not auth_credential.service_account.use_default_credential - ): - raise ValueError( - "Service account credentials are invalid. Please set the" - " service_account_credential field or set `use_default_credential =" - " True` to use application default credential in a hosted service" - " like Google Cloud Run." - ) - - try: - if auth_credential.service_account.use_default_credential: - credentials, _ = google.auth.default() - else: - config = auth_credential.service_account - credentials = service_account.Credentials.from_service_account_info( - config.service_account_credential.model_dump(), scopes=config.scopes - ) - - # Refresh credentials to ensure we have a valid access token - credentials.refresh(Request()) - - return AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json=credentials.to_json(), - ) - except Exception as e: - raise ValueError(f"Failed to exchange service account token: {e}") from e diff --git a/src/google/adk/auth/refresher/oauth2_credential_refresher.py b/src/google/adk/auth/refresher/oauth2_credential_refresher.py index 2d0a8b670..4c19520ce 100644 --- a/src/google/adk/auth/refresher/oauth2_credential_refresher.py +++ b/src/google/adk/auth/refresher/oauth2_credential_refresher.py @@ -60,27 +60,12 @@ async def is_refresh_needed( Returns: True if the credential needs to be refreshed, False otherwise. """ - # Handle Google OAuth2 credentials (from service account exchange) - if auth_credential.google_oauth2_json: - try: - google_credential = Credentials.from_authorized_user_info( - json.loads(auth_credential.google_oauth2_json) - ) - return google_credential.expired and bool( - google_credential.refresh_token - ) - except Exception as e: - logger.warning("Failed to parse Google OAuth2 JSON credential: %s", e) - return False # Handle regular OAuth2 credentials - elif auth_credential.oauth2 and auth_scheme: + if auth_credential.oauth2: if not AUTHLIB_AVIALABLE: return False - if not auth_credential.oauth2: - return False - return OAuth2Token({ "expires_at": auth_credential.oauth2.expires_at, "expires_in": auth_credential.oauth2.expires_in, @@ -105,22 +90,9 @@ async def refresh( The refreshed credential. """ - # Handle Google OAuth2 credentials (from service account exchange) - if auth_credential.google_oauth2_json: - try: - google_credential = Credentials.from_authorized_user_info( - json.loads(auth_credential.google_oauth2_json) - ) - if google_credential.expired and google_credential.refresh_token: - google_credential.refresh(Request()) - auth_credential.google_oauth2_json = google_credential.to_json() - logger.info("Successfully refreshed Google OAuth2 JSON credential") - except Exception as e: - # TODO reconsider whether we should raise error when refresh failed. - logger.error("Failed to refresh Google OAuth2 JSON credential: %s", e) # Handle regular OAuth2 credentials - elif auth_credential.oauth2 and auth_scheme: + if auth_credential.oauth2 and auth_scheme: if not AUTHLIB_AVIALABLE: return auth_credential diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 24998c925..310fc48f1 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -138,11 +138,6 @@ async def _get_headers( if credential: if credential.oauth2: headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} - elif credential.google_oauth2_json: - google_credential = Credentials.from_authorized_user_info( - json.loads(credential.google_oauth2_json) - ) - headers = {"Authorization": f"Bearer {google_credential.token}"} elif credential.http: # Handle HTTP authentication schemes if ( @@ -178,10 +173,9 @@ async def _get_headers( headers = {"X-API-Key": credential.api_key} elif credential.service_account: # Service accounts should be exchanged for access tokens before reaching this point - # If we reach here, we can try to use google_oauth2_json or log a warning logger.warning( - "Service account credentials should be exchanged for access" - " tokens before MCP session creation" + "Service account credentials should be exchanged before MCP" + " session creation" ) return headers diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index 08e535d28..74166b00e 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -233,7 +233,6 @@ def _external_exchange_required(self, credential) -> bool: AuthCredentialTypes.OPEN_ID_CONNECT, ) and not credential.oauth2.access_token - and not credential.google_oauth2_json ) async def prepare_auth_credentials( diff --git a/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py b/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py deleted file mode 100644 index 195e143d3..000000000 --- a/tests/unittests/auth/exchanger/test_service_account_credential_exchanger.py +++ /dev/null @@ -1,433 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for the ServiceAccountCredentialExchanger.""" - -from unittest.mock import MagicMock -from unittest.mock import patch - -from fastapi.openapi.models import HTTPBearer -from google.adk.auth.auth_credential import AuthCredential -from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.auth.auth_credential import ServiceAccount -from google.adk.auth.auth_credential import ServiceAccountCredential -from google.adk.auth.exchanger.service_account_credential_exchanger import ServiceAccountCredentialExchanger -import pytest - - -class TestServiceAccountCredentialExchanger: - """Test cases for ServiceAccountCredentialExchanger.""" - - def test_exchange_with_valid_credential(self): - """Test successful exchange with valid service account credential.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="test-project", - private_key_id="key-id", - private_key=( - "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE" - " KEY-----" - ), - client_email="test@test-project.iam.gserviceaccount.com", - client_id="12345", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - # This should not raise an exception - assert exchanger is not None - - @pytest.mark.asyncio - async def test_exchange_invalid_credential_type(self): - """Test exchange with invalid credential type raises ValueError.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, - api_key="test-key", - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Credential is not a service account credential" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" - ) - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.Request" - ) - async def test_exchange_with_explicit_credentials_success( - self, mock_request_class, mock_from_service_account_info - ): - """Test successful exchange with explicit service account credentials.""" - # Setup mocks - mock_request = MagicMock() - mock_request_class.return_value = mock_request - - mock_credentials = MagicMock() - mock_credentials.token = "mock_access_token" - mock_credentials.to_json.return_value = ( - '{"token": "mock_access_token", "type": "authorized_user"}' - ) - mock_from_service_account_info.return_value = mock_credentials - - # Create test credential - service_account_cred = ServiceAccountCredential( - type_="service_account", - project_id="test-project", - private_key_id="key-id", - private_key=( - "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE KEY-----" - ), - client_email="test@test-project.iam.gserviceaccount.com", - client_id="12345", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", - universe_domain="googleapis.com", - ) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=service_account_cred, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - result = await exchanger.exchange(credential, auth_scheme) - - # Verify the result - assert result.auth_type == AuthCredentialTypes.OAUTH2 - assert result.google_oauth2_json is not None - # Verify that google_oauth2_json contains the token - import json - - exchanged_creds = json.loads(result.google_oauth2_json) - assert exchanged_creds.get( - "token" - ) == "mock_access_token" or "mock_access_token" in str(exchanged_creds) - - # Verify mocks were called correctly - mock_from_service_account_info.assert_called_once_with( - service_account_cred.model_dump(), - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - mock_credentials.refresh.assert_called_once_with(mock_request) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" - ) - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.Request" - ) - async def test_exchange_with_default_credentials_success( - self, mock_request_class, mock_google_auth_default - ): - """Test successful exchange with default application credentials.""" - # Setup mocks - mock_request = MagicMock() - mock_request_class.return_value = mock_request - - mock_credentials = MagicMock() - mock_credentials.token = "default_access_token" - mock_credentials.to_json.return_value = ( - '{"token": "default_access_token", "type": "authorized_user"}' - ) - mock_google_auth_default.return_value = (mock_credentials, "test-project") - - # Create test credential with use_default_credential=True - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - result = await exchanger.exchange(credential, auth_scheme) - - # Verify the result - assert result.auth_type == AuthCredentialTypes.OAUTH2 - assert result.google_oauth2_json is not None - # Verify that google_oauth2_json contains the token - import json - - exchanged_creds = json.loads(result.google_oauth2_json) - assert exchanged_creds.get( - "token" - ) == "default_access_token" or "default_access_token" in str( - exchanged_creds - ) - - # Verify mocks were called correctly - mock_google_auth_default.assert_called_once() - mock_credentials.refresh.assert_called_once_with(mock_request) - - @pytest.mark.asyncio - async def test_exchange_missing_service_account(self): - """Test exchange fails when service_account is None.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=None, - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Service account credentials are missing" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - async def test_exchange_missing_credentials_and_not_default(self): - """Test exchange fails when credentials are missing and use_default_credential is False.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=None, - use_default_credential=False, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Service account credentials are invalid" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" - ) - async def test_exchange_credential_creation_failure( - self, mock_from_service_account_info - ): - """Test exchange handles credential creation failure gracefully.""" - # Setup mock to raise exception - mock_from_service_account_info.side_effect = Exception( - "Invalid private key" - ) - - # Create test credential - service_account_cred = ServiceAccountCredential( - type_="service_account", - project_id="test-project", - private_key_id="key-id", - private_key="invalid-key", - client_email="test@test-project.iam.gserviceaccount.com", - client_id="12345", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", - universe_domain="googleapis.com", - ) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=service_account_cred, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Failed to exchange service account token" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" - ) - async def test_exchange_default_credential_failure( - self, mock_google_auth_default - ): - """Test exchange handles default credential failure gracefully.""" - # Setup mock to raise exception - mock_google_auth_default.side_effect = Exception( - "No default credentials found" - ) - - # Create test credential with use_default_credential=True - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Failed to exchange service account token" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.service_account.Credentials.from_service_account_info" - ) - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.Request" - ) - async def test_exchange_refresh_failure( - self, mock_request_class, mock_from_service_account_info - ): - """Test exchange handles credential refresh failure gracefully.""" - # Setup mocks - mock_request = MagicMock() - mock_request_class.return_value = mock_request - - mock_credentials = MagicMock() - mock_credentials.refresh.side_effect = Exception( - "Network error during refresh" - ) - mock_from_service_account_info.return_value = mock_credentials - - # Create test credential - service_account_cred = ServiceAccountCredential( - type_="service_account", - project_id="test-project", - private_key_id="key-id", - private_key=( - "-----BEGIN PRIVATE KEY-----\nMOCK_KEY\n-----END PRIVATE KEY-----" - ), - client_email="test@test-project.iam.gserviceaccount.com", - client_id="12345", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", - universe_domain="googleapis.com", - ) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=service_account_cred, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises( - ValueError, match="Failed to exchange service account token" - ): - await exchanger.exchange(credential, auth_scheme) - - @pytest.mark.asyncio - async def test_exchange_none_credential_in_constructor(self): - """Test that passing None credential raises appropriate error during exchange.""" - # This test verifies behavior when credential is None - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - - with pytest.raises(ValueError, match="Credential cannot be None"): - await exchanger.exchange(None, auth_scheme) - - @pytest.mark.asyncio - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.google.auth.default" - ) - @patch( - "google.adk.auth.exchanger.service_account_credential_exchanger.Request" - ) - async def test_exchange_with_service_account_no_explicit_credentials( - self, mock_request_class, mock_google_auth_default - ): - """Test exchange with service account that has no explicit credentials uses default.""" - # Setup mocks - mock_request = MagicMock() - mock_request_class.return_value = mock_request - - mock_credentials = MagicMock() - mock_credentials.token = "default_access_token" - mock_credentials.to_json.return_value = ( - '{"token": "default_access_token", "type": "authorized_user"}' - ) - mock_google_auth_default.return_value = (mock_credentials, "test-project") - - # Create test credential with no explicit credentials but use_default_credential=True - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount( - service_account_credential=None, - use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ), - ) - - auth_scheme = HTTPBearer() - exchanger = ServiceAccountCredentialExchanger() - result = await exchanger.exchange(credential, auth_scheme) - - # Verify the result - assert result.auth_type == AuthCredentialTypes.OAUTH2 - assert result.google_oauth2_json is not None - # Verify that google_oauth2_json contains the token - import json - - exchanged_creds = json.loads(result.google_oauth2_json) - assert exchanged_creds.get( - "token" - ) == "default_access_token" or "default_access_token" in str( - exchanged_creds - ) - - # Verify mocks were called correctly - mock_google_auth_default.assert_called_once() - mock_credentials.refresh.assert_called_once_with(mock_request) diff --git a/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py index b22bf2ccd..3342fcb05 100644 --- a/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py +++ b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py @@ -165,124 +165,6 @@ async def test_refresh_no_oauth2_credential(self): assert result == credential - @pytest.mark.asyncio - async def test_needs_refresh_google_oauth2_json_expired(self): - """Test needs_refresh with Google OAuth2 JSON credential that is expired.""" - import json - from unittest.mock import patch - - # Mock Google OAuth2 JSON credential data - google_oauth2_json = json.dumps({ - "client_id": "test_client_id", - "client_secret": "test_client_secret", - "refresh_token": "test_refresh_token", - "type": "authorized_user", - }) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json=google_oauth2_json, - ) - - # Mock the Google Credentials class - with patch( - "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" - ) as mock_credentials: - mock_google_credential = Mock() - mock_google_credential.expired = True - mock_google_credential.refresh_token = "test_refresh_token" - mock_credentials.from_authorized_user_info.return_value = ( - mock_google_credential - ) - - refresher = OAuth2CredentialRefresher() - needs_refresh = await refresher.is_refresh_needed(credential, None) - - assert needs_refresh - - @pytest.mark.asyncio - async def test_needs_refresh_google_oauth2_json_not_expired(self): - """Test needs_refresh with Google OAuth2 JSON credential that is not expired.""" - import json - from unittest.mock import patch - - # Mock Google OAuth2 JSON credential data - google_oauth2_json = json.dumps({ - "client_id": "test_client_id", - "client_secret": "test_client_secret", - "refresh_token": "test_refresh_token", - "type": "authorized_user", - }) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json=google_oauth2_json, - ) - - # Mock the Google Credentials class - with patch( - "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" - ) as mock_credentials: - mock_google_credential = Mock() - mock_google_credential.expired = False - mock_google_credential.refresh_token = "test_refresh_token" - mock_credentials.from_authorized_user_info.return_value = ( - mock_google_credential - ) - - refresher = OAuth2CredentialRefresher() - needs_refresh = await refresher.is_refresh_needed(credential, None) - - assert not needs_refresh - - @pytest.mark.asyncio - async def test_refresh_google_oauth2_json_success(self): - """Test successful refresh of Google OAuth2 JSON credential.""" - import json - from unittest.mock import patch - - # Mock Google OAuth2 JSON credential data - google_oauth2_json = json.dumps({ - "client_id": "test_client_id", - "client_secret": "test_client_secret", - "refresh_token": "test_refresh_token", - "type": "authorized_user", - }) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json=google_oauth2_json, - ) - - # Mock the Google Credentials and Request classes - with patch( - "google.adk.auth.refresher.oauth2_credential_refresher.Credentials" - ) as mock_credentials: - with patch( - "google.adk.auth.refresher.oauth2_credential_refresher.Request" - ) as mock_request: - mock_google_credential = Mock() - mock_google_credential.expired = True - mock_google_credential.refresh_token = "test_refresh_token" - mock_google_credential.to_json.return_value = json.dumps({ - "client_id": "test_client_id", - "client_secret": "test_client_secret", - "refresh_token": "new_refresh_token", - "access_token": "new_access_token", - "type": "authorized_user", - }) - mock_credentials.from_authorized_user_info.return_value = ( - mock_google_credential - ) - - refresher = OAuth2CredentialRefresher() - result = await refresher.refresh(credential, None) - - mock_google_credential.refresh.assert_called_once() - assert ( - result.google_oauth2_json != google_oauth2_json - ) # Should be updated - @pytest.mark.asyncio async def test_needs_refresh_no_oauth2_credential(self): """Test needs_refresh with no OAuth2 credential returns False.""" diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py index 283e865a7..8e3638dd6 100644 --- a/tests/unittests/auth/test_credential_manager.py +++ b/tests/unittests/auth/test_credential_manager.py @@ -410,39 +410,25 @@ async def test_validate_credential_oauth2_missing_oauth2_field(self): @pytest.mark.asyncio async def test_exchange_credentials_service_account(self): - """Test _exchange_credential with service account credential.""" + """Test _exchange_credential with service account credential (no exchanger available).""" mock_raw_credential = Mock(spec=AuthCredential) mock_raw_credential.auth_type = AuthCredentialTypes.SERVICE_ACCOUNT - mock_exchanged_credential = Mock(spec=AuthCredential) - auth_config = Mock(spec=AuthConfig) auth_config.auth_scheme = Mock() manager = CredentialManager(auth_config) - # Mock the exchanger that gets created during registration + # Mock the exchanger registry to return None (no exchanger available) with patch.object( - manager._exchanger_registry, "get_exchanger" - ) as mock_get_exchanger: - mock_exchanger = Mock() - mock_exchanger.exchange = AsyncMock( - return_value=mock_exchanged_credential - ) - mock_get_exchanger.return_value = mock_exchanger - + manager._exchanger_registry, "get_exchanger", return_value=None + ): result, was_exchanged = await manager._exchange_credential( mock_raw_credential ) - assert result == mock_exchanged_credential - assert was_exchanged is True - mock_get_exchanger.assert_called_once_with( - AuthCredentialTypes.SERVICE_ACCOUNT - ) - mock_exchanger.exchange.assert_called_once_with( - mock_raw_credential, auth_config.auth_scheme - ) + assert result == mock_raw_credential + assert was_exchanged is False @pytest.mark.asyncio async def test_exchange_credential_no_exchanger(self): diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 4d9cffb4d..d25a84eac 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -263,40 +263,6 @@ async def test_get_headers_api_key(self): assert headers == {"X-API-Key": "my_api_key"} - @pytest.mark.asyncio - @patch("google.adk.tools.mcp_tool.mcp_tool.json") - @patch("google.adk.tools.mcp_tool.mcp_tool.Credentials") - async def test_get_headers_google_oauth2_json( - self, mock_credentials, mock_json - ): - """Test header generation for Google OAuth2 JSON credentials.""" - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - # Mock the JSON parsing and Credentials creation - mock_json.loads.return_value = {"token": "google_token"} - mock_google_credential = Mock() - mock_google_credential.token = "google_access_token" - mock_credentials.from_authorized_user_info.return_value = ( - mock_google_credential - ) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - google_oauth2_json='{"token": "google_token"}', - ) - - tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, credential) - - assert headers == {"Authorization": "Bearer google_access_token"} - mock_json.loads.assert_called_once_with('{"token": "google_token"}') - mock_credentials.from_authorized_user_info.assert_called_once_with( - {"token": "google_token"} - ) - @pytest.mark.asyncio async def test_get_headers_no_credential(self): """Test header generation with no credentials.""" @@ -311,14 +277,14 @@ async def test_get_headers_no_credential(self): assert headers is None @pytest.mark.asyncio - async def test_get_headers_service_account_no_json(self): - """Test header generation for service account credentials without google_oauth2_json.""" + async def test_get_headers_service_account(self): + """Test header generation for service account credentials.""" tool = MCPTool( mcp_tool=self.mock_mcp_tool, mcp_session_manager=self.mock_session_manager, ) - # Create service account credential without google_oauth2_json + # Create service account credential service_account = ServiceAccount(scopes=["test"]) credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, @@ -328,7 +294,7 @@ async def test_get_headers_service_account_no_json(self): tool_context = Mock(spec=ToolContext) headers = await tool._get_headers(tool_context, credential) - # Should return None as no google_oauth2_json is provided + # Should return None as service account credentials are not supported for direct header generation assert headers is None @pytest.mark.asyncio