From 18dd87a648c240c8c1e6750bd07614ee9676cfb5 Mon Sep 17 00:00:00 2001 From: avidela Date: Mon, 16 Jun 2025 20:33:11 +1200 Subject: [PATCH 1/2] feat: add optional per-agent Vertex AI project and location configuration --- src/google/adk/models/anthropic_llm.py | 20 ++-- src/google/adk/models/google_llm.py | 25 +++-- .../models/test_vertex_per_agent_config.py | 97 +++++++++++++++++++ 3 files changed, 128 insertions(+), 14 deletions(-) create mode 100644 tests/unittests/models/test_vertex_per_agent_config.py diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index a3a0e0962..15126c684 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -205,10 +205,14 @@ class Claude(BaseLlm): Attributes: model: The name of the Claude model. + project_id: Optional Google Cloud project ID. If not provided, uses GOOGLE_CLOUD_PROJECT environment variable. + location: Optional Google Cloud location. If not provided, uses GOOGLE_CLOUD_LOCATION environment variable. """ model: str = "claude-3-5-sonnet-v2@20241022" - + project_id: Optional[str] = None + location: Optional[str] = None + @staticmethod @override def supported_models() -> list[str]: @@ -250,16 +254,16 @@ async def generate_content_async( @cached_property def _anthropic_client(self) -> AnthropicVertex: - if ( - "GOOGLE_CLOUD_PROJECT" not in os.environ - or "GOOGLE_CLOUD_LOCATION" not in os.environ - ): + project = self.project_id or os.environ.get("GOOGLE_CLOUD_PROJECT") + location = self.location or os.environ.get("GOOGLE_CLOUD_LOCATION") + + if not project or not location: raise ValueError( "GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using" " Anthropic on Vertex." ) - + return AnthropicVertex( - project_id=os.environ["GOOGLE_CLOUD_PROJECT"], - region=os.environ["GOOGLE_CLOUD_LOCATION"], + project_id=project, + region=location, ) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index bff2b675c..b68dff53f 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -24,7 +24,7 @@ from typing import cast from typing import TYPE_CHECKING from typing import Union - +from typing import Optional from google.genai import Client from google.genai import types from typing_extensions import override @@ -52,10 +52,15 @@ class Gemini(BaseLlm): Attributes: model: The name of the Gemini model. + project_id: Optional Google Cloud project ID. If not provided, uses GOOGLE_CLOUD_PROJECT environment variable. + location: Optional Google Cloud location. If not provided, uses GOOGLE_CLOUD_LOCATION environment variable. + """ model: str = 'gemini-1.5-flash' - + project_id: Optional[str] = None + location: Optional[str] = None + @staticmethod @override def supported_models() -> list[str]: @@ -177,14 +182,22 @@ async def generate_content_async( @cached_property def api_client(self) -> Client: - """Provides the api client. + """Provides the api client with per-instance configuration support. Returns: The api client. """ - return Client( - http_options=types.HttpOptions(headers=self._tracking_headers) - ) + if self.project_id or self.location: + return Client( + vertexai=True, + project=self.project_id, + location=self.location, + http_options=types.HttpOptions(headers=self._tracking_headers) + ) + else: + return Client( + http_options=types.HttpOptions(headers=self._tracking_headers) + ) @cached_property def _api_backend(self) -> GoogleLLMVariant: diff --git a/tests/unittests/models/test_vertex_per_agent_config.py b/tests/unittests/models/test_vertex_per_agent_config.py new file mode 100644 index 000000000..8a108e9d4 --- /dev/null +++ b/tests/unittests/models/test_vertex_per_agent_config.py @@ -0,0 +1,97 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch +from src.google.adk.models.anthropic_llm import Claude +from src.google.adk.models.google_llm import Gemini + + +def test_claude_custom_config(): + claude = Claude( + project_id="test-project-claude", + location="us-central1" + ) + + assert claude.project_id == "test-project-claude" + assert claude.location == "us-central1" + + +def test_gemini_custom_config(): + gemini = Gemini( + project_id="test-project-gemini", + location="europe-west1" + ) + + assert gemini.project_id == "test-project-gemini" + assert gemini.location == "europe-west1" + + +def test_claude_per_instance_configuration(): + claude1 = Claude(project_id="project-1", location="us-central1") + claude2 = Claude(project_id="project-2", location="europe-west1") + claude3 = Claude() + + assert claude1.project_id == "project-1" + assert claude1.location == "us-central1" + + assert claude2.project_id == "project-2" + assert claude2.location == "europe-west1" + + assert claude3.project_id is None + assert claude3.location is None + + +def test_gemini_per_instance_configuration(): + gemini1 = Gemini(project_id="project-1", location="us-central1") + gemini2 = Gemini(project_id="project-2", location="europe-west1") + gemini3 = Gemini() + + assert gemini1.project_id == "project-1" + assert gemini1.location == "us-central1" + + assert gemini2.project_id == "project-2" + assert gemini2.location == "europe-west1" + + assert gemini3.project_id is None + assert gemini3.location is None + + +def test_backward_compatibility(): + claude = Claude() + gemini = Gemini() + + assert claude.project_id is None + assert claude.location is None + assert gemini.project_id is None + assert gemini.location is None + + +@patch.dict('os.environ', {'GOOGLE_CLOUD_PROJECT': 'env-project', 'GOOGLE_CLOUD_LOCATION': 'env-location'}) +def test_claude_fallback_to_env_vars(): + claude = Claude() + + cache_key = f"{claude.project_id or 'default'}:{claude.location or 'default'}" + assert cache_key == "default:default" + + +def test_mixed_configuration(): + claude_custom = Claude(project_id="custom-project", location="us-west1") + claude_default = Claude() + + key_custom = f"{claude_custom.project_id or 'default'}:{claude_custom.location or 'default'}" + key_default = f"{claude_default.project_id or 'default'}:{claude_default.location or 'default'}" + + assert key_custom != key_default + assert key_custom == "custom-project:us-west1" + assert key_default == "default:default" From 3620fd6f07abe4c2786e232af7985598d3c69979 Mon Sep 17 00:00:00 2001 From: avidela Date: Mon, 16 Jun 2025 20:36:06 +1200 Subject: [PATCH 2/2] style: run autoformat --- src/google/adk/models/anthropic_llm.py | 6 +- src/google/adk/models/google_llm.py | 7 +- .../models/test_vertex_per_agent_config.py | 123 +++++++++--------- 3 files changed, 71 insertions(+), 65 deletions(-) diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 15126c684..08211683b 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -212,7 +212,7 @@ class Claude(BaseLlm): model: str = "claude-3-5-sonnet-v2@20241022" project_id: Optional[str] = None location: Optional[str] = None - + @staticmethod @override def supported_models() -> list[str]: @@ -256,13 +256,13 @@ async def generate_content_async( def _anthropic_client(self) -> AnthropicVertex: project = self.project_id or os.environ.get("GOOGLE_CLOUD_PROJECT") location = self.location or os.environ.get("GOOGLE_CLOUD_LOCATION") - + if not project or not location: raise ValueError( "GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using" " Anthropic on Vertex." ) - + return AnthropicVertex( project_id=project, region=location, diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index b68dff53f..3c2109286 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -22,9 +22,10 @@ import sys from typing import AsyncGenerator from typing import cast +from typing import Optional from typing import TYPE_CHECKING from typing import Union -from typing import Optional + from google.genai import Client from google.genai import types from typing_extensions import override @@ -60,7 +61,7 @@ class Gemini(BaseLlm): model: str = 'gemini-1.5-flash' project_id: Optional[str] = None location: Optional[str] = None - + @staticmethod @override def supported_models() -> list[str]: @@ -192,7 +193,7 @@ def api_client(self) -> Client: vertexai=True, project=self.project_id, location=self.location, - http_options=types.HttpOptions(headers=self._tracking_headers) + http_options=types.HttpOptions(headers=self._tracking_headers), ) else: return Client( diff --git a/tests/unittests/models/test_vertex_per_agent_config.py b/tests/unittests/models/test_vertex_per_agent_config.py index 8a108e9d4..1b79a3a00 100644 --- a/tests/unittests/models/test_vertex_per_agent_config.py +++ b/tests/unittests/models/test_vertex_per_agent_config.py @@ -13,85 +13,90 @@ # limitations under the License. from unittest.mock import patch + from src.google.adk.models.anthropic_llm import Claude from src.google.adk.models.google_llm import Gemini def test_claude_custom_config(): - claude = Claude( - project_id="test-project-claude", - location="us-central1" - ) - - assert claude.project_id == "test-project-claude" - assert claude.location == "us-central1" + claude = Claude(project_id="test-project-claude", location="us-central1") + + assert claude.project_id == "test-project-claude" + assert claude.location == "us-central1" def test_gemini_custom_config(): - gemini = Gemini( - project_id="test-project-gemini", - location="europe-west1" - ) - - assert gemini.project_id == "test-project-gemini" - assert gemini.location == "europe-west1" + gemini = Gemini(project_id="test-project-gemini", location="europe-west1") + + assert gemini.project_id == "test-project-gemini" + assert gemini.location == "europe-west1" def test_claude_per_instance_configuration(): - claude1 = Claude(project_id="project-1", location="us-central1") - claude2 = Claude(project_id="project-2", location="europe-west1") - claude3 = Claude() - - assert claude1.project_id == "project-1" - assert claude1.location == "us-central1" - - assert claude2.project_id == "project-2" - assert claude2.location == "europe-west1" - - assert claude3.project_id is None - assert claude3.location is None + claude1 = Claude(project_id="project-1", location="us-central1") + claude2 = Claude(project_id="project-2", location="europe-west1") + claude3 = Claude() + + assert claude1.project_id == "project-1" + assert claude1.location == "us-central1" + + assert claude2.project_id == "project-2" + assert claude2.location == "europe-west1" + + assert claude3.project_id is None + assert claude3.location is None def test_gemini_per_instance_configuration(): - gemini1 = Gemini(project_id="project-1", location="us-central1") - gemini2 = Gemini(project_id="project-2", location="europe-west1") - gemini3 = Gemini() - - assert gemini1.project_id == "project-1" - assert gemini1.location == "us-central1" - - assert gemini2.project_id == "project-2" - assert gemini2.location == "europe-west1" - - assert gemini3.project_id is None - assert gemini3.location is None + gemini1 = Gemini(project_id="project-1", location="us-central1") + gemini2 = Gemini(project_id="project-2", location="europe-west1") + gemini3 = Gemini() + + assert gemini1.project_id == "project-1" + assert gemini1.location == "us-central1" + + assert gemini2.project_id == "project-2" + assert gemini2.location == "europe-west1" + + assert gemini3.project_id is None + assert gemini3.location is None def test_backward_compatibility(): - claude = Claude() - gemini = Gemini() - - assert claude.project_id is None - assert claude.location is None - assert gemini.project_id is None - assert gemini.location is None + claude = Claude() + gemini = Gemini() + + assert claude.project_id is None + assert claude.location is None + assert gemini.project_id is None + assert gemini.location is None -@patch.dict('os.environ', {'GOOGLE_CLOUD_PROJECT': 'env-project', 'GOOGLE_CLOUD_LOCATION': 'env-location'}) +@patch.dict( + "os.environ", + { + "GOOGLE_CLOUD_PROJECT": "env-project", + "GOOGLE_CLOUD_LOCATION": "env-location", + }, +) def test_claude_fallback_to_env_vars(): - claude = Claude() - - cache_key = f"{claude.project_id or 'default'}:{claude.location or 'default'}" - assert cache_key == "default:default" + claude = Claude() + + cache_key = f"{claude.project_id or 'default'}:{claude.location or 'default'}" + assert cache_key == "default:default" def test_mixed_configuration(): - claude_custom = Claude(project_id="custom-project", location="us-west1") - claude_default = Claude() - - key_custom = f"{claude_custom.project_id or 'default'}:{claude_custom.location or 'default'}" - key_default = f"{claude_default.project_id or 'default'}:{claude_default.location or 'default'}" - - assert key_custom != key_default - assert key_custom == "custom-project:us-west1" - assert key_default == "default:default" + claude_custom = Claude(project_id="custom-project", location="us-west1") + claude_default = Claude() + + key_custom = ( + f"{claude_custom.project_id or 'default'}:{claude_custom.location or 'default'}" + ) + key_default = ( + f"{claude_default.project_id or 'default'}:{claude_default.location or 'default'}" + ) + + assert key_custom != key_default + assert key_custom == "custom-project:us-west1" + assert key_default == "default:default"