diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 7c2ffdd95..f04f3f039 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -31,5 +31,8 @@ If applicable, add screenshots to help explain your problem. - Python version(python -V): - ADK version(pip show google-adk): + **Model Information:** + For example, which model is being used. + **Additional context** Add any other context about the problem here. diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index a504fde0d..52e61b8a3 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -36,18 +36,26 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Install uv - run: curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v6 - name: Install dependencies run: | uv venv .venv source .venv/bin/activate - uv sync --extra test --extra eval + uv sync --extra test --extra eval --extra a2a - name: Run unit tests with pytest run: | source .venv/bin/activate - pytest tests/unittests \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + if [[ "${{ matrix.python-version }}" == "3.9" ]]; then + pytest tests/unittests \ + --ignore=tests/unittests/a2a \ + --ignore=tests/unittests/tools/mcp_tool \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + else + pytest tests/unittests \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + fi \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6fb068d48..6f398cbf9 100644 --- a/.gitignore +++ b/.gitignore @@ -82,6 +82,7 @@ log/ .env.development.local .env.test.local .env.production.local +uv.lock # Google Cloud specific .gcloudignore diff --git a/CHANGELOG.md b/CHANGELOG.md index 04740bb7a..b6bba2692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,84 @@ # Changelog +## [1.5.0](https://github.com/google/adk-python/compare/v1.4.2...v1.5.0) (2025-06-25) + + +### Features + +* Add a new option `eval_storage_uri` in adk web & adk eval to specify GCS bucket to store eval data ([fa025d7](https://github.com/google/adk-python/commit/fa025d755978e1506fa0da1fecc49775bebc1045)) +* Add ADK examples for litellm with add_function_to_prompt ([f33e090](https://github.com/google/adk-python/commit/f33e0903b21b752168db3006dd034d7d43f7e84d)) +* Add implementation of VertexAiMemoryBankService and support in FastAPI endpoint ([abc89d2](https://github.com/google/adk-python/commit/abc89d2c811ba00805f81b27a3a07d56bdf55a0b)) +* Add rouge_score library to ADK eval dependencies, and implement RougeEvaluator that is computes ROUGE-1 for "response_match_score" metric ([9597a44](https://github.com/google/adk-python/commit/9597a446fdec63ad9e4c2692d6966b14f80ff8e2)) +* Add usage span attributes to telemetry ([#356](https://github.com/google/adk-python/issues/356)) ([ea69c90](https://github.com/google/adk-python/commit/ea69c9093a16489afdf72657136c96f61c69cafd)) +* Add Vertex Express mode compatibility for VertexAiSessionService ([00cc8cd](https://github.com/google/adk-python/commit/00cc8cd6433fc45ecfc2dbaa04dbbc1a81213b4d)) + + +### Bug Fixes + +* Include current turn context when include_contents='none' ([9e473e0](https://github.com/google/adk-python/commit/9e473e0abdded24e710fd857782356c15d04b515)) +* Make LiteLLM streaming truly asynchronous ([bd67e84](https://github.com/google/adk-python/commit/bd67e8480f6e8b4b0f8c22b94f15a8cda1336339)) +* Make raw_auth_credential and exchanged_auth_credential optional given their default value is None ([acbdca0](https://github.com/google/adk-python/commit/acbdca0d8400e292ba5525931175e0d6feab15f1)) +* Minor typo fix in the agent instruction ([ef3c745](https://github.com/google/adk-python/commit/ef3c745d655538ebd1ed735671be615f842341a8)) +* Typo fix in sample agent instruction ([ef3c745](https://github.com/google/adk-python/commit/ef3c745d655538ebd1ed735671be615f842341a8)) +* Update contributing links ([a1e1441](https://github.com/google/adk-python/commit/a1e14411159fd9f3e114e15b39b4949d0fd6ecb1)) +* Use starred tuple unpacking on GCS artifact blob names ([3b1d9a8](https://github.com/google/adk-python/commit/3b1d9a8a3e631ca2d86d30f09640497f1728986c)) + + +### Chore + +* Do not send api request when session does not have events ([88a4402](https://github.com/google/adk-python/commit/88a4402d142672171d0a8ceae74671f47fa14289)) +* Leverage official uv action for install([09f1269](https://github.com/google/adk-python/commit/09f1269bf7fa46ab4b9324e7f92b4f70ffc923e5)) +* Update google-genai package and related deps to latest([ed7a21e](https://github.com/google/adk-python/commit/ed7a21e1890466fcdf04f7025775305dc71f603d)) +* Add credential service backed by session state([29cd183](https://github.com/google/adk-python/commit/29cd183aa1b47dc4f5d8afe22f410f8546634abc)) +* Clarify the behavior of Event.invocation_id([f033e40](https://github.com/google/adk-python/commit/f033e405c10ff8d86550d1419a9d63c0099182f9)) +* Send user message to the agent that returned a corresponding function call if user message is a function response([7c670f6](https://github.com/google/adk-python/commit/7c670f638bc17374ceb08740bdd057e55c9c2e12)) +* Add request converter to convert a2a request to ADK request([fb13963](https://github.com/google/adk-python/commit/fb13963deda0ff0650ac27771711ea0411474bf5)) +* Support allow_origins in cloud_run deployment ([2fd8feb](https://github.com/google/adk-python/commit/2fd8feb65d6ae59732fb3ec0652d5650f47132cc)) + +## [1.4.2](https://github.com/google/adk-python/compare/v1.4.1...v1.4.2) (2025-06-20) + + +### Bug Fixes + +* Add type checking to handle different response type of genai API client ([4d72d31](https://github.com/google/adk-python/commit/4d72d31b13f352245baa72b78502206dcbe25406)) + * This fixes the broken VertexAiSessionService +* Allow more credentials types for BigQuery tools ([2f716ad](https://github.com/google/adk-python/commit/2f716ada7fbcf8e03ff5ae16ce26a80ca6fd7bf6)) + +## [1.4.1](https://github.com/google/adk-python/compare/v1.3.0...v1.4.1) (2025-06-18) + + +### Features + +* Add Authenticated Tool (Experimental) ([dcea776](https://github.com/google/adk-python/commit/dcea7767c67c7edfb694304df32dca10b74c9a71)) +* Add enable_affective_dialog and proactivity to run_config and llm_request ([fe1d5aa](https://github.com/google/adk-python/commit/fe1d5aa439cc56b89d248a52556c0a9b4cbd15e4)) +* Add import session API in the fast API ([233fd20](https://github.com/google/adk-python/commit/233fd2024346abd7f89a16c444de0cf26da5c1a1)) +* Add integration tests for litellm with and without turning on add_function_to_prompt ([8e28587](https://github.com/google/adk-python/commit/8e285874da7f5188ea228eb4d7262dbb33b1ae6f)) +* Allow data_store_specs pass into ADK VAIS built-in tool ([675faef](https://github.com/google/adk-python/commit/675faefc670b5cd41991939fe0fc604df331111a)) +* Enable MCP Tool Auth (Experimental) ([157d9be](https://github.com/google/adk-python/commit/157d9be88d92f22320604832e5a334a6eb81e4af)) +* Implement GcsEvalSetResultsManager to handle storage of eval sets on GCS, and refactor eval set results manager ([0a5cf45](https://github.com/google/adk-python/commit/0a5cf45a75aca7b0322136b65ca5504a0c3c7362)) +* Re-factor some eval sets manager logic, and implement GcsEvalSetsManager to handle storage of eval sets on GCS ([1551bd4](https://github.com/google/adk-python/commit/1551bd4f4d7042fffb497d9308b05f92d45d818f)) +* Support real time input config ([d22920b](https://github.com/google/adk-python/commit/d22920bd7f827461afd649601326b0c58aea6716)) +* Support refresh access token automatically for rest_api_tool ([1779801](https://github.com/google/adk-python/commit/177980106b2f7be9a8c0a02f395ff0f85faa0c5a)) + +### Bug Fixes + +* Fix Agent generate config err ([#1305](https://github.com/google/adk-python/issues/1305)) ([badbcbd](https://github.com/google/adk-python/commit/badbcbd7a464e6b323cf3164d2bcd4e27cbc057f)) +* Fix Agent generate config error ([#1450](https://github.com/google/adk-python/issues/1450)) ([694b712](https://github.com/google/adk-python/commit/694b71256c631d44bb4c4488279ea91d82f43e26)) +* Fix liteLLM test failures ([fef8778](https://github.com/google/adk-python/commit/fef87784297b806914de307f48c51d83f977298f)) +* Fix tracing for live ([58e07ca](https://github.com/google/adk-python/commit/58e07cae83048d5213d822be5197a96be9ce2950)) +* Merge custom http options with adk specific http options in model api request ([4ccda99](https://github.com/google/adk-python/commit/4ccda99e8ec7aa715399b4b83c3f101c299a95e8)) +* Remove unnecessary double quote on Claude docstring ([bbceb4f](https://github.com/google/adk-python/commit/bbceb4f2e89f720533b99cf356c532024a120dc4)) +* Set explicit project in the BigQuery client ([6d174eb](https://github.com/google/adk-python/commit/6d174eba305a51fcf2122c0fd481378752d690ef)) +* Support streaming in litellm + adk and add corresponding integration tests ([aafa80b](https://github.com/google/adk-python/commit/aafa80bd85a49fb1c1a255ac797587cffd3fa567)) +* Support project-based gemini model path to use google_search_tool ([b2fc774](https://github.com/google/adk-python/commit/b2fc7740b363a4e33ec99c7377f396f5cee40b5a)) +* Update conversion between Celsius and Fahrenheit ([1ae176a](https://github.com/google/adk-python/commit/1ae176ad2fa2b691714ac979aec21f1cf7d35e45)) + +### Chores + +* Set `agent_engine_id` in the VertexAiSessionService constructor, also use the `agent_engine_id` field instead of overriding `app_name` in FastAPI endpoint ([fc65873](https://github.com/google/adk-python/commit/fc65873d7c31be607f6cd6690f142a031631582a)) + + + ## [1.3.0](https://github.com/google/adk-python/compare/v1.2.1...v1.3.0) (2025-06-11) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c0f3d0069..0d7b2d67d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -200,7 +200,7 @@ For any changes that impact user-facing documentation (guides, API reference, tu ## Contributing Resources -[Contributing folder](https://github.com/google/adk-python/tree/main/contributing/samples) has resources that is helpful for contributors. +[Contributing folder](https://github.com/google/adk-python/tree/main/contributing) has resources that is helpful for contributors. ## Code reviews diff --git a/README.md b/README.md index 7bd5e7401..874658d07 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ adk eval \ ## 🤝 Contributing We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our -- [General contribution guideline and flow](https://google.github.io/adk-docs/contributing-guide/#questions). +- [General contribution guideline and flow](https://google.github.io/adk-docs/contributing-guide/). - Then if you want to contribute code, please read [Code Contributing Guidelines](./CONTRIBUTING.md) to get started. ## 📄 License diff --git a/contributing/samples/bigquery_agent/__init__.py b/contributing/samples/adk_issue_formatting_agent/__init__.py similarity index 100% rename from contributing/samples/bigquery_agent/__init__.py rename to contributing/samples/adk_issue_formatting_agent/__init__.py diff --git a/contributing/samples/adk_issue_formatting_agent/agent.py b/contributing/samples/adk_issue_formatting_agent/agent.py new file mode 100644 index 000000000..78add9b83 --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/agent.py @@ -0,0 +1,241 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any + +from adk_issue_formatting_agent.settings import GITHUB_BASE_URL +from adk_issue_formatting_agent.settings import IS_INTERACTIVE +from adk_issue_formatting_agent.settings import OWNER +from adk_issue_formatting_agent.settings import REPO +from adk_issue_formatting_agent.utils import error_response +from adk_issue_formatting_agent.utils import get_request +from adk_issue_formatting_agent.utils import post_request +from adk_issue_formatting_agent.utils import read_file +from google.adk import Agent +import requests + +BUG_REPORT_TEMPLATE = read_file( + Path(__file__).parent / "../../../../.github/ISSUE_TEMPLATE/bug_report.md" +) +FREATURE_REQUEST_TEMPLATE = read_file( + Path(__file__).parent + / "../../../../.github/ISSUE_TEMPLATE/feature_request.md" +) + +APPROVAL_INSTRUCTION = ( + "**Do not** wait or ask for user approval or confirmation for adding the" + " comment." +) +if IS_INTERACTIVE: + APPROVAL_INSTRUCTION = ( + "Ask for user approval or confirmation for adding the comment." + ) + + +def list_open_issues(issue_count: int) -> dict[str, Any]: + """List most recent `issue_count` numer of open issues in the repo. + + Args: + issue_count: number of issues to return + + Returns: + The status of this request, with a list of issues when successful. + """ + url = f"{GITHUB_BASE_URL}/search/issues" + query = f"repo:{OWNER}/{REPO} is:open is:issue" + params = { + "q": query, + "sort": "created", + "order": "desc", + "per_page": issue_count, + "page": 1, + } + + try: + response = get_request(url, params) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + issues = response.get("items", None) + return {"status": "success", "issues": issues} + + +def get_issue(issue_number: int) -> dict[str, Any]: + """Get the details of the specified issue number. + + Args: + issue_number: issue number of the Github issue. + + Returns: + The status of this request, with the issue details when successful. + """ + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}" + try: + response = get_request(url) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + return {"status": "success", "issue": response} + + +def add_comment_to_issue(issue_number: int, comment: str) -> dict[str, any]: + """Add the specified comment to the given issue number. + + Args: + issue_number: issue number of the Github issue + comment: comment to add + + Returns: + The the status of this request, with the applied comment when successful. + """ + print(f"Attempting to add comment '{comment}' to issue #{issue_number}") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}/comments" + payload = {"body": comment} + + try: + response = post_request(url, payload) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + return { + "status": "success", + "added_comment": response, + } + + +def list_comments_on_issue(issue_number: int) -> dict[str, any]: + """List all comments on the given issue number. + + Args: + issue_number: issue number of the Github issue + + Returns: + The the status of this request, with the list of comments when successful. + """ + print(f"Attempting to list comments on issue #{issue_number}") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}/comments" + + try: + response = get_request(url) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + return {"status": "success", "comments": response} + + +root_agent = Agent( + model="gemini-2.5-pro", + name="adk_issue_formatting_assistant", + description="Check ADK issue format and content.", + instruction=f""" + # 1. IDENTITY + You are an AI assistant designed to help maintain the quality and consistency of issues in our GitHub repository. + Your primary role is to act as a "GitHub Issue Format Validator." You will analyze new and existing **open** issues + to ensure they contain all the necessary information as required by our templates. You are helpful, polite, + and precise in your feedback. + + # 2. CONTEXT & RESOURCES + * **Repository:** You are operating on the GitHub repository `{OWNER}/{REPO}`. + * **Bug Report Template:** (`{BUG_REPORT_TEMPLATE}`) + * **Feature Request Template:** (`{FREATURE_REQUEST_TEMPLATE}`) + + # 3. CORE MISSION + Your goal is to check if a GitHub issue, identified as either a "bug" or a "feature request," + contains all the information required by the corresponding template. If it does not, your job is + to post a single, helpful comment asking the original author to provide the missing information. + {APPROVAL_INSTRUCTION} + + **IMPORTANT NOTE:** + * You add one comment at most each time you are invoked. + * Don't proceed to other issues which are not the target issues. + * Don't take any action on closed issues. + + # 4. BEHAVIORAL RULES & LOGIC + + ## Step 1: Identify Issue Type & Applicability + + Your first task is to determine if the issue is a valid target for validation. + + 1. **Assess Content Intent:** You must perform a quick semantic check of the issue's title, body, and comments. + If you determine the issue's content is fundamentally *not* a bug report or a feature request + (for example, it is a general question, a request for help, or a discussion prompt), then you must ignore it. + 2. **Exit Condition:** If the issue does not clearly fall into the categories of "bug" or "feature request" + based on both its labels and its content, **take no action**. + + ## Step 2: Analyze the Issue Content + + If you have determined the issue is a valid bug or feature request, your analysis depends on whether it has comments. + + **Scenario A: Issue has NO comments** + 1. Read the main body of the issue. + 2. Compare the content of the issue body against the required headings/sections in the relevant template (Bug or Feature). + 3. Check for the presence of content under each heading. A heading with no content below it is considered incomplete. + 4. If one or more sections are missing or empty, proceed to Step 3. + 5. If all sections are filled out, your task is complete. Do nothing. + + **Scenario B: Issue HAS one or more comments** + 1. First, analyze the main issue body to see which sections of the template are filled out. + 2. Next, read through **all** the comments in chronological order. + 3. As you read the comments, check if the information provided in them satisfies any of the template sections that were missing from the original issue body. + 4. After analyzing the body and all comments, determine if any required sections from the template *still* remain unaddressed. + 5. If one or more sections are still missing information, proceed to Step 3. + 6. If the issue body and comments *collectively* provide all the required information, your task is complete. Do nothing. + + ## Step 3: Formulate and Post a Comment (If Necessary) + + If you determined in Step 2 that information is missing, you must post a **single comment** on the issue. + + Please include a bolded note in your comment that this comment was added by an ADK agent. + + **Comment Guidelines:** + * **Be Polite and Helpful:** Start with a friendly tone. + * **Be Specific:** Clearly list only the sections from the template that are still missing. Do not list sections that have already been filled out. + * **Address the Author:** Mention the issue author by their username (e.g., `@username`). + * **Provide Context:** Explain *why* the information is needed (e.g., "to help us reproduce the bug" or "to better understand your request"). + * **Do not be repetitive:** If you have already commented on an issue asking for information, do not comment again unless new information has been added and it's still incomplete. + + **Example Comment for a Bug Report:** + > **Response from ADK Agent** + > + > Hello @[issue-author-username], thank you for submitting this issue! + > + > To help us investigate and resolve this bug effectively, could you please provide the missing details for the following sections of our bug report template: + > + > * **To Reproduce:** (Please provide the specific steps required to reproduce the behavior) + > * **Desktop (please complete the following information):** (Please provide OS, Python version, and ADK version) + > + > This information will give us the context we need to move forward. Thanks! + + **Example Comment for a Feature Request:** + > **Response from ADK Agent** + > + > Hi @[issue-author-username], thanks for this great suggestion! + > + > To help our team better understand and evaluate your feature request, could you please provide a bit more information on the following section: + > + > * **Is your feature request related to a problem? Please describe.** + > + > We look forward to hearing more about your idea! + + # 5. FINAL INSTRUCTION + + Execute this process for the given GitHub issue. Your final output should either be **[NO ACTION]** + if the issue is complete or invalid, or **[POST COMMENT]** followed by the exact text of the comment you will post. + + Please include your justification for your decision in your output. + """, + tools={ + list_open_issues, + get_issue, + add_comment_to_issue, + list_comments_on_issue, + }, +) diff --git a/contributing/samples/adk_issue_formatting_agent/settings.py b/contributing/samples/adk_issue_formatting_agent/settings.py new file mode 100644 index 000000000..d29bda9b7 --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/settings.py @@ -0,0 +1,33 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dotenv import load_dotenv + +load_dotenv(override=True) + +GITHUB_BASE_URL = "https://api.github.com" + +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +if not GITHUB_TOKEN: + raise ValueError("GITHUB_TOKEN environment variable not set") + +OWNER = os.getenv("OWNER", "google") +REPO = os.getenv("REPO", "adk-python") +EVENT_NAME = os.getenv("EVENT_NAME") +ISSUE_NUMBER = os.getenv("ISSUE_NUMBER") +ISSUE_COUNT_TO_PROCESS = os.getenv("ISSUE_COUNT_TO_PROCESS") + +IS_INTERACTIVE = os.environ.get("INTERACTIVE", "1").lower() in ["true", "1"] diff --git a/contributing/samples/adk_issue_formatting_agent/utils.py b/contributing/samples/adk_issue_formatting_agent/utils.py new file mode 100644 index 000000000..2ee735d3d --- /dev/null +++ b/contributing/samples/adk_issue_formatting_agent/utils.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from adk_issue_formatting_agent.settings import GITHUB_TOKEN +import requests + +headers = { + "Authorization": f"token {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +} + + +def get_request( + url: str, params: dict[str, Any] | None = None +) -> dict[str, Any]: + if params is None: + params = {} + response = requests.get(url, headers=headers, params=params, timeout=60) + response.raise_for_status() + return response.json() + + +def post_request(url: str, payload: Any) -> dict[str, Any]: + response = requests.post(url, headers=headers, json=payload, timeout=60) + response.raise_for_status() + return response.json() + + +def error_response(error_message: str) -> dict[str, Any]: + return {"status": "error", "message": error_message} + + +def read_file(file_path: str) -> str: + """Read the content of the given file.""" + try: + with open(file_path, "r") as f: + return f.read() + except FileNotFoundError: + print(f"Error: File not found: {file_path}.") + return "" diff --git a/contributing/samples/artifact_save_text/agent.py b/contributing/samples/artifact_save_text/agent.py index 53a7f300d..3ce43bcd1 100755 --- a/contributing/samples/artifact_save_text/agent.py +++ b/contributing/samples/artifact_save_text/agent.py @@ -31,7 +31,7 @@ async def log_query(tool_context: ToolContext, query: str): model='gemini-2.0-flash', name='log_agent', description='Log user query.', - instruction="""Always log the user query and reploy "kk, I've logged." + instruction="""Always log the user query and reply "kk, I've logged." """, tools=[log_query], generate_content_config=types.GenerateContentConfig( diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index cd4583c72..050ce1332 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -40,13 +40,28 @@ would set: ### With Application Default Credentials This mode is useful for quick development when the agent builder is the only -user interacting with the agent. The tools are initialized with the default -credentials present on the machine running the agent. +user interacting with the agent. The tools are run with these credentials. 1. Create application default credentials on the machine where the agent would be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc. -1. Set `RUN_WITH_ADC=True` in `agent.py` and run the agent +1. Set `CREDENTIALS_TYPE=None` in `agent.py` + +1. Run the agent + +### With Service Account Keys + +This mode is useful for quick development when the agent builder wants to run +the agent with service account credentials. The tools are run with these +credentials. + +1. Create service account key by following https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.SERVICE_ACCOUNT` in `agent.py` + +1. Download the key file and replace `"service_account_key.json"` with the path + +1. Run the agent ### With Interactive OAuth @@ -72,7 +87,7 @@ type. Note: don't create a separate .env, instead put it to the same .env file that stores your Vertex AI or Dev ML credentials -1. Set `RUN_WITH_ADC=False` in `agent.py` and run the agent +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.OAUTH2` in `agent.py` and run the agent ## Sample prompts diff --git a/contributing/samples/bigquery/agent.py b/contributing/samples/bigquery/agent.py index 0999ca12a..c1b265c00 100644 --- a/contributing/samples/bigquery/agent.py +++ b/contributing/samples/bigquery/agent.py @@ -15,24 +15,21 @@ import os from google.adk.agents import llm_agent +from google.adk.auth import AuthCredentialTypes from google.adk.tools.bigquery import BigQueryCredentialsConfig from google.adk.tools.bigquery import BigQueryToolset from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode import google.auth -RUN_WITH_ADC = False +# Define an appropriate credential type +CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2 +# Define BigQuery tool config tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) -if RUN_WITH_ADC: - # Initialize the tools to use the application default credentials. - application_default_credentials, _ = google.auth.default() - credentials_config = BigQueryCredentialsConfig( - credentials=application_default_credentials - ) -else: +if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: # Initiaze the tools to do interactive OAuth # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET # must be set @@ -40,6 +37,20 @@ client_id=os.getenv("OAUTH_CLIENT_ID"), client_secret=os.getenv("OAUTH_CLIENT_SECRET"), ) +elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT: + # Initialize the tools to use the credentials in the service account key. + # If this flow is enabled, make sure to replace the file path with your own + # service account key file + # https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys + creds, _ = google.auth.load_credentials_from_file("service_account_key.json") + credentials_config = BigQueryCredentialsConfig(credentials=creds) +else: + # Initialize the tools to use the application default credentials. + # https://cloud.google.com/docs/authentication/provide-credentials-adc + application_default_credentials, _ = google.auth.default() + credentials_config = BigQueryCredentialsConfig( + credentials=application_default_credentials + ) bigquery_toolset = BigQueryToolset( credentials_config=credentials_config, bigquery_tool_config=tool_config @@ -49,7 +60,7 @@ # debug CLI root_agent = llm_agent.Agent( model="gemini-2.0-flash", - name="hello_agent", + name="bigquery_agent", description=( "Agent to answer questions about BigQuery data and models and execute" " SQL queries." diff --git a/contributing/samples/bigquery_agent/README.md b/contributing/samples/google_api/README.md similarity index 50% rename from contributing/samples/bigquery_agent/README.md rename to contributing/samples/google_api/README.md index c7dc7fd8b..c1e6e8d4c 100644 --- a/contributing/samples/bigquery_agent/README.md +++ b/contributing/samples/google_api/README.md @@ -1,45 +1,40 @@ -# BigQuery Sample +# Google API Tools Sample ## Introduction -This sample tests and demos the BigQuery support in ADK via two tools: +This sample tests and demos Google API tools available in the +`google.adk.tools.google_api_tool` module. We pick the following BigQuery API +tools for this sample agent: -* 1. bigquery_datasets_list: +1. `bigquery_datasets_list`: List user's datasets. - List user's datasets. +2. `bigquery_datasets_get`: Get a dataset's details. -* 2. bigquery_datasets_get: - Get a dataset's details. +3. `bigquery_datasets_insert`: Create a new dataset. -* 3. bigquery_datasets_insert: - Create a new dataset. +4. `bigquery_tables_list`: List all tables in a dataset. -* 4. bigquery_tables_list: - List all tables in a dataset. +5. `bigquery_tables_get`: Get a table's details. -* 5. bigquery_tables_get: - Get a table's details. - -* 6. bigquery_tables_insert: - Insert a new table into a dataset. +6. `bigquery_tables_insert`: Insert a new table into a dataset. ## How to use -* 1. Follow https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name. to get your client id and client secret. +1. Follow https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name. to get your client id and client secret. Be sure to choose "web" as your client type. -* 2. Configure your `.env` file to add two variables: +2. Configure your `.env` file to add two variables: * OAUTH_CLIENT_ID={your client id} * OAUTH_CLIENT_SECRET={your client secret} Note: don't create a separate `.env` file , instead put it to the same `.env` file that stores your Vertex AI or Dev ML credentials -* 3. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs". +3. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs". Note: localhost here is just a hostname that you use to access the dev ui, replace it with the actual hostname you use to access the dev ui. -* 4. For 1st run, allow popup for localhost in Chrome. +4. For 1st run, allow popup for localhost in Chrome. ## Sample prompt diff --git a/contributing/samples/google_api/__init__.py b/contributing/samples/google_api/__init__.py new file mode 100644 index 000000000..c48963cdc --- /dev/null +++ b/contributing/samples/google_api/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/bigquery_agent/agent.py b/contributing/samples/google_api/agent.py similarity index 98% rename from contributing/samples/bigquery_agent/agent.py rename to contributing/samples/google_api/agent.py index 976cea170..1cdbab9c6 100644 --- a/contributing/samples/bigquery_agent/agent.py +++ b/contributing/samples/google_api/agent.py @@ -40,7 +40,7 @@ root_agent = Agent( model="gemini-2.0-flash", - name="bigquery_agent", + name="google_api_bigquery_agent", instruction=""" You are a helpful Google BigQuery agent that help to manage users' data on Google BigQuery. Use the provided tools to conduct various operations on users' data in Google BigQuery. diff --git a/contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py b/contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py new file mode 100644 index 000000000..7d5bb0b1c --- /dev/null +++ b/contributing/samples/hello_world_litellm_add_function_to_prompt/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from . import agent diff --git a/contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py b/contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py new file mode 100644 index 000000000..0f10621ae --- /dev/null +++ b/contributing/samples/hello_world_litellm_add_function_to_prompt/agent.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random + +from google.adk import Agent +from google.adk.models.lite_llm import LiteLlm +from langchain_core.utils.function_calling import convert_to_openai_function + + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +def check_prime(number: int) -> str: + """Check if a given number is prime. + + Args: + number: The input number to check. + + Returns: + A str indicating the number is prime or not. + """ + if number <= 1: + return f"{number} is not prime." + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + return f"{number} is prime." + else: + return f"{number} is not prime." + + +root_agent = Agent( + model=LiteLlm( + model="vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas", + # If the model is not trained with functions and you would like to + # enable function calling, you can add functions to the models, and the + # functions will be added to the prompts during inferences. + functions=[ + convert_to_openai_function(roll_die), + convert_to_openai_function(check_prime), + ], + ), + name="data_processing_agent", + description="""You are a helpful assistant.""", + instruction=""" + You are a helpful assistant, and call tools optionally. + If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs. + """, + tools=[ + roll_die, + check_prime, + ], +) diff --git a/contributing/samples/hello_world_litellm_add_function_to_prompt/main.py b/contributing/samples/hello_world_litellm_add_function_to_prompt/main.py new file mode 100644 index 000000000..123ba1368 --- /dev/null +++ b/contributing/samples/hello_world_litellm_add_function_to_prompt/main.py @@ -0,0 +1,81 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import time + +import agent +from dotenv import load_dotenv +from google.adk import Runner +from google.adk.artifacts import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.sessions import InMemorySessionService +from google.adk.sessions import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder() + + +async def main(): + app_name = 'my_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_11 = await session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts: + part = event.content.parts[0] + if part.text: + print(f'** {event.author}: {part.text}') + if part.function_call: + print(f'** {event.author} calls tool: {part.function_call}') + if part.function_response: + print( + f'** {event.author} gets tool response: {part.function_response}' + ) + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi, introduce yourself.') + await run_prompt(session_11, 'Roll a die with 100 sides.') + await run_prompt(session_11, 'Check if it is prime.') + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/contributing/samples/live_bidi_streaming_agent/__init__.py b/contributing/samples/live_bidi_streaming_agent/__init__.py new file mode 100755 index 000000000..c48963cdc --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/live_bidi_streaming_agent/agent.py b/contributing/samples/live_bidi_streaming_agent/agent.py new file mode 100755 index 000000000..2896bd70f --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/agent.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk import Agent +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +def roll_die(sides: int, tool_context: ToolContext) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if not 'rolls' in tool_context.state: + tool_context.state['rolls'] = [] + + tool_context.state['rolls'] = tool_context.state['rolls'] + [result] + return result + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + 'No prime numbers found.' + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model='gemini-2.0-flash-live-preview-04-09', # for Vertex project + # model='gemini-2.0-flash-live-001', # for AI studio key + name='hello_world_agent', + description=( + 'hello world agent that can roll a dice of 8 sides and check prime' + ' numbers.' + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) diff --git a/contributing/samples/live_bidi_streaming_agent/readme.md b/contributing/samples/live_bidi_streaming_agent/readme.md new file mode 100644 index 000000000..6a9258f3e --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/readme.md @@ -0,0 +1,37 @@ +# Simplistic Live (Bidi-Streaming) Agent +This project provides a basic example of a live, bidirectional streaming agent +designed for testing and experimentation. + +You can see full documentation [here](https://google.github.io/adk-docs/streaming/). + +## Getting Started + +Follow these steps to get the agent up and running: + +1. **Start the ADK Web Server** + Open your terminal, navigate to the root directory that contains the + `live_bidi_streaming_agent` folder, and execute the following command: + ```bash + adk web + ``` + +2. **Access the ADK Web UI** + Once the server is running, open your web browser and navigate to the URL + provided in the terminal (it will typically be `http://localhost:8000`). + +3. **Select the Agent** + In the top-left corner of the ADK Web UI, use the dropdown menu to select + this agent. + +4. **Start Streaming** + Click on either the **Audio** or **Video** icon located near the chat input + box to begin the streaming session. + +5. **Interact with the Agent** + You can now begin talking to the agent, and it will respond in real-time. + +## Usage Notes + +* You only need to click the **Audio** or **Video** button once to initiate the + stream. The current version does not support stopping and restarting the stream + by clicking the button again during a session. diff --git a/contributing/samples/mcp_sse_agent/agent.py b/contributing/samples/mcp_sse_agent/agent.py index 888a88b24..5423bfc6b 100755 --- a/contributing/samples/mcp_sse_agent/agent.py +++ b/contributing/samples/mcp_sse_agent/agent.py @@ -16,8 +16,8 @@ import os from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset -from google.adk.tools.mcp_tool.mcp_toolset import SseServerParams _allowed_path = os.path.dirname(os.path.abspath(__file__)) @@ -31,7 +31,7 @@ """, tools=[ MCPToolset( - connection_params=SseServerParams( + connection_params=SseConnectionParams( url='http://localhost:3000/sse', headers={'Accept': 'text/event-stream'}, ), diff --git a/contributing/samples/mcp_streamablehttp_agent/README.md b/contributing/samples/mcp_streamablehttp_agent/README.md index 1c211dd71..547a0788d 100644 --- a/contributing/samples/mcp_streamablehttp_agent/README.md +++ b/contributing/samples/mcp_streamablehttp_agent/README.md @@ -1,8 +1,7 @@ -This agent connects to a local MCP server via sse. +This agent connects to a local MCP server via Streamable HTTP. To run this agent, start the local MCP server first by : ```bash uv run filesystem_server.py ``` - diff --git a/contributing/samples/mcp_streamablehttp_agent/agent.py b/contributing/samples/mcp_streamablehttp_agent/agent.py index 61d59e051..f165c4c1b 100644 --- a/contributing/samples/mcp_streamablehttp_agent/agent.py +++ b/contributing/samples/mcp_streamablehttp_agent/agent.py @@ -18,7 +18,6 @@ from google.adk.agents.llm_agent import LlmAgent from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPServerParams from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset -from google.adk.tools.mcp_tool.mcp_toolset import SseServerParams _allowed_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/contributing/samples/oauth_calendar_agent/agent.py b/contributing/samples/oauth_calendar_agent/agent.py index 9d56d3ff8..3f966b787 100644 --- a/contributing/samples/oauth_calendar_agent/agent.py +++ b/contributing/samples/oauth_calendar_agent/agent.py @@ -13,7 +13,6 @@ # limitations under the License. from datetime import datetime -import json import os from dotenv import load_dotenv @@ -27,10 +26,8 @@ from google.adk.auth import AuthCredentialTypes from google.adk.auth import OAuth2Auth from google.adk.tools import ToolContext -from google.adk.tools.authenticated_tool.base_authenticated_tool import AuthenticatedFunctionTool -from google.adk.tools.authenticated_tool.credentials_store import ToolContextCredentialsStore +from google.adk.tools.authenticated_function_tool import AuthenticatedFunctionTool from google.adk.tools.google_api_tool import CalendarToolset -from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from googleapiclient.discovery import build @@ -150,11 +147,7 @@ def update_time(callback_context: CallbackContext): ), tokenUrl="https://oauth2.googleapis.com/token", scopes={ - "https://www.googleapis.com/auth/calendar": ( - "See, edit, share, and permanently delete" - " all the calendars you can access using" - " Google Calendar" - ) + "https://www.googleapis.com/auth/calendar": "", }, ) ) @@ -167,7 +160,6 @@ def update_time(callback_context: CallbackContext): ), ), ), - credential_store=ToolContextCredentialsStore(), ), calendar_toolset, ], diff --git a/contributing/samples/quickstart/agent.py b/contributing/samples/quickstart/agent.py index fdd6b7f9d..b251069ad 100644 --- a/contributing/samples/quickstart/agent.py +++ b/contributing/samples/quickstart/agent.py @@ -29,7 +29,7 @@ def get_weather(city: str) -> dict: "status": "success", "report": ( "The weather in New York is sunny with a temperature of 25 degrees" - " Celsius (41 degrees Fahrenheit)." + " Celsius (77 degrees Fahrenheit)." ), } else: diff --git a/pyproject.toml b/pyproject.toml index 8ece4db81..6cf78ab40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start + "PyYAML>=6.0.2", # For APIHubToolset. "anyio>=4.9.0;python_version>='3.10'", # For MCP Session Manager "authlib>=1.5.1", # For RestAPI Tool "click>=8.1.8", # For CLI tools @@ -34,7 +35,7 @@ dependencies = [ "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool "google-cloud-speech>=2.30.0", # For Audio Transcription "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service - "google-genai>=1.17.0", # Google GenAI SDK + "google-genai>=1.21.1", # Google GenAI SDK "graphviz>=0.20.2", # Graphviz for graph rendering "mcp>=1.8.0;python_version>='3.10'", # For MCP Toolset "opentelemetry-api>=1.31.0", # OpenTelemetry @@ -43,7 +44,6 @@ dependencies = [ "pydantic>=2.0, <3.0.0", # For data validation/models "python-dateutil>=2.9.0.post0", # For Vertext AI Session Service "python-dotenv>=1.0.0", # To manage environment variables - "PyYAML>=6.0.2", # For APIHubToolset. "requests>=2.32.4", "sqlalchemy>=2.0", # SQL database ORM "starlette>=0.46.2", # For FastAPI CLI @@ -70,9 +70,9 @@ dev = [ # go/keep-sorted start "flit>=3.10.0", "isort>=6.0.0", + "mypy>=1.15.0", "pyink>=24.10.0", "pylint>=2.6.0", - "mypy>=1.15.0", # go/keep-sorted end ] @@ -87,6 +87,7 @@ eval = [ "google-cloud-aiplatform[evaluation]>=1.87.0", "pandas>=2.2.3", "tabulate>=0.9.0", + "rouge-score>=0.1.2", # go/keep-sorted end ] @@ -97,7 +98,6 @@ test = [ "langgraph>=0.2.60", # For LangGraphAgent "litellm>=1.71.2", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests - "pytest-asyncio>=0.25.0", "pytest-mock>=3.14.0", "pytest-xdist>=3.6.1", diff --git a/src/google/adk/a2a/converters/__init__.py b/src/google/adk/a2a/converters/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/a2a/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py new file mode 100644 index 000000000..25183f6be --- /dev/null +++ b/src/google/adk/a2a/converters/event_converter.py @@ -0,0 +1,604 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +import uuid + +from a2a.server.events import Event as A2AEvent +from a2a.types import Artifact +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import Task +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.genai import types as genai_types + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME +from ...utils.feature_decorator import working_in_progress +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import convert_a2a_part_to_genai_part +from .part_converter import convert_genai_part_to_a2a_part +from .utils import _get_adk_metadata_key + +# Constants + +ARTIFACT_ID_SEPARATOR = "-" +DEFAULT_ERROR_MESSAGE = "An error occurred during processing" + +# Logger +logger = logging.getLogger("google_adk." + __name__) + + +def _serialize_metadata_value(value: Any) -> str: + """Safely serializes metadata values to string format. + + Args: + value: The value to serialize. + + Returns: + String representation of the value. + """ + if hasattr(value, "model_dump"): + try: + return value.model_dump(exclude_none=True, by_alias=True) + except Exception as e: + logger.warning("Failed to serialize metadata value: %s", e) + return str(value) + return str(value) + + +def _get_context_metadata( + event: Event, invocation_context: InvocationContext +) -> Dict[str, str]: + """Gets the context metadata for the event. + + Args: + event: The ADK event to extract metadata from. + invocation_context: The invocation context containing session information. + + Returns: + A dictionary containing the context metadata. + + Raises: + ValueError: If required fields are missing from event or context. + """ + if not event: + raise ValueError("Event cannot be None") + if not invocation_context: + raise ValueError("Invocation context cannot be None") + + try: + metadata = { + _get_adk_metadata_key("app_name"): invocation_context.app_name, + _get_adk_metadata_key("user_id"): invocation_context.user_id, + _get_adk_metadata_key("session_id"): invocation_context.session.id, + _get_adk_metadata_key("invocation_id"): event.invocation_id, + _get_adk_metadata_key("author"): event.author, + } + + # Add optional metadata fields if present + optional_fields = [ + ("branch", event.branch), + ("grounding_metadata", event.grounding_metadata), + ("custom_metadata", event.custom_metadata), + ("usage_metadata", event.usage_metadata), + ("error_code", event.error_code), + ] + + for field_name, field_value in optional_fields: + if field_value is not None: + metadata[_get_adk_metadata_key(field_name)] = _serialize_metadata_value( + field_value + ) + + return metadata + + except Exception as e: + logger.error("Failed to create context metadata: %s", e) + raise + + +def _create_artifact_id( + app_name: str, user_id: str, session_id: str, filename: str, version: int +) -> str: + """Creates a unique artifact ID. + + Args: + app_name: The application name. + user_id: The user ID. + session_id: The session ID. + filename: The artifact filename. + version: The artifact version. + + Returns: + A unique artifact ID string. + """ + components = [app_name, user_id, session_id, filename, str(version)] + return ARTIFACT_ID_SEPARATOR.join(components) + + +def _convert_artifact_to_a2a_events( + event: Event, + invocation_context: InvocationContext, + filename: str, + version: int, + task_id: Optional[str] = None, + context_id: Optional[str] = None, +) -> TaskArtifactUpdateEvent: + """Converts a new artifact version to an A2A TaskArtifactUpdateEvent. + + Args: + event: The ADK event containing the artifact information. + invocation_context: The invocation context. + filename: The name of the artifact file. + version: The version number of the artifact. + task_id: Optional task ID to use for generated events. If not provided, new UUIDs will be generated. + + Returns: + A TaskArtifactUpdateEvent representing the artifact update. + + Raises: + ValueError: If required parameters are invalid. + RuntimeError: If artifact loading fails. + """ + if not filename: + raise ValueError("Filename cannot be empty") + if version < 0: + raise ValueError("Version must be non-negative") + + try: + artifact_part = invocation_context.artifact_service.load_artifact( + app_name=invocation_context.app_name, + user_id=invocation_context.user_id, + session_id=invocation_context.session.id, + filename=filename, + version=version, + ) + + converted_part = convert_genai_part_to_a2a_part(part=artifact_part) + if not converted_part: + raise RuntimeError(f"Failed to convert artifact part for {filename}") + + artifact_id = _create_artifact_id( + invocation_context.app_name, + invocation_context.user_id, + invocation_context.session.id, + filename, + version, + ) + + return TaskArtifactUpdateEvent( + taskId=task_id, + append=False, + contextId=context_id, + lastChunk=True, + artifact=Artifact( + artifactId=artifact_id, + name=filename, + metadata={ + "filename": filename, + "version": version, + }, + parts=[converted_part], + ), + ) + except Exception as e: + logger.error( + "Failed to convert artifact for %s, version %s: %s", + filename, + version, + e, + ) + raise RuntimeError(f"Artifact conversion failed: {e}") from e + + +def _process_long_running_tool(a2a_part: A2APart, event: Event) -> None: + """Processes long-running tool metadata for an A2A part. + + Args: + a2a_part: The A2A part to potentially mark as long-running. + event: The ADK event containing long-running tool information. + """ + if ( + isinstance(a2a_part.root, DataPart) + and event.long_running_tool_ids + and a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and a2a_part.root.data.get("id") in event.long_running_tool_ids + ): + a2a_part.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ] = True + + +@working_in_progress +def convert_a2a_task_to_event( + a2a_task: Task, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, +) -> Event: + """Converts an A2A task to an ADK event. + + Args: + a2a_task: The A2A task to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + + Returns: + An ADK Event object representing the converted task. + + Raises: + ValueError: If a2a_task is None. + RuntimeError: If conversion of the underlying message fails. + """ + if a2a_task is None: + raise ValueError("A2A task cannot be None") + + try: + # Extract message from task status or history + message = None + if a2a_task.status and a2a_task.status.message: + message = a2a_task.status.message + elif a2a_task.history: + message = a2a_task.history[-1] + + # Convert message if available + if message: + try: + return convert_a2a_message_to_event(message, author, invocation_context) + except Exception as e: + logger.error("Failed to convert A2A task message to event: %s", e) + raise RuntimeError(f"Failed to convert task message: {e}") from e + + # Create minimal event if no message is available + return Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + ) + + except Exception as e: + logger.error("Failed to convert A2A task to event: %s", e) + raise + + +@working_in_progress +def convert_a2a_message_to_event( + a2a_message: Message, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, +) -> Event: + """Converts an A2A message to an ADK event. + + Args: + a2a_message: The A2A message to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + + Returns: + An ADK Event object with converted content and long-running tool metadata. + + Raises: + ValueError: If a2a_message is None. + RuntimeError: If conversion of message parts fails. + """ + if a2a_message is None: + raise ValueError("A2A message cannot be None") + + if not a2a_message.parts: + logger.warning( + "A2A message has no parts, creating event with empty content" + ) + return Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + content=genai_types.Content(role="model", parts=[]), + ) + + try: + parts = [] + long_running_tool_ids = set() + + for a2a_part in a2a_message.parts: + try: + part = convert_a2a_part_to_genai_part(a2a_part) + if part is None: + logger.warning("Failed to convert A2A part, skipping: %s", a2a_part) + continue + + # Check for long-running tools + if ( + a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY + ) + ) + is True + ): + long_running_tool_ids.add(part.function_call.id) + + parts.append(part) + + except Exception as e: + logger.error("Failed to convert A2A part: %s, error: %s", a2a_part, e) + # Continue processing other parts instead of failing completely + continue + + if not parts: + logger.warning( + "No parts could be converted from A2A message %s", a2a_message + ) + + return Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + long_running_tool_ids=long_running_tool_ids + if long_running_tool_ids + else None, + content=genai_types.Content( + role="model", + parts=parts, + ), + ) + + except Exception as e: + logger.error("Failed to convert A2A message to event: %s", e) + raise RuntimeError(f"Failed to convert message: {e}") from e + + +@working_in_progress +def convert_event_to_a2a_message( + event: Event, invocation_context: InvocationContext, role: Role = Role.agent +) -> Optional[Message]: + """Converts an ADK event to an A2A message. + + Args: + event: The ADK event to convert. + invocation_context: The invocation context. + + Returns: + An A2A Message if the event has content, None otherwise. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + if not invocation_context: + raise ValueError("Invocation context cannot be None") + + if not event.content or not event.content.parts: + return None + + try: + a2a_parts = [] + for part in event.content.parts: + a2a_part = convert_genai_part_to_a2a_part(part) + if a2a_part: + a2a_parts.append(a2a_part) + _process_long_running_tool(a2a_part, event) + + if a2a_parts: + return Message(messageId=str(uuid.uuid4()), role=role, parts=a2a_parts) + + except Exception as e: + logger.error("Failed to convert event to status message: %s", e) + raise + + return None + + +def _create_error_status_event( + event: Event, + invocation_context: InvocationContext, + task_id: Optional[str] = None, + context_id: Optional[str] = None, +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent for error scenarios. + + Args: + event: The ADK event containing error information. + invocation_context: The invocation context. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + + Returns: + A TaskStatusUpdateEvent with FAILED state. + """ + error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE + + # Get context metadata and add error code + event_metadata = _get_context_metadata(event, invocation_context) + if event.error_code: + event_metadata[_get_adk_metadata_key("error_code")] = str(event.error_code) + + return TaskStatusUpdateEvent( + taskId=task_id, + contextId=context_id, + metadata=event_metadata, + status=TaskStatus( + state=TaskState.failed, + message=Message( + messageId=str(uuid.uuid4()), + role=Role.agent, + parts=[TextPart(text=error_message)], + metadata={ + _get_adk_metadata_key("error_code"): str(event.error_code) + } + if event.error_code + else {}, + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + final=False, + ) + + +def _create_status_update_event( + message: Message, + invocation_context: InvocationContext, + event: Event, + task_id: Optional[str] = None, + context_id: Optional[str] = None, +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent for running scenarios. + + Args: + message: The A2A message to include. + invocation_context: The invocation context. + event: The ADK event. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + + + Returns: + A TaskStatusUpdateEvent with RUNNING state. + """ + status = TaskStatus( + state=TaskState.working, + message=message, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + if any( + part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ) + is True + and part.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME + for part in message.parts + if part.root.metadata + ): + status.state = TaskState.auth_required + elif any( + part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ) + is True + for part in message.parts + if part.root.metadata + ): + status.state = TaskState.input_required + + return TaskStatusUpdateEvent( + taskId=task_id, + contextId=context_id, + status=status, + metadata=_get_context_metadata(event, invocation_context), + final=False, + ) + + +@working_in_progress +def convert_event_to_a2a_events( + event: Event, + invocation_context: InvocationContext, + task_id: Optional[str] = None, + context_id: Optional[str] = None, +) -> List[A2AEvent]: + """Converts a GenAI event to a list of A2A events. + + Args: + event: The ADK event to convert. + invocation_context: The invocation context. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + + Returns: + A list of A2A events representing the converted ADK event. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + if not invocation_context: + raise ValueError("Invocation context cannot be None") + + a2a_events = [] + + try: + # Handle artifact deltas + if event.actions.artifact_delta: + for filename, version in event.actions.artifact_delta.items(): + artifact_event = _convert_artifact_to_a2a_events( + event, invocation_context, filename, version, task_id, context_id + ) + a2a_events.append(artifact_event) + + # Handle error scenarios + if event.error_code: + error_event = _create_error_status_event( + event, invocation_context, task_id, context_id + ) + a2a_events.append(error_event) + + # Handle regular message content + message = convert_event_to_a2a_message(event, invocation_context) + if message: + running_event = _create_status_update_event( + message, invocation_context, event, task_id, context_id + ) + a2a_events.append(running_event) + + except Exception as e: + logger.error("Failed to convert event to A2A events: %s", e) + raise + + return a2a_events diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py new file mode 100644 index 000000000..8dab1097d --- /dev/null +++ b/src/google/adk/a2a/converters/part_converter.py @@ -0,0 +1,247 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +module containing utilities for conversion betwen A2A Part and Google GenAI Part +""" + +from __future__ import annotations + +import base64 +import json +import logging +import sys +from typing import Optional + +from .utils import _get_adk_metadata_key + +try: + from a2a import types as a2a_types +except ImportError as e: + if sys.version_info < (3, 10): + raise ImportError( + 'A2A Tool requires Python 3.10 or above. Please upgrade your Python' + ' version.' + ) from e + else: + raise e + +from google.genai import types as genai_types + +from ...utils.feature_decorator import working_in_progress + +logger = logging.getLogger('google_adk.' + __name__) + +A2A_DATA_PART_METADATA_TYPE_KEY = 'type' +A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY = 'is_long_running' +A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = 'function_call' +A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response' +A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = 'code_execution_result' +A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = 'executable_code' + + +@working_in_progress +def convert_a2a_part_to_genai_part( + a2a_part: a2a_types.Part, +) -> Optional[genai_types.Part]: + """Convert an A2A Part to a Google GenAI Part.""" + part = a2a_part.root + if isinstance(part, a2a_types.TextPart): + return genai_types.Part(text=part.text) + + if isinstance(part, a2a_types.FilePart): + if isinstance(part.file, a2a_types.FileWithUri): + return genai_types.Part( + file_data=genai_types.FileData( + file_uri=part.file.uri, mime_type=part.file.mimeType + ) + ) + + elif isinstance(part.file, a2a_types.FileWithBytes): + return genai_types.Part( + inline_data=genai_types.Blob( + data=base64.b64decode(part.file.bytes), + mime_type=part.file.mimeType, + ) + ) + else: + logger.warning( + 'Cannot convert unsupported file type: %s for A2A part: %s', + type(part.file), + a2a_part, + ) + return None + + if isinstance(part, a2a_types.DataPart): + # Conver the Data Part to funcall and function reponse. + # This is mainly for converting human in the loop and auth request and + # response. + # TODO once A2A defined how to suervice such information, migrate below + # logic accordinlgy + if ( + part.metadata + and _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + in part.metadata + ): + if ( + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ): + return genai_types.Part( + function_call=genai_types.FunctionCall.model_validate( + part.data, by_alias=True + ) + ) + if ( + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ): + return genai_types.Part( + function_response=genai_types.FunctionResponse.model_validate( + part.data, by_alias=True + ) + ) + if ( + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + ): + return genai_types.Part( + code_execution_result=genai_types.CodeExecutionResult.model_validate( + part.data, by_alias=True + ) + ) + if ( + part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + == A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + ): + return genai_types.Part( + executable_code=genai_types.ExecutableCode.model_validate( + part.data, by_alias=True + ) + ) + return genai_types.Part(text=json.dumps(part.data)) + + logger.warning( + 'Cannot convert unsupported part type: %s for A2A part: %s', + type(part), + a2a_part, + ) + return None + + +@working_in_progress +def convert_genai_part_to_a2a_part( + part: genai_types.Part, +) -> Optional[a2a_types.Part]: + """Convert a Google GenAI Part to an A2A Part.""" + + if part.text: + a2a_part = a2a_types.TextPart(text=part.text) + if part.thought is not None: + a2a_part.metadata = {_get_adk_metadata_key('thought'): part.thought} + return a2a_types.Part(root=a2a_part) + + if part.file_data: + return a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=part.file_data.file_uri, + mimeType=part.file_data.mime_type, + ) + ) + ) + + if part.inline_data: + a2a_part = a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=base64.b64encode(part.inline_data.data).decode('utf-8'), + mimeType=part.inline_data.mime_type, + ) + ) + + if part.video_metadata: + a2a_part.metadata = { + _get_adk_metadata_key( + 'video_metadata' + ): part.video_metadata.model_dump(by_alias=True, exclude_none=True) + } + + return a2a_types.Part(root=a2a_part) + + # Conver the funcall and function reponse to A2A DataPart. + # This is mainly for converting human in the loop and auth request and + # response. + # TODO once A2A defined how to suervice such information, migrate below + # logic accordinlgy + if part.function_call: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.function_call.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + }, + ) + ) + + if part.function_response: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.function_response.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + }, + ) + ) + + if part.code_execution_result: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.code_execution_result.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + }, + ) + ) + + if part.executable_code: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.executable_code.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + }, + ) + ) + + logger.warning( + 'Cannot convert unsupported part for Google GenAI part: %s', + part, + ) + return None diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py new file mode 100644 index 000000000..293df46e6 --- /dev/null +++ b/src/google/adk/a2a/converters/request_converter.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from typing import Any + +try: + from a2a.server.agent_execution import RequestContext +except ImportError as e: + if sys.version_info < (3, 10): + raise ImportError( + 'A2A Tool requires Python 3.10 or above. Please upgrade your Python' + ' version.' + ) from e + else: + raise e + +from google.genai import types as genai_types + +from ...runners import RunConfig +from ...utils.feature_decorator import working_in_progress +from .part_converter import convert_a2a_part_to_genai_part +from .utils import _from_a2a_context_id +from .utils import _get_adk_metadata_key + + +def _get_user_id(request: RequestContext, user_id_from_context: str) -> str: + # Get user from call context if available (auth is enabled on a2a server) + if request.call_context and request.call_context.user: + return request.call_context.user.user_name + + # Get user from context id if available + if user_id_from_context: + return user_id_from_context + + # Get user from message metadata if available (client is an ADK agent) + if request.message.metadata: + user_id = request.message.metadata.get(_get_adk_metadata_key('user_id')) + if user_id: + return f'ADK_USER_{user_id}' + + # Get user from task if available (client is a an ADK agent) + if request.current_task: + user_id = request.current_task.metadata.get( + _get_adk_metadata_key('user_id') + ) + if user_id: + return f'ADK_USER_{user_id}' + return ( + f'temp_user_{request.task_id}' + if request.task_id + else f'TEMP_USER_{request.message.messageId}' + ) + + +@working_in_progress +def convert_a2a_request_to_adk_run_args( + request: RequestContext, +) -> dict[str, Any]: + + if not request.message: + raise ValueError('Request message cannot be None') + + _, user_id, session_id = _from_a2a_context_id(request.context_id) + + return { + 'user_id': _get_user_id(request, user_id), + 'session_id': session_id, + 'new_message': genai_types.Content( + role='user', + parts=[ + convert_a2a_part_to_genai_part(part) + for part in request.message.parts + ], + ), + 'run_config': RunConfig(), + } diff --git a/src/google/adk/a2a/converters/utils.py b/src/google/adk/a2a/converters/utils.py new file mode 100644 index 000000000..acb2581d4 --- /dev/null +++ b/src/google/adk/a2a/converters/utils.py @@ -0,0 +1,89 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +ADK_METADATA_KEY_PREFIX = "adk_" +ADK_CONTEXT_ID_PREFIX = "ADK" +ADK_CONTEXT_ID_SEPARATOR = "/" + + +def _get_adk_metadata_key(key: str) -> str: + """Gets the A2A event metadata key for the given key. + + Args: + key: The metadata key to prefix. + + Returns: + The prefixed metadata key. + + Raises: + ValueError: If key is empty or None. + """ + if not key: + raise ValueError("Metadata key cannot be empty or None") + return f"{ADK_METADATA_KEY_PREFIX}{key}" + + +def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str: + """Converts app name, user id and session id to an A2A context id. + + Args: + app_name: The app name. + user_id: The user id. + session_id: The session id. + + Returns: + The A2A context id. + + Raises: + ValueError: If any of the input parameters are empty or None. + """ + if not all([app_name, user_id, session_id]): + raise ValueError( + "All parameters (app_name, user_id, session_id) must be non-empty" + ) + return ADK_CONTEXT_ID_SEPARATOR.join( + [ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id] + ) + + +def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: + """Converts an A2A context id to app name, user id and session id. + if context_id is None, return None, None, None + if context_id is not None, but not in the format of + ADK$app_name$user_id$session_id, return None, None, None + + Args: + context_id: The A2A context id. + + Returns: + The app name, user id and session id. + """ + if not context_id: + return None, None, None + + try: + parts = context_id.split(ADK_CONTEXT_ID_SEPARATOR) + if len(parts) != 4: + return None, None, None + + prefix, app_name, user_id, session_id = parts + if prefix == ADK_CONTEXT_ID_PREFIX and app_name and user_id and session_id: + return app_name, user_id, session_id + except ValueError: + # Handle any split errors gracefully + pass + + return None, None, None diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index f70371535..765f22a2c 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -22,6 +22,7 @@ from pydantic import ConfigDict from ..artifacts.base_artifact_service import BaseArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService from ..memory.base_memory_service import BaseMemoryService from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session @@ -115,6 +116,7 @@ class InvocationContext(BaseModel): artifact_service: Optional[BaseArtifactService] = None session_service: BaseSessionService memory_service: Optional[BaseMemoryService] = None + credential_service: Optional[BaseCredentialService] = None invocation_id: str """The id of this invocation context. Readonly.""" diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index fe145a60e..64c3628df 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -161,10 +161,12 @@ class LlmAgent(BaseAgent): # LLM-based agent transfer configs - End include_contents: Literal['default', 'none'] = 'default' - """Whether to include contents in the model request. + """Controls content inclusion in model requests. - When set to 'none', the model request will not include any contents, such as - user messages, tool results, etc. + Options: + default: Model receives relevant conversation history + none: Model receives no prior history, operates solely on current + instruction and input """ # Controlled input/output configurations - Start diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index 5679f04e9..c9a50a0ae 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -73,6 +73,12 @@ class RunConfig(BaseModel): realtime_input_config: Optional[types.RealtimeInputConfig] = None """Realtime input config for live agents with audio input from user.""" + enable_affective_dialog: Optional[bool] = None + """If enabled, the model will detect emotions and adapt its responses accordingly.""" + + proactivity: Optional[types.ProactivityConfig] = None + """Configures the proactivity of the model. This allows the model to respond proactively to the input and to ignore irrelevant input.""" + max_llm_calls: int = 500 """ A limit on the total number of llm calls for a given run. diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index e4af21e15..35aa88622 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -13,6 +13,7 @@ # limitations under the License. """An artifact service implementation using Google Cloud Storage (GCS).""" +from __future__ import annotations import logging from typing import Optional @@ -151,7 +152,7 @@ async def list_artifact_keys( self.bucket, prefix=session_prefix ) for blob in session_blobs: - _, _, _, filename, _ = blob.name.split("/") + *_, filename, _ = blob.name.split("/") filenames.add(filename) user_namespace_prefix = f"{app_name}/{user_id}/user/" @@ -159,7 +160,7 @@ async def list_artifact_keys( self.bucket, prefix=user_namespace_prefix ) for blob in user_namespace_blobs: - _, _, _, filename, _ = blob.name.split("/") + *_, filename, _ = blob.name.split("/") filenames.add(filename) return sorted(list(filenames)) diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 1009a50dd..34d04dde9 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -230,4 +230,3 @@ class AuthCredential(BaseModelWithConfig): http: Optional[HttpAuth] = None service_account: Optional[ServiceAccount] = None oauth2: Optional[OAuth2Auth] = None - google_oauth2_json: Optional[str] = None diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 3e13cbac2..473f31413 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -22,7 +22,7 @@ from .auth_schemes import AuthSchemeType from .auth_schemes import OpenIdConnectWithConfig from .auth_tool import AuthConfig -from .oauth2_credential_fetcher import OAuth2CredentialFetcher +from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger if TYPE_CHECKING: from ..sessions.state import State @@ -36,18 +36,23 @@ class AuthHandler: + """A handler that handles the auth flow in Agent Development Kit to help + orchestrate the credential request and response flow (e.g. OAuth flow) + This class should only be used by Agent Development Kit. + """ def __init__(self, auth_config: AuthConfig): self.auth_config = auth_config - def exchange_auth_token( + async def exchange_auth_token( self, ) -> AuthCredential: - return OAuth2CredentialFetcher( - self.auth_config.auth_scheme, self.auth_config.exchanged_auth_credential - ).exchange() + exchanger = OAuth2CredentialExchanger() + return await exchanger.exchange( + self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme + ) - def parse_and_store_auth_response(self, state: State) -> None: + async def parse_and_store_auth_response(self, state: State) -> None: credential_key = "temp:" + self.auth_config.credential_key @@ -60,7 +65,7 @@ def parse_and_store_auth_response(self, state: State) -> None: ): return - state[credential_key] = self.exchange_auth_token() + state[credential_key] = await self.exchange_auth_token() def _validate(self) -> None: if not self.auth_scheme: diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 0c964ed96..b06774973 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -67,9 +67,9 @@ async def run_async( # function call request_euc_function_call_ids.add(function_call_response.id) auth_config = AuthConfig.model_validate(function_call_response.response) - AuthHandler(auth_config=auth_config).parse_and_store_auth_response( - state=invocation_context.session.state - ) + await AuthHandler( + auth_config=auth_config + ).parse_and_store_auth_response(state=invocation_context.session.state) break if not request_euc_function_call_ids: diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py index 53c571d42..0316e5258 100644 --- a/src/google/adk/auth/auth_tool.py +++ b/src/google/adk/auth/auth_tool.py @@ -31,12 +31,12 @@ class AuthConfig(BaseModelWithConfig): auth_scheme: AuthScheme """The auth scheme used to collect credentials""" - raw_auth_credential: AuthCredential = None + raw_auth_credential: Optional[AuthCredential] = None """The raw auth credential used to collect credentials. The raw auth credentials are used in some auth scheme that needs to exchange auth credentials. e.g. OAuth2 and OIDC. For other auth scheme, it could be None. """ - exchanged_auth_credential: AuthCredential = None + exchanged_auth_credential: Optional[AuthCredential] = None """The exchanged auth credential used to collect credentials. adk and client will work together to fill it. For those auth scheme that doesn't need to exchange auth credentials, e.g. API key, service account etc. It's filled by diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py new file mode 100644 index 000000000..0dbf006ab --- /dev/null +++ b/src/google/adk/auth/credential_manager.py @@ -0,0 +1,261 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from ..tools.tool_context import ToolContext +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_credential import AuthCredentialTypes +from .auth_schemes import AuthSchemeType +from .auth_tool import AuthConfig +from .exchanger.base_credential_exchanger import BaseCredentialExchanger +from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry +from .refresher.base_credential_refresher import BaseCredentialRefresher +from .refresher.credential_refresher_registry import CredentialRefresherRegistry + + +@experimental +class CredentialManager: + """Manages authentication credentials through a structured workflow. + + The CredentialManager orchestrates the complete lifecycle of authentication + credentials, from initial loading to final preparation for use. It provides + a centralized interface for handling various credential types and authentication + schemes while maintaining proper credential hygiene (refresh, exchange, caching). + + This class is only for use by Agent Development Kit. + + Args: + auth_config: Configuration containing authentication scheme and credentials + + Example: + ```python + auth_config = AuthConfig( + auth_scheme=oauth2_scheme, + raw_auth_credential=service_account_credential + ) + manager = CredentialManager(auth_config) + + # Register custom exchanger if needed + manager.register_credential_exchanger( + AuthCredentialTypes.CUSTOM_TYPE, + CustomCredentialExchanger() + ) + + # Register custom refresher if needed + manager.register_credential_refresher( + AuthCredentialTypes.CUSTOM_TYPE, + CustomCredentialRefresher() + ) + + # Load and prepare credential + credential = await manager.load_auth_credential(tool_context) + ``` + """ + + def __init__( + self, + auth_config: AuthConfig, + ): + self._auth_config = auth_config + self._exchanger_registry = CredentialExchangerRegistry() + self._refresher_registry = CredentialRefresherRegistry() + + # Register default exchangers and refreshers + # TODO: support service account credential exchanger + from .refresher.oauth2_credential_refresher import OAuth2CredentialRefresher + + oauth2_refresher = OAuth2CredentialRefresher() + self._refresher_registry.register( + AuthCredentialTypes.OAUTH2, oauth2_refresher + ) + self._refresher_registry.register( + AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher + ) + + def register_credential_exchanger( + self, + credential_type: AuthCredentialTypes, + exchanger_instance: BaseCredentialExchanger, + ) -> None: + """Register a credential exchanger for a credential type. + + Args: + credential_type: The credential type to register for. + exchanger_instance: The exchanger instance to register. + """ + self._exchanger_registry.register(credential_type, exchanger_instance) + + async def request_credential(self, tool_context: ToolContext) -> None: + tool_context.request_credential(self._auth_config) + + async def get_auth_credential( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load and prepare authentication credential through a structured workflow.""" + + # Step 1: Validate credential configuration + await self._validate_credential() + + # Step 2: Check if credential is already ready (no processing needed) + if self._is_credential_ready(): + return self._auth_config.raw_auth_credential + + # Step 3: Try to load existing processed credential + credential = await self._load_existing_credential(tool_context) + + # Step 4: If no existing credential, load from auth response + # TODO instead of load from auth response, we can store auth response in + # credential service. + was_from_auth_response = False + if not credential: + credential = await self._load_from_auth_response(tool_context) + was_from_auth_response = True + + # Step 5: If still no credential available, return None + if not credential: + return None + + # Step 6: Exchange credential if needed (e.g., service account to access token) + credential, was_exchanged = await self._exchange_credential(credential) + + # Step 7: Refresh credential if expired + if not was_exchanged: + credential, was_refreshed = await self._refresh_credential(credential) + + # Step 8: Save credential if it was modified + if was_from_auth_response or was_exchanged or was_refreshed: + await self._save_credential(tool_context, credential) + + return credential + + async def _load_existing_credential( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load existing credential from credential service or cached exchanged credential.""" + + # Try loading from credential service first + credential = await self._load_from_credential_service(tool_context) + if credential: + return credential + + # Check if we have a cached exchanged credential + if self._auth_config.exchanged_auth_credential: + return self._auth_config.exchanged_auth_credential + + return None + + async def _load_from_credential_service( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load credential from credential service if available.""" + credential_service = tool_context._invocation_context.credential_service + if credential_service: + # Note: This should be made async in a future refactor + # For now, assuming synchronous operation + return await credential_service.load_credential( + self._auth_config, tool_context + ) + return None + + async def _load_from_auth_response( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load credential from auth response in tool context.""" + return tool_context.get_auth_response(self._auth_config) + + async def _exchange_credential( + self, credential: AuthCredential + ) -> tuple[AuthCredential, bool]: + """Exchange credential if needed and return the credential and whether it was exchanged.""" + exchanger = self._exchanger_registry.get_exchanger(credential.auth_type) + if not exchanger: + return credential, False + + exchanged_credential = await exchanger.exchange( + credential, self._auth_config.auth_scheme + ) + return exchanged_credential, True + + async def _refresh_credential( + self, credential: AuthCredential + ) -> tuple[AuthCredential, bool]: + """Refresh credential if expired and return the credential and whether it was refreshed.""" + refresher = self._refresher_registry.get_refresher(credential.auth_type) + if not refresher: + return credential, False + + if await refresher.is_refresh_needed( + credential, self._auth_config.auth_scheme + ): + refreshed_credential = await refresher.refresh( + credential, self._auth_config.auth_scheme + ) + return refreshed_credential, True + + return credential, False + + def _is_credential_ready(self) -> bool: + """Check if credential is ready to use without further processing.""" + raw_credential = self._auth_config.raw_auth_credential + if not raw_credential: + return False + + # Simple credentials that don't need exchange or refresh + return raw_credential.auth_type in ( + AuthCredentialTypes.API_KEY, + AuthCredentialTypes.HTTP, + # Add other simple auth types as needed + ) + + async def _validate_credential(self) -> None: + """Validate credential configuration and raise errors if invalid.""" + if not self._auth_config.raw_auth_credential: + if self._auth_config.auth_scheme.type_ in ( + AuthSchemeType.oauth2, + AuthSchemeType.openIdConnect, + ): + raise ValueError( + "raw_auth_credential is required for auth_scheme type " + f"{self._auth_config.auth_scheme.type_}" + ) + + raw_credential = self._auth_config.raw_auth_credential + if raw_credential: + if ( + raw_credential.auth_type + in ( + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + ) + and not raw_credential.oauth2 + ): + raise ValueError( + "auth_config.raw_credential.oauth2 required for credential type " + f"{raw_credential.auth_type}" + ) + # Additional validation can be added here + + async def _save_credential( + self, tool_context: ToolContext, credential: AuthCredential + ) -> None: + """Save credential to credential service if available.""" + credential_service = tool_context._invocation_context.credential_service + if credential_service: + # Update the exchanged credential in config + self._auth_config.exchanged_auth_credential = credential + await credential_service.save_credential(self._auth_config, tool_context) diff --git a/src/google/adk/auth/credential_service/base_credential_service.py b/src/google/adk/auth/credential_service/base_credential_service.py index 7416ccc65..fc6cd500d 100644 --- a/src/google/adk/auth/credential_service/base_credential_service.py +++ b/src/google/adk/auth/credential_service/base_credential_service.py @@ -19,12 +19,12 @@ from typing import Optional from ...tools.tool_context import ToolContext -from ...utils.feature_decorator import working_in_progress +from ...utils.feature_decorator import experimental from ..auth_credential import AuthCredential from ..auth_tool import AuthConfig -@working_in_progress("Implementation are in progress. Don't use it for now.") +@experimental class BaseCredentialService(ABC): """Abstract class for Service that loads / saves tool credentials from / to the backend credential store.""" diff --git a/src/google/adk/auth/credential_service/in_memory_credential_service.py b/src/google/adk/auth/credential_service/in_memory_credential_service.py new file mode 100644 index 000000000..f6f51b35a --- /dev/null +++ b/src/google/adk/auth/credential_service/in_memory_credential_service.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from typing_extensions import override + +from ...tools.tool_context import ToolContext +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_tool import AuthConfig +from .base_credential_service import BaseCredentialService + + +@experimental +class InMemoryCredentialService(BaseCredentialService): + """Class for in memory implementation of credential service(Experimental)""" + + def __init__(self): + super().__init__() + self._credentials = {} + + @override + async def load_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> Optional[AuthCredential]: + credential_bucket = self._get_bucket_for_current_context(tool_context) + return credential_bucket.get(auth_config.credential_key) + + @override + async def save_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> None: + credential_bucket = self._get_bucket_for_current_context(tool_context) + credential_bucket[auth_config.credential_key] = ( + auth_config.exchanged_auth_credential + ) + + def _get_bucket_for_current_context(self, tool_context: ToolContext) -> str: + app_name = tool_context._invocation_context.app_name + user_id = tool_context._invocation_context.user_id + + if app_name not in self._credentials: + self._credentials[app_name] = {} + if user_id not in self._credentials[app_name]: + self._credentials[app_name][user_id] = {} + return self._credentials[app_name][user_id] diff --git a/src/google/adk/auth/credential_service/session_state_credential_service.py b/src/google/adk/auth/credential_service/session_state_credential_service.py new file mode 100644 index 000000000..e2ff7e07d --- /dev/null +++ b/src/google/adk/auth/credential_service/session_state_credential_service.py @@ -0,0 +1,83 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from typing_extensions import override + +from ...tools.tool_context import ToolContext +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_tool import AuthConfig +from .base_credential_service import BaseCredentialService + + +@experimental +class SessionStateCredentialService(BaseCredentialService): + """Class for implementation of credential service using session state as the + store. + Note: store credential in session may not be secure, use at your own risk. + """ + + @override + async def load_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> Optional[AuthCredential]: + """ + Loads the credential by auth config and current tool context from the + backend credential store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to load the credential. + + tool_context: The context of the current invocation when the tool is + trying to load the credential. + + Returns: + Optional[AuthCredential]: the credential saved in the store. + + """ + return tool_context.state.get(auth_config.credential_key) + + @override + async def save_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> None: + """ + Saves the exchanged_auth_credential in auth config to the backend credential + store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to save the credential. + + tool_context: The context of the current invocation when the tool is + trying to save the credential. + + Returns: + None + """ + + tool_context.state[auth_config.credential_key] = ( + auth_config.exchanged_auth_credential + ) diff --git a/src/google/adk/auth/exchanger/__init__.py b/src/google/adk/auth/exchanger/__init__.py new file mode 100644 index 000000000..3b0fbb246 --- /dev/null +++ b/src/google/adk/auth/exchanger/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential exchanger module.""" + +from .base_credential_exchanger import BaseCredentialExchanger + +__all__ = [ + "BaseCredentialExchanger", +] diff --git a/src/google/adk/auth/exchanger/base_credential_exchanger.py b/src/google/adk/auth/exchanger/base_credential_exchanger.py new file mode 100644 index 000000000..b09adb80a --- /dev/null +++ b/src/google/adk/auth/exchanger/base_credential_exchanger.py @@ -0,0 +1,57 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base credential exchanger interface.""" + +from __future__ import annotations + +import abc +from typing import Optional + +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_schemes import AuthScheme + + +class CredentialExchangError(Exception): + """Base exception for credential exchange errors.""" + + +@experimental +class BaseCredentialExchanger(abc.ABC): + """Base interface for credential exchangers. + + Credential exchangers are responsible for exchanging credentials from + one format or scheme to another. + """ + + @abc.abstractmethod + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Exchange credential if needed. + + Args: + auth_credential: The credential to exchange. + auth_scheme: The authentication scheme (optional, some exchangers don't need it). + + Returns: + The exchanged credential. + + Raises: + CredentialExchangError: If credential exchange fails. + """ + pass diff --git a/src/google/adk/auth/exchanger/credential_exchanger_registry.py b/src/google/adk/auth/exchanger/credential_exchanger_registry.py new file mode 100644 index 000000000..5af7f3c1a --- /dev/null +++ b/src/google/adk/auth/exchanger/credential_exchanger_registry.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential exchanger registry.""" + +from __future__ import annotations + +from typing import Dict +from typing import Optional + +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredentialTypes +from .base_credential_exchanger import BaseCredentialExchanger + + +@experimental +class CredentialExchangerRegistry: + """Registry for credential exchanger instances.""" + + def __init__(self): + self._exchangers: Dict[AuthCredentialTypes, BaseCredentialExchanger] = {} + + def register( + self, + credential_type: AuthCredentialTypes, + exchanger_instance: BaseCredentialExchanger, + ) -> None: + """Register an exchanger instance for a credential type. + + Args: + credential_type: The credential type to register for. + exchanger_instance: The exchanger instance to register. + """ + self._exchangers[credential_type] = exchanger_instance + + def get_exchanger( + self, credential_type: AuthCredentialTypes + ) -> Optional[BaseCredentialExchanger]: + """Get the exchanger instance for a credential type. + + Args: + credential_type: The credential type to get exchanger for. + + Returns: + The exchanger instance if registered, None otherwise. + """ + return self._exchangers.get(credential_type) diff --git a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py new file mode 100644 index 000000000..768457e1a --- /dev/null +++ b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 credential exchanger implementation.""" + +from __future__ import annotations + +import logging +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import OAuthGrantType +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +from google.adk.utils.feature_decorator import experimental +from typing_extensions import override + +from .base_credential_exchanger import BaseCredentialExchanger +from .base_credential_exchanger import CredentialExchangError + +try: + from authlib.integrations.requests_client import OAuth2Session + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class OAuth2CredentialExchanger(BaseCredentialExchanger): + """Exchanges OAuth2 credentials from authorization responses.""" + + @override + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Exchange OAuth2 credential from authorization response. + if credential exchange failed, the original credential will be returned. + + Args: + auth_credential: The OAuth2 credential to exchange. + auth_scheme: The OAuth2 authentication scheme. + + Returns: + The exchanged credential with access token. + + Raises: + CredentialExchangError: If auth_scheme is missing. + """ + if not auth_scheme: + raise CredentialExchangError( + "auth_scheme is required for OAuth2 credential exchange" + ) + + if not AUTHLIB_AVIALABLE: + # If authlib is not available, we cannot exchange the credential. + # We return the original credential without exchange. + # The client using this tool can decide to exchange the credential + # themselves using other lib. + logger.warning( + "authlib is not available, skipping OAuth2 credential exchange." + ) + return auth_credential + + if auth_credential.oauth2 and auth_credential.oauth2.access_token: + return auth_credential + + client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential) + if not client: + logger.warning("Could not create OAuth2 session for token exchange") + return auth_credential + + try: + tokens = client.fetch_token( + token_endpoint, + authorization_response=auth_credential.oauth2.auth_response_uri, + code=auth_credential.oauth2.auth_code, + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + ) + update_credential_with_tokens(auth_credential, tokens) + logger.debug("Successfully exchanged OAuth2 tokens") + except Exception as e: + # TODO reconsider whether we should raise errors in this case + logger.error("Failed to exchange OAuth2 tokens: %s", e) + # Return original credential on failure + return auth_credential + + return auth_credential diff --git a/src/google/adk/auth/oauth2_credential_fetcher.py b/src/google/adk/auth/oauth2_credential_fetcher.py deleted file mode 100644 index 1a8692417..000000000 --- a/src/google/adk/auth/oauth2_credential_fetcher.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import logging -from typing import Optional -from typing import Tuple - -from fastapi.openapi.models import OAuth2 - -from .auth_credential import AuthCredential -from .auth_schemes import AuthScheme -from .auth_schemes import OAuthGrantType -from .auth_schemes import OpenIdConnectWithConfig - -try: - from authlib.integrations.requests_client import OAuth2Session - from authlib.oauth2.rfc6749 import OAuth2Token - - AUTHLIB_AVIALABLE = True -except ImportError: - AUTHLIB_AVIALABLE = False - - -logger = logging.getLogger("google_adk." + __name__) - - -class OAuth2CredentialFetcher: - """Exchanges and refreshes an OAuth2 access token.""" - - def __init__( - self, - auth_scheme: AuthScheme, - auth_credential: AuthCredential, - ): - self._auth_scheme = auth_scheme - self._auth_credential = auth_credential - - def _oauth2_session(self) -> Tuple[Optional[OAuth2Session], Optional[str]]: - auth_scheme = self._auth_scheme - auth_credential = self._auth_credential - - if isinstance(auth_scheme, OpenIdConnectWithConfig): - if not hasattr(auth_scheme, "token_endpoint"): - return None, None - token_endpoint = auth_scheme.token_endpoint - scopes = auth_scheme.scopes - elif isinstance(auth_scheme, OAuth2): - if ( - not auth_scheme.flows.authorizationCode - or not auth_scheme.flows.authorizationCode.tokenUrl - ): - return None, None - token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl - scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) - else: - return None, None - - if ( - not auth_credential - or not auth_credential.oauth2 - or not auth_credential.oauth2.client_id - or not auth_credential.oauth2.client_secret - ): - return None, None - - return ( - OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, - scope=" ".join(scopes), - redirect_uri=auth_credential.oauth2.redirect_uri, - state=auth_credential.oauth2.state, - ), - token_endpoint, - ) - - def _update_credential(self, tokens: OAuth2Token) -> None: - self._auth_credential.oauth2.access_token = tokens.get("access_token") - self._auth_credential.oauth2.refresh_token = tokens.get("refresh_token") - self._auth_credential.oauth2.expires_at = ( - int(tokens.get("expires_at")) if tokens.get("expires_at") else None - ) - self._auth_credential.oauth2.expires_in = ( - int(tokens.get("expires_in")) if tokens.get("expires_in") else None - ) - - def exchange(self) -> AuthCredential: - """Exchange an oauth token from the authorization response. - - Returns: - An AuthCredential object containing the access token. - """ - if not AUTHLIB_AVIALABLE: - return self._auth_credential - - if ( - self._auth_credential.oauth2 - and self._auth_credential.oauth2.access_token - ): - return self._auth_credential - - client, token_endpoint = self._oauth2_session() - if not client: - logger.warning("Could not create OAuth2 session for token exchange") - return self._auth_credential - - try: - tokens = client.fetch_token( - token_endpoint, - authorization_response=self._auth_credential.oauth2.auth_response_uri, - code=self._auth_credential.oauth2.auth_code, - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - ) - self._update_credential(tokens) - logger.info("Successfully exchanged OAuth2 tokens") - except Exception as e: - logger.error("Failed to exchange OAuth2 tokens: %s", e) - # Return original credential on failure - return self._auth_credential - - return self._auth_credential - - def refresh(self) -> AuthCredential: - """Refresh an oauth token. - - Returns: - An AuthCredential object containing the refreshed access token. - """ - if not AUTHLIB_AVIALABLE: - return self._auth_credential - credential = self._auth_credential - if not credential.oauth2: - return credential - - if OAuth2Token({ - "expires_at": credential.oauth2.expires_at, - "expires_in": credential.oauth2.expires_in, - }).is_expired(): - client, token_endpoint = self._oauth2_session() - if not client: - logger.warning("Could not create OAuth2 session for token refresh") - return credential - - try: - tokens = client.refresh_token( - url=token_endpoint, - refresh_token=credential.oauth2.refresh_token, - ) - self._update_credential(tokens) - logger.info("Successfully refreshed OAuth2 tokens") - except Exception as e: - logger.error("Failed to refresh OAuth2 tokens: %s", e) - # Return original credential on failure - return credential - - return self._auth_credential diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py new file mode 100644 index 000000000..51ed4d29f --- /dev/null +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional +from typing import Tuple + +from fastapi.openapi.models import OAuth2 + +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_schemes import AuthScheme +from .auth_schemes import OpenIdConnectWithConfig + +try: + from authlib.integrations.requests_client import OAuth2Session + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +def create_oauth2_session( + auth_scheme: AuthScheme, + auth_credential: AuthCredential, +) -> Tuple[Optional[OAuth2Session], Optional[str]]: + """Create an OAuth2 session for token operations. + + Args: + auth_scheme: The authentication scheme configuration. + auth_credential: The authentication credential. + + Returns: + Tuple of (OAuth2Session, token_endpoint) or (None, None) if cannot create session. + """ + if isinstance(auth_scheme, OpenIdConnectWithConfig): + if not hasattr(auth_scheme, "token_endpoint"): + return None, None + token_endpoint = auth_scheme.token_endpoint + scopes = auth_scheme.scopes + elif isinstance(auth_scheme, OAuth2): + if ( + not auth_scheme.flows.authorizationCode + or not auth_scheme.flows.authorizationCode.tokenUrl + ): + return None, None + token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl + scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) + else: + return None, None + + if ( + not auth_credential + or not auth_credential.oauth2 + or not auth_credential.oauth2.client_id + or not auth_credential.oauth2.client_secret + ): + return None, None + + return ( + OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + scope=" ".join(scopes), + redirect_uri=auth_credential.oauth2.redirect_uri, + state=auth_credential.oauth2.state, + ), + token_endpoint, + ) + + +@experimental +def update_credential_with_tokens( + auth_credential: AuthCredential, tokens: OAuth2Token +) -> None: + """Update the credential with new tokens. + + Args: + auth_credential: The authentication credential to update. + tokens: The OAuth2Token object containing new token information. + """ + auth_credential.oauth2.access_token = tokens.get("access_token") + auth_credential.oauth2.refresh_token = tokens.get("refresh_token") + auth_credential.oauth2.expires_at = ( + int(tokens.get("expires_at")) if tokens.get("expires_at") else None + ) + auth_credential.oauth2.expires_in = ( + int(tokens.get("expires_in")) if tokens.get("expires_in") else None + ) diff --git a/src/google/adk/auth/refresher/__init__.py b/src/google/adk/auth/refresher/__init__.py new file mode 100644 index 000000000..27d7245dc --- /dev/null +++ b/src/google/adk/auth/refresher/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential refresher module.""" + +from .base_credential_refresher import BaseCredentialRefresher + +__all__ = [ + "BaseCredentialRefresher", +] diff --git a/src/google/adk/auth/refresher/base_credential_refresher.py b/src/google/adk/auth/refresher/base_credential_refresher.py new file mode 100644 index 000000000..230b07d09 --- /dev/null +++ b/src/google/adk/auth/refresher/base_credential_refresher.py @@ -0,0 +1,74 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base credential refresher interface.""" + +from __future__ import annotations + +import abc +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.utils.feature_decorator import experimental + + +class CredentialRefresherError(Exception): + """Base exception for credential refresh errors.""" + + +@experimental +class BaseCredentialRefresher(abc.ABC): + """Base interface for credential refreshers. + + Credential refreshers are responsible for checking if a credential is expired + or needs to be refreshed, and for refreshing it if necessary. + """ + + @abc.abstractmethod + async def is_refresh_needed( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> bool: + """Checks if a credential needs to be refreshed. + + Args: + auth_credential: The credential to check. + auth_scheme: The authentication scheme (optional, some refreshers don't need it). + + Returns: + True if the credential needs to be refreshed, False otherwise. + """ + pass + + @abc.abstractmethod + async def refresh( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Refreshes a credential if needed. + + Args: + auth_credential: The credential to refresh. + auth_scheme: The authentication scheme (optional, some refreshers don't need it). + + Returns: + The refreshed credential. + + Raises: + CredentialRefresherError: If credential refresh fails. + """ + pass diff --git a/src/google/adk/auth/refresher/credential_refresher_registry.py b/src/google/adk/auth/refresher/credential_refresher_registry.py new file mode 100644 index 000000000..90975d66d --- /dev/null +++ b/src/google/adk/auth/refresher/credential_refresher_registry.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential refresher registry.""" + +from __future__ import annotations + +from typing import Dict +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.utils.feature_decorator import experimental + +from .base_credential_refresher import BaseCredentialRefresher + + +@experimental +class CredentialRefresherRegistry: + """Registry for credential refresher instances.""" + + def __init__(self): + self._refreshers: Dict[AuthCredentialTypes, BaseCredentialRefresher] = {} + + def register( + self, + credential_type: AuthCredentialTypes, + refresher_instance: BaseCredentialRefresher, + ) -> None: + """Register a refresher instance for a credential type. + + Args: + credential_type: The credential type to register for. + refresher_instance: The refresher instance to register. + """ + self._refreshers[credential_type] = refresher_instance + + def get_refresher( + self, credential_type: AuthCredentialTypes + ) -> Optional[BaseCredentialRefresher]: + """Get the refresher instance for a credential type. + + Args: + credential_type: The credential type to get refresher for. + + Returns: + The refresher instance if registered, None otherwise. + """ + return self._refreshers.get(credential_type) diff --git a/src/google/adk/auth/refresher/oauth2_credential_refresher.py b/src/google/adk/auth/refresher/oauth2_credential_refresher.py new file mode 100644 index 000000000..4c19520ce --- /dev/null +++ b/src/google/adk/auth/refresher/oauth2_credential_refresher.py @@ -0,0 +1,126 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 credential refresher implementation.""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +from google.adk.utils.feature_decorator import experimental +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from typing_extensions import override + +from .base_credential_refresher import BaseCredentialRefresher + +try: + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class OAuth2CredentialRefresher(BaseCredentialRefresher): + """Refreshes OAuth2 credentials including Google OAuth2 JSON credentials.""" + + @override + async def is_refresh_needed( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> bool: + """Check if the OAuth2 credential needs to be refreshed. + + Args: + auth_credential: The OAuth2 credential to check. + auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON). + + Returns: + True if the credential needs to be refreshed, False otherwise. + """ + + # Handle regular OAuth2 credentials + if auth_credential.oauth2: + if not AUTHLIB_AVIALABLE: + return False + + return OAuth2Token({ + "expires_at": auth_credential.oauth2.expires_at, + "expires_in": auth_credential.oauth2.expires_in, + }).is_expired() + + return False + + @override + async def refresh( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Refresh the OAuth2 credential. + If refresh failed, return the original credential. + + Args: + auth_credential: The OAuth2 credential to refresh. + auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON). + + Returns: + The refreshed credential. + + """ + + # Handle regular OAuth2 credentials + if auth_credential.oauth2 and auth_scheme: + if not AUTHLIB_AVIALABLE: + return auth_credential + + if not auth_credential.oauth2: + return auth_credential + + if OAuth2Token({ + "expires_at": auth_credential.oauth2.expires_at, + "expires_in": auth_credential.oauth2.expires_in, + }).is_expired(): + client, token_endpoint = create_oauth2_session( + auth_scheme, auth_credential + ) + if not client: + logger.warning("Could not create OAuth2 session for token refresh") + return auth_credential + + try: + tokens = client.refresh_token( + url=token_endpoint, + refresh_token=auth_credential.oauth2.refresh_token, + ) + update_credential_with_tokens(auth_credential, tokens) + logger.debug("Successfully refreshed OAuth2 tokens") + except Exception as e: + # TODO reconsider whether we should raise error when refresh failed. + logger.error("Failed to refresh OAuth2 tokens: %s", e) + # Return original credential on failure + return auth_credential + + return auth_credential diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index aceb3fcce..79d0bfe65 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -24,6 +24,8 @@ from ..agents.llm_agent import LlmAgent from ..artifacts import BaseArtifactService from ..artifacts import InMemoryArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService +from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService @@ -43,6 +45,7 @@ async def run_input_file( root_agent: LlmAgent, artifact_service: BaseArtifactService, session_service: BaseSessionService, + credential_service: BaseCredentialService, input_path: str, ) -> Session: runner = Runner( @@ -50,6 +53,7 @@ async def run_input_file( agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, ) with open(input_path, 'r', encoding='utf-8') as f: input_file = InputFile.model_validate_json(f.read()) @@ -75,12 +79,14 @@ async def run_interactively( artifact_service: BaseArtifactService, session: Session, session_service: BaseSessionService, + credential_service: BaseCredentialService, ) -> None: runner = Runner( app_name=session.app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, ) while True: query = input('[user]: ') @@ -125,6 +131,7 @@ async def run_cli( artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() + credential_service = InMemoryCredentialService() user_id = 'test_user' session = await session_service.create_session( @@ -141,6 +148,7 @@ async def run_cli( root_agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, input_path=input_file, ) elif saved_session_file: @@ -163,6 +171,7 @@ async def run_cli( artifact_service, session, session_service, + credential_service, ) else: click.echo(f'Running agent {root_agent.name}, type exit to exit.') @@ -171,6 +180,7 @@ async def run_cli( artifact_service, session, session_service, + credential_service, ) if save_session: diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 99c7e9bb1..44d4a900d 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -55,7 +55,7 @@ EXPOSE {port} -CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} "/app/agents" +CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} {allow_origins_option} "/app/agents" """ _AGENT_ENGINE_APP_TEMPLATE = """ @@ -121,8 +121,10 @@ def to_cloud_run( port: int, trace_to_cloud: bool, with_ui: bool, + log_level: str, verbosity: str, adk_version: str, + allow_origins: Optional[list[str]] = None, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, @@ -150,6 +152,7 @@ def to_cloud_run( app_name: The name of the app, by default, it's basename of `agent_folder`. temp_folder: The temp folder for the generated Cloud Run source files. port: The port of the ADK api server. + allow_origins: The list of allowed origins for the ADK api server. trace_to_cloud: Whether to enable Cloud Trace. with_ui: Whether to deploy with UI. verbosity: The verbosity level of the CLI. @@ -183,6 +186,9 @@ def to_cloud_run( # create Dockerfile click.echo('Creating Dockerfile...') host_option = '--host=0.0.0.0' if adk_version > '0.5.0' else '' + allow_origins_option = ( + f'--allow_origins={",".join(allow_origins)}' if allow_origins else '' + ) dockerfile_content = _DOCKERFILE_TEMPLATE.format( gcp_project_id=project, gcp_region=region, @@ -197,6 +203,7 @@ def to_cloud_run( memory_service_uri, ), trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '', + allow_origins_option=allow_origins_option, adk_version=adk_version, host_option=host_option, ) @@ -226,7 +233,7 @@ def to_cloud_run( '--port', str(port), '--verbosity', - verbosity, + log_level.lower() if log_level else verbosity, '--labels', 'created-by=adk', ], diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 8f45db96d..c0935cceb 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -31,14 +31,22 @@ from . import cli_create from . import cli_deploy from .. import version +from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..sessions.in_memory_session_service import InMemorySessionService from .cli import run_cli from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE from .fast_api import get_fast_api_app from .utils import envs +from .utils import evals from .utils import logs +LOG_LEVELS = click.Choice( + ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + case_sensitive=False, +) + class HelpfulCommand(click.Command): """Command that shows full help on error instead of just the error message. @@ -277,11 +285,21 @@ def cli_run( default=False, help="Optional. Whether to print detailed results on console or not.", ) +@click.option( + "--eval_storage_uri", + type=str, + help=( + "Optional. The evals storage URI to store agent evals," + " supported URIs: gs://." + ), + default=None, +) def cli_eval( agent_module_file_path: str, - eval_set_file_path: tuple[str], + eval_set_file_path: list[str], config_file_path: str, print_detailed_results: bool, + eval_storage_uri: Optional[str] = None, ): """Evaluates an agent given the eval sets. @@ -333,12 +351,33 @@ def cli_eval( root_agent = get_root_agent(agent_module_file_path) reset_func = try_get_reset_func(agent_module_file_path) + gcs_eval_sets_manager = None + eval_set_results_manager = None + if eval_storage_uri: + gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( + eval_storage_uri + ) + gcs_eval_sets_manager = gcs_eval_managers.eval_sets_manager + eval_set_results_manager = gcs_eval_managers.eval_set_results_manager + else: + eval_set_results_manager = LocalEvalSetResultsManager( + agents_dir=os.path.dirname(agent_module_file_path) + ) eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path) eval_set_id_to_eval_cases = {} # Read the eval_set files and get the cases. for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items(): - eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) + if gcs_eval_sets_manager: + eval_set = gcs_eval_sets_manager._load_eval_set_from_blob( + eval_set_file_path + ) + if not eval_set: + raise click.ClickException( + f"Eval set {eval_set_file_path} not found in GCS." + ) + else: + eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) eval_cases = eval_set.eval_cases if eval_case_ids: @@ -373,16 +412,13 @@ async def _collect_eval_results() -> list[EvalCaseResult]: raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) # Write eval set results. - local_eval_set_results_manager = LocalEvalSetResultsManager( - agents_dir=os.path.dirname(agent_module_file_path) - ) eval_set_id_to_eval_results = collections.defaultdict(list) for eval_case_result in eval_results: eval_set_id = eval_case_result.eval_set_id eval_set_id_to_eval_results[eval_set_id].append(eval_case_result) for eval_set_id, eval_case_results in eval_set_id_to_eval_results.items(): - local_eval_set_results_manager.save_eval_set_result( + eval_set_results_manager.save_eval_set_result( app_name=os.path.basename(agent_module_file_path), eval_set_id=eval_set_id, eval_case_results=eval_case_results, @@ -439,12 +475,22 @@ def decorator(func): ), default=None, ) + @click.option( + "--eval_storage_uri", + type=str, + help=( + "Optional. The evals storage URI to store agent evals," + " supported URIs: gs://." + ), + default=None, + ) @click.option( "--memory_service_uri", type=str, help=( """Optional. The URI of the memory service. - - Use 'rag://' to connect to Vertex AI Rag Memory Service.""" + - Use 'rag://' to connect to Vertex AI Rag Memory Service. + - Use 'agentengine://' to connect to Vertex AI Memory Bank Service. e.g. agentengine://12345""" ), default=None, ) @@ -498,13 +544,6 @@ def fast_api_common_options(): """Decorator to add common fast api options to click commands.""" def decorator(func): - @click.option( - "--host", - type=str, - help="Optional. The binding host of the server", - default="127.0.0.1", - show_default=True, - ) @click.option( "--port", type=int, @@ -518,10 +557,7 @@ def decorator(func): ) @click.option( "--log_level", - type=click.Choice( - ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - case_sensitive=False, - ), + type=LOG_LEVELS, default="INFO", help="Optional. Set the logging level", ) @@ -535,7 +571,10 @@ def decorator(func): @click.option( "--reload/--no-reload", default=True, - help="Optional. Whether to enable auto reload for server.", + help=( + "Optional. Whether to enable auto reload for server. Not supported" + " for Cloud Run." + ), ) @functools.wraps(func) def wrapper(*args, **kwargs): @@ -547,6 +586,13 @@ def wrapper(*args, **kwargs): @main.command("web") +@click.option( + "--host", + type=str, + help="Optional. The binding host of the server", + default="127.0.0.1", + show_default=True, +) @fast_api_common_options() @adk_services_options() @deprecated_adk_services_options() @@ -559,6 +605,7 @@ def wrapper(*args, **kwargs): ) def cli_web( agents_dir: str, + eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, host: str = "127.0.0.1", @@ -578,7 +625,7 @@ def cli_web( Example: - adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir + adk web --port=[port] path/to/agents_dir """ logs.setup_adk_logger(getattr(logging, log_level.upper())) @@ -611,6 +658,7 @@ async def _lifespan(app: FastAPI): session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, memory_service_uri=memory_service_uri, + eval_storage_uri=eval_storage_uri, allow_origins=allow_origins, web=True, trace_to_cloud=trace_to_cloud, @@ -628,6 +676,16 @@ async def _lifespan(app: FastAPI): @main.command("api_server") +@click.option( + "--host", + type=str, + help="Optional. The binding host of the server", + default="127.0.0.1", + show_default=True, +) +@fast_api_common_options() +@adk_services_options() +@deprecated_adk_services_options() # The directory of agents, where each sub-directory is a single agent. # By default, it is the current working directory @click.argument( @@ -637,11 +695,9 @@ async def _lifespan(app: FastAPI): ), default=os.getcwd(), ) -@fast_api_common_options() -@adk_services_options() -@deprecated_adk_services_options() def cli_api_server( agents_dir: str, + eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, host: str = "127.0.0.1", @@ -661,7 +717,7 @@ def cli_api_server( Example: - adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir + adk api_server --port=[port] path/to/agents_dir """ logs.setup_adk_logger(getattr(logging, log_level.upper())) @@ -673,6 +729,7 @@ def cli_api_server( session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, memory_service_uri=memory_service_uri, + eval_storage_uri=eval_storage_uri, allow_origins=allow_origins, web=False, trace_to_cloud=trace_to_cloud, @@ -720,19 +777,7 @@ def cli_api_server( " of the AGENT source code)." ), ) -@click.option( - "--port", - type=int, - default=8000, - help="Optional. The port of the ADK API server (default: 8000).", -) -@click.option( - "--trace_to_cloud", - is_flag=True, - show_default=True, - default=False, - help="Optional. Whether to enable Cloud Trace for cloud run.", -) +@fast_api_common_options() @click.option( "--with_ui", is_flag=True, @@ -743,6 +788,11 @@ def cli_api_server( " only)" ), ) +@click.option( + "--verbosity", + type=LOG_LEVELS, + help="Deprecated. Use --log_level instead.", +) @click.option( "--temp_folder", type=str, @@ -756,20 +806,6 @@ def cli_api_server( " (default: a timestamped folder in the system temp directory)." ), ) -@click.option( - "--verbosity", - type=click.Choice( - ["debug", "info", "warning", "error", "critical"], case_sensitive=False - ), - default="WARNING", - help="Optional. Override the default verbosity level.", -) -@click.argument( - "agent", - type=click.Path( - exists=True, dir_okay=True, file_okay=False, resolve_path=True - ), -) @click.option( "--adk_version", type=str, @@ -780,8 +816,23 @@ def cli_api_server( " version in the dev environment)" ), ) +@click.option( + "--eval_storage_uri", + type=str, + help=( + "Optional. The evals storage URI to store agent evals," + " supported URIs: gs://." + ), + default=None, +) @adk_services_options() @deprecated_adk_services_options() +@click.argument( + "agent", + type=click.Path( + exists=True, dir_okay=True, file_okay=False, resolve_path=True + ), +) def cli_deploy_cloud_run( agent: str, project: Optional[str], @@ -792,11 +843,15 @@ def cli_deploy_cloud_run( port: int, trace_to_cloud: bool, with_ui: bool, - verbosity: str, adk_version: str, + log_level: Optional[str] = None, + verbosity: str = "WARNING", + reload: bool = True, + allow_origins: Optional[list[str]] = None, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, + eval_storage_uri: Optional[str] = None, session_db_url: Optional[str] = None, # Deprecated artifact_storage_uri: Optional[str] = None, # Deprecated ): @@ -808,6 +863,7 @@ def cli_deploy_cloud_run( adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent """ + log_level = log_level or verbosity session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri try: @@ -820,7 +876,9 @@ def cli_deploy_cloud_run( temp_folder=temp_folder, port=port, trace_to_cloud=trace_to_cloud, + allow_origins=allow_origins, with_ui=with_ui, + log_level=log_level, verbosity=verbosity, adk_version=adk_version, session_service_uri=session_service_uri, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4512174c5..abe1961e7 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -57,6 +57,7 @@ from ..agents.run_config import StreamingMode from ..artifacts.gcs_artifact_service import GcsArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService +from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..errors.not_found_error import NotFoundError from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import SessionInput @@ -64,10 +65,13 @@ from ..evaluation.eval_metrics import EvalMetricResult from ..evaluation.eval_metrics import EvalMetricResultPerInvocation from ..evaluation.eval_result import EvalSetResult +from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..events.event import Event from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService from ..runners import Runner from ..sessions.database_session_service import DatabaseSessionService @@ -197,6 +201,7 @@ def get_fast_api_app( session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, + eval_storage_uri: Optional[str] = None, allow_origins: Optional[list[str]] = None, web: bool, trace_to_cloud: bool = False, @@ -255,8 +260,18 @@ async def internal_lifespan(app: FastAPI): runner_dict = {} - eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) - eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) + # Set up eval managers. + eval_sets_manager = None + eval_set_results_manager = None + if eval_storage_uri: + gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( + eval_storage_uri + ) + eval_sets_manager = gcs_eval_managers.eval_sets_manager + eval_set_results_manager = gcs_eval_managers.eval_set_results_manager + else: + eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) + eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) # Build the Memory service if memory_service_uri: @@ -268,6 +283,16 @@ async def internal_lifespan(app: FastAPI): memory_service = VertexAiRagMemoryService( rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}' ) + elif memory_service_uri.startswith("agentengine://"): + agent_engine_id = memory_service_uri.split("://")[1] + if not agent_engine_id: + raise click.ClickException("Agent engine id can not be empty.") + envs.load_dotenv_for_agent("", agents_dir) + memory_service = VertexAiMemoryBankService( + project=os.environ["GOOGLE_CLOUD_PROJECT"], + location=os.environ["GOOGLE_CLOUD_LOCATION"], + agent_engine_id=agent_engine_id, + ) else: raise click.ClickException( "Unsupported memory service URI: %s" % memory_service_uri @@ -305,6 +330,9 @@ async def internal_lifespan(app: FastAPI): else: artifact_service = InMemoryArtifactService() + # Build the Credential service + credential_service = InMemoryCredentialService() + # initialize Agent Loader agent_loader = AgentLoader(agents_dir) @@ -929,6 +957,7 @@ async def _get_runner_async(app_name: str) -> Runner: artifact_service=artifact_service, session_service=session_service, memory_service=memory_service, + credential_service=credential_service, ) runner_dict[app_name] = runner return runner diff --git a/src/google/adk/cli/utils/evals.py b/src/google/adk/cli/utils/evals.py index c8d1a3296..305d47544 100644 --- a/src/google/adk/cli/utils/evals.py +++ b/src/google/adk/cli/utils/evals.py @@ -14,17 +14,36 @@ from __future__ import annotations +import dataclasses +import os from typing import Any from typing import Tuple from google.genai import types as genai_types +from pydantic import alias_generators +from pydantic import BaseModel +from pydantic import ConfigDict from typing_extensions import deprecated from ...evaluation.eval_case import IntermediateData from ...evaluation.eval_case import Invocation +from ...evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from ...evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ...sessions.session import Session +class GcsEvalManagers(BaseModel): + model_config = ConfigDict( + alias_generator=alias_generators.to_camel, + populate_by_name=True, + arbitrary_types_allowed=True, + ) + + eval_sets_manager: GcsEvalSetsManager + + eval_set_results_manager: GcsEvalSetResultsManager + + @deprecated('Use convert_session_to_eval_invocations instead.') def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]: """Converts a session data into eval format. @@ -176,3 +195,37 @@ def convert_session_to_eval_invocations(session: Session) -> list[Invocation]: ) return invocations + + +def create_gcs_eval_managers_from_uri( + eval_storage_uri: str, +) -> GcsEvalManagers: + """Creates GcsEvalManagers from eval_storage_uri. + + Args: + eval_storage_uri: The evals storage URI to use. Supported URIs: + gs://. If a path is provided, the bucket will be extracted. + + Returns: + GcsEvalManagers: The GcsEvalManagers object. + + Raises: + ValueError: If the eval_storage_uri is not supported. + """ + if eval_storage_uri.startswith('gs://'): + gcs_bucket = eval_storage_uri.split('://')[1] + eval_sets_manager = GcsEvalSetsManager( + bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT'] + ) + eval_set_results_manager = GcsEvalSetResultsManager( + bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT'] + ) + return GcsEvalManagers( + eval_sets_manager=eval_sets_manager, + eval_set_results_manager=eval_set_results_manager, + ) + else: + raise ValueError( + f'Unsupported evals storage URI: {eval_storage_uri}. Supported URIs:' + ' gs://' + ) diff --git a/src/google/adk/evaluation/final_response_match_v1.py b/src/google/adk/evaluation/final_response_match_v1.py new file mode 100644 index 000000000..a034b470f --- /dev/null +++ b/src/google/adk/evaluation/final_response_match_v1.py @@ -0,0 +1,110 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from google.genai import types as genai_types +from rouge_score import rouge_scorer +from typing_extensions import override + +from .eval_case import Invocation +from .eval_metrics import EvalMetric +from .evaluator import EvalStatus +from .evaluator import EvaluationResult +from .evaluator import Evaluator +from .evaluator import PerInvocationResult + + +class RougeEvaluator(Evaluator): + """Calculates the ROUGE-1 metric to compare responses.""" + + def __init__(self, eval_metric: EvalMetric): + self._eval_metric = eval_metric + + @override + def evaluate_invocations( + self, + actual_invocations: list[Invocation], + expected_invocations: list[Invocation], + ) -> EvaluationResult: + total_score = 0.0 + num_invocations = 0 + per_invocation_results = [] + for actual, expected in zip(actual_invocations, expected_invocations): + reference = _get_text_from_content(expected.final_response) + response = _get_text_from_content(actual.final_response) + rouge_1_scores = _calculate_rouge_1_scores(response, reference) + score = rouge_1_scores.fmeasure + per_invocation_results.append( + PerInvocationResult( + actual_invocation=actual, + expected_invocation=expected, + score=score, + eval_status=_get_eval_status(score, self._eval_metric.threshold), + ) + ) + total_score += score + num_invocations += 1 + + if per_invocation_results: + overall_score = total_score / num_invocations + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=_get_eval_status( + overall_score, self._eval_metric.threshold + ), + per_invocation_results=per_invocation_results, + ) + + return EvaluationResult() + + +def _get_text_from_content(content: Optional[genai_types.Content]) -> str: + if content and content.parts: + return "\n".join([part.text for part in content.parts if part.text]) + + return "" + + +def _get_eval_status(score: float, threshold: float): + return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED + + +def _calculate_rouge_1_scores(candidate: str, reference: str): + """Calculates the ROUGE-1 score between a candidate and reference text. + + ROUGE-1 measures the overlap of unigrams (single words) between the + candidate and reference texts. The score is broken down into: + - Precision: The proportion of unigrams in the candidate that are also in the + reference. + - Recall: The proportion of unigrams in the reference that are also in the + candidate. + - F-measure: The harmonic mean of precision and recall. + + Args: + candidate: The generated text to be evaluated. + reference: The ground-truth text to compare against. + + Returns: + A dictionary containing the ROUGE-1 precision, recall, and f-measure. + """ + scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True) + + # The score method returns a dictionary where keys are the ROUGE types + # and values are Score objects (tuples) with precision, recall, and fmeasure. + scores = scorer.score(reference, candidate) + + return scores["rouge1"] diff --git a/src/google/adk/evaluation/gcs_eval_sets_manager.py b/src/google/adk/evaluation/gcs_eval_sets_manager.py index fe5d8c9b5..c253e4cd5 100644 --- a/src/google/adk/evaluation/gcs_eval_sets_manager.py +++ b/src/google/adk/evaluation/gcs_eval_sets_manager.py @@ -72,6 +72,13 @@ def _validate_id(self, id_name: str, id_value: str): f"Invalid {id_name}. {id_name} should have the `{pattern}` format", ) + def _load_eval_set_from_blob(self, blob_name: str) -> Optional[EvalSet]: + blob = self.bucket.blob(blob_name) + if not blob.exists(): + return None + eval_set_data = blob.download_as_text() + return EvalSet.model_validate_json(eval_set_data) + def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet): """Writes an EvalSet to GCS.""" blob = self.bucket.blob(blob_name) @@ -88,11 +95,7 @@ def _save_eval_set(self, app_name: str, eval_set_id: str, eval_set: EvalSet): def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]: """Returns an EvalSet identified by an app_name and eval_set_id.""" eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id) - blob = self.bucket.blob(eval_set_blob_name) - if not blob.exists(): - return None - eval_set_data = blob.download_as_text() - return EvalSet.model_validate_json(eval_set_data) + return self._load_eval_set_from_blob(eval_set_blob_name) @override def create_eval_set(self, app_name: str, eval_set_id: str): diff --git a/src/google/adk/evaluation/response_evaluator.py b/src/google/adk/evaluation/response_evaluator.py index 52ab50c74..0826f8796 100644 --- a/src/google/adk/evaluation/response_evaluator.py +++ b/src/google/adk/evaluation/response_evaluator.py @@ -27,10 +27,12 @@ from .eval_case import IntermediateData from .eval_case import Invocation +from .eval_metrics import EvalMetric from .evaluator import EvalStatus from .evaluator import EvaluationResult from .evaluator import Evaluator from .evaluator import PerInvocationResult +from .final_response_match_v1 import RougeEvaluator class ResponseEvaluator(Evaluator): @@ -40,7 +42,7 @@ def __init__(self, threshold: float, metric_name: str): if "response_evaluation_score" == metric_name: self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE elif "response_match_score" == metric_name: - self._metric_name = "rouge_1" + self._metric_name = "response_match_score" else: raise ValueError(f"`{metric_name}` is not supported.") @@ -52,6 +54,15 @@ def evaluate_invocations( actual_invocations: list[Invocation], expected_invocations: list[Invocation], ) -> EvaluationResult: + # If the metric is response_match_score, just use the RougeEvaluator. + if self._metric_name == "response_match_score": + rouge_evaluator = RougeEvaluator( + EvalMetric(metric_name=self._metric_name, threshold=self._threshold) + ) + return rouge_evaluator.evaluate_invocations( + actual_invocations, expected_invocations + ) + total_score = 0.0 num_invocations = 0 per_invocation_results = [] diff --git a/src/google/adk/events/event.py b/src/google/adk/events/event.py index c3b8b8699..6dd617fff 100644 --- a/src/google/adk/events/event.py +++ b/src/google/adk/events/event.py @@ -34,9 +34,10 @@ class Event(LlmResponse): taken by the agents like function calls, etc. Attributes: - invocation_id: The invocation ID of the event. - author: "user" or the name of the agent, indicating who appended the event - to the session. + invocation_id: Required. The invocation ID of the event. Should be non-empty + before appending to a session. + author: Required. "user" or the name of the agent, indicating who appended + the event to the session. actions: The actions taken by the agent. long_running_tool_ids: The ids of the long running function calls. branch: The branch of the event. @@ -55,9 +56,8 @@ class Event(LlmResponse): ) """The pydantic model config.""" - # TODO: revert to be required after spark migration invocation_id: str = '' - """The invocation ID of the event.""" + """The invocation ID of the event. Should be non-empty before appending to a session.""" author: str """'user' or the name of the agent, indicating who appended the event to the session.""" diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 7efadd97e..ee5c83da1 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -68,6 +68,12 @@ async def run_async( llm_request.live_connect_config.realtime_input_config = ( invocation_context.run_config.realtime_input_config ) + llm_request.live_connect_config.enable_affective_dialog = ( + invocation_context.run_config.enable_affective_dialog + ) + llm_request.live_connect_config.proactivity = ( + invocation_context.run_config.proactivity + ) # TODO: handle tool append here, instead of in BaseTool.process_llm_request. diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index ea418888f..039eaf8c5 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -43,12 +43,20 @@ async def run_async( if not isinstance(agent, LlmAgent): return - if agent.include_contents != 'none': + if agent.include_contents == 'default': + # Include full conversation history llm_request.contents = _get_contents( invocation_context.branch, invocation_context.session.events, agent.name, ) + else: + # Include current turn context only (no conversation history) + llm_request.contents = _get_current_turn_contents( + invocation_context.branch, + invocation_context.session.events, + agent.name, + ) # Maintain async generator behavior if False: # Ensures it behaves as a generator @@ -190,13 +198,15 @@ def _get_contents( ) -> list[types.Content]: """Get the contents for the LLM request. + Applies filtering, rearrangement, and content processing to events. + Args: current_branch: The current branch of the agent. - events: A list of events. + events: Events to process. agent_name: The name of the agent. Returns: - A list of contents. + A list of processed contents. """ filtered_events = [] # Parse the events, leaving the contents and the function calls and @@ -211,12 +221,13 @@ def _get_contents( # Skip events without content, or generated neither by user nor by model # or has empty text. # E.g. events purely for mutating session states. + continue if not _is_event_belongs_to_branch(current_branch, event): # Skip events not belong to current branch. continue if _is_auth_event(event): - # skip auth event + # Skip auth events. continue filtered_events.append( _convert_foreign_event(event) @@ -224,12 +235,15 @@ def _get_contents( else event ) + # Rearrange events for proper function call/response pairing result_events = _rearrange_events_for_latest_function_response( filtered_events ) result_events = _rearrange_events_for_async_function_responses_in_history( result_events ) + + # Convert events to contents contents = [] for event in result_events: content = copy.deepcopy(event.content) @@ -238,6 +252,37 @@ def _get_contents( return contents +def _get_current_turn_contents( + current_branch: Optional[str], events: list[Event], agent_name: str = '' +) -> list[types.Content]: + """Get contents for the current turn only (no conversation history). + + When include_contents='none', we want to include: + - The current user input + - Tool calls and responses from the current turn + But exclude conversation history from previous turns. + + In multi-agent scenarios, the "current turn" for an agent starts from an + actual user or from another agent. + + Args: + current_branch: The current branch of the agent. + events: A list of all session events. + agent_name: The name of the agent. + + Returns: + A list of contents for the current turn only, preserving context needed + for proper tool execution while excluding conversation history. + """ + # Find the latest event that starts the current turn and process from there + for i in range(len(events) - 1, -1, -1): + event = events[i] + if event.author == 'user' or _is_other_agent_reply(agent_name, event): + return _get_contents(current_branch, events[i:], agent_name) + + return [] + + def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool: """Whether the event is a reply from another agent.""" return bool( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 2541ac664..5c690f1fd 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -288,8 +288,7 @@ async def handle_function_calls_live( trace_tool_call( tool=tool, args=function_args, - response_event_id=function_response_event.id, - function_response=function_response, + function_response_event=function_response_event, ) function_response_events.append(function_response_event) @@ -520,3 +519,35 @@ def merge_parallel_function_response_events( # Use the base_event as the timestamp merged_event.timestamp = base_event.timestamp return merged_event + + +def find_matching_function_call( + events: list[Event], +) -> Optional[Event]: + """Finds the function call event that matches the function response id of the last event.""" + if not events: + return None + + last_event = events[-1] + if ( + last_event.content + and last_event.content.parts + and any(part.function_response for part in last_event.content.parts) + ): + + function_call_id = next( + part.function_response.id + for part in last_event.content.parts + if part.function_response + ) + for i in range(len(events) - 2, -1, -1): + event = events[i] + # looking for the system long running request euc function call + function_calls = event.get_function_calls() + if not function_calls: + continue + + for function_call in function_calls: + if function_call.id == function_call_id: + return event + return None diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index f2ac4f9b5..915d7e517 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -15,12 +15,14 @@ from .base_memory_service import BaseMemoryService from .in_memory_memory_service import InMemoryMemoryService +from .vertex_ai_memory_bank_service import VertexAiMemoryBankService logger = logging.getLogger('google_adk.' + __name__) __all__ = [ 'BaseMemoryService', 'InMemoryMemoryService', + 'VertexAiMemoryBankService', ] try: @@ -29,7 +31,7 @@ __all__.append('VertexAiRagMemoryService') except ImportError: logger.debug( - 'The Vertex sdk is not installed. If you want to use the' + 'The Vertex SDK is not installed. If you want to use the' ' VertexAiRagMemoryService please install it. If not, you can ignore this' ' warning.' ) diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py new file mode 100644 index 000000000..083b48e8d --- /dev/null +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -0,0 +1,150 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import logging +from typing import Optional +from typing import TYPE_CHECKING + +from typing_extensions import override + +from google import genai + +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..sessions.session import Session + +logger = logging.getLogger('google_adk.' + __name__) + + +class VertexAiMemoryBankService(BaseMemoryService): + """Implementation of the BaseMemoryService using Vertex AI Memory Bank.""" + + def __init__( + self, + project: Optional[str] = None, + location: Optional[str] = None, + agent_engine_id: Optional[str] = None, + ): + """Initializes a VertexAiMemoryBankService. + + Args: + project: The project ID of the Memory Bank to use. + location: The location of the Memory Bank to use. + agent_engine_id: The ID of the agent engine to use for the Memory Bank. + e.g. '456' in + 'projects/my-project/locations/us-central1/reasoningEngines/456'. + """ + self._project = project + self._location = location + self._agent_engine_id = agent_engine_id + + @override + async def add_session_to_memory(self, session: Session): + api_client = self._get_api_client() + + if not self._agent_engine_id: + raise ValueError('Agent Engine ID is required for Memory Bank.') + + events = [] + for event in session.events: + if event.content and event.content.parts: + events.append({ + 'content': event.content.model_dump(exclude_none=True, mode='json') + }) + request_dict = { + 'direct_contents_source': { + 'events': events, + }, + 'scope': { + 'app_name': session.app_name, + 'user_id': session.user_id, + }, + } + + if events: + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:generate', + request_dict=request_dict, + ) + logger.info(f'Generate memory response: {api_response}') + else: + logger.info('No events to add to memory.') + + @override + async def search_memory(self, *, app_name: str, user_id: str, query: str): + api_client = self._get_api_client() + + api_response = await api_client.async_request( + http_method='POST', + path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve', + request_dict={ + 'scope': { + 'app_name': app_name, + 'user_id': user_id, + }, + 'similarity_search_params': { + 'search_query': query, + }, + }, + ) + api_response = _convert_api_response(api_response) + logger.info(f'Search memory response: {api_response}') + + if not api_response or not api_response.get('retrievedMemories', None): + return SearchMemoryResponse() + + memory_events = [] + for memory in api_response.get('retrievedMemories', []): + # TODO: add more complex error handling + memory_events.append( + MemoryEntry( + author='user', + content=genai.types.Content( + parts=[ + genai.types.Part(text=memory.get('memory').get('fact')) + ], + role='user', + ), + timestamp=memory.get('updateTime'), + ) + ) + return SearchMemoryResponse(memories=memory_events) + + def _get_api_client(self): + """Instantiates an API client for the given project and location. + + It needs to be instantiated inside each request so that the event loop + management can be properly propagated. + + Returns: + An API client for the given project and location. + """ + client = genai.Client( + vertexai=True, project=self._project, location=self._location + ) + return client._api_client + + +def _convert_api_response(api_response): + """Converts the API response to a JSON object based on the type.""" + if hasattr(api_response, 'body'): + return json.loads(api_response.body) + return api_response diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index ed54faecf..39514d6f4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,12 +23,14 @@ from typing import Dict from typing import Generator from typing import Iterable +from typing import List from typing import Literal from typing import Optional from typing import Tuple from typing import Union from google.genai import types +import litellm from litellm import acompletion from litellm import ChatCompletionAssistantMessage from litellm import ChatCompletionAssistantToolCall @@ -53,6 +55,9 @@ from .llm_request import LlmRequest from .llm_response import LlmResponse +# This will add functions to prompts if functions are provided. +litellm.add_function_to_prompt = True + logger = logging.getLogger("google_adk." + __name__) _NEW_LINE = "\n" @@ -481,16 +486,22 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> tuple[Iterable[Message], Iterable[dict]]: - """Converts an LlmRequest to litellm inputs. +) -> Tuple[ + List[Message], + Optional[List[Dict]], + Optional[types.SchemaUnion], + Optional[Dict], +]: + """Converts an LlmRequest to litellm inputs and extracts generation params. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary and response format). + The litellm inputs (message list, tool dictionary, response format and generation params). """ - messages = [] + # 1. Construct messages + messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -507,7 +518,8 @@ def _get_completion_inputs( ), ) - tools = None + # 2. Convert tool declarations + tools: Optional[List[Dict]] = None if ( llm_request.config and llm_request.config.tools @@ -518,12 +530,39 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - response_format = None - - if llm_request.config.response_schema: + # 3. Handle response format + response_format: Optional[types.SchemaUnion] = None + if llm_request.config and llm_request.config.response_schema: response_format = llm_request.config.response_schema - return messages, tools, response_format + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: + config_dict = llm_request.config.model_dump(exclude_none=True) + # Generate LiteLlm parameters here, + # Following https://docs.litellm.ai/docs/completion/input. + generation_params = {} + param_mapping = { + "max_output_tokens": "max_completion_tokens", + "stop_sequences": "stop", + } + for key in ( + "temperature", + "max_output_tokens", + "top_p", + "top_k", + "stop_sequences", + "presence_penalty", + "frequency_penalty", + ): + if key in config_dict: + mapped_key = param_mapping.get(key, key) + generation_params[mapped_key] = config_dict[key] + + if not generation_params: + generation_params = None + + return messages, tools, response_format, generation_params def _build_function_declaration_log( @@ -660,7 +699,13 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format = _get_completion_inputs(llm_request) + messages, tools, response_format, generation_params = ( + _get_completion_inputs(llm_request) + ) + + if "functions" in self._additional_args: + # LiteLLM does not support both tools and functions together. + tools = None completion_args = { "model": self.model, @@ -670,6 +715,9 @@ async def generate_content_async( } completion_args.update(self._additional_args) + if generation_params: + completion_args.update(generation_params) + if stream: text = "" # Track function calls by index @@ -679,7 +727,7 @@ async def generate_content_async( aggregated_llm_response_with_tool_call = None usage_metadata = None fallback_index = 0 - for part in self.llm_client.completion(**completion_args): + async for part in await self.llm_client.acompletion(**completion_args): for chunk, finish_reason in _model_response_to_chunk(part): if isinstance(chunk, FunctionChunk): index = chunk.index or fallback_index @@ -739,11 +787,12 @@ async def generate_content_async( _message_to_generate_content_response( ChatCompletionAssistantMessage( role="assistant", - content="", + content=text, tool_calls=tool_calls, ) ) ) + text = "" function_calls.clear() elif finish_reason == "stop" and text: aggregated_llm_response = _message_to_generate_content_response( diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index c4fcdfb9e..017997bb3 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -17,7 +17,6 @@ import asyncio import logging import queue -import threading from typing import AsyncGenerator from typing import Generator from typing import Optional @@ -34,8 +33,10 @@ from .agents.run_config import RunConfig from .artifacts.base_artifact_service import BaseArtifactService from .artifacts.in_memory_artifact_service import InMemoryArtifactService +from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .events.event import Event +from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService from .memory.in_memory_memory_service import InMemoryMemoryService from .platform.thread import create_thread @@ -73,6 +74,8 @@ class Runner: """The session service for the runner.""" memory_service: Optional[BaseMemoryService] = None """The memory service for the runner.""" + credential_service: Optional[BaseCredentialService] = None + """The credential service for the runner.""" def __init__( self, @@ -82,6 +85,7 @@ def __init__( artifact_service: Optional[BaseArtifactService] = None, session_service: BaseSessionService, memory_service: Optional[BaseMemoryService] = None, + credential_service: Optional[BaseCredentialService] = None, ): """Initializes the Runner. @@ -97,6 +101,7 @@ def __init__( self.artifact_service = artifact_service self.session_service = session_service self.memory_service = memory_service + self.credential_service = credential_service def run( self, @@ -333,6 +338,8 @@ def _find_agent_to_run( """Finds the agent to run to continue the session. A qualified agent must be either of: + - The agent that returned a function call and the last user message is a + function response to this function call. - The root agent; - An LlmAgent who replied last and is capable to transfer to any other agent in the agent hierarchy. @@ -344,6 +351,13 @@ def _find_agent_to_run( Returns: The agent of the last message in the session or the root agent. """ + # If the last event is a function response, should send this response to + # the agent that returned the corressponding function call regardless the + # type of the agent. e.g. a remote a2a agent may surface a credential + # request as a special long running function tool call. + event = find_matching_function_call(session.events) + if event and event.author: + return root_agent.find_agent(event.author) for event in filter(lambda e: e.author != 'user', reversed(session.events)): if event.author == root_agent.name: # Found root agent. @@ -418,6 +432,7 @@ def _new_invocation_context( artifact_service=self.artifact_service, session_service=self.session_service, memory_service=self.memory_service, + credential_service=self.credential_service, invocation_id=invocation_id, agent=self.agent, session=session, diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 258dcd933..06a904c89 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -14,7 +14,9 @@ from __future__ import annotations import asyncio +import json import logging +import os import re from typing import Any from typing import Dict @@ -22,6 +24,7 @@ import urllib.parse from dateutil import parser +from google.genai.errors import ClientError from typing_extensions import override from google import genai @@ -87,30 +90,53 @@ async def create_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions', request_dict=session_json_dict, ) + api_response = _convert_api_response(api_response) logger.info(f'Create Session response {api_response}') session_id = api_response['name'].split('/')[-3] operation_id = api_response['name'].split('/')[-1] max_retry_attempt = 5 - lro_response = None - while max_retry_attempt >= 0: - lro_response = await api_client.async_request( - http_method='GET', - path=f'operations/{operation_id}', - request_dict={}, - ) - if lro_response.get('done', None): - break - - await asyncio.sleep(1) - max_retry_attempt -= 1 - - if lro_response is None or not lro_response.get('done', None): - raise TimeoutError( - f'Timeout waiting for operation {operation_id} to complete.' - ) + if _is_vertex_express_mode(self._project, self._location): + # Express mode doesn't support LRO, so we need to poll + # the session resource. + # TODO: remove this once LRO polling is supported in Express mode. + for i in range(max_retry_attempt): + try: + await api_client.async_request( + http_method='GET', + path=( + f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' + ), + request_dict={}, + ) + break + except ClientError as e: + logger.info('Polling for session %s: %s', session_id, e) + # Add slight exponential backoff to avoid excessive polling. + await asyncio.sleep(1 + 0.5 * i) + else: + raise TimeoutError('Session creation failed.') + else: + lro_response = None + for _ in range(max_retry_attempt): + lro_response = await api_client.async_request( + http_method='GET', + path=f'operations/{operation_id}', + request_dict={}, + ) + lro_response = _convert_api_response(lro_response) + + if lro_response.get('done', None): + break + + await asyncio.sleep(1) + + if lro_response is None or not lro_response.get('done', None): + raise TimeoutError( + f'Timeout waiting for operation {operation_id} to complete.' + ) # Get session resource get_session_api_response = await api_client.async_request( @@ -118,6 +144,7 @@ async def create_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) + get_session_api_response = _convert_api_response(get_session_api_response) update_timestamp = isoparse( get_session_api_response['updateTime'] @@ -149,6 +176,7 @@ async def get_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) + get_session_api_response = _convert_api_response(get_session_api_response) session_id = get_session_api_response['name'].split('/')[-1] update_timestamp = isoparse( @@ -167,9 +195,12 @@ async def get_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', request_dict={}, ) + list_events_api_response = _convert_api_response(list_events_api_response) # Handles empty response case - if list_events_api_response.get('httpHeaders', None): + if not list_events_api_response or list_events_api_response.get( + 'httpHeaders', None + ): return session session.events += [ @@ -226,9 +257,10 @@ async def list_sessions( path=path, request_dict={}, ) + api_response = _convert_api_response(api_response) # Handles empty response case - if api_response.get('httpHeaders', None): + if not api_response or api_response.get('httpHeaders', None): return ListSessionsResponse() sessions = [] @@ -303,6 +335,25 @@ def _get_api_client(self): return client._api_client +def _is_vertex_express_mode( + project: Optional[str], location: Optional[str] +) -> bool: + """Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode.""" + return ( + os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in ['true', '1'] + and os.environ.get('GOOGLE_API_KEY', None) is not None + and project is None + and location is None + ) + + +def _convert_api_response(api_response): + """Converts the API response to a JSON object based on the type.""" + if hasattr(api_response, 'body'): + return json.loads(api_response.body) + return api_response + + def _convert_event_to_json(event: Event) -> Dict[str, Any]: metadata_json = { 'partial': event.partial, diff --git a/src/google/adk/telemetry.py b/src/google/adk/telemetry.py index badaec46d..a09c2f55b 100644 --- a/src/google/adk/telemetry.py +++ b/src/google/adk/telemetry.py @@ -195,6 +195,16 @@ def trace_call_llm( llm_response_json, ) + if llm_response.usage_metadata is not None: + span.set_attribute( + 'gen_ai.usage.input_tokens', + llm_response.usage_metadata.prompt_token_count, + ) + span.set_attribute( + 'gen_ai.usage.output_tokens', + llm_response.usage_metadata.total_token_count, + ) + def trace_send_data( invocation_context: InvocationContext, diff --git a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py index 4e5be5959..5a50a7f0c 100644 --- a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py +++ b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py @@ -150,7 +150,7 @@ async def run_async( tool_auth_handler = ToolAuthHandler.from_tool_context( tool_context, self._auth_scheme, self._auth_credential ) - auth_result = tool_auth_handler.prepare_auth_credentials() + auth_result = await tool_auth_handler.prepare_auth_credentials() if auth_result.state == 'pending': return { @@ -178,7 +178,7 @@ async def run_async( args['operation'] = self._operation args['action'] = self._action logger.info('Running tool: %s with args: %s', self.name, args) - return self._rest_api_tool.call(args=args, tool_context=tool_context) + return await self._rest_api_tool.call(args=args, tool_context=tool_context) def __str__(self): return ( diff --git a/src/google/adk/tools/authenticated_function_tool.py b/src/google/adk/tools/authenticated_function_tool.py new file mode 100644 index 000000000..67cc5885f --- /dev/null +++ b/src/google/adk/tools/authenticated_function_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import logging +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .function_tool import FunctionTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class AuthenticatedFunctionTool(FunctionTool): + """A FunctionTool that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + func: Callable[..., Any], + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """Initializes the AuthenticatedFunctionTool. + + Args: + func: The function to be called. + auth_config: The authentication configuration. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__(func=func) + self._ignore_params.append("credential") + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + args_to_call = args.copy() + signature = inspect.signature(self.func) + if "credential" in signature.parameters: + args_to_call["credential"] = credential + return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/base_authenticated_tool.py b/src/google/adk/tools/base_authenticated_tool.py new file mode 100644 index 000000000..4858e4953 --- /dev/null +++ b/src/google/adk/tools/base_authenticated_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +import logging +from typing import Any +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .base_tool import BaseTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class BaseAuthenticatedTool(BaseTool): + """A base tool class that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + name, + description, + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """ + Args: + name: The name of the tool. + description: The description of the tool. + auth_config: The auth configuration of the tool. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__( + name=name, + description=description, + ) + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, + tool_context=tool_context, + credential=credential, + ) + + @abstractmethod + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + pass diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index 0a99136c4..d0f3abe0e 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -21,9 +21,10 @@ from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import OAuthFlowAuthorizationCode from fastapi.openapi.models import OAuthFlows +import google.auth.credentials from google.auth.exceptions import RefreshError from google.auth.transport.requests import Request -from google.oauth2.credentials import Credentials +import google.oauth2.credentials from pydantic import BaseModel from pydantic import model_validator @@ -40,26 +41,35 @@ @experimental class BigQueryCredentialsConfig(BaseModel): - """Configuration for Google API tools. (Experimental)""" + """Configuration for Google API tools (Experimental). + + Please do not use this in production, as it may be deprecated later. + """ # Configure the model to allow arbitrary types like Credentials model_config = {"arbitrary_types_allowed": True} - credentials: Optional[Credentials] = None - """the existing oauth credentials to use. If set,this credential will be used + credentials: Optional[google.auth.credentials.Credentials] = None + """The existing auth credentials to use. If set, this credential will be used for every end user, end users don't need to be involved in the oauthflow. This field is mutually exclusive with client_id, client_secret and scopes. Don't set this field unless you are sure this credential has the permission to access every end user's data. - Example usage: when the agent is deployed in Google Cloud environment and + Example usage 1: When the agent is deployed in Google Cloud environment and the service account (used as application default credentials) has access to all the required BigQuery resource. Setting this credential to allow user to access the BigQuery resource without end users going through oauth flow. - To get application default credential: `google.auth.default(...)`. See more + To get application default credential, use: `google.auth.default(...)`. See more details in https://cloud.google.com/docs/authentication/application-default-credentials. + Example usage 2: When the agent wants to access the user's BigQuery resources + using the service account key credentials. + + To load service account key credentials, use: `google.auth.load_credentials_from_file(...)`. + See more details in https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + When the deployed environment cannot provide a pre-existing credential, consider setting below client_id, client_secret and scope for end users to go through oauth flow, so that agent can access the user data. @@ -86,7 +96,9 @@ def __post_init__(self) -> BigQueryCredentialsConfig: " client_id/client_secret/scopes." ) - if self.credentials: + if self.credentials and isinstance( + self.credentials, google.oauth2.credentials.Credentials + ): self.client_id = self.credentials.client_id self.client_secret = self.credentials.client_secret self.scopes = self.credentials.scopes @@ -115,7 +127,7 @@ def __init__(self, credentials_config: BigQueryCredentialsConfig): async def get_valid_credentials( self, tool_context: ToolContext - ) -> Optional[Credentials]: + ) -> Optional[google.auth.credentials.Credentials]: """Get valid credentials, handling refresh and OAuth flow as needed. Args: @@ -127,7 +139,7 @@ async def get_valid_credentials( # First, try to get credentials from the tool context creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None) creds = ( - Credentials.from_authorized_user_info( + google.oauth2.credentials.Credentials.from_authorized_user_info( json.loads(creds_json), self.credentials_config.scopes ) if creds_json @@ -138,6 +150,11 @@ async def get_valid_credentials( if not creds: creds = self.credentials_config.credentials + # If non-oauth credentials are provided then use them as is. This helps + # in flows such as service account keys + if creds and not isinstance(creds, google.oauth2.credentials.Credentials): + return creds + # Check if we have valid credentials if creds and creds.valid: return creds @@ -159,7 +176,7 @@ async def get_valid_credentials( async def _perform_oauth_flow( self, tool_context: ToolContext - ) -> Optional[Credentials]: + ) -> Optional[google.oauth2.credentials.Credentials]: """Perform OAuth flow to get new credentials. Args: @@ -199,7 +216,7 @@ async def _perform_oauth_flow( if auth_response: # OAuth flow completed, create credentials - creds = Credentials( + creds = google.oauth2.credentials.Credentials( token=auth_response.oauth2.access_token, refresh_token=auth_response.oauth2.refresh_token, token_uri=auth_scheme.flows.authorizationCode.tokenUrl, diff --git a/src/google/adk/tools/bigquery/bigquery_tool.py b/src/google/adk/tools/bigquery/bigquery_tool.py index 182734188..50d49ff77 100644 --- a/src/google/adk/tools/bigquery/bigquery_tool.py +++ b/src/google/adk/tools/bigquery/bigquery_tool.py @@ -19,7 +19,7 @@ from typing import Callable from typing import Optional -from google.oauth2.credentials import Credentials +from google.auth.credentials import Credentials from typing_extensions import override from ...utils.feature_decorator import experimental diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index d72761b2d..8b2816ebe 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -15,19 +15,23 @@ from __future__ import annotations import google.api_core.client_info +from google.auth.credentials import Credentials from google.cloud import bigquery -from google.oauth2.credentials import Credentials -USER_AGENT = "adk-bigquery-tool" +from ... import version +USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" -def get_bigquery_client(*, credentials: Credentials) -> bigquery.Client: + +def get_bigquery_client( + *, project: str, credentials: Credentials +) -> bigquery.Client: """Get a BigQuery client.""" client_info = google.api_core.client_info.ClientInfo(user_agent=USER_AGENT) bigquery_client = bigquery.Client( - credentials=credentials, client_info=client_info + project=project, credentials=credentials, client_info=client_info ) return bigquery_client diff --git a/src/google/adk/tools/bigquery/metadata_tool.py b/src/google/adk/tools/bigquery/metadata_tool.py index 6e279d59e..64f23d07b 100644 --- a/src/google/adk/tools/bigquery/metadata_tool.py +++ b/src/google/adk/tools/bigquery/metadata_tool.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.auth.credentials import Credentials from google.cloud import bigquery -from google.oauth2.credentials import Credentials from . import client @@ -42,7 +42,9 @@ def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]: 'bbc_news'] """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) datasets = [] for dataset in bq_client.list_datasets(project_id): @@ -106,7 +108,9 @@ def get_dataset_info( } """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) dataset = bq_client.get_dataset( bigquery.DatasetReference(project_id, dataset_id) ) @@ -137,7 +141,9 @@ def list_table_ids( 'local_data_for_better_health_county_data'] """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) tables = [] for table in bq_client.list_tables( @@ -251,7 +257,9 @@ def get_table_info( } """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) return bq_client.get_table( bigquery.TableReference( bigquery.DatasetReference(project_id, dataset_id), table_id diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 80b56aad3..147d0b4db 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -16,8 +16,8 @@ import types from typing import Callable +from google.auth.credentials import Credentials from google.cloud import bigquery -from google.oauth2.credentials import Credentials from . import client from .config import BigQueryToolConfig @@ -72,7 +72,9 @@ def execute_sql( """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) if not config or config.write_mode == WriteMode.BLOCKED: query_job = bq_client.query( query, diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 5bc06e398..90b39e6cb 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -18,9 +18,12 @@ from contextlib import AsyncExitStack from datetime import timedelta import functools +import hashlib +import json import logging import sys from typing import Any +from typing import Dict from typing import Optional from typing import TextIO from typing import Union @@ -105,74 +108,39 @@ class StreamableHTTPConnectionParams(BaseModel): terminate_on_close: bool = True -def retry_on_closed_resource(session_manager_field_name: str): - """Decorator to automatically reinitialize session and retry action. +def retry_on_closed_resource(func): + """Decorator to automatically retry action when MCP session is closed. - When MCP session was closed, the decorator will automatically recreate the - session and retry the action with the same parameters. - - Note: - 1. session_manager_field_name is the name of the class member field that - contains the MCPSessionManager instance. - 2. The session manager must have a reinitialize_session() async method. - - Usage: - class MCPTool: - def __init__(self): - self._mcp_session_manager = MCPSessionManager(...) - - @retry_on_closed_resource('_mcp_session_manager') - async def use_session(self): - session = await self._mcp_session_manager.create_session() - await session.call_tool() + When MCP session was closed, the decorator will automatically retry the + action once. The create_session method will handle creating a new session + if the old one was disconnected. Args: - session_manager_field_name: The name of the session manager field. + func: The function to decorate. Returns: The decorated function. """ - def decorator(func): - @functools.wraps(func) # Preserves original function metadata - async def wrapper(self, *args, **kwargs): - try: - return await func(self, *args, **kwargs) - except anyio.ClosedResourceError as close_err: - try: - if hasattr(self, session_manager_field_name): - session_manager = getattr(self, session_manager_field_name) - if hasattr(session_manager, 'reinitialize_session') and callable( - getattr(session_manager, 'reinitialize_session') - ): - await session_manager.reinitialize_session() - else: - raise ValueError( - f'Session manager {session_manager_field_name} does not have' - ' reinitialize_session method.' - ) from close_err - else: - raise ValueError( - f'Session manager field {session_manager_field_name} does not' - ' exist in decorated class. Please check the field name in' - ' retry_on_closed_resource decorator.' - ) from close_err - except Exception as reinit_err: - raise RuntimeError( - f'Error reinitializing: {reinit_err}' - ) from reinit_err - return await func(self, *args, **kwargs) - - return wrapper - - return decorator + @functools.wraps(func) # Preserves original function metadata + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except anyio.ClosedResourceError: + # Simply retry the function - create_session will handle + # detecting and replacing disconnected sessions + logger.info('Retrying %s due to closed resource', func.__name__) + return await func(self, *args, **kwargs) + + return wrapper class MCPSessionManager: """Manages MCP client sessions. This class provides methods for creating and initializing MCP client sessions, - handling different connection parameters (Stdio and SSE). + handling different connection parameters (Stdio and SSE) and supporting + session pooling based on authentication headers. """ def __init__( @@ -209,30 +177,125 @@ def __init__( else: self._connection_params = connection_params self._errlog = errlog - # Each session manager maintains its own exit stack for proper cleanup - self._exit_stack: Optional[AsyncExitStack] = None - self._session: Optional[ClientSession] = None + + # Session pool: maps session keys to (session, exit_stack) tuples + self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} + # Lock to prevent race conditions in session creation self._session_lock = asyncio.Lock() - async def create_session(self) -> ClientSession: + def _generate_session_key( + self, merged_headers: Optional[Dict[str, str]] = None + ) -> str: + """Generates a session key based on connection params and merged headers. + + For StdioConnectionParams, returns a constant key since headers are not + supported. For SSE and StreamableHTTP connections, generates a key based + on the provided merged headers. + + Args: + merged_headers: Already merged headers (base + additional). + + Returns: + A unique session key string. + """ + if isinstance(self._connection_params, StdioConnectionParams): + # For stdio connections, headers are not supported, so use constant key + return 'stdio_session' + + # For SSE and StreamableHTTP connections, use merged headers + if merged_headers: + headers_json = json.dumps(merged_headers, sort_keys=True) + headers_hash = hashlib.md5(headers_json.encode()).hexdigest() + return f'session_{headers_hash}' + else: + return 'session_no_headers' + + def _merge_headers( + self, additional_headers: Optional[Dict[str, str]] = None + ) -> Optional[Dict[str, str]]: + """Merges base connection headers with additional headers. + + Args: + additional_headers: Optional headers to merge with connection headers. + + Returns: + Merged headers dictionary, or None if no headers are provided. + """ + if isinstance(self._connection_params, StdioConnectionParams) or isinstance( + self._connection_params, StdioServerParameters + ): + # Stdio connections don't support headers + return None + + base_headers = {} + if ( + hasattr(self._connection_params, 'headers') + and self._connection_params.headers + ): + base_headers = self._connection_params.headers.copy() + + if additional_headers: + base_headers.update(additional_headers) + + return base_headers + + def _is_session_disconnected(self, session: ClientSession) -> bool: + """Checks if a session is disconnected or closed. + + Args: + session: The ClientSession to check. + + Returns: + True if the session is disconnected, False otherwise. + """ + return session._read_stream._closed or session._write_stream._closed + + async def create_session( + self, headers: Optional[Dict[str, str]] = None + ) -> ClientSession: """Creates and initializes an MCP client session. + This method will check if an existing session for the given headers + is still connected. If it's disconnected, it will be cleaned up and + a new session will be created. + + Args: + headers: Optional headers to include in the session. These will be + merged with any existing connection headers. Only applicable + for SSE and StreamableHTTP connections. + Returns: ClientSession: The initialized MCP client session. """ - # Fast path: if session already exists, return it without acquiring lock - if self._session is not None: - return self._session + # Merge headers once at the beginning + merged_headers = self._merge_headers(headers) + + # Generate session key using merged headers + session_key = self._generate_session_key(merged_headers) # Use async lock to prevent race conditions async with self._session_lock: - # Double-check: session might have been created while waiting for lock - if self._session is not None: - return self._session - - # Create a new exit stack for this session - self._exit_stack = AsyncExitStack() + # Check if we have an existing session + if session_key in self._sessions: + session, exit_stack = self._sessions[session_key] + + # Check if the existing session is still connected + if not self._is_session_disconnected(session): + # Session is still good, return it + return session + else: + # Session is disconnected, clean it up + logger.info('Cleaning up disconnected session: %s', session_key) + try: + await exit_stack.aclose() + except Exception as e: + logger.warning('Error during disconnected session cleanup: %s', e) + finally: + del self._sessions[session_key] + + # Create a new session (either first time or replacing disconnected one) + exit_stack = AsyncExitStack() try: if isinstance(self._connection_params, StdioConnectionParams): @@ -243,7 +306,7 @@ async def create_session(self) -> ClientSession: elif isinstance(self._connection_params, SseConnectionParams): client = sse_client( url=self._connection_params.url, - headers=self._connection_params.headers, + headers=merged_headers, timeout=self._connection_params.timeout, sse_read_timeout=self._connection_params.sse_read_timeout, ) @@ -252,7 +315,7 @@ async def create_session(self) -> ClientSession: ): client = streamablehttp_client( url=self._connection_params.url, - headers=self._connection_params.headers, + headers=merged_headers, timeout=timedelta(seconds=self._connection_params.timeout), sse_read_timeout=timedelta( seconds=self._connection_params.sse_read_timeout @@ -266,11 +329,11 @@ async def create_session(self) -> ClientSession: f' {self._connection_params}' ) - transports = await self._exit_stack.enter_async_context(client) + transports = await exit_stack.enter_async_context(client) # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. if isinstance(self._connection_params, StdioConnectionParams): - session = await self._exit_stack.enter_async_context( + session = await exit_stack.enter_async_context( ClientSession( *transports[:2], read_timeout_seconds=timedelta( @@ -279,44 +342,38 @@ async def create_session(self) -> ClientSession: ) ) else: - session = await self._exit_stack.enter_async_context( + session = await exit_stack.enter_async_context( ClientSession(*transports[:2]) ) await session.initialize() - self._session = session + # Store session and exit stack in the pool + self._sessions[session_key] = (session, exit_stack) + logger.debug('Created new session: %s', session_key) return session except Exception: # If session creation fails, clean up the exit stack - if self._exit_stack: - await self._exit_stack.aclose() - self._exit_stack = None + if exit_stack: + await exit_stack.aclose() raise async def close(self): - """Closes the session and cleans up resources.""" - if not self._exit_stack: - return + """Closes all sessions and cleans up resources.""" async with self._session_lock: - if self._exit_stack: + for session_key in list(self._sessions.keys()): + _, exit_stack = self._sessions[session_key] try: - await self._exit_stack.aclose() + await exit_stack.aclose() except Exception as e: # Log the error but don't re-raise to avoid blocking shutdown print( - f'Warning: Error during MCP session cleanup: {e}', + 'Warning: Error during MCP session cleanup for' + f' {session_key}: {e}', file=self._errlog, ) finally: - self._exit_stack = None - self._session = None - - async def reinitialize_session(self): - """Reinitializes the session when connection is lost.""" - # Close the old session and create a new one - await self.close() - await self.create_session() + del self._sessions[session_key] SseServerParams = SseConnectionParams diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 6553bb2c0..310fc48f1 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -14,10 +14,13 @@ from __future__ import annotations +import base64 +import json import logging from typing import Optional from google.genai.types import FunctionDeclaration +from google.oauth2.credentials import Credentials from typing_extensions import override from .._gemini_schema_util import _to_gemini_schema @@ -42,13 +45,15 @@ from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme -from ..base_tool import BaseTool +from ...auth.auth_tool import AuthConfig +from ..base_authenticated_tool import BaseAuthenticatedTool +# import from ..tool_context import ToolContext logger = logging.getLogger("google_adk." + __name__) -class MCPTool(BaseTool): +class MCPTool(BaseAuthenticatedTool): """Turns an MCP Tool into an ADK Tool. Internally, the tool initializes from a MCP Tool, and uses the MCP Session to @@ -77,19 +82,17 @@ def __init__( Raises: ValueError: If mcp_tool or mcp_session_manager is None. """ - if mcp_tool is None: - raise ValueError("mcp_tool cannot be None") - if mcp_session_manager is None: - raise ValueError("mcp_session_manager cannot be None") super().__init__( name=mcp_tool.name, description=mcp_tool.description if mcp_tool.description else "", + auth_config=AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential + ) + if auth_scheme + else None, ) self._mcp_tool = mcp_tool self._mcp_session_manager = mcp_session_manager - # TODO(cheliu): Support passing auth to MCP Server. - self._auth_scheme = auth_scheme - self._auth_credential = auth_credential @override def _get_declaration(self) -> FunctionDeclaration: @@ -105,8 +108,11 @@ def _get_declaration(self) -> FunctionDeclaration: ) return function_decl - @retry_on_closed_resource("_mcp_session_manager") - async def run_async(self, *, args, tool_context: ToolContext): + @retry_on_closed_resource + @override + async def _run_async_impl( + self, *, args, tool_context: ToolContext, credential: AuthCredential + ): """Runs the tool asynchronously. Args: @@ -116,8 +122,60 @@ async def run_async(self, *, args, tool_context: ToolContext): Returns: Any: The response from the tool. """ + # Extract headers from credential for session pooling + headers = await self._get_headers(tool_context, credential) + # Get the session from the session manager - session = await self._mcp_session_manager.create_session() + session = await self._mcp_session_manager.create_session(headers=headers) response = await session.call_tool(self.name, arguments=args) return response + + async def _get_headers( + self, tool_context: ToolContext, credential: AuthCredential + ) -> Optional[dict[str, str]]: + headers = None + if credential: + if credential.oauth2: + headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} + elif credential.http: + # Handle HTTP authentication schemes + if ( + credential.http.scheme.lower() == "bearer" + and credential.http.credentials.token + ): + headers = { + "Authorization": f"Bearer {credential.http.credentials.token}" + } + elif credential.http.scheme.lower() == "basic": + # Handle basic auth + if ( + credential.http.credentials.username + and credential.http.credentials.password + ): + + credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}" + encoded_credentials = base64.b64encode( + credentials.encode() + ).decode() + headers = {"Authorization": f"Basic {encoded_credentials}"} + elif credential.http.credentials.token: + # Handle other HTTP schemes with token + headers = { + "Authorization": ( + f"{credential.http.scheme} {credential.http.credentials.token}" + ) + } + elif credential.api_key: + # For API keys, we'll add them as headers since MCP typically uses header-based auth + # The specific header name would depend on the API, using a common default + # TODO Allow user to specify the header name for API keys. + headers = {"X-API-Key": credential.api_key} + elif credential.service_account: + # Service accounts should be exchanged for access tokens before reaching this point + logger.warning( + "Service account credentials should be exchanged before MCP" + " session creation" + ) + + return headers diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index f55693e86..c01b0cec2 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -22,6 +22,8 @@ from typing import Union from ...agents.readonly_context import ReadonlyContext +from ...auth.auth_credential import AuthCredential +from ...auth.auth_schemes import AuthScheme from ..base_tool import BaseTool from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate @@ -94,6 +96,8 @@ def __init__( ], tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, errlog: TextIO = sys.stderr, + auth_scheme: Optional[AuthScheme] = None, + auth_credential: Optional[AuthCredential] = None, ): """Initializes the MCPToolset. @@ -110,6 +114,8 @@ def __init__( list of tool names to include - A ToolPredicate function for custom filtering logic errlog: TextIO stream for error logging. + auth_scheme: The auth scheme of the tool for tool calling + auth_credential: The auth credential of the tool for tool calling """ super().__init__(tool_filter=tool_filter) @@ -124,8 +130,10 @@ def __init__( connection_params=self._connection_params, errlog=self._errlog, ) + self._auth_scheme = auth_scheme + self._auth_credential = auth_credential - @retry_on_closed_resource("_mcp_session_manager") + @retry_on_closed_resource async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None, @@ -151,6 +159,8 @@ async def get_tools( mcp_tool = MCPTool( mcp_tool=tool, mcp_session_manager=self._mcp_session_manager, + auth_scheme=self._auth_scheme, + auth_credential=self._auth_credential, ) if self._is_tool_selected(mcp_tool, readonly_context): diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 1e451fe0f..dee103932 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -345,9 +345,9 @@ def _prepare_request_params( async def run_async( self, *, args: dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: - return self.call(args=args, tool_context=tool_context) + return await self.call(args=args, tool_context=tool_context) - def call( + async def call( self, *, args: dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: """Executes the REST API call. @@ -364,7 +364,7 @@ def call( tool_auth_handler = ToolAuthHandler.from_tool_context( tool_context, self.auth_scheme, self.auth_credential ) - auth_result = tool_auth_handler.prepare_auth_credentials() + auth_result = await tool_auth_handler.prepare_auth_credentials() auth_state, auth_scheme, auth_credential = ( auth_result.state, auth_result.auth_scheme, diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index c36793fdc..74166b00e 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -25,7 +25,7 @@ from ....auth.auth_schemes import AuthScheme from ....auth.auth_schemes import AuthSchemeType from ....auth.auth_tool import AuthConfig -from ....auth.oauth2_credential_fetcher import OAuth2CredentialFetcher +from ....auth.refresher.oauth2_credential_refresher import OAuth2CredentialRefresher from ...tool_context import ToolContext from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError @@ -146,7 +146,7 @@ def from_tool_context( credential_store, ) - def _get_existing_credential( + async def _get_existing_credential( self, ) -> Optional[AuthCredential]: """Checks for and returns an existing, exchanged credential.""" @@ -156,9 +156,11 @@ def _get_existing_credential( ) if existing_credential: if existing_credential.oauth2: - existing_credential = OAuth2CredentialFetcher( - self.auth_scheme, existing_credential - ).refresh() + refresher = OAuth2CredentialRefresher() + if await refresher.is_refresh_needed(existing_credential): + existing_credential = await refresher.refresh( + existing_credential, self.auth_scheme + ) return existing_credential return None @@ -231,10 +233,9 @@ def _external_exchange_required(self, credential) -> bool: AuthCredentialTypes.OPEN_ID_CONNECT, ) and not credential.oauth2.access_token - and not credential.google_oauth2_json ) - def prepare_auth_credentials( + async def prepare_auth_credentials( self, ) -> AuthPreparationResult: """Prepares authentication credentials, handling exchange and user interaction.""" @@ -244,7 +245,7 @@ def prepare_auth_credentials( return AuthPreparationResult(state="done") # Check for existing credential. - existing_credential = self._get_existing_credential() + existing_credential = await self._get_existing_credential() credential = existing_credential or self.auth_credential # fetch credential from adk framework diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index 5449f5090..b00cd0329 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -39,6 +39,9 @@ def __init__( self, *, data_store_id: Optional[str] = None, + data_store_specs: Optional[ + list[types.VertexAISearchDataStoreSpec] + ] = None, search_engine_id: Optional[str] = None, filter: Optional[str] = None, max_results: Optional[int] = None, @@ -49,6 +52,8 @@ def __init__( data_store_id: The Vertex AI search data store resource ID in the format of "projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}". + data_store_specs: Specifications that define the specific DataStores to be + searched. It should only be set if engine is used. search_engine_id: The Vertex AI search engine resource ID in the format of "projects/{project}/locations/{location}/collections/{collection}/engines/{engine}". @@ -64,7 +69,12 @@ def __init__( raise ValueError( 'Either data_store_id or search_engine_id must be specified.' ) + if data_store_specs is not None and search_engine_id is None: + raise ValueError( + 'search_engine_id must be specified if data_store_specs is specified.' + ) self.data_store_id = data_store_id + self.data_store_specs = data_store_specs self.search_engine_id = search_engine_id self.filter = filter self.max_results = max_results @@ -89,6 +99,7 @@ async def process_llm_request( retrieval=types.Retrieval( vertex_ai_search=types.VertexAISearch( datastore=self.data_store_id, + data_store_specs=self.data_store_specs, engine=self.search_engine_id, filter=self.filter, max_results=self.max_results, diff --git a/src/google/adk/version.py b/src/google/adk/version.py index c0b08cc60..1c061dd03 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.3.0" +__version__ = "1.5.0" diff --git a/tests/integration/models/test_litellm_no_function.py b/tests/integration/models/test_litellm_no_function.py index e662384ce..05072b899 100644 --- a/tests/integration/models/test_litellm_no_function.py +++ b/tests/integration/models/test_litellm_no_function.py @@ -20,12 +20,25 @@ from google.genai.types import Part import pytest -_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" - +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """You are a helpful assistant.""" +def get_weather(city: str) -> str: + """Simulates a web search. Use it get information on weather. + + Args: + city: A string containing the location to get weather information for. + + Returns: + A string with the simulated weather information for the queried city. + """ + if "sf" in city.lower() or "san francisco" in city.lower(): + return "It's 70 degrees and foggy." + return "It's 80 degrees and sunny." + + @pytest.fixture def oss_llm(): return LiteLlm(model=_TEST_MODEL_NAME) @@ -44,6 +57,48 @@ def llm_request(): ) +@pytest.fixture +def llm_request_with_tools(): + return LlmRequest( + model=_TEST_MODEL_NAME, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="What is the weather in San Francisco?") + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction=_SYSTEM_PROMPT, + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="get_weather", + description="Get the weather in a given location", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "city": types.Schema( + type=types.Type.STRING, + description=( + "The city to get the weather for." + ), + ), + }, + required=["city"], + ), + ) + ] + ) + ], + ), + ) + + @pytest.mark.asyncio async def test_generate_content_async(oss_llm, llm_request): async for response in oss_llm.generate_content_async(llm_request): @@ -51,10 +106,8 @@ async def test_generate_content_async(oss_llm, llm_request): assert response.content.parts[0].text -# Note that, this test disabled streaming because streaming is not supported -# properly in the current test model for now. @pytest.mark.asyncio -async def test_generate_content_async_stream(oss_llm, llm_request): +async def test_generate_content_async(oss_llm, llm_request): responses = [ resp async for resp in oss_llm.generate_content_async( @@ -63,3 +116,50 @@ async def test_generate_content_async_stream(oss_llm, llm_request): ] part = responses[0].content.parts[0] assert len(part.text) > 0 + + +@pytest.mark.asyncio +async def test_generate_content_async_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=False + ) + ] + function_call = responses[0].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_async_stream(oss_llm, llm_request): + responses = [ + resp + async for resp in oss_llm.generate_content_async(llm_request, stream=True) + ] + text = "" + for i in range(len(responses) - 1): + assert responses[i].partial is True + assert responses[i].content.parts[0].text + text += responses[i].content.parts[0].text + + # Last message should be accumulated text + assert responses[-1].content.parts[0].text == text + assert not responses[-1].partial + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" diff --git a/tests/integration/models/test_litellm_with_function.py b/tests/integration/models/test_litellm_with_function.py index a2ceb540a..e0d2bc991 100644 --- a/tests/integration/models/test_litellm_with_function.py +++ b/tests/integration/models/test_litellm_with_function.py @@ -13,22 +13,17 @@ # limitations under the License. from google.adk.models import LlmRequest -from google.adk.models import LlmResponse from google.adk.models.lite_llm import LiteLlm from google.genai import types from google.genai.types import Content from google.genai.types import Part -import litellm import pytest -litellm.add_function_to_prompt = True - -_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" - +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """ You are a helpful assistant, and call tools optionally. -If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs. +If call tools, the tool format should be in json body, and the tool argument values should be parsed from users inputs. """ @@ -40,7 +35,7 @@ "properties": { "city": { "type": "string", - "description": "The city, e.g. San Francisco", + "description": "The city to get the weather for.", }, }, "required": ["city"], @@ -87,8 +82,6 @@ def llm_request(): ) -# Note that, this test disabled streaming because streaming is not supported -# properly in the current test model for now. @pytest.mark.asyncio async def test_generate_content_asyn_with_function( oss_llm_with_function, llm_request @@ -102,3 +95,18 @@ async def test_generate_content_asyn_with_function( function_call = responses[0].content.parts[0].function_call assert function_call.name == "get_weather" assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_asyn_stream_with_function( + oss_llm_with_function, llm_request +): + responses = [ + resp + async for resp in oss_llm_with_function.generate_content_async( + llm_request, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" diff --git a/tests/unittests/a2a/__init__.py b/tests/unittests/a2a/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/converters/__init__.py b/tests/unittests/a2a/converters/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py new file mode 100644 index 000000000..2ba8e26b4 --- /dev/null +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -0,0 +1,1214 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a.types import DataPart + from a2a.types import Message + from a2a.types import Role + from a2a.types import Task + from a2a.types import TaskArtifactUpdateEvent + from a2a.types import TaskState + from a2a.types import TaskStatusUpdateEvent + from google.adk.a2a.converters.event_converter import _convert_artifact_to_a2a_events + from google.adk.a2a.converters.event_converter import _create_artifact_id + from google.adk.a2a.converters.event_converter import _create_error_status_event + from google.adk.a2a.converters.event_converter import _create_status_update_event + from google.adk.a2a.converters.event_converter import _get_adk_metadata_key + from google.adk.a2a.converters.event_converter import _get_context_metadata + from google.adk.a2a.converters.event_converter import _process_long_running_tool + from google.adk.a2a.converters.event_converter import _serialize_metadata_value + from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR + from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events + from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message + from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE + from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX + from google.adk.agents.invocation_context import InvocationContext + from google.adk.events.event import Event + from google.adk.events.event_actions import EventActions +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + DataPart = DummyTypes() + Message = DummyTypes() + Role = DummyTypes() + Task = DummyTypes() + TaskArtifactUpdateEvent = DummyTypes() + TaskState = DummyTypes() + TaskStatusUpdateEvent = DummyTypes() + _convert_artifact_to_a2a_events = lambda *args: None + _create_artifact_id = lambda *args: None + _create_error_status_event = lambda *args: None + _create_status_update_event = lambda *args: None + _get_adk_metadata_key = lambda *args: None + _get_context_metadata = lambda *args: None + _process_long_running_tool = lambda *args: None + _serialize_metadata_value = lambda *args: None + ADK_METADATA_KEY_PREFIX = "adk_" + ARTIFACT_ID_SEPARATOR = "_" + convert_event_to_a2a_events = lambda *args: None + convert_event_to_a2a_message = lambda *args: None + DEFAULT_ERROR_MESSAGE = "error" + InvocationContext = DummyTypes() + Event = DummyTypes() + EventActions = DummyTypes() + types = DummyTypes() + else: + raise e + + +class TestEventConverter: + """Test suite for event_converter module.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_session = Mock() + self.mock_session.id = "test-session-id" + + self.mock_artifact_service = Mock() + self.mock_invocation_context = Mock(spec=InvocationContext) + self.mock_invocation_context.app_name = "test-app" + self.mock_invocation_context.user_id = "test-user" + self.mock_invocation_context.session = self.mock_session + self.mock_invocation_context.artifact_service = self.mock_artifact_service + + self.mock_event = Mock(spec=Event) + self.mock_event.invocation_id = "test-invocation-id" + self.mock_event.author = "test-author" + self.mock_event.branch = None + self.mock_event.grounding_metadata = None + self.mock_event.custom_metadata = None + self.mock_event.usage_metadata = None + self.mock_event.error_code = None + self.mock_event.error_message = None + self.mock_event.content = None + self.mock_event.long_running_tool_ids = None + self.mock_event.actions = Mock(spec=EventActions) + self.mock_event.actions.artifact_delta = None + + def test_get_adk_event_metadata_key_success(self): + """Test successful metadata key generation.""" + key = "test_key" + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_get_adk_event_metadata_key_empty_string(self): + """Test metadata key generation with empty string.""" + with pytest.raises(ValueError) as exc_info: + _get_adk_metadata_key("") + assert "cannot be empty or None" in str(exc_info.value) + + def test_get_adk_event_metadata_key_none(self): + """Test metadata key generation with None.""" + with pytest.raises(ValueError) as exc_info: + _get_adk_metadata_key(None) + assert "cannot be empty or None" in str(exc_info.value) + + def test_serialize_metadata_value_with_model_dump(self): + """Test serialization of value with model_dump method.""" + mock_value = Mock() + mock_value.model_dump.return_value = {"key": "value"} + + result = _serialize_metadata_value(mock_value) + + assert result == {"key": "value"} + mock_value.model_dump.assert_called_once_with( + exclude_none=True, by_alias=True + ) + + def test_serialize_metadata_value_with_model_dump_exception(self): + """Test serialization when model_dump raises exception.""" + mock_value = Mock() + mock_value.model_dump.side_effect = Exception("Serialization failed") + + with patch( + "google.adk.a2a.converters.event_converter.logger" + ) as mock_logger: + result = _serialize_metadata_value(mock_value) + + assert result == str(mock_value) + mock_logger.warning.assert_called_once() + + def test_serialize_metadata_value_without_model_dump(self): + """Test serialization of value without model_dump method.""" + value = "simple_string" + result = _serialize_metadata_value(value) + assert result == "simple_string" + + def test_get_context_metadata_success(self): + """Test successful context metadata creation.""" + result = _get_context_metadata( + self.mock_event, self.mock_invocation_context + ) + + assert result is not None + expected_keys = [ + f"{ADK_METADATA_KEY_PREFIX}app_name", + f"{ADK_METADATA_KEY_PREFIX}user_id", + f"{ADK_METADATA_KEY_PREFIX}session_id", + f"{ADK_METADATA_KEY_PREFIX}invocation_id", + f"{ADK_METADATA_KEY_PREFIX}author", + ] + + for key in expected_keys: + assert key in result + + def test_get_context_metadata_with_optional_fields(self): + """Test context metadata creation with optional fields.""" + self.mock_event.branch = "test-branch" + self.mock_event.error_code = "ERROR_001" + + mock_metadata = Mock() + mock_metadata.model_dump.return_value = {"test": "value"} + self.mock_event.grounding_metadata = mock_metadata + + result = _get_context_metadata( + self.mock_event, self.mock_invocation_context + ) + + assert result is not None + assert f"{ADK_METADATA_KEY_PREFIX}branch" in result + assert f"{ADK_METADATA_KEY_PREFIX}grounding_metadata" in result + assert result[f"{ADK_METADATA_KEY_PREFIX}branch"] == "test-branch" + + # Check if error_code is in the result - it should be there since we set it + if f"{ADK_METADATA_KEY_PREFIX}error_code" in result: + assert result[f"{ADK_METADATA_KEY_PREFIX}error_code"] == "ERROR_001" + + def test_get_context_metadata_none_event(self): + """Test context metadata creation with None event.""" + with pytest.raises(ValueError) as exc_info: + _get_context_metadata(None, self.mock_invocation_context) + assert "Event cannot be None" in str(exc_info.value) + + def test_get_context_metadata_none_context(self): + """Test context metadata creation with None context.""" + with pytest.raises(ValueError) as exc_info: + _get_context_metadata(self.mock_event, None) + assert "Invocation context cannot be None" in str(exc_info.value) + + def test_create_artifact_id(self): + """Test artifact ID creation.""" + app_name = "test-app" + user_id = "user123" + session_id = "session456" + filename = "test.txt" + version = 1 + + result = _create_artifact_id( + app_name, user_id, session_id, filename, version + ) + expected = f"{app_name}{ARTIFACT_ID_SEPARATOR}{user_id}{ARTIFACT_ID_SEPARATOR}{session_id}{ARTIFACT_ID_SEPARATOR}{filename}{ARTIFACT_ID_SEPARATOR}{version}" + + assert result == expected + + @patch( + "google.adk.a2a.converters.event_converter.convert_genai_part_to_a2a_part" + ) + def test_convert_artifact_to_a2a_events_success(self, mock_convert_part): + """Test successful artifact delta conversion.""" + filename = "test.txt" + version = 1 + task_id = "test-task-id" + context_id = "test-context-id" + + mock_artifact_part = Mock() + # Create a proper Part that Pydantic will accept + from a2a.types import Part + from a2a.types import TextPart + + text_part = TextPart(text="test content") + mock_converted_part = Part(root=text_part) + + self.mock_artifact_service.load_artifact.return_value = mock_artifact_part + mock_convert_part.return_value = mock_converted_part + + result = _convert_artifact_to_a2a_events( + self.mock_event, + self.mock_invocation_context, + filename, + version, + task_id, + context_id, + ) + + assert isinstance(result, TaskArtifactUpdateEvent) + assert result.taskId == task_id + assert result.contextId == context_id + assert result.append is False + assert result.lastChunk is True + + # Check artifact properties + assert result.artifact.name == filename + assert result.artifact.metadata["filename"] == filename + assert result.artifact.metadata["version"] == version + assert len(result.artifact.parts) == 1 + assert result.artifact.parts[0].root.text == "test content" + + def test_convert_artifact_to_a2a_events_empty_filename(self): + """Test artifact delta conversion with empty filename.""" + with pytest.raises(ValueError) as exc_info: + _convert_artifact_to_a2a_events( + self.mock_event, self.mock_invocation_context, "", 1, "", "" + ) + assert "Filename cannot be empty" in str(exc_info.value) + + def test_convert_artifact_to_a2a_events_negative_version(self): + """Test artifact delta conversion with negative version.""" + with pytest.raises(ValueError) as exc_info: + _convert_artifact_to_a2a_events( + self.mock_event, self.mock_invocation_context, "test.txt", -1, "", "" + ) + assert "Version must be non-negative" in str(exc_info.value) + + @patch( + "google.adk.a2a.converters.event_converter.convert_genai_part_to_a2a_part" + ) + def test_convert_artifact_to_a2a_events_conversion_failure( + self, mock_convert_part + ): + """Test artifact delta conversion when part conversion fails.""" + filename = "test.txt" + version = 1 + + mock_artifact_part = Mock() + self.mock_artifact_service.load_artifact.return_value = mock_artifact_part + mock_convert_part.return_value = None # Simulate conversion failure + + with pytest.raises(RuntimeError) as exc_info: + _convert_artifact_to_a2a_events( + self.mock_event, + self.mock_invocation_context, + filename, + version, + "", + "", + ) + assert "Failed to convert artifact part" in str(exc_info.value) + + def test_process_long_running_tool_marks_tool(self): + """Test processing of long-running tool metadata.""" + mock_a2a_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-123"} + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="tool-123") + mock_a2a_part.root = mock_data_part + + self.mock_event.long_running_tool_ids = {"tool-123"} + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, + ): + mock_get_key.side_effect = lambda key: f"adk_{key}" + + _process_long_running_tool(mock_a2a_part, self.mock_event) + + expected_key = f"{ADK_METADATA_KEY_PREFIX}is_long_running" + assert mock_data_part.metadata[expected_key] is True + + def test_process_long_running_tool_no_marking(self): + """Test processing when tool should not be marked as long-running.""" + mock_a2a_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-456"} + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="tool-456") + mock_a2a_part.root = mock_data_part + + self.mock_event.long_running_tool_ids = {"tool-123"} # Different ID + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, + ): + mock_get_key.side_effect = lambda key: f"adk_{key}" + + _process_long_running_tool(mock_a2a_part, self.mock_event) + + expected_key = f"{ADK_METADATA_KEY_PREFIX}is_long_running" + assert expected_key not in mock_data_part.metadata + + @patch( + "google.adk.a2a.converters.event_converter.convert_genai_part_to_a2a_part" + ) + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + def test_convert_event_to_message_success(self, mock_uuid, mock_convert_part): + """Test successful event to message conversion.""" + mock_uuid.return_value = "test-uuid" + + mock_part = Mock() + # Create a proper Part that Pydantic will accept + from a2a.types import Part + from a2a.types import TextPart + + text_part = TextPart(text="test message") + mock_a2a_part = Part(root=text_part) + mock_convert_part.return_value = mock_a2a_part + + mock_content = Mock() + mock_content.parts = [mock_part] + self.mock_event.content = mock_content + + result = convert_event_to_a2a_message( + self.mock_event, self.mock_invocation_context + ) + + assert isinstance(result, Message) + assert result.messageId == "test-uuid" + assert result.role == Role.agent + assert len(result.parts) == 1 + assert result.parts[0].root.text == "test message" + + def test_convert_event_to_message_no_content(self): + """Test event to message conversion with no content.""" + self.mock_event.content = None + + result = convert_event_to_a2a_message( + self.mock_event, self.mock_invocation_context + ) + + assert result is None + + def test_convert_event_to_message_empty_parts(self): + """Test event to message conversion with empty parts.""" + mock_content = Mock() + mock_content.parts = [] + self.mock_event.content = mock_content + + result = convert_event_to_a2a_message( + self.mock_event, self.mock_invocation_context + ) + + assert result is None + + def test_convert_event_to_message_none_event(self): + """Test event to message conversion with None event.""" + with pytest.raises(ValueError) as exc_info: + convert_event_to_a2a_message(None, self.mock_invocation_context) + assert "Event cannot be None" in str(exc_info.value) + + def test_convert_event_to_message_none_context(self): + """Test event to message conversion with None context.""" + with pytest.raises(ValueError) as exc_info: + convert_event_to_a2a_message(self.mock_event, None) + assert "Invocation context cannot be None" in str(exc_info.value) + + @patch( + "google.adk.a2a.converters.event_converter.convert_genai_part_to_a2a_part" + ) + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + def test_convert_event_to_message_with_custom_role( + self, mock_uuid, mock_convert_part + ): + """Test event to message conversion with custom role.""" + mock_uuid.return_value = "test-uuid" + + mock_part = Mock() + # Create a proper Part that Pydantic will accept + from a2a.types import Part + from a2a.types import TextPart + + text_part = TextPart(text="test message") + mock_a2a_part = Part(root=text_part) + mock_convert_part.return_value = mock_a2a_part + + mock_content = Mock() + mock_content.parts = [mock_part] + self.mock_event.content = mock_content + + result = convert_event_to_a2a_message( + self.mock_event, self.mock_invocation_context, role=Role.user + ) + + assert isinstance(result, Message) + assert result.messageId == "test-uuid" + assert result.role == Role.user + assert len(result.parts) == 1 + assert result.parts[0].root.text == "test message" + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + @patch("google.adk.a2a.converters.event_converter.datetime") + def test_create_error_status_event(self, mock_datetime, mock_uuid): + """Test creation of error status event.""" + mock_uuid.return_value = "test-uuid" + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + self.mock_event.error_message = "Test error message" + task_id = "test-task-id" + context_id = "test-context-id" + + result = _create_error_status_event( + self.mock_event, self.mock_invocation_context, task_id, context_id + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.taskId == task_id + assert result.contextId == context_id + assert result.status.state == TaskState.failed + assert result.status.message.parts[0].root.text == "Test error message" + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + @patch("google.adk.a2a.converters.event_converter.datetime") + def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid): + """Test creation of error status event without error message.""" + mock_uuid.return_value = "test-uuid" + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + task_id = "test-task-id" + context_id = "test-context-id" + + result = _create_error_status_event( + self.mock_event, self.mock_invocation_context, task_id, context_id + ) + + assert result.status.message.parts[0].root.text == DEFAULT_ERROR_MESSAGE + + @patch("google.adk.a2a.converters.event_converter.datetime") + def test_create_running_status_event(self, mock_datetime): + """Test creation of running status event.""" + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + mock_message = Mock(spec=Message) + mock_message.parts = [] + task_id = "test-task-id" + context_id = "test-context-id" + + result = _create_status_update_event( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.taskId == task_id + assert result.contextId == context_id + assert result.status.state == TaskState.working + assert result.status.message == mock_message + + @patch( + "google.adk.a2a.converters.event_converter._convert_artifact_to_a2a_events" + ) + @patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) + @patch("google.adk.a2a.converters.event_converter._create_error_status_event") + @patch( + "google.adk.a2a.converters.event_converter._create_status_update_event" + ) + def test_convert_event_to_a2a_events_full_scenario( + self, + mock_create_running, + mock_create_error, + mock_convert_message, + mock_convert_artifact, + ): + """Test full event to A2A events conversion scenario.""" + # Setup artifact delta + self.mock_event.actions.artifact_delta = {"file1.txt": 1, "file2.txt": 2} + + # Setup error + self.mock_event.error_code = "ERROR_001" + + # Setup message + mock_message = Mock(spec=Message) + mock_convert_message.return_value = mock_message + + # Setup mock returns + mock_artifact_event1 = Mock() + mock_artifact_event2 = Mock() + mock_convert_artifact.side_effect = [ + mock_artifact_event1, + mock_artifact_event2, + ] + + mock_error_event = Mock() + mock_create_error.return_value = mock_error_event + + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context + ) + + # Verify artifact delta events + assert mock_convert_artifact.call_count == 2 + + # Verify error event - now called with task_id and context_id parameters + mock_create_error.assert_called_once_with( + self.mock_event, self.mock_invocation_context, None, None + ) + + # Verify running event - now called with task_id and context_id parameters + mock_create_running.assert_called_once_with( + mock_message, self.mock_invocation_context, self.mock_event, None, None + ) + + # Verify result contains all events + assert len(result) == 4 # 2 artifact + 1 error + 1 running + assert mock_artifact_event1 in result + assert mock_artifact_event2 in result + assert mock_error_event in result + assert mock_running_event in result + + def test_convert_event_to_a2a_events_empty_scenario(self): + """Test event to A2A events conversion with empty event.""" + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context + ) + + assert result == [] + + def test_convert_event_to_a2a_events_none_event(self): + """Test event to A2A events conversion with None event.""" + with pytest.raises(ValueError) as exc_info: + convert_event_to_a2a_events(None, self.mock_invocation_context) + assert "Event cannot be None" in str(exc_info.value) + + def test_convert_event_to_a2a_events_none_context(self): + """Test event to A2A events conversion with None context.""" + with pytest.raises(ValueError) as exc_info: + convert_event_to_a2a_events(self.mock_event, None) + assert "Invocation context cannot be None" in str(exc_info.value) + + @patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) + def test_convert_event_to_a2a_events_message_only(self, mock_convert_message): + """Test event to A2A events conversion with message only.""" + mock_message = Mock(spec=Message) + mock_convert_message.return_value = mock_message + + with patch( + "google.adk.a2a.converters.event_converter._create_status_update_event" + ) as mock_create_running: + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context + ) + + assert len(result) == 1 + assert result[0] == mock_running_event + # Verify the function is called with task_id and context_id parameters + mock_create_running.assert_called_once_with( + mock_message, + self.mock_invocation_context, + self.mock_event, + None, + None, + ) + + @patch("google.adk.a2a.converters.event_converter.logger") + def test_convert_event_to_a2a_events_exception_handling(self, mock_logger): + """Test exception handling in convert_event_to_a2a_events.""" + # Make convert_event_to_a2a_message raise an exception + with patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.side_effect = Exception("Test exception") + + with pytest.raises(Exception): + convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context + ) + + mock_logger.error.assert_called_once() + + def test_convert_event_to_a2a_events_with_task_id_and_context_id(self): + """Test event to A2A events conversion with specific task_id and context_id.""" + # Setup message + mock_message = Mock(spec=Message) + mock_message.parts = [] + + with patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.return_value = mock_message + + with patch( + "google.adk.a2a.converters.event_converter._create_status_update_event" + ) as mock_create_running: + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + task_id = "custom-task-id" + context_id = "custom-context-id" + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context, task_id, context_id + ) + + assert len(result) == 1 + assert result[0] == mock_running_event + + # Verify the function is called with the specific task_id and context_id + mock_create_running.assert_called_once_with( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + def test_convert_event_to_a2a_events_with_artifacts_and_custom_ids(self): + """Test event to A2A events conversion with artifacts and custom IDs.""" + # Setup artifact delta + self.mock_event.actions.artifact_delta = {"file1.txt": 1} + + # Setup message + mock_message = Mock(spec=Message) + mock_message.parts = [] + + with patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.return_value = mock_message + + with patch( + "google.adk.a2a.converters.event_converter._convert_artifact_to_a2a_events" + ) as mock_convert_artifact: + mock_artifact_event = Mock() + mock_convert_artifact.return_value = mock_artifact_event + + with patch( + "google.adk.a2a.converters.event_converter._create_status_update_event" + ) as mock_create_running: + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + + task_id = "custom-task-id" + context_id = "custom-context-id" + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context, task_id, context_id + ) + + assert len(result) == 2 # 1 artifact + 1 status + assert mock_artifact_event in result + assert mock_running_event in result + + # Verify artifact conversion is called with custom IDs + mock_convert_artifact.assert_called_once_with( + self.mock_event, + self.mock_invocation_context, + "file1.txt", + 1, + task_id, + context_id, + ) + + # Verify status update is called with custom IDs + mock_create_running.assert_called_once_with( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + def test_create_status_update_event_with_auth_required_state(self): + """Test creation of status update event with auth_required state.""" + from a2a.types import DataPart + from a2a.types import Part + + # Create a mock message with a part that triggers auth_required state + mock_message = Mock(spec=Message) + mock_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = { + "adk_type": "function_call", + "adk_is_long_running": True, + } + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="request_euc") + mock_part.root = mock_data_part + mock_message.parts = [mock_part] + + task_id = "test-task-id" + context_id = "test-context-id" + + with patch( + "google.adk.a2a.converters.event_converter.datetime" + ) as mock_datetime: + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY", + "is_long_running", + ), + patch( + "google.adk.a2a.converters.event_converter.REQUEST_EUC_FUNCTION_CALL_NAME", + "request_euc", + ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, + ): + mock_get_key.side_effect = lambda key: f"adk_{key}" + + result = _create_status_update_event( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.taskId == task_id + assert result.contextId == context_id + assert result.status.state == TaskState.auth_required + + def test_create_status_update_event_with_input_required_state(self): + """Test creation of status update event with input_required state.""" + from a2a.types import DataPart + from a2a.types import Part + + # Create a mock message with a part that triggers input_required state + mock_message = Mock(spec=Message) + mock_part = Mock() + mock_data_part = Mock(spec=DataPart) + mock_data_part.metadata = { + "adk_type": "function_call", + "adk_is_long_running": True, + } + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="some_other_function") + mock_part.root = mock_data_part + mock_message.parts = [mock_part] + + task_id = "test-task-id" + context_id = "test-context-id" + + with patch( + "google.adk.a2a.converters.event_converter.datetime" + ) as mock_datetime: + mock_datetime.now.return_value.isoformat.return_value = ( + "2023-01-01T00:00:00" + ) + + with ( + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_KEY", + "type", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", + "function_call", + ), + patch( + "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY", + "is_long_running", + ), + patch( + "google.adk.a2a.converters.event_converter.REQUEST_EUC_FUNCTION_CALL_NAME", + "request_euc", + ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, + ): + mock_get_key.side_effect = lambda key: f"adk_{key}" + + result = _create_status_update_event( + mock_message, + self.mock_invocation_context, + self.mock_event, + task_id, + context_id, + ) + + assert isinstance(result, TaskStatusUpdateEvent) + assert result.taskId == task_id + assert result.contextId == context_id + assert result.status.state == TaskState.input_required + + +class TestA2AToEventConverters: + """Test suite for A2A to Event conversion functions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_invocation_context = Mock(spec=InvocationContext) + self.mock_invocation_context.branch = "test-branch" + self.mock_invocation_context.invocation_id = "test-invocation-id" + + def test_convert_a2a_task_to_event_with_status_message(self): + """Test converting A2A task with status message.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock message and task + mock_message = Mock(spec=Message) + mock_status = Mock() + mock_status.message = mock_message + mock_task = Mock(spec=Task) + mock_task.status = mock_status + mock_task.history = [] + + # Mock the convert_a2a_message_to_event function + with patch( + "google.adk.a2a.converters.event_converter.convert_a2a_message_to_event" + ) as mock_convert_message: + mock_event = Mock(spec=Event) + mock_event.invocation_id = "test-invocation-id" + mock_convert_message.return_value = mock_event + + result = convert_a2a_task_to_event( + mock_task, "test-author", self.mock_invocation_context + ) + + # Verify the message converter was called with correct parameters + mock_convert_message.assert_called_once_with( + mock_message, "test-author", self.mock_invocation_context + ) + assert result == mock_event + assert result.invocation_id == "test-invocation-id" + + def test_convert_a2a_task_to_event_with_history_message(self): + """Test converting A2A task with history message when no status message.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock message and task + mock_message = Mock(spec=Message) + mock_task = Mock(spec=Task) + mock_task.status = None + mock_task.history = [mock_message] + + # Mock the convert_a2a_message_to_event function + with patch( + "google.adk.a2a.converters.event_converter.convert_a2a_message_to_event" + ) as mock_convert_message: + mock_event = Mock(spec=Event) + mock_event.invocation_id = "test-invocation-id" + mock_convert_message.return_value = mock_event + + result = convert_a2a_task_to_event(mock_task, "test-author") + + # Verify the message converter was called with correct parameters + mock_convert_message.assert_called_once_with( + mock_message, "test-author", None + ) + assert result == mock_event + + def test_convert_a2a_task_to_event_no_message(self): + """Test converting A2A task with no message.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock task with no message + mock_task = Mock(spec=Task) + mock_task.status = None + mock_task.history = [] + + result = convert_a2a_task_to_event( + mock_task, "test-author", self.mock_invocation_context + ) + + # Verify minimal event was created with correct invocation_id + assert result.author == "test-author" + assert result.branch == "test-branch" + assert result.invocation_id == "test-invocation-id" + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + def test_convert_a2a_task_to_event_default_author(self, mock_uuid): + """Test converting A2A task with default author and no invocation context.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock task with no message + mock_task = Mock(spec=Task) + mock_task.status = None + mock_task.history = [] + + # Mock UUID generation + mock_uuid.return_value = "generated-uuid" + + result = convert_a2a_task_to_event(mock_task) + + # Verify default author was used and UUID was generated for invocation_id + assert result.author == "a2a agent" + assert result.branch is None + assert result.invocation_id == "generated-uuid" + + def test_convert_a2a_task_to_event_none_task(self): + """Test converting None task raises ValueError.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + with pytest.raises(ValueError, match="A2A task cannot be None"): + convert_a2a_task_to_event(None) + + def test_convert_a2a_task_to_event_message_conversion_error(self): + """Test error handling when message conversion fails.""" + from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event + + # Create mock message and task + mock_message = Mock(spec=Message) + mock_status = Mock() + mock_status.message = mock_message + mock_task = Mock(spec=Task) + mock_task.status = mock_status + mock_task.history = [] + + # Mock the convert_a2a_message_to_event function to raise an exception + with patch( + "google.adk.a2a.converters.event_converter.convert_a2a_message_to_event" + ) as mock_convert_message: + mock_convert_message.side_effect = Exception("Conversion failed") + + with pytest.raises(RuntimeError, match="Failed to convert task message"): + convert_a2a_task_to_event(mock_task, "test-author") + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_success(self, mock_convert_part): + """Test successful conversion of A2A message to event.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + from google.genai import types as genai_types + + # Create mock parts and message with valid genai Part + mock_a2a_part = Mock() + mock_genai_part = genai_types.Part(text="test content") + mock_convert_part.return_value = mock_genai_part + + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify conversion was successful + assert result.author == "test-author" + assert result.branch == "test-branch" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 1 + assert result.content.parts[0].text == "test content" + mock_convert_part.assert_called_once_with(mock_a2a_part) + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_with_long_running_tools( + self, mock_convert_part + ): + """Test conversion with long-running tools by mocking the entire flow.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + # Create mock parts and message + mock_a2a_part = Mock() + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + # Mock the part conversion to return None to simulate long-running tool detection logic + mock_convert_part.return_value = None + + # Patch the long-running tool detection since the main logic is in the actual conversion + with patch( + "google.adk.a2a.converters.event_converter.logger" + ) as mock_logger: + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify basic conversion worked + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + # Parts will be empty since conversion returned None, but that's expected for this test + + def test_convert_a2a_message_to_event_empty_parts(self): + """Test conversion with empty parts list.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + mock_message = Mock(spec=Message) + mock_message.parts = [] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify event was created with empty parts + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 0 + + def test_convert_a2a_message_to_event_none_message(self): + """Test converting None message raises ValueError.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + with pytest.raises(ValueError, match="A2A message cannot be None"): + convert_a2a_message_to_event(None) + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_part_conversion_fails( + self, mock_convert_part + ): + """Test handling when part conversion returns None.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + # Setup mock to return None (conversion failure) + mock_a2a_part = Mock() + mock_convert_part.return_value = None + + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify event was created but with no parts + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 0 + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_part_conversion_exception( + self, mock_convert_part + ): + """Test handling when part conversion raises exception.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + from google.genai import types as genai_types + + # Setup mock to raise exception + mock_a2a_part1 = Mock() + mock_a2a_part2 = Mock() + mock_genai_part = genai_types.Part(text="successful conversion") + + mock_convert_part.side_effect = [ + Exception("Conversion failed"), # First part fails + mock_genai_part, # Second part succeeds + ] + + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part1, mock_a2a_part2] + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify event was created with only the successfully converted part + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + assert len(result.content.parts) == 1 + assert result.content.parts[0].text == "successful conversion" + + @patch( + "google.adk.a2a.converters.event_converter.convert_a2a_part_to_genai_part" + ) + def test_convert_a2a_message_to_event_missing_tool_id( + self, mock_convert_part + ): + """Test handling of message conversion when part conversion fails.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + # Create mock parts and message + mock_a2a_part = Mock() + mock_message = Mock(spec=Message) + mock_message.parts = [mock_a2a_part] + + # Mock the part conversion to return None + mock_convert_part.return_value = None + + result = convert_a2a_message_to_event( + mock_message, "test-author", self.mock_invocation_context + ) + + # Verify basic conversion worked + assert result.author == "test-author" + assert result.invocation_id == "test-invocation-id" + assert result.content.role == "model" + # Parts will be empty since conversion returned None + assert len(result.content.parts) == 0 + + @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") + def test_convert_a2a_message_to_event_default_author(self, mock_uuid): + """Test conversion with default author and no invocation context.""" + from google.adk.a2a.converters.event_converter import convert_a2a_message_to_event + + mock_message = Mock(spec=Message) + mock_message.parts = [] + + # Mock UUID generation + mock_uuid.return_value = "generated-uuid" + + result = convert_a2a_message_to_event(mock_message) + + # Verify default author was used and UUID was generated for invocation_id + assert result.author == "a2a agent" + assert result.branch is None + assert result.invocation_id == "generated-uuid" diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py new file mode 100644 index 000000000..1e8f0d4a3 --- /dev/null +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -0,0 +1,760 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a import types as a2a_types + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY + from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part + from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part + from google.adk.a2a.converters.utils import _get_adk_metadata_key + from google.genai import types as genai_types +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + a2a_types = DummyTypes() + genai_types = DummyTypes() + A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = "function_call" + A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = "function_response" + A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = "code_execution_result" + A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = "executable_code" + A2A_DATA_PART_METADATA_TYPE_KEY = "type" + convert_a2a_part_to_genai_part = lambda x: None + convert_genai_part_to_a2a_part = lambda x: None + _get_adk_metadata_key = lambda x: f"adk_{x}" + else: + raise e + + +class TestConvertA2aPartToGenaiPart: + """Test cases for convert_a2a_part_to_genai_part function.""" + + def test_convert_text_part(self): + """Test conversion of A2A TextPart to GenAI Part.""" + # Arrange + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="Hello, world!")) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == "Hello, world!" + + def test_convert_file_part_with_uri(self): + """Test conversion of A2A FilePart with URI to GenAI Part.""" + # Arrange + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri="gs://bucket/file.txt", mimeType="text/plain" + ) + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.file_data is not None + assert result.file_data.file_uri == "gs://bucket/file.txt" + assert result.file_data.mime_type == "text/plain" + + def test_convert_file_part_with_bytes(self): + """Test conversion of A2A FilePart with bytes to GenAI Part.""" + # Arrange + test_bytes = b"test file content" + # A2A FileWithBytes expects base64-encoded string + import base64 + + base64_encoded = base64.b64encode(test_bytes).decode("utf-8") + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=base64_encoded, mimeType="text/plain" + ) + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.inline_data is not None + # The converter decodes base64 back to original bytes + assert result.inline_data.data == test_bytes + assert result.inline_data.mime_type == "text/plain" + + def test_convert_data_part_function_call(self): + """Test conversion of A2A DataPart with function call metadata.""" + # Arrange + function_call_data = { + "name": "test_function", + "args": {"param1": "value1", "param2": 42}, + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=function_call_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.function_call is not None + assert result.function_call.name == "test_function" + assert result.function_call.args == {"param1": "value1", "param2": 42} + + def test_convert_data_part_function_response(self): + """Test conversion of A2A DataPart with function response metadata.""" + # Arrange + function_response_data = { + "name": "test_function", + "response": {"result": "success", "data": [1, 2, 3]}, + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=function_response_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, + "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.function_response is not None + assert result.function_response.name == "test_function" + assert result.function_response.response == { + "result": "success", + "data": [1, 2, 3], + } + + def test_convert_data_part_without_special_metadata(self): + """Test conversion of A2A DataPart without special metadata to text.""" + # Arrange + data = {"key": "value", "number": 123} + a2a_part = a2a_types.Part( + root=a2a_types.DataPart(data=data, metadata={"other": "metadata"}) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == json.dumps(data) + + def test_convert_data_part_no_metadata(self): + """Test conversion of A2A DataPart with no metadata to text.""" + # Arrange + data = {"key": "value", "array": [1, 2, 3]} + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data)) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == json.dumps(data) + + def test_convert_unsupported_file_type(self): + """Test handling of unsupported file types.""" + + # Arrange - Create a mock unsupported file type + class UnsupportedFileType: + pass + + # Create a part manually since FilePart validation might reject it + mock_file_part = Mock() + mock_file_part.file = UnsupportedFileType() + a2a_part = Mock() + a2a_part.root = mock_file_part + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + def test_convert_unsupported_part_type(self): + """Test handling of unsupported part types.""" + + # Arrange - Create a mock unsupported part type + class UnsupportedPartType: + pass + + mock_part = Mock() + mock_part.root = UnsupportedPartType() + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_a2a_part_to_genai_part(mock_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + +class TestConvertGenaiPartToA2aPart: + """Test cases for convert_genai_part_to_a2a_part function.""" + + def test_convert_text_part(self): + """Test conversion of GenAI text Part to A2A Part.""" + # Arrange + genai_part = genai_types.Part(text="Hello, world!") + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.TextPart) + assert result.root.text == "Hello, world!" + + def test_convert_text_part_with_thought(self): + """Test conversion of GenAI text Part with thought to A2A Part.""" + # Arrange - thought is a boolean field in genai_types.Part + genai_part = genai_types.Part(text="Hello, world!", thought=True) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.TextPart) + assert result.root.text == "Hello, world!" + assert result.root.metadata is not None + assert result.root.metadata[_get_adk_metadata_key("thought")] == True + + def test_convert_file_data_part(self): + """Test conversion of GenAI file_data Part to A2A Part.""" + # Arrange + genai_part = genai_types.Part( + file_data=genai_types.FileData( + file_uri="gs://bucket/file.txt", mime_type="text/plain" + ) + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithUri) + assert result.root.file.uri == "gs://bucket/file.txt" + assert result.root.file.mimeType == "text/plain" + + def test_convert_inline_data_part(self): + """Test conversion of GenAI inline_data Part to A2A Part.""" + # Arrange + test_bytes = b"test file content" + genai_part = genai_types.Part( + inline_data=genai_types.Blob(data=test_bytes, mime_type="text/plain") + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithBytes) + # A2A FileWithBytes now stores base64-encoded bytes to ensure round-trip compatibility + import base64 + + expected_base64 = base64.b64encode(test_bytes).decode("utf-8") + assert result.root.file.bytes == expected_base64 + assert result.root.file.mimeType == "text/plain" + + def test_convert_inline_data_part_with_video_metadata(self): + """Test conversion of GenAI inline_data Part with video metadata to A2A Part.""" + # Arrange + test_bytes = b"test video content" + video_metadata = genai_types.VideoMetadata(fps=30.0) + genai_part = genai_types.Part( + inline_data=genai_types.Blob(data=test_bytes, mime_type="video/mp4"), + video_metadata=video_metadata, + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithBytes) + assert result.root.metadata is not None + assert _get_adk_metadata_key("video_metadata") in result.root.metadata + + def test_convert_function_call_part(self): + """Test conversion of GenAI function_call Part to A2A Part.""" + # Arrange + function_call = genai_types.FunctionCall( + name="test_function", args={"param1": "value1", "param2": 42} + ) + genai_part = genai_types.Part(function_call=function_call) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = function_call.model_dump(by_alias=True, exclude_none=True) + assert result.root.data == expected_data + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + + def test_convert_function_response_part(self): + """Test conversion of GenAI function_response Part to A2A Part.""" + # Arrange + function_response = genai_types.FunctionResponse( + name="test_function", response={"result": "success", "data": [1, 2, 3]} + ) + genai_part = genai_types.Part(function_response=function_response) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = function_response.model_dump( + by_alias=True, exclude_none=True + ) + assert result.root.data == expected_data + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + + def test_convert_code_execution_result_part(self): + """Test conversion of GenAI code_execution_result Part to A2A Part.""" + # Arrange + code_execution_result = genai_types.CodeExecutionResult( + outcome=genai_types.Outcome.OUTCOME_OK, output="Hello, World!" + ) + genai_part = genai_types.Part(code_execution_result=code_execution_result) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = code_execution_result.model_dump( + by_alias=True, exclude_none=True + ) + assert result.root.data == expected_data + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + ) + + def test_convert_executable_code_part(self): + """Test conversion of GenAI executable_code Part to A2A Part.""" + # Arrange + executable_code = genai_types.ExecutableCode( + language=genai_types.Language.PYTHON, code="print('Hello, World!')" + ) + genai_part = genai_types.Part(executable_code=executable_code) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = executable_code.model_dump(by_alias=True, exclude_none=True) + assert result.root.data == expected_data + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + ) + + def test_convert_unsupported_part(self): + """Test handling of unsupported GenAI Part types.""" + # Arrange - Create a GenAI Part with no recognized fields + genai_part = genai_types.Part() + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + +class TestRoundTripConversions: + """Test cases for round-trip conversions to ensure consistency.""" + + def test_text_part_round_trip(self): + """Test round-trip conversion for text parts.""" + # Arrange + original_text = "Hello, world!" + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text=original_text)) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.TextPart) + assert result_a2a_part.root.text == original_text + + def test_file_uri_round_trip(self): + """Test round-trip conversion for file parts with URI.""" + # Arrange + original_uri = "gs://bucket/file.txt" + original_mime_type = "text/plain" + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=original_uri, mimeType=original_mime_type + ) + ) + ) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.FilePart) + assert isinstance(result_a2a_part.root.file, a2a_types.FileWithUri) + assert result_a2a_part.root.file.uri == original_uri + assert result_a2a_part.root.file.mimeType == original_mime_type + + def test_file_bytes_round_trip(self): + """Test round-trip conversion for file parts with bytes.""" + # Arrange + original_bytes = b"test file content for round trip" + original_mime_type = "application/octet-stream" + + # Start with GenAI part (the more common starting point) + genai_part = genai_types.Part( + inline_data=genai_types.Blob( + data=original_bytes, mime_type=original_mime_type + ) + ) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.inline_data is not None + assert result_genai_part.inline_data.data == original_bytes + assert result_genai_part.inline_data.mime_type == original_mime_type + + def test_function_call_round_trip(self): + """Test round-trip conversion for function call parts.""" + # Arrange + function_call = genai_types.FunctionCall( + name="test_function", args={"param1": "value1", "param2": 42} + ) + genai_part = genai_types.Part(function_call=function_call) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.function_call is not None + assert result_genai_part.function_call.name == function_call.name + assert result_genai_part.function_call.args == function_call.args + + def test_function_response_round_trip(self): + """Test round-trip conversion for function response parts.""" + # Arrange + function_response = genai_types.FunctionResponse( + name="test_function", response={"result": "success", "data": [1, 2, 3]} + ) + genai_part = genai_types.Part(function_response=function_response) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.function_response is not None + assert result_genai_part.function_response.name == function_response.name + assert ( + result_genai_part.function_response.response + == function_response.response + ) + + def test_code_execution_result_round_trip(self): + """Test round-trip conversion for code execution result parts.""" + # Arrange + code_execution_result = genai_types.CodeExecutionResult( + outcome=genai_types.Outcome.OUTCOME_OK, output="Hello, World!" + ) + genai_part = genai_types.Part(code_execution_result=code_execution_result) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.code_execution_result is not None + assert ( + result_genai_part.code_execution_result.outcome + == code_execution_result.outcome + ) + assert ( + result_genai_part.code_execution_result.output + == code_execution_result.output + ) + + def test_executable_code_round_trip(self): + """Test round-trip conversion for executable code parts.""" + # Arrange + executable_code = genai_types.ExecutableCode( + language=genai_types.Language.PYTHON, code="print('Hello, World!')" + ) + genai_part = genai_types.Part(executable_code=executable_code) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.executable_code is not None + assert ( + result_genai_part.executable_code.language == executable_code.language + ) + assert result_genai_part.executable_code.code == executable_code.code + + +class TestEdgeCases: + """Test cases for edge cases and error conditions.""" + + def test_empty_text_part(self): + """Test conversion of empty text part.""" + # Arrange + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="")) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == "" + + def test_none_input_a2a_to_genai(self): + """Test handling of None input for A2A to GenAI conversion.""" + # This test depends on how the function handles None input + # If it should raise an exception, we test for that + with pytest.raises(AttributeError): + convert_a2a_part_to_genai_part(None) + + def test_none_input_genai_to_a2a(self): + """Test handling of None input for GenAI to A2A conversion.""" + # This test depends on how the function handles None input + # If it should raise an exception, we test for that + with pytest.raises(AttributeError): + convert_genai_part_to_a2a_part(None) + + def test_data_part_with_complex_data(self): + """Test conversion of DataPart with complex nested data.""" + # Arrange + complex_data = { + "nested": { + "array": [1, 2, {"inner": "value"}], + "boolean": True, + "null_value": None, + }, + "unicode": "Hello 世界 🌍", + } + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=complex_data)) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == json.dumps(complex_data) + + def test_data_part_with_empty_metadata(self): + """Test conversion of DataPart with empty metadata dict.""" + # Arrange + data = {"key": "value"} + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data, metadata={})) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == json.dumps(data) + + +class TestNewConstants: + """Test cases for new constants and functionality.""" + + def test_new_constants_exist(self): + """Test that new constants are defined.""" + assert ( + A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + == "code_execution_result" + ) + assert A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE == "executable_code" + + def test_convert_a2a_data_part_with_code_execution_result_metadata(self): + """Test conversion of A2A DataPart with code execution result metadata.""" + # Arrange + code_execution_result_data = { + "outcome": "OUTCOME_OK", + "output": "Hello, World!", + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=code_execution_result_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + # Now it should convert back to a proper CodeExecutionResult + assert result.code_execution_result is not None + assert ( + result.code_execution_result.outcome == genai_types.Outcome.OUTCOME_OK + ) + assert result.code_execution_result.output == "Hello, World!" + + def test_convert_a2a_data_part_with_executable_code_metadata(self): + """Test conversion of A2A DataPart with executable code metadata.""" + # Arrange + executable_code_data = { + "language": "PYTHON", + "code": "print('Hello, World!')", + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=executable_code_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + # Now it should convert back to a proper ExecutableCode + assert result.executable_code is not None + assert result.executable_code.language == genai_types.Language.PYTHON + assert result.executable_code.code == "print('Hello, World!')" diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py new file mode 100644 index 000000000..08266751e --- /dev/null +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -0,0 +1,497 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a.server.agent_execution import RequestContext + from google.adk.a2a.converters.request_converter import _get_user_id + from google.adk.a2a.converters.request_converter import convert_a2a_request_to_adk_run_args + from google.adk.runners import RunConfig + from google.genai import types as genai_types +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + a2a_types = DummyTypes() + genai_types = DummyTypes() + RequestContext = DummyTypes() + RunConfig = DummyTypes() + _get_user_id = lambda x, y: None + convert_a2a_request_to_adk_run_args = lambda x: None + else: + raise e + + +class TestGetUserId: + """Test cases for _get_user_id function.""" + + def test_get_user_id_from_call_context(self): + """Test getting user ID from call context when auth is enabled.""" + # Arrange + mock_user = Mock() + mock_user.user_name = "authenticated_user" + + mock_call_context = Mock() + mock_call_context.user = mock_user + + request = Mock(spec=RequestContext) + request.call_context = mock_call_context + request.message = Mock() + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "context_user") + + # Assert + assert result == "authenticated_user" + + def test_get_user_id_from_context_when_no_call_context(self): + """Test getting user ID from context when call context is not available.""" + # Arrange + request = Mock(spec=RequestContext) + request.call_context = None + request.message = Mock() + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "context_user") + + # Assert + assert result == "context_user" + + def test_get_user_id_from_context_when_call_context_has_no_user(self): + """Test getting user ID from context when call context has no user.""" + # Arrange + mock_call_context = Mock() + mock_call_context.user = None + + request = Mock(spec=RequestContext) + request.call_context = mock_call_context + request.message = Mock() + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "context_user") + + # Assert + assert result == "context_user" + + def test_get_user_id_from_message_metadata(self): + """Test getting user ID from message metadata when context user is not available.""" + # Arrange + mock_message = Mock() + mock_message.metadata = {"adk_user_id": "message_user"} + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "ADK_USER_message_user" + + def test_get_user_id_from_task_metadata(self): + """Test getting user ID from task metadata when message metadata is not available.""" + # Arrange + mock_message = Mock() + mock_message.metadata = None + + mock_task = Mock() + mock_task.metadata = {"adk_user_id": "task_user"} + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = mock_task + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "ADK_USER_task_user" + + def test_get_user_id_fallback_to_task_id(self): + """Test fallback to task ID when no other user ID is available.""" + # Arrange + mock_message = Mock() + mock_message.metadata = None + mock_message.messageId = "msg456" + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "temp_user_task123" + + def test_get_user_id_fallback_to_message_id(self): + """Test fallback to message ID when no task ID is available.""" + # Arrange + mock_message = Mock() + mock_message.metadata = None + mock_message.messageId = "msg456" + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = None + request.task_id = None + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "TEMP_USER_msg456" + + def test_get_user_id_message_metadata_empty(self): + """Test getting user ID when message metadata exists but doesn't contain user_id.""" + # Arrange + mock_message = Mock() + mock_message.metadata = {"other_key": "other_value"} + mock_message.messageId = "msg456" + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = None + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "temp_user_task123" + + def test_get_user_id_task_metadata_empty(self): + """Test getting user ID when task metadata exists but doesn't contain user_id.""" + # Arrange + mock_message = Mock() + mock_message.metadata = None + mock_message.messageId = "msg456" + + mock_task = Mock() + mock_task.metadata = {"other_key": "other_value"} + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.current_task = mock_task + request.task_id = "task123" + + # Act + result = _get_user_id(request, "") + + # Assert + assert result == "temp_user_task123" + + +class TestConvertA2aRequestToAdkRunArgs: + """Test cases for convert_a2a_request_to_adk_run_args function.""" + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + @patch("google.adk.a2a.converters.request_converter._get_user_id") + def test_convert_a2a_request_basic( + self, mock_get_user_id, mock_from_context_id, mock_convert_part + ): + """Test basic conversion of A2A request to ADK run args.""" + # Arrange + mock_part1 = Mock() + mock_part2 = Mock() + + mock_message = Mock() + mock_message.parts = [mock_part1, mock_part2] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = "ADK/app/user/session" + + mock_from_context_id.return_value = ( + "app_name", + "user_from_context", + "session123", + ) + mock_get_user_id.return_value = "final_user" + + # Create proper genai_types.Part objects instead of mocks + mock_genai_part1 = genai_types.Part(text="test part 1") + mock_genai_part2 = genai_types.Part(text="test part 2") + mock_convert_part.side_effect = [mock_genai_part1, mock_genai_part2] + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert result["user_id"] == "final_user" + assert result["session_id"] == "session123" + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part1, mock_genai_part2] + assert isinstance(result["run_config"], RunConfig) + + # Verify calls + mock_from_context_id.assert_called_once_with("ADK/app/user/session") + mock_get_user_id.assert_called_once_with(request, "user_from_context") + assert mock_convert_part.call_count == 2 + mock_convert_part.assert_any_call(mock_part1) + mock_convert_part.assert_any_call(mock_part2) + + def test_convert_a2a_request_no_message_raises_error(self): + """Test that conversion raises ValueError when message is None.""" + # Arrange + request = Mock(spec=RequestContext) + request.message = None + + # Act & Assert + with pytest.raises(ValueError, match="Request message cannot be None"): + convert_a2a_request_to_adk_run_args(request) + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + @patch("google.adk.a2a.converters.request_converter._get_user_id") + def test_convert_a2a_request_empty_parts( + self, mock_get_user_id, mock_from_context_id, mock_convert_part + ): + """Test conversion with empty parts list.""" + # Arrange + mock_message = Mock() + mock_message.parts = [] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = "ADK/app/user/session" + + mock_from_context_id.return_value = ( + "app_name", + "user_from_context", + "session123", + ) + mock_get_user_id.return_value = "final_user" + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert result["user_id"] == "final_user" + assert result["session_id"] == "session123" + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [] + assert isinstance(result["run_config"], RunConfig) + + # Verify convert_part wasn't called + mock_convert_part.assert_not_called() + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + @patch("google.adk.a2a.converters.request_converter._get_user_id") + def test_convert_a2a_request_none_context_id( + self, mock_get_user_id, mock_from_context_id, mock_convert_part + ): + """Test conversion when context_id is None.""" + # Arrange + mock_part = Mock() + mock_message = Mock() + mock_message.parts = [mock_part] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = None + + mock_from_context_id.return_value = (None, None, None) + mock_get_user_id.return_value = "fallback_user" + + # Create proper genai_types.Part object instead of mock + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part.return_value = mock_genai_part + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert result["user_id"] == "fallback_user" + assert result["session_id"] is None + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part] + assert isinstance(result["run_config"], RunConfig) + + # Verify calls + mock_from_context_id.assert_called_once_with(None) + mock_get_user_id.assert_called_once_with(request, None) + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + @patch("google.adk.a2a.converters.request_converter._get_user_id") + def test_convert_a2a_request_invalid_context_id( + self, mock_get_user_id, mock_from_context_id, mock_convert_part + ): + """Test conversion when context_id is invalid format.""" + # Arrange + mock_part = Mock() + mock_message = Mock() + mock_message.parts = [mock_part] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = "invalid_format" + + mock_from_context_id.return_value = (None, None, None) + mock_get_user_id.return_value = "fallback_user" + + # Create proper genai_types.Part object instead of mock + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part.return_value = mock_genai_part + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert result["user_id"] == "fallback_user" + assert result["session_id"] is None + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part] + assert isinstance(result["run_config"], RunConfig) + + # Verify calls + mock_from_context_id.assert_called_once_with("invalid_format") + mock_get_user_id.assert_called_once_with(request, None) + + +class TestIntegration: + """Integration test cases combining both functions.""" + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + def test_end_to_end_conversion_with_auth_user(self, mock_convert_part): + """Test end-to-end conversion with authenticated user.""" + # Arrange + mock_user = Mock() + mock_user.user_name = "auth_user" + + mock_call_context = Mock() + mock_call_context.user = mock_user + + mock_part = Mock() + mock_message = Mock() + mock_message.parts = [mock_part] + + request = Mock(spec=RequestContext) + request.call_context = mock_call_context + request.message = mock_message + request.context_id = "ADK/myapp/context_user/mysession" + request.current_task = None + request.task_id = "task123" + + # Create proper genai_types.Part object instead of mock + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part.return_value = mock_genai_part + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert ( + result["user_id"] == "auth_user" + ) # Should use authenticated user, not context user + assert result["session_id"] == "mysession" + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part] + assert isinstance(result["run_config"], RunConfig) + + @patch( + "google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part" + ) + @patch("google.adk.a2a.converters.request_converter._from_a2a_context_id") + def test_end_to_end_conversion_with_fallback_user( + self, mock_from_context_id, mock_convert_part + ): + """Test end-to-end conversion with fallback user ID.""" + # Arrange + mock_part = Mock() + mock_message = Mock() + mock_message.parts = [mock_part] + mock_message.messageId = "msg789" + mock_message.metadata = None + + request = Mock(spec=RequestContext) + request.call_context = None + request.message = mock_message + request.context_id = "invalid_format" + request.current_task = None + request.task_id = None + + # Mock the utils function to return None values for invalid context + mock_from_context_id.return_value = (None, None, None) + + # Create proper genai_types.Part object instead of mock + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part.return_value = mock_genai_part + + # Act + result = convert_a2a_request_to_adk_run_args(request) + + # Assert + assert result is not None + assert ( + result["user_id"] == "TEMP_USER_msg789" + ) # Should fallback to message ID + assert result["session_id"] is None + assert isinstance(result["new_message"], genai_types.Content) + assert result["new_message"].role == "user" + assert result["new_message"].parts == [mock_genai_part] + assert isinstance(result["run_config"], RunConfig) diff --git a/tests/unittests/a2a/converters/test_utils.py b/tests/unittests/a2a/converters/test_utils.py new file mode 100644 index 000000000..f919cbd00 --- /dev/null +++ b/tests/unittests/a2a/converters/test_utils.py @@ -0,0 +1,213 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) + +from google.adk.a2a.converters.utils import _from_a2a_context_id +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.converters.utils import _to_a2a_context_id +from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX +import pytest + + +class TestUtilsFunctions: + """Test suite for utils module functions.""" + + def test_get_adk_metadata_key_success(self): + """Test successful metadata key generation.""" + key = "test_key" + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_get_adk_metadata_key_empty_string(self): + """Test metadata key generation with empty string.""" + with pytest.raises( + ValueError, match="Metadata key cannot be empty or None" + ): + _get_adk_metadata_key("") + + def test_get_adk_metadata_key_none(self): + """Test metadata key generation with None.""" + with pytest.raises( + ValueError, match="Metadata key cannot be empty or None" + ): + _get_adk_metadata_key(None) + + def test_get_adk_metadata_key_whitespace(self): + """Test metadata key generation with whitespace string.""" + key = " " + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_to_a2a_context_id_success(self): + """Test successful context ID generation.""" + app_name = "test-app" + user_id = "test-user" + session_id = "test-session" + + result = _to_a2a_context_id(app_name, user_id, session_id) + + expected = f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user/test-session" + assert result == expected + + def test_to_a2a_context_id_empty_app_name(self): + """Test context ID generation with empty app name.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("", "user", "session") + + def test_to_a2a_context_id_empty_user_id(self): + """Test context ID generation with empty user ID.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("app", "", "session") + + def test_to_a2a_context_id_empty_session_id(self): + """Test context ID generation with empty session ID.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("app", "user", "") + + def test_to_a2a_context_id_none_values(self): + """Test context ID generation with None values.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id(None, "user", "session") + + def test_to_a2a_context_id_special_characters(self): + """Test context ID generation with special characters.""" + app_name = "test-app@2024" + user_id = "user_123" + session_id = "session-456" + + result = _to_a2a_context_id(app_name, user_id, session_id) + + expected = f"{ADK_CONTEXT_ID_PREFIX}/test-app@2024/user_123/session-456" + assert result == expected + + def test_from_a2a_context_id_success(self): + """Test successful context ID parsing.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user/test-session" + + app_name, user_id, session_id = _from_a2a_context_id(context_id) + + assert app_name == "test-app" + assert user_id == "test-user" + assert session_id == "test-session" + + def test_from_a2a_context_id_none_input(self): + """Test context ID parsing with None input.""" + result = _from_a2a_context_id(None) + assert result == (None, None, None) + + def test_from_a2a_context_id_empty_string(self): + """Test context ID parsing with empty string.""" + result = _from_a2a_context_id("") + assert result == (None, None, None) + + def test_from_a2a_context_id_invalid_prefix(self): + """Test context ID parsing with invalid prefix.""" + context_id = "INVALID/test-app/test-user/test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_too_few_parts(self): + """Test context ID parsing with too few parts.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_too_many_parts(self): + """Test context ID parsing with too many parts.""" + context_id = ( + f"{ADK_CONTEXT_ID_PREFIX}/test-app/test-user/test-session/extra" + ) + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_empty_components(self): + """Test context ID parsing with empty components.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}//test-user/test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_no_dollar_separator(self): + """Test context ID parsing without dollar separators.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}-test-app-test-user-test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_roundtrip_context_id(self): + """Test roundtrip conversion: to -> from.""" + app_name = "test-app" + user_id = "test-user" + session_id = "test-session" + + # Convert to context ID + context_id = _to_a2a_context_id(app_name, user_id, session_id) + + # Convert back + parsed_app, parsed_user, parsed_session = _from_a2a_context_id(context_id) + + assert parsed_app == app_name + assert parsed_user == user_id + assert parsed_session == session_id + + def test_from_a2a_context_id_special_characters(self): + """Test context ID parsing with special characters.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}/test-app@2024/user_123/session-456" + + app_name, user_id, session_id = _from_a2a_context_id(context_id) + + assert app_name == "test-app@2024" + assert user_id == "user_123" + assert session_id == "session-456" diff --git a/tests/unittests/agents/test_llm_agent_include_contents.py b/tests/unittests/agents/test_llm_agent_include_contents.py new file mode 100644 index 000000000..d4d76cf4e --- /dev/null +++ b/tests/unittests/agents/test_llm_agent_include_contents.py @@ -0,0 +1,242 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for LlmAgent include_contents field behavior.""" + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.genai import types +import pytest + +from .. import testing_utils + + +@pytest.mark.asyncio +async def test_include_contents_default_behavior(): + """Test that include_contents='default' preserves conversation history including tool interactions.""" + + def simple_tool(message: str) -> dict: + return {"result": f"Tool processed: {message}"} + + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + "First response", + types.Part.from_function_call( + name="simple_tool", args={"message": "second"} + ), + "Second response", + ] + ) + + agent = LlmAgent( + name="test_agent", + model=mock_model, + include_contents="default", + instruction="You are a helpful assistant", + tools=[simple_tool], + ) + + runner = testing_utils.InMemoryRunner(agent) + runner.run("First message") + runner.run("Second message") + + # First turn requests + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "First message") + ] + + assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ] + + # Second turn should include full conversation history + assert testing_utils.simplify_contents(mock_model.requests[2].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ("model", "First response"), + ("user", "Second message"), + ] + + # Second turn with tool should include full history + current tool interaction + assert testing_utils.simplify_contents(mock_model.requests[3].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ("model", "First response"), + ("user", "Second message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "second"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: second"} + ), + ), + ] + + +@pytest.mark.asyncio +async def test_include_contents_none_behavior(): + """Test that include_contents='none' excludes conversation history but includes current input.""" + + def simple_tool(message: str) -> dict: + return {"result": f"Tool processed: {message}"} + + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + "First response", + "Second response", + ] + ) + + agent = LlmAgent( + name="test_agent", + model=mock_model, + include_contents="none", + instruction="You are a helpful assistant", + tools=[simple_tool], + ) + + runner = testing_utils.InMemoryRunner(agent) + runner.run("First message") + runner.run("Second message") + + # First turn behavior + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "First message") + ] + + assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [ + ("user", "First message"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "first"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", response={"result": "Tool processed: first"} + ), + ), + ] + + # Second turn should only have current input, no history + assert testing_utils.simplify_contents(mock_model.requests[2].contents) == [ + ("user", "Second message") + ] + + # System instruction and tools should be preserved + assert ( + "You are a helpful assistant" + in mock_model.requests[0].config.system_instruction + ) + assert len(mock_model.requests[0].config.tools) > 0 + + +@pytest.mark.asyncio +async def test_include_contents_none_sequential_agents(): + """Test include_contents='none' with sequential agents.""" + + agent1_model = testing_utils.MockModel.create( + responses=["Agent1 response: XYZ"] + ) + agent1 = LlmAgent( + name="agent1", + model=agent1_model, + instruction="You are Agent1", + ) + + agent2_model = testing_utils.MockModel.create( + responses=["Agent2 final response"] + ) + agent2 = LlmAgent( + name="agent2", + model=agent2_model, + include_contents="none", + instruction="You are Agent2", + ) + + sequential_agent = SequentialAgent( + name="sequential_test_agent", sub_agents=[agent1, agent2] + ) + + runner = testing_utils.InMemoryRunner(sequential_agent) + events = runner.run("Original user request") + + assert len(events) == 2 + assert events[0].author == "agent1" + assert events[1].author == "agent2" + + # Agent1 sees original user request + agent1_contents = testing_utils.simplify_contents( + agent1_model.requests[0].contents + ) + assert ("user", "Original user request") in agent1_contents + + # Agent2 with include_contents='none' should not see original request + agent2_contents = testing_utils.simplify_contents( + agent2_model.requests[0].contents + ) + + assert not any( + "Original user request" in str(content) for _, content in agent2_contents + ) + assert any( + "Agent1 response" in str(content) for _, content in agent2_contents + ) diff --git a/tests/unittests/auth/__init__.py b/tests/unittests/auth/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/credential_service/__init__.py b/tests/unittests/auth/credential_service/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/credential_service/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/credential_service/test_in_memory_credential_service.py b/tests/unittests/auth/credential_service/test_in_memory_credential_service.py new file mode 100644 index 000000000..9312f72a3 --- /dev/null +++ b/tests/unittests/auth/credential_service/test_in_memory_credential_service.py @@ -0,0 +1,323 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from google.adk.tools.tool_context import ToolContext +import pytest + + +class TestInMemoryCredentialService: + """Tests for the InMemoryCredentialService class.""" + + @pytest.fixture + def credential_service(self): + """Create an InMemoryCredentialService instance for testing.""" + return InMemoryCredentialService() + + @pytest.fixture + def oauth2_auth_scheme(self): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + @pytest.fixture + def oauth2_credentials(self): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + @pytest.fixture + def auth_config(self, oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + exchanged_credential = oauth2_credentials.model_copy(deep=True) + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged_credential, + ) + + @pytest.fixture + def tool_context(self): + """Create a mock ToolContext for testing.""" + mock_context = Mock(spec=ToolContext) + mock_invocation_context = Mock() + mock_invocation_context.app_name = "test_app" + mock_invocation_context.user_id = "test_user" + mock_context._invocation_context = mock_invocation_context + return mock_context + + @pytest.fixture + def another_tool_context(self): + """Create another mock ToolContext with different app/user for testing isolation.""" + mock_context = Mock(spec=ToolContext) + mock_invocation_context = Mock() + mock_invocation_context.app_name = "another_app" + mock_invocation_context.user_id = "another_user" + mock_context._invocation_context = mock_invocation_context + return mock_context + + def test_init(self, credential_service): + """Test that the service initializes with an empty store.""" + assert isinstance(credential_service._credentials, dict) + assert len(credential_service._credentials) == 0 + + @pytest.mark.asyncio + async def test_load_credential_not_found( + self, credential_service, auth_config, tool_context + ): + """Test loading a credential that doesn't exist returns None.""" + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_save_and_load_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving and then loading a credential.""" + # Save the credential + await credential_service.save_credential(auth_config, tool_context) + + # Load the credential + result = await credential_service.load_credential(auth_config, tool_context) + + # Verify the credential was saved and loaded correctly + assert result is not None + assert result == auth_config.exchanged_auth_credential + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.client_id == "mock_client_id" + + @pytest.mark.asyncio + async def test_save_credential_updates_existing( + self, credential_service, auth_config, tool_context, oauth2_credentials + ): + """Test that saving a credential updates an existing one.""" + # Save initial credential + await credential_service.save_credential(auth_config, tool_context) + + # Create a new credential and update the auth_config + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="updated_client_id", + client_secret="updated_client_secret", + redirect_uri="https://updated.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + + # Save the updated credential + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify the credential was updated + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "updated_client_id" + assert result.oauth2.client_secret == "updated_client_secret" + + @pytest.mark.asyncio + async def test_credentials_isolated_by_context( + self, credential_service, auth_config, tool_context, another_tool_context + ): + """Test that credentials are isolated between different app/user contexts.""" + # Save credential in first context + await credential_service.save_credential(auth_config, tool_context) + + # Try to load from another context + result = await credential_service.load_credential( + auth_config, another_tool_context + ) + assert result is None + + # Verify original context still has the credential + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + + @pytest.mark.asyncio + async def test_multiple_credentials_same_context( + self, credential_service, tool_context, oauth2_auth_scheme + ): + """Test storing multiple credentials in the same context with different keys.""" + # Create two different auth configs with different credential keys + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client1", + client_secret="secret1", + redirect_uri="https://example1.com/callback", + ), + ) + + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client2", + client_secret="secret2", + redirect_uri="https://example2.com/callback", + ), + ) + + auth_config1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred1, + exchanged_auth_credential=cred1, + credential_key="key1", + ) + + auth_config2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred2, + exchanged_auth_credential=cred2, + credential_key="key2", + ) + + # Save both credentials + await credential_service.save_credential(auth_config1, tool_context) + await credential_service.save_credential(auth_config2, tool_context) + + # Load and verify both credentials + result1 = await credential_service.load_credential( + auth_config1, tool_context + ) + result2 = await credential_service.load_credential( + auth_config2, tool_context + ) + + assert result1 is not None + assert result2 is not None + assert result1.oauth2.client_id == "client1" + assert result2.oauth2.client_id == "client2" + + def test_get_bucket_for_current_context_creates_nested_structure( + self, credential_service, tool_context + ): + """Test that _get_bucket_for_current_context creates the proper nested structure.""" + storage = credential_service._get_bucket_for_current_context(tool_context) + + # Verify the nested structure was created + assert "test_app" in credential_service._credentials + assert "test_user" in credential_service._credentials["test_app"] + assert isinstance(storage, dict) + assert storage is credential_service._credentials["test_app"]["test_user"] + + def test_get_bucket_for_current_context_reuses_existing( + self, credential_service, tool_context + ): + """Test that _get_bucket_for_current_context reuses existing structure.""" + # Create initial structure + storage1 = credential_service._get_bucket_for_current_context(tool_context) + storage1["test_key"] = "test_value" + + # Get storage again + storage2 = credential_service._get_bucket_for_current_context(tool_context) + + # Verify it's the same storage instance + assert storage1 is storage2 + assert storage2["test_key"] == "test_value" + + def test_get_storage_different_apps( + self, credential_service, tool_context, another_tool_context + ): + """Test that different apps get different storage instances.""" + storage1 = credential_service._get_bucket_for_current_context(tool_context) + storage2 = credential_service._get_bucket_for_current_context( + another_tool_context + ) + + # Verify they are different storage instances + assert storage1 is not storage2 + + # Verify the structure + assert "test_app" in credential_service._credentials + assert "another_app" in credential_service._credentials + assert "test_user" in credential_service._credentials["test_app"] + assert "another_user" in credential_service._credentials["another_app"] + + @pytest.mark.asyncio + async def test_same_user_different_apps( + self, credential_service, auth_config + ): + """Test that the same user in different apps get isolated storage.""" + # Create two contexts with same user but different apps + context1 = Mock(spec=ToolContext) + mock_invocation_context1 = Mock() + mock_invocation_context1.app_name = "app1" + mock_invocation_context1.user_id = "same_user" + context1._invocation_context = mock_invocation_context1 + + context2 = Mock(spec=ToolContext) + mock_invocation_context2 = Mock() + mock_invocation_context2.app_name = "app2" + mock_invocation_context2.user_id = "same_user" + context2._invocation_context = mock_invocation_context2 + + # Save credential in app1 + await credential_service.save_credential(auth_config, context1) + + # Try to load from app2 (should not find it) + result = await credential_service.load_credential(auth_config, context2) + assert result is None + + # Verify app1 still has the credential + result = await credential_service.load_credential(auth_config, context1) + assert result is not None + + @pytest.mark.asyncio + async def test_same_app_different_users( + self, credential_service, auth_config + ): + """Test that different users in the same app get isolated storage.""" + # Create two contexts with same app but different users + context1 = Mock(spec=ToolContext) + mock_invocation_context1 = Mock() + mock_invocation_context1.app_name = "same_app" + mock_invocation_context1.user_id = "user1" + context1._invocation_context = mock_invocation_context1 + + context2 = Mock(spec=ToolContext) + mock_invocation_context2 = Mock() + mock_invocation_context2.app_name = "same_app" + mock_invocation_context2.user_id = "user2" + context2._invocation_context = mock_invocation_context2 + + # Save credential for user1 + await credential_service.save_credential(auth_config, context1) + + # Try to load for user2 (should not find it) + result = await credential_service.load_credential(auth_config, context2) + assert result is None + + # Verify user1 still has the credential + result = await credential_service.load_credential(auth_config, context1) + assert result is not None diff --git a/tests/unittests/auth/credential_service/test_session_state_credential_service.py b/tests/unittests/auth/credential_service/test_session_state_credential_service.py new file mode 100644 index 000000000..610a9d3d1 --- /dev/null +++ b/tests/unittests/auth/credential_service/test_session_state_credential_service.py @@ -0,0 +1,355 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_service.session_state_credential_service import SessionStateCredentialService +from google.adk.tools.tool_context import ToolContext +import pytest + + +class TestSessionStateCredentialService: + """Tests for the SessionStateCredentialService class.""" + + @pytest.fixture + def credential_service(self): + """Create a SessionStateCredentialService instance for testing.""" + return SessionStateCredentialService() + + @pytest.fixture + def oauth2_auth_scheme(self): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + @pytest.fixture + def oauth2_credentials(self): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + @pytest.fixture + def auth_config(self, oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + exchanged_credential = oauth2_credentials.model_copy(deep=True) + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged_credential, + ) + + @pytest.fixture + def tool_context(self): + """Create a mock ToolContext for testing.""" + mock_context = Mock(spec=ToolContext) + # Create a state dictionary that behaves like session state + mock_context.state = {} + return mock_context + + @pytest.fixture + def another_tool_context(self): + """Create another mock ToolContext with different state for testing isolation.""" + mock_context = Mock(spec=ToolContext) + # Create a separate state dictionary to simulate different session + mock_context.state = {} + return mock_context + + @pytest.mark.asyncio + async def test_load_credential_not_found( + self, credential_service, auth_config, tool_context + ): + """Test loading a credential that doesn't exist returns None.""" + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_save_and_load_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving and then loading a credential.""" + # Save the credential + await credential_service.save_credential(auth_config, tool_context) + + # Load the credential + result = await credential_service.load_credential(auth_config, tool_context) + + # Verify the credential was saved and loaded correctly + assert result is not None + assert result == auth_config.exchanged_auth_credential + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.client_id == "mock_client_id" + + @pytest.mark.asyncio + async def test_save_credential_updates_existing( + self, credential_service, auth_config, tool_context, oauth2_credentials + ): + """Test that saving a credential updates an existing one.""" + # Save initial credential + await credential_service.save_credential(auth_config, tool_context) + + # Create a new credential and update the auth_config + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="updated_client_id", + client_secret="updated_client_secret", + redirect_uri="https://updated.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + + # Save the updated credential + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify the credential was updated + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "updated_client_id" + assert result.oauth2.client_secret == "updated_client_secret" + + @pytest.mark.asyncio + async def test_credentials_isolated_by_context( + self, credential_service, auth_config, tool_context, another_tool_context + ): + """Test that credentials are isolated between different tool contexts.""" + # Save credential in first context + await credential_service.save_credential(auth_config, tool_context) + + # Try to load from another context (should not find it) + result = await credential_service.load_credential( + auth_config, another_tool_context + ) + assert result is None + + # Verify original context still has the credential + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + + @pytest.mark.asyncio + async def test_multiple_credentials_same_context( + self, credential_service, tool_context, oauth2_auth_scheme + ): + """Test storing multiple credentials in the same context with different keys.""" + # Create two different auth configs with different credential keys + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client1", + client_secret="secret1", + redirect_uri="https://example1.com/callback", + ), + ) + + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client2", + client_secret="secret2", + redirect_uri="https://example2.com/callback", + ), + ) + + auth_config1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred1, + exchanged_auth_credential=cred1, + credential_key="key1", + ) + + auth_config2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred2, + exchanged_auth_credential=cred2, + credential_key="key2", + ) + + # Save both credentials + await credential_service.save_credential(auth_config1, tool_context) + await credential_service.save_credential(auth_config2, tool_context) + + # Load and verify both credentials + result1 = await credential_service.load_credential( + auth_config1, tool_context + ) + result2 = await credential_service.load_credential( + auth_config2, tool_context + ) + + assert result1 is not None + assert result2 is not None + assert result1.oauth2.client_id == "client1" + assert result2.oauth2.client_id == "client2" + + @pytest.mark.asyncio + async def test_save_credential_with_none_exchanged_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving when exchanged_auth_credential is None.""" + # Set exchanged credential to None + auth_config.exchanged_auth_credential = None + + # Save the credential (should save None) + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify None was saved + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_load_credential_with_empty_credential_key( + self, credential_service, auth_config, tool_context + ): + """Test loading credential with empty credential key.""" + # Set credential key to empty string + auth_config.credential_key = "" + + # Save first to have something to load + await credential_service.save_credential(auth_config, tool_context) + + # Load should work with empty key + result = await credential_service.load_credential(auth_config, tool_context) + assert result == auth_config.exchanged_auth_credential + + @pytest.mark.asyncio + async def test_state_persistence_across_operations( + self, credential_service, auth_config, tool_context + ): + """Test that state persists correctly across multiple operations.""" + # Initially, no credential should exist + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + # Save a credential + await credential_service.save_credential(auth_config, tool_context) + + # Verify it was saved + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result == auth_config.exchanged_auth_credential + + # Update and save again + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="new_client_id", + client_secret="new_client_secret", + redirect_uri="https://new.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + await credential_service.save_credential(auth_config, tool_context) + + # Verify the update persisted + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "new_client_id" + + @pytest.mark.asyncio + async def test_credential_key_uniqueness( + self, credential_service, oauth2_auth_scheme, tool_context + ): + """Test that different credential keys create separate storage slots.""" + # Create credentials with same content but different keys + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="same_client", + client_secret="same_secret", + redirect_uri="https://same.com/callback", + ), + ) + + config_key1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=credential, + exchanged_auth_credential=credential, + credential_key="unique_key_1", + ) + + config_key2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=credential, + exchanged_auth_credential=credential, + credential_key="unique_key_2", + ) + + # Save credential with first key + await credential_service.save_credential(config_key1, tool_context) + + # Verify it's stored under first key + result1 = await credential_service.load_credential( + config_key1, tool_context + ) + assert result1 is not None + + # Verify it's not accessible under second key + result2 = await credential_service.load_credential( + config_key2, tool_context + ) + assert result2 is None + + # Save under second key + await credential_service.save_credential(config_key2, tool_context) + + # Now both should be accessible + result1 = await credential_service.load_credential( + config_key1, tool_context + ) + result2 = await credential_service.load_credential( + config_key2, tool_context + ) + assert result1 is not None + assert result2 is not None + assert result1 == result2 # Same credential content + + def test_direct_state_access( + self, credential_service, auth_config, tool_context + ): + """Test that the service correctly uses tool_context.state for storage.""" + # Verify that the state starts empty + assert len(tool_context.state) == 0 + + # Save a credential (this is async but we're testing the state directly) + credential_key = auth_config.credential_key + test_credential = auth_config.exchanged_auth_credential + + # Directly set the state to simulate save_credential behavior + tool_context.state[credential_key] = test_credential + + # Verify the credential is in the state + assert credential_key in tool_context.state + assert tool_context.state[credential_key] == test_credential + + # Verify we can retrieve it using the get method (simulating load_credential) + retrieved = tool_context.state.get(credential_key) + assert retrieved == test_credential diff --git a/tests/unittests/auth/exchanger/__init__.py b/tests/unittests/auth/exchanger/__init__.py new file mode 100644 index 000000000..5fb8a262b --- /dev/null +++ b/tests/unittests/auth/exchanger/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for credential exchanger.""" diff --git a/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py b/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py new file mode 100644 index 000000000..66b858232 --- /dev/null +++ b/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py @@ -0,0 +1,242 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the CredentialExchangerRegistry.""" + +from typing import Optional +from unittest.mock import MagicMock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.exchanger.base_credential_exchanger import BaseCredentialExchanger +from google.adk.auth.exchanger.credential_exchanger_registry import CredentialExchangerRegistry +import pytest + + +class MockCredentialExchanger(BaseCredentialExchanger): + """Mock credential exchanger for testing.""" + + def __init__(self, exchange_result: Optional[AuthCredential] = None): + self.exchange_result = exchange_result or AuthCredential( + auth_type=AuthCredentialTypes.HTTP + ) + + def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Mock exchange method.""" + return self.exchange_result + + +class TestCredentialExchangerRegistry: + """Test cases for CredentialExchangerRegistry.""" + + def test_initialization(self): + """Test that the registry initializes with an empty exchangers dictionary.""" + registry = CredentialExchangerRegistry() + + # Access the private attribute for testing + assert hasattr(registry, '_exchangers') + assert isinstance(registry._exchangers, dict) + assert len(registry._exchangers) == 0 + + def test_register_single_exchanger(self): + """Test registering a single exchanger.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + # Verify the exchanger was registered + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert retrieved_exchanger is mock_exchanger + + def test_register_multiple_exchangers(self): + """Test registering multiple exchangers for different credential types.""" + registry = CredentialExchangerRegistry() + + api_key_exchanger = MockCredentialExchanger() + oauth2_exchanger = MockCredentialExchanger() + service_account_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.API_KEY, api_key_exchanger) + registry.register(AuthCredentialTypes.OAUTH2, oauth2_exchanger) + registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, service_account_exchanger + ) + + # Verify all exchangers were registered correctly + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is api_key_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.OAUTH2) is oauth2_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.SERVICE_ACCOUNT) + is service_account_exchanger + ) + + def test_register_overwrites_existing_exchanger(self): + """Test that registering an exchanger for an existing type overwrites the previous one.""" + registry = CredentialExchangerRegistry() + + first_exchanger = MockCredentialExchanger() + second_exchanger = MockCredentialExchanger() + + # Register first exchanger + registry.register(AuthCredentialTypes.API_KEY, first_exchanger) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is first_exchanger + ) + + # Register second exchanger for the same type + registry.register(AuthCredentialTypes.API_KEY, second_exchanger) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is second_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) + is not first_exchanger + ) + + def test_get_exchanger_returns_correct_instance(self): + """Test that get_exchanger returns the correct exchanger instance.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.HTTP, mock_exchanger) + + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.HTTP) + assert retrieved_exchanger is mock_exchanger + assert isinstance(retrieved_exchanger, BaseCredentialExchanger) + + def test_get_exchanger_nonexistent_type_returns_none(self): + """Test that get_exchanger returns None for non-existent credential types.""" + registry = CredentialExchangerRegistry() + + # Try to get an exchanger that was never registered + result = registry.get_exchanger(AuthCredentialTypes.OAUTH2) + assert result is None + + def test_get_exchanger_after_registration_and_removal(self): + """Test behavior when an exchanger is registered and then the registry is cleared indirectly.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + # Register exchanger + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + assert registry.get_exchanger(AuthCredentialTypes.API_KEY) is mock_exchanger + + # Clear the internal dictionary (simulating some edge case) + registry._exchangers.clear() + assert registry.get_exchanger(AuthCredentialTypes.API_KEY) is None + + def test_register_with_all_credential_types(self): + """Test registering exchangers for all available credential types.""" + registry = CredentialExchangerRegistry() + + exchangers = {} + credential_types = [ + AuthCredentialTypes.API_KEY, + AuthCredentialTypes.HTTP, + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + AuthCredentialTypes.SERVICE_ACCOUNT, + ] + + # Register an exchanger for each credential type + for cred_type in credential_types: + exchanger = MockCredentialExchanger() + exchangers[cred_type] = exchanger + registry.register(cred_type, exchanger) + + # Verify all exchangers can be retrieved + for cred_type in credential_types: + retrieved_exchanger = registry.get_exchanger(cred_type) + assert retrieved_exchanger is exchangers[cred_type] + + def test_register_with_mock_exchanger_using_magicmock(self): + """Test registering with a MagicMock exchanger.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MagicMock(spec=BaseCredentialExchanger) + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert retrieved_exchanger is mock_exchanger + + def test_registry_isolation(self): + """Test that different registry instances are isolated from each other.""" + registry1 = CredentialExchangerRegistry() + registry2 = CredentialExchangerRegistry() + + exchanger1 = MockCredentialExchanger() + exchanger2 = MockCredentialExchanger() + + # Register different exchangers in different registry instances + registry1.register(AuthCredentialTypes.API_KEY, exchanger1) + registry2.register(AuthCredentialTypes.API_KEY, exchanger2) + + # Verify isolation + assert registry1.get_exchanger(AuthCredentialTypes.API_KEY) is exchanger1 + assert registry2.get_exchanger(AuthCredentialTypes.API_KEY) is exchanger2 + assert ( + registry1.get_exchanger(AuthCredentialTypes.API_KEY) is not exchanger2 + ) + assert ( + registry2.get_exchanger(AuthCredentialTypes.API_KEY) is not exchanger1 + ) + + def test_exchanger_functionality_through_registry(self): + """Test that exchangers registered in the registry function correctly.""" + registry = CredentialExchangerRegistry() + + # Create a mock exchanger with specific return value + expected_result = AuthCredential(auth_type=AuthCredentialTypes.HTTP) + mock_exchanger = MockCredentialExchanger(exchange_result=expected_result) + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + # Get the exchanger and test its functionality + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + input_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY) + + result = retrieved_exchanger.exchange(input_credential) + assert result is expected_result + + def test_register_none_exchanger(self): + """Test that registering None as an exchanger works (edge case).""" + registry = CredentialExchangerRegistry() + + # This should work but return None when retrieved + registry.register(AuthCredentialTypes.API_KEY, None) + + result = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert result is None + + def test_internal_dictionary_structure(self): + """Test the internal structure of the registry.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.OAUTH2, mock_exchanger) + + # Verify internal dictionary structure + assert AuthCredentialTypes.OAUTH2 in registry._exchangers + assert registry._exchangers[AuthCredentialTypes.OAUTH2] is mock_exchanger + assert len(registry._exchangers) == 1 diff --git a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py new file mode 100644 index 000000000..ef1dbbbee --- /dev/null +++ b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py @@ -0,0 +1,220 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock +from unittest.mock import patch + +from authlib.oauth2.rfc6749 import OAuth2Token +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.exchanger.base_credential_exchanger import CredentialExchangError +from google.adk.auth.exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger +import pytest + + +class TestOAuth2CredentialExchanger: + """Test suite for OAuth2CredentialExchanger.""" + + @pytest.mark.asyncio + async def test_exchange_with_existing_token(self): + """Test exchange method when access token already exists.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return the same credential since access token already exists + assert result == credential + assert result.oauth2.access_token == "existing_token" + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_success(self, mock_oauth2_session): + """Test successful token exchange.""" + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Verify token exchange was successful + assert result.oauth2.access_token == "new_access_token" + assert result.oauth2.refresh_token == "new_refresh_token" + mock_client.fetch_token.assert_called_once() + + @pytest.mark.asyncio + async def test_exchange_missing_auth_scheme(self): + """Test exchange with missing auth_scheme raises ValueError.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + exchanger = OAuth2CredentialExchanger() + try: + await exchanger.exchange(credential, None) + assert False, "Should have raised ValueError" + except CredentialExchangError as e: + assert "auth_scheme is required" in str(e) + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_no_session(self, mock_oauth2_session): + """Test exchange when OAuth2Session cannot be created.""" + # Mock to return None for create_oauth2_session + mock_oauth2_session.return_value = None + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret to trigger session creation failure + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when session creation fails + assert result == credential + assert result.oauth2.access_token is None + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_fetch_token_failure(self, mock_oauth2_session): + """Test exchange when fetch_token fails.""" + # Setup mock to raise exception during fetch_token + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_client.fetch_token.side_effect = Exception("Token fetch failed") + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when fetch_token fails + assert result == credential + assert result.oauth2.access_token is None + mock_client.fetch_token.assert_called_once() + + @pytest.mark.asyncio + async def test_exchange_authlib_not_available(self): + """Test exchange when authlib is not available.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + + # Mock AUTHLIB_AVIALABLE to False + with patch( + "google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVIALABLE", + False, + ): + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when authlib is not available + assert result == credential + assert result.oauth2.access_token is None diff --git a/tests/unittests/auth/refresher/__init__.py b/tests/unittests/auth/refresher/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/refresher/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/refresher/test_credential_refresher_registry.py b/tests/unittests/auth/refresher/test_credential_refresher_registry.py new file mode 100644 index 000000000..b00cc4da8 --- /dev/null +++ b/tests/unittests/auth/refresher/test_credential_refresher_registry.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CredentialRefresherRegistry.""" + +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.refresher.base_credential_refresher import BaseCredentialRefresher +from google.adk.auth.refresher.credential_refresher_registry import CredentialRefresherRegistry + + +class TestCredentialRefresherRegistry: + """Tests for the CredentialRefresherRegistry class.""" + + def test_init(self): + """Test that registry initializes with empty refreshers dictionary.""" + registry = CredentialRefresherRegistry() + assert registry._refreshers == {} + + def test_register_refresher(self): + """Test registering a refresher instance for a credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher) + + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher + + def test_register_multiple_refreshers(self): + """Test registering multiple refresher instances for different credential types.""" + registry = CredentialRefresherRegistry() + mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher) + mock_openid_refresher = Mock(spec=BaseCredentialRefresher) + mock_service_account_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher) + registry.register( + AuthCredentialTypes.OPEN_ID_CONNECT, mock_openid_refresher + ) + registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, mock_service_account_refresher + ) + + assert ( + registry._refreshers[AuthCredentialTypes.OAUTH2] + == mock_oauth2_refresher + ) + assert ( + registry._refreshers[AuthCredentialTypes.OPEN_ID_CONNECT] + == mock_openid_refresher + ) + assert ( + registry._refreshers[AuthCredentialTypes.SERVICE_ACCOUNT] + == mock_service_account_refresher + ) + + def test_register_overwrite_existing_refresher(self): + """Test that registering a refresher overwrites an existing one for the same credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher_1 = Mock(spec=BaseCredentialRefresher) + mock_refresher_2 = Mock(spec=BaseCredentialRefresher) + + # Register first refresher + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_1) + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_1 + + # Register second refresher for same credential type + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_2) + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_2 + + def test_get_refresher_existing(self): + """Test getting a refresher instance for a registered credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher) + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result == mock_refresher + + def test_get_refresher_non_existing(self): + """Test getting a refresher instance for a non-registered credential type returns None.""" + registry = CredentialRefresherRegistry() + + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result is None + + def test_get_refresher_after_registration(self): + """Test getting refresher instances for multiple credential types.""" + registry = CredentialRefresherRegistry() + mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher) + mock_api_key_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher) + registry.register(AuthCredentialTypes.API_KEY, mock_api_key_refresher) + + # Get registered refreshers + oauth2_result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + api_key_result = registry.get_refresher(AuthCredentialTypes.API_KEY) + + assert oauth2_result == mock_oauth2_refresher + assert api_key_result == mock_api_key_refresher + + # Get non-registered refresher + http_result = registry.get_refresher(AuthCredentialTypes.HTTP) + assert http_result is None + + def test_register_all_credential_types(self): + """Test registering refreshers for all available credential types.""" + registry = CredentialRefresherRegistry() + + refreshers = {} + for credential_type in AuthCredentialTypes: + mock_refresher = Mock(spec=BaseCredentialRefresher) + refreshers[credential_type] = mock_refresher + registry.register(credential_type, mock_refresher) + + # Verify all refreshers are registered correctly + for credential_type in AuthCredentialTypes: + result = registry.get_refresher(credential_type) + assert result == refreshers[credential_type] + + def test_empty_registry_get_refresher(self): + """Test getting refresher from empty registry returns None for any credential type.""" + registry = CredentialRefresherRegistry() + + for credential_type in AuthCredentialTypes: + result = registry.get_refresher(credential_type) + assert result is None + + def test_registry_independence(self): + """Test that multiple registry instances are independent.""" + registry1 = CredentialRefresherRegistry() + registry2 = CredentialRefresherRegistry() + + mock_refresher1 = Mock(spec=BaseCredentialRefresher) + mock_refresher2 = Mock(spec=BaseCredentialRefresher) + + registry1.register(AuthCredentialTypes.OAUTH2, mock_refresher1) + registry2.register(AuthCredentialTypes.OAUTH2, mock_refresher2) + + # Verify registries are independent + assert ( + registry1.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher1 + ) + assert ( + registry2.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher2 + ) + assert registry1.get_refresher( + AuthCredentialTypes.OAUTH2 + ) != registry2.get_refresher(AuthCredentialTypes.OAUTH2) + + def test_register_with_none_refresher(self): + """Test registering None as a refresher instance.""" + registry = CredentialRefresherRegistry() + + # This should technically work as the registry accepts any value + registry.register(AuthCredentialTypes.OAUTH2, None) + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result is None diff --git a/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py new file mode 100644 index 000000000..3342fcb05 --- /dev/null +++ b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py @@ -0,0 +1,179 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock +from unittest.mock import patch + +from authlib.oauth2.rfc6749 import OAuth2Token +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.refresher.oauth2_credential_refresher import OAuth2CredentialRefresher +import pytest + + +class TestOAuth2CredentialRefresher: + """Test suite for OAuth2CredentialRefresher.""" + + @patch("google.adk.auth.refresher.oauth2_credential_refresher.OAuth2Token") + @pytest.mark.asyncio + async def test_needs_refresh_token_not_expired(self, mock_oauth2_token): + """Test needs_refresh when token is not expired.""" + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = False + mock_oauth2_token.return_value = mock_token_instance + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + expires_at=int(time.time()) + 3600, + ), + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, scheme) + + assert not needs_refresh + + @patch("google.adk.auth.refresher.oauth2_credential_refresher.OAuth2Token") + @pytest.mark.asyncio + async def test_needs_refresh_token_expired(self, mock_oauth2_token): + """Test needs_refresh when token is expired.""" + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + expires_at=int(time.time()) - 3600, # Expired + ), + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, scheme) + + assert needs_refresh + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @patch("google.adk.auth.oauth2_credential_util.OAuth2Token") + @pytest.mark.asyncio + async def test_refresh_token_expired_success( + self, mock_oauth2_token, mock_oauth2_session + ): + """Test successful token refresh when token is expired.""" + # Setup mock token + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + # Setup mock session + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "refreshed_access_token", + "refresh_token": "refreshed_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.refresh_token.return_value = mock_tokens + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="old_token", + refresh_token="old_refresh_token", + expires_at=int(time.time()) - 3600, # Expired + ), + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, scheme) + + # Verify token refresh was successful + assert result.oauth2.access_token == "refreshed_access_token" + assert result.oauth2.refresh_token == "refreshed_refresh_token" + mock_client.refresh_token.assert_called_once() + + @pytest.mark.asyncio + async def test_refresh_no_oauth2_credential(self): + """Test refresh with no OAuth2 credential returns original.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + # No oauth2 field + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, scheme) + + assert result == credential + + @pytest.mark.asyncio + async def test_needs_refresh_no_oauth2_credential(self): + """Test needs_refresh with no OAuth2 credential returns False.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + # No oauth2 field + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, None) + + assert not needs_refresh diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index aaed35a19..f0d730d02 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -13,8 +13,11 @@ # limitations under the License. import copy +import time +from unittest.mock import Mock from unittest.mock import patch +from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import APIKey from fastapi.openapi.models import APIKeyIn from fastapi.openapi.models import OAuth2 @@ -405,7 +408,8 @@ def test_get_auth_response_not_exists(self, auth_config): class TestParseAndStoreAuthResponse: """Tests for the parse_and_store_auth_response method.""" - def test_non_oauth_scheme(self, auth_config_with_exchanged): + @pytest.mark.asyncio + async def test_non_oauth_scheme(self, auth_config_with_exchanged): """Test with a non-OAuth auth scheme.""" # Modify the auth scheme type to be non-OAuth auth_config = copy.deepcopy(auth_config_with_exchanged) @@ -416,7 +420,7 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged): handler = AuthHandler(auth_config) state = MockState() - handler.parse_and_store_auth_response(state) + await handler.parse_and_store_auth_response(state) credential_key = auth_config.credential_key assert ( @@ -424,7 +428,10 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged): ) @patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token") - def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): + @pytest.mark.asyncio + async def test_oauth_scheme( + self, mock_exchange_token, auth_config_with_exchanged + ): """Test with an OAuth auth scheme.""" mock_exchange_token.return_value = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, @@ -434,7 +441,7 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): handler = AuthHandler(auth_config_with_exchanged) state = MockState() - handler.parse_and_store_auth_response(state) + await handler.parse_and_store_auth_response(state) credential_key = auth_config_with_exchanged.credential_key assert state["temp:" + credential_key] == mock_exchange_token.return_value @@ -444,20 +451,20 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): class TestExchangeAuthToken: """Tests for the exchange_auth_token method.""" - def test_token_exchange_not_supported( + @pytest.mark.asyncio + async def test_token_exchange_not_supported( self, auth_config_with_auth_code, monkeypatch ): """Test when token exchange is not supported.""" - monkeypatch.setattr( - "google.adk.auth.oauth2_credential_fetcher.AUTHLIB_AVIALABLE", False - ) + monkeypatch.setattr("google.adk.auth.auth_handler.AUTHLIB_AVIALABLE", False) handler = AuthHandler(auth_config_with_auth_code) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == auth_config_with_auth_code.exchanged_auth_credential - def test_openid_missing_token_endpoint( + @pytest.mark.asyncio + async def test_openid_missing_token_endpoint( self, openid_auth_scheme, oauth2_credentials_with_auth_code ): """Test OpenID Connect without a token endpoint.""" @@ -472,11 +479,12 @@ def test_openid_missing_token_endpoint( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_auth_code - def test_oauth2_missing_token_url( + @pytest.mark.asyncio + async def test_oauth2_missing_token_url( self, oauth2_auth_scheme, oauth2_credentials_with_auth_code ): """Test OAuth2 without a token URL.""" @@ -491,11 +499,12 @@ def test_oauth2_missing_token_url( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_auth_code - def test_non_oauth_scheme(self, auth_config_with_auth_code): + @pytest.mark.asyncio + async def test_non_oauth_scheme(self, auth_config_with_auth_code): """Test with a non-OAuth auth scheme.""" # Modify the auth scheme type to be non-OAuth auth_config = copy.deepcopy(auth_config_with_auth_code) @@ -504,11 +513,12 @@ def test_non_oauth_scheme(self, auth_config_with_auth_code): ) handler = AuthHandler(auth_config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == auth_config.exchanged_auth_credential - def test_missing_credentials(self, oauth2_auth_scheme): + @pytest.mark.asyncio + async def test_missing_credentials(self, oauth2_auth_scheme): """Test with missing credentials.""" empty_credential = AuthCredential(auth_type=AuthCredentialTypes.OAUTH2) @@ -518,11 +528,12 @@ def test_missing_credentials(self, oauth2_auth_scheme): ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == empty_credential - def test_credentials_with_token( + @pytest.mark.asyncio + async def test_credentials_with_token( self, auth_config, oauth2_credentials_with_token ): """Test when credentials already have a token.""" @@ -533,18 +544,29 @@ def test_credentials_with_token( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_token - @patch( - "google.adk.auth.oauth2_credential_fetcher.OAuth2Session", - MockOAuth2Session, - ) - def test_successful_token_exchange(self, auth_config_with_auth_code): + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_successful_token_exchange( + self, mock_oauth2_session, auth_config_with_auth_code + ): """Test a successful token exchange.""" + # Setup mock OAuth2Session + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "mock_access_token", + "refresh_token": "mock_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + handler = AuthHandler(auth_config_with_auth_code) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result.oauth2.access_token == "mock_access_token" assert result.oauth2.refresh_token == "mock_refresh_token" diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py new file mode 100644 index 000000000..8e3638dd6 --- /dev/null +++ b/tests/unittests/auth/test_credential_manager.py @@ -0,0 +1,545 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from fastapi.openapi.models import HTTPBearer +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_credential import ServiceAccountCredential +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +import pytest + + +class TestCredentialManager: + """Test suite for CredentialManager.""" + + def test_init(self): + """Test CredentialManager initialization.""" + auth_config = Mock(spec=AuthConfig) + manager = CredentialManager(auth_config) + assert manager._auth_config == auth_config + + @pytest.mark.asyncio + async def test_request_credential(self): + """Test request_credential method.""" + auth_config = Mock(spec=AuthConfig) + tool_context = Mock() + tool_context.request_credential = Mock() + + manager = CredentialManager(auth_config) + await manager.request_credential(tool_context) + + tool_context.request_credential.assert_called_once_with(auth_config) + + @pytest.mark.asyncio + async def test_load_auth_credentials_success(self): + """Test load_auth_credential with successful flow.""" + # Create mocks + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.exchanged_auth_credential = None + + # Mock the credential that will be returned + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.API_KEY + + tool_context = Mock() + + manager = CredentialManager(auth_config) + + # Mock the private methods + manager._validate_credential = AsyncMock() + manager._is_credential_ready = Mock(return_value=False) + manager._load_existing_credential = AsyncMock(return_value=None) + manager._load_from_auth_response = AsyncMock(return_value=mock_credential) + manager._exchange_credential = AsyncMock( + return_value=(mock_credential, False) + ) + manager._refresh_credential = AsyncMock( + return_value=(mock_credential, False) + ) + manager._save_credential = AsyncMock() + + result = await manager.get_auth_credential(tool_context) + + # Verify all methods were called + manager._validate_credential.assert_called_once() + manager._is_credential_ready.assert_called_once() + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) + manager._exchange_credential.assert_called_once_with(mock_credential) + manager._refresh_credential.assert_called_once_with(mock_credential) + manager._save_credential.assert_called_once_with( + tool_context, mock_credential + ) + + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_auth_credentials_no_credential(self): + """Test load_auth_credential when no credential is available.""" + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.exchanged_auth_credential = None + + tool_context = Mock() + + manager = CredentialManager(auth_config) + + # Mock the private methods + manager._validate_credential = AsyncMock() + manager._is_credential_ready = Mock(return_value=False) + manager._load_existing_credential = AsyncMock(return_value=None) + manager._load_from_auth_response = AsyncMock(return_value=None) + manager._exchange_credential = AsyncMock() + manager._refresh_credential = AsyncMock() + manager._save_credential = AsyncMock() + + result = await manager.get_auth_credential(tool_context) + + # Verify methods were called but no credential returned + manager._validate_credential.assert_called_once() + manager._is_credential_ready.assert_called_once() + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) + manager._exchange_credential.assert_not_called() + manager._refresh_credential.assert_not_called() + manager._save_credential.assert_not_called() + + assert result is None + + @pytest.mark.asyncio + async def test_load_existing_credential_already_exchanged(self): + """Test _load_existing_credential when credential is already exchanged.""" + auth_config = Mock(spec=AuthConfig) + mock_credential = Mock(spec=AuthCredential) + auth_config.exchanged_auth_credential = mock_credential + + tool_context = Mock() + + manager = CredentialManager(auth_config) + manager._load_from_credential_service = AsyncMock(return_value=None) + + result = await manager._load_existing_credential(tool_context) + + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_existing_credential_with_credential_service(self): + """Test _load_existing_credential with credential service.""" + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = None + + mock_credential = Mock(spec=AuthCredential) + + tool_context = Mock() + + manager = CredentialManager(auth_config) + manager._load_from_credential_service = AsyncMock( + return_value=mock_credential + ) + + result = await manager._load_existing_credential(tool_context) + + manager._load_from_credential_service.assert_called_once_with(tool_context) + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_from_credential_service_with_service(self): + """Test _load_from_credential_service from tool context when credential service is available.""" + auth_config = Mock(spec=AuthConfig) + + mock_credential = Mock(spec=AuthCredential) + + # Mock credential service + credential_service = Mock() + credential_service.load_credential = AsyncMock(return_value=mock_credential) + + # Mock invocation context + invocation_context = Mock() + invocation_context.credential_service = credential_service + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + result = await manager._load_from_credential_service(tool_context) + + credential_service.load_credential.assert_called_once_with( + auth_config, tool_context + ) + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_from_credential_service_no_service(self): + """Test _load_from_credential_service when no credential service is available.""" + auth_config = Mock(spec=AuthConfig) + + # Mock invocation context with no credential service + invocation_context = Mock() + invocation_context.credential_service = None + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + result = await manager._load_from_credential_service(tool_context) + + assert result is None + + @pytest.mark.asyncio + async def test_save_credential_with_service(self): + """Test _save_credential with credential service.""" + auth_config = Mock(spec=AuthConfig) + mock_credential = Mock(spec=AuthCredential) + + # Mock credential service + credential_service = AsyncMock() + + # Mock invocation context + invocation_context = Mock() + invocation_context.credential_service = credential_service + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + await manager._save_credential(tool_context, mock_credential) + + credential_service.save_credential.assert_called_once_with( + auth_config, tool_context + ) + assert auth_config.exchanged_auth_credential == mock_credential + + @pytest.mark.asyncio + async def test_save_credential_no_service(self): + """Test _save_credential when no credential service is available.""" + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = None + mock_credential = Mock(spec=AuthCredential) + + # Mock invocation context with no credential service + invocation_context = Mock() + invocation_context.credential_service = None + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + await manager._save_credential(tool_context, mock_credential) + + # Should not raise an error, and credential should not be set in auth_config + # when there's no credential service (according to implementation) + assert auth_config.exchanged_auth_credential is None + + @pytest.mark.asyncio + async def test_refresh_credential_oauth2(self): + """Test _refresh_credential with OAuth2 credential.""" + mock_oauth2_auth = Mock(spec=OAuth2Auth) + + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.OAUTH2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = Mock() + + # Mock refresher + mock_refresher = Mock() + mock_refresher.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher.refresh = AsyncMock(return_value=mock_credential) + + auth_config.raw_auth_credential = mock_credential + + manager = CredentialManager(auth_config) + + # Mock the refresher registry to return our mock refresher + with patch.object( + manager._refresher_registry, + "get_refresher", + return_value=mock_refresher, + ): + result, was_refreshed = await manager._refresh_credential(mock_credential) + + mock_refresher.is_refresh_needed.assert_called_once_with( + mock_credential, auth_config.auth_scheme + ) + mock_refresher.refresh.assert_called_once_with( + mock_credential, auth_config.auth_scheme + ) + assert result == mock_credential + assert was_refreshed is True + + @pytest.mark.asyncio + async def test_refresh_credential_no_refresher(self): + """Test _refresh_credential with credential that has no refresher.""" + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + + manager = CredentialManager(auth_config) + + # Mock the refresher registry to return None (no refresher available) + with patch.object( + manager._refresher_registry, + "get_refresher", + return_value=None, + ): + result, was_refreshed = await manager._refresh_credential(mock_credential) + + assert result == mock_credential + assert was_refreshed is False + + @pytest.mark.asyncio + async def test_is_credential_ready_api_key(self): + """Test _is_credential_ready with API key credential.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + + manager = CredentialManager(auth_config) + result = manager._is_credential_ready() + + assert result is True + + @pytest.mark.asyncio + async def test_is_credential_ready_oauth2(self): + """Test _is_credential_ready with OAuth2 credential (needs processing).""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + + manager = CredentialManager(auth_config) + result = manager._is_credential_ready() + + assert result is False + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_oauth2(self): + """Test _validate_credential with no raw credential for OAuth2.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises(ValueError, match="raw_auth_credential is required"): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_openid(self): + """Test _validate_credential with no raw credential for OpenID Connect.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.openIdConnect + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises(ValueError, match="raw_auth_credential is required"): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_other_scheme(self): + """Test _validate_credential with no raw credential for other schemes.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.apiKey + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + await manager._validate_credential() + + # Should return without error for non-OAuth2/OpenID schemes + + @pytest.mark.asyncio + async def test_validate_credential_oauth2_missing_oauth2_field(self): + """Test _validate_credential with OAuth2 credential missing oauth2 field.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.oauth2 + + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 + mock_raw_credential.oauth2 = None + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises( + ValueError, match="auth_config.raw_credential.oauth2 required" + ): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_exchange_credentials_service_account(self): + """Test _exchange_credential with service account credential (no exchanger available).""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.SERVICE_ACCOUNT + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = Mock() + + manager = CredentialManager(auth_config) + + # Mock the exchanger registry to return None (no exchanger available) + with patch.object( + manager._exchanger_registry, "get_exchanger", return_value=None + ): + result, was_exchanged = await manager._exchange_credential( + mock_raw_credential + ) + + assert result == mock_raw_credential + assert was_exchanged is False + + @pytest.mark.asyncio + async def test_exchange_credential_no_exchanger(self): + """Test _exchange_credential with credential that has no exchanger.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + + manager = CredentialManager(auth_config) + + # Mock the exchanger registry to return None (no exchanger available) + with patch.object( + manager._exchanger_registry, "get_exchanger", return_value=None + ): + result, was_exchanged = await manager._exchange_credential( + mock_raw_credential + ) + + assert result == mock_raw_credential + assert was_exchanged is False + + +# Test fixtures +@pytest.fixture +def oauth2_auth_scheme(): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + +@pytest.fixture +def openid_auth_scheme(): + """Create an OpenID Connect auth scheme for testing.""" + return OpenIdConnectWithConfig( + type_="openIdConnect", + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + + +@pytest.fixture +def bearer_auth_scheme(): + """Create a Bearer auth scheme for testing.""" + return HTTPBearer(bearerFormat="JWT") + + +@pytest.fixture +def oauth2_credential(): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + +@pytest.fixture +def service_account_credential(): + """Create service account credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=ServiceAccountCredential( + type="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE" + " KEY-----\n" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="123456789", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + ), + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + +@pytest.fixture +def api_key_credential(): + """Create API key credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key="test-api-key", + ) + + +@pytest.fixture +def http_bearer_credential(): + """Create HTTP Bearer credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="bearer-token"), + ), + ) diff --git a/tests/unittests/auth/test_oauth2_credential_fetcher.py b/tests/unittests/auth/test_oauth2_credential_fetcher.py deleted file mode 100644 index 0b9b5a3c1..000000000 --- a/tests/unittests/auth/test_oauth2_credential_fetcher.py +++ /dev/null @@ -1,441 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from unittest.mock import Mock -from unittest.mock import patch - -from authlib.oauth2.rfc6749 import OAuth2Token -from fastapi.openapi.models import OAuth2 -from fastapi.openapi.models import OAuthFlowAuthorizationCode -from fastapi.openapi.models import OAuthFlows -from google.adk.auth.auth_credential import AuthCredential -from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.auth.auth_credential import OAuth2Auth -from google.adk.auth.auth_schemes import OpenIdConnectWithConfig -from google.adk.auth.oauth2_credential_fetcher import OAuth2CredentialFetcher - - -class TestOAuth2CredentialFetcher: - """Test suite for OAuth2CredentialFetcher.""" - - def test_init(self): - """Test OAuth2CredentialFetcher initialization.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - assert fetcher._auth_scheme == scheme - assert fetcher._auth_credential == credential - - def test_oauth2_session_openid_connect(self): - """Test _oauth2_session with OpenID Connect scheme.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - state="test_state", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() - - assert client is not None - assert token_endpoint == "https://example.com/token" - assert client.client_id == "test_client_id" - assert client.client_secret == "test_client_secret" - - def test_oauth2_session_oauth2_scheme(self): - """Test _oauth2_session with OAuth2 scheme.""" - flows = OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://example.com/auth", - tokenUrl="https://example.com/token", - scopes={"read": "Read access", "write": "Write access"}, - ) - ) - scheme = OAuth2(type_="oauth2", flows=flows) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() - - assert client is not None - assert token_endpoint == "https://example.com/token" - - def test_oauth2_session_invalid_scheme(self): - """Test _oauth2_session with invalid scheme.""" - scheme = Mock() # Invalid scheme type - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() - - assert client is None - assert token_endpoint is None - - def test_oauth2_session_missing_credentials(self): - """Test _oauth2_session with missing credentials.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - # Missing client_secret - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() - - assert client is None - assert token_endpoint is None - - def test_update_credential(self): - """Test _update_credential method.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - tokens = OAuth2Token({ - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - }) - - fetcher._update_credential(tokens) - - assert credential.oauth2.access_token == "new_access_token" - assert credential.oauth2.refresh_token == "new_refresh_token" - assert credential.oauth2.expires_at == int(time.time()) + 3600 - assert credential.oauth2.expires_in == 3600 - - def test_exchange_with_existing_token(self): - """Test exchange method when access token already exists.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="existing_token", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result == credential - assert result.oauth2.access_token == "existing_token" - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_exchange_success(self, mock_oauth2_session): - """Test successful token exchange.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_response_uri=( - "https://example.com/callback?code=auth_code&state=test_state" - ), - ), - ) - - # Mock the OAuth2Session - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - } - mock_client.fetch_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result.oauth2.access_token == "new_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" - mock_client.fetch_token.assert_called_once() - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_exchange_with_auth_code(self, mock_oauth2_session): - """Test token exchange with auth code.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_code="test_auth_code", - ), - ) - - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - } - mock_client.fetch_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result.oauth2.access_token == "new_access_token" - mock_client.fetch_token.assert_called_once() - - def test_exchange_no_session(self): - """Test exchange when OAuth2Session cannot be created.""" - scheme = Mock() # Invalid scheme - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_response_uri="https://example.com/callback?code=auth_code", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result == credential - assert result.oauth2.access_token is None - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_refresh_token_not_expired( - self, mock_oauth2_session, mock_oauth2_token - ): - """Test refresh when token is not expired.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="current_token", - refresh_token="refresh_token", - expires_at=int(time.time()) + 3600, - expires_in=3600, - ), - ) - - # Mock token not expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = False - mock_oauth2_token.return_value = mock_token_instance - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - assert result.oauth2.access_token == "current_token" - mock_oauth2_session.assert_not_called() - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_refresh_token_expired_success( - self, mock_oauth2_session, mock_oauth2_token - ): - """Test successful token refresh when token is expired.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="expired_token", - refresh_token="refresh_token", - expires_at=int(time.time()) - 3600, # Expired - expires_in=3600, - ), - ) - - # Mock token expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = True - mock_oauth2_token.return_value = mock_token_instance - - # Mock refresh token response - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "refreshed_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - } - mock_client.refresh_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result.oauth2.access_token == "refreshed_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" - mock_client.refresh_token.assert_called_once_with( - url="https://example.com/token", - refresh_token="refresh_token", - ) - - def test_refresh_no_oauth2_credential(self): - """Test refresh when oauth2 credential is missing.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP) # No oauth2 - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - def test_refresh_no_session(self, mock_oauth2_token): - """Test refresh when OAuth2Session cannot be created.""" - scheme = Mock() # Invalid scheme - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="expired_token", - refresh_token="refresh_token", - expires_at=int(time.time()) - 3600, - ), - ) - - # Mock token expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = True - mock_oauth2_token.return_value = mock_token_instance - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - assert result.oauth2.access_token == "expired_token" # Unchanged diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py new file mode 100644 index 000000000..aba6a9923 --- /dev/null +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock + +from authlib.oauth2.rfc6749 import OAuth2Token +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens + + +class TestOAuth2CredentialUtil: + """Test suite for OAuth2 credential utility functions.""" + + def test_create_oauth2_session_openid_connect(self): + """Test create_oauth2_session with OpenID Connect scheme.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + + def test_create_oauth2_session_oauth2_scheme(self): + """Test create_oauth2_session with OAuth2 scheme.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + scheme = OAuth2(type_="oauth2", flows=flows) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + + def test_create_oauth2_session_invalid_scheme(self): + """Test create_oauth2_session with invalid scheme.""" + scheme = Mock() # Invalid scheme type + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_create_oauth2_session_missing_credentials(self): + """Test create_oauth2_session with missing credentials.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_update_credential_with_tokens(self): + """Test update_credential_with_tokens function.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + + update_credential_with_tokens(credential, tokens) + + assert credential.oauth2.access_token == "new_access_token" + assert credential.oauth2.refresh_token == "new_refresh_token" + assert credential.oauth2.expires_at == int(time.time()) + 3600 + assert credential.oauth2.expires_in == 3600 diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 65c1eee3b..aec7a020b 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -40,7 +40,7 @@ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) -logger = logging.getLogger(__name__) +logger = logging.getLogger("google_adk." + __name__) # Here we create a dummy agent module that get_fast_api_app expects @@ -138,6 +138,7 @@ async def mock_run_evals_for_fast_api(*args, **kwargs): final_eval_status=1, # Matches expected (assuming 1 is PASSED) user_id="test_user", # Placeholder, adapt if needed session_id="test_session_for_eval_case", # Placeholder + eval_set_file="test_eval_set_file", # Placeholder overall_eval_metric_results=[{ # Matches expected "metricName": "tool_trajectory_avg_score", "threshold": 0.5, @@ -372,7 +373,7 @@ def add_eval_case(self, app_name, eval_set_id, eval_case): @pytest.fixture def mock_eval_set_results_manager(): - """Create a mock eval set results manager.""" + """Create a mock local eval set results manager.""" # Storage for eval set results. eval_set_results = {} diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 1721885f3..2139a8c20 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -129,6 +129,7 @@ def _echo(msg: str) -> None: artifact_service = cli.InMemoryArtifactService() session_service = cli.InMemorySessionService() + credential_service = cli.InMemoryCredentialService() dummy_root = types.SimpleNamespace(name="root") session = await cli.run_input_file( @@ -137,6 +138,7 @@ def _echo(msg: str) -> None: root_agent=dummy_root, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, input_path=str(input_path), ) @@ -199,9 +201,10 @@ async def test_run_interactively_whitespace_and_exit( ) -> None: """run_interactively should skip blank input, echo once, then exit.""" # make a session that belongs to dummy agent - svc = cli.InMemorySessionService() - sess = await svc.create_session(app_name="dummy", user_id="u") + session_service = cli.InMemorySessionService() + sess = await session_service.create_session(app_name="dummy", user_id="u") artifact_service = cli.InMemoryArtifactService() + credential_service = cli.InMemoryCredentialService() root_agent = types.SimpleNamespace(name="root") # fake user input: blank -> 'hello' -> 'exit' @@ -212,7 +215,9 @@ async def test_run_interactively_whitespace_and_exit( echoed: list[str] = [] monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg)) - await cli.run_interactively(root_agent, artifact_service, sess, svc) + await cli.run_interactively( + root_agent, artifact_service, sess, session_service, credential_service + ) # verify: assistant echoed once with 'echo:hello' assert any("echo:hello" in m for m in echoed) diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index 312844db8..d3b2a538c 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -162,6 +162,7 @@ def _recording_copytree(*args: Any, **kwargs: Any): trace_to_cloud=True, with_ui=True, verbosity="info", + log_level="info", session_service_uri="sqlite://", artifact_service_uri="gs://bucket", memory_service_uri="rag://", @@ -206,6 +207,7 @@ def _fake_rmtree(path: str | Path, *a: Any, **k: Any) -> None: trace_to_cloud=False, with_ui=False, verbosity="info", + log_level="info", adk_version="1.0.0", session_service_uri=None, artifact_service_uri=None, diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index ad204005e..2b93226db 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -23,6 +23,7 @@ 'GOOGLE_API_KEY': 'fake_google_api_key', 'GOOGLE_CLOUD_PROJECT': 'fake_google_cloud_project', 'GOOGLE_CLOUD_LOCATION': 'fake_google_cloud_location', + 'ADK_ALLOW_WIP_FEATURES': 'true', } ENV_SETUPS = { diff --git a/tests/unittests/evaluation/test_final_response_match_v1.py b/tests/unittests/evaluation/test_final_response_match_v1.py new file mode 100644 index 000000000..d5544a5a1 --- /dev/null +++ b/tests/unittests/evaluation/test_final_response_match_v1.py @@ -0,0 +1,140 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_metrics import EvalMetric +from google.adk.evaluation.evaluator import EvalStatus +from google.adk.evaluation.final_response_match_v1 import _calculate_rouge_1_scores +from google.adk.evaluation.final_response_match_v1 import RougeEvaluator +from google.genai import types as genai_types +import pytest + + +def _create_test_rouge_evaluator(threshold: float) -> RougeEvaluator: + return RougeEvaluator( + EvalMetric(metric_name="response_match_score", threshold=threshold) + ) + + +def _create_test_invocations( + candidate: str, reference: str +) -> tuple[Invocation, Invocation]: + """Returns tuple of (actual_invocation, expected_invocation).""" + return Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text=candidate)] + ), + ), Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text=reference)] + ), + ) + + +def test_calculate_rouge_1_scores_empty_candidate_and_reference(): + candidate = "" + reference = "" + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == 0 + assert rouge_1_score.recall == 0 + assert rouge_1_score.fmeasure == 0 + + +def test_calculate_rouge_1_scores_empty_candidate(): + candidate = "" + reference = "This is a test reference." + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == 0 + assert rouge_1_score.recall == 0 + assert rouge_1_score.fmeasure == 0 + + +def test_calculate_rouge_1_scores_empty_reference(): + candidate = "This is a test candidate response." + reference = "" + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == 0 + assert rouge_1_score.recall == 0 + assert rouge_1_score.fmeasure == 0 + + +def test_calculate_rouge_1_scores(): + candidate = "This is a test candidate response." + reference = "This is a test reference." + rouge_1_score = _calculate_rouge_1_scores(candidate, reference) + assert rouge_1_score.precision == pytest.approx(2 / 3) + assert rouge_1_score.recall == pytest.approx(4 / 5) + assert rouge_1_score.fmeasure == pytest.approx(8 / 11) + + +@pytest.mark.parametrize( + "candidates, references, expected_score, expected_status", + [ + ( + ["The quick brown fox jumps.", "hello world"], + ["The quick brown fox jumps over the lazy dog.", "hello"], + 0.69048, # (5/7 + 2/3) / 2 + EvalStatus.FAILED, + ), + ( + ["This is a test.", "Another test case."], + ["This is a test.", "This is a different test."], + 0.625, # (1 + 1/4) / 2 + EvalStatus.FAILED, + ), + ( + ["No matching words here.", "Second candidate."], + ["Completely different text.", "Another reference."], + 0.0, # (0 + 1/2) / 2 + EvalStatus.FAILED, + ), + ( + ["Same words", "Same words"], + ["Same words", "Same words"], + 1.0, + EvalStatus.PASSED, + ), + ], +) +def test_rouge_evaluator_multiple_invocations( + candidates: list[str], + references: list[str], + expected_score: float, + expected_status: EvalStatus, +): + rouge_evaluator = _create_test_rouge_evaluator(threshold=0.8) + actual_invocations = [] + expected_invocations = [] + for candidate, reference in zip(candidates, references): + actual_invocation, expected_invocation = _create_test_invocations( + candidate, reference + ) + actual_invocations.append(actual_invocation) + expected_invocations.append(expected_invocation) + + evaluation_result = rouge_evaluator.evaluate_invocations( + actual_invocations, expected_invocations + ) + assert evaluation_result.overall_score == pytest.approx( + expected_score, rel=1e-3 + ) + assert evaluation_result.overall_eval_status == expected_status diff --git a/tests/unittests/evaluation/test_response_evaluator.py b/tests/unittests/evaluation/test_response_evaluator.py index bbaa694f2..839b7188a 100644 --- a/tests/unittests/evaluation/test_response_evaluator.py +++ b/tests/unittests/evaluation/test_response_evaluator.py @@ -16,7 +16,10 @@ from unittest.mock import MagicMock from unittest.mock import patch +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.evaluator import EvalStatus from google.adk.evaluation.response_evaluator import ResponseEvaluator +from google.genai import types as genai_types import pandas as pd import pytest from vertexai.preview.evaluation import MetricPromptTemplateExamples @@ -63,7 +66,7 @@ "google.adk.evaluation.response_evaluator.ResponseEvaluator._perform_eval" ) class TestResponseEvaluator: - """A class to help organize "patch" that are applicabple to all tests.""" + """A class to help organize "patch" that are applicable to all tests.""" def test_evaluate_none_dataset_raises_value_error(self, mock_perform_eval): """Test evaluate function raises ValueError for an empty list.""" @@ -77,6 +80,40 @@ def test_evaluate_empty_dataset_raises_value_error(self, mock_perform_eval): ResponseEvaluator.evaluate([], ["response_evaluation_score"]) mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called + def test_evaluate_invocations_rouge_metric(self, mock_perform_eval): + """Test evaluate_invocations function for Rouge metric.""" + actual_invocations = [ + Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[ + genai_types.Part(text="This is a test candidate response.") + ] + ), + ) + ] + expected_invocations = [ + Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="This is a test query.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text="This is a test reference.")] + ), + ) + ] + evaluator = ResponseEvaluator( + threshold=0.8, metric_name="response_match_score" + ) + evaluation_result = evaluator.evaluate_invocations( + actual_invocations, expected_invocations + ) + assert evaluation_result.overall_score == pytest.approx(8 / 11) + # ROUGE-1 F1 is approx. 0.73 < 0.8 threshold, so eval status is FAILED. + assert evaluation_result.overall_eval_status == EvalStatus.FAILED + def test_evaluate_determines_metrics_correctly_for_perform_eval( self, mock_perform_eval ): diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py new file mode 100644 index 000000000..a330852a1 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -0,0 +1,361 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows import contents +from google.adk.flows.llm_flows.contents import _convert_foreign_event +from google.adk.flows.llm_flows.contents import _get_contents +from google.adk.flows.llm_flows.contents import _merge_function_response_events +from google.adk.flows.llm_flows.contents import _rearrange_events_for_async_function_responses_in_history +from google.adk.flows.llm_flows.contents import _rearrange_events_for_latest_function_response +from google.adk.models import LlmRequest +from google.genai import types +import pytest + +from ... import testing_utils + + +@pytest.mark.asyncio +async def test_content_processor_no_contents(): + """Test ContentLlmRequestProcessor when include_contents is 'none'.""" + agent = Agent(model="gemini-1.5-flash", name="agent", include_contents="none") + llm_request = LlmRequest(model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + # Collect events from async generator + events = [] + async for event in contents.request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + # Should not yield any events + assert len(events) == 0 + # Contents should not be set when include_contents is 'none' + assert llm_request.contents == [] + + +@pytest.mark.asyncio +async def test_content_processor_with_contents(): + """Test ContentLlmRequestProcessor when include_contents is not 'none'.""" + agent = Agent(model="gemini-1.5-flash", name="agent") + llm_request = LlmRequest(model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + # Add some test events to the session + test_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + invocation_context.session.events = [test_event] + + # Collect events from async generator + events = [] + async for event in contents.request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + # Should not yield any events (processor doesn't emit events, just modifies request) + assert len(events) == 0 + # Contents should be set + assert llm_request.contents is not None + assert len(llm_request.contents) == 1 + assert llm_request.contents[0].role == "user" + assert llm_request.contents[0].parts[0].text == "Hello" + + +@pytest.mark.asyncio +async def test_content_processor_non_llm_agent(): + """Test ContentLlmRequestProcessor with non-LLM agent.""" + from google.adk.agents.base_agent import BaseAgent + + # Create a base agent (not LLM agent) + agent = BaseAgent(name="base_agent") + llm_request = LlmRequest(model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + # Collect events from async generator + events = [] + async for event in contents.request_processor.run_async( + invocation_context, llm_request + ): + events.append(event) + + # Should not yield any events and not modify request + assert len(events) == 0 + assert llm_request.contents == [] + + +def test_get_contents_empty_events(): + """Test _get_contents with empty events list.""" + contents_result = _get_contents(None, [], "test_agent") + assert contents_result == [] + + +def test_get_contents_with_events(): + """Test _get_contents with valid events.""" + test_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + + contents_result = _get_contents(None, [test_event], "test_agent") + assert len(contents_result) == 1 + assert contents_result[0].role == "user" + assert contents_result[0].parts[0].text == "Hello" + + +def test_get_contents_filters_empty_events(): + """Test _get_contents filters out events with empty content.""" + # Event with empty text + empty_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content(role="user", parts=[types.Part.from_text(text="")]), + ) + + # Event without content + no_content_event = Event( + invocation_id="test_inv", + author="user", + ) + + # Valid event + valid_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ), + ) + + contents_result = _get_contents( + None, [empty_event, no_content_event, valid_event], "test_agent" + ) + assert len(contents_result) == 1 + assert contents_result[0].role == "user" + assert contents_result[0].parts[0].text == "Hello" + + +def test_convert_foreign_event(): + """Test _convert_foreign_event function.""" + agent_event = Event( + invocation_id="test_inv", + author="agent1", + content=types.Content( + role="model", parts=[types.Part.from_text(text="Agent response")] + ), + ) + + converted_event = _convert_foreign_event(agent_event) + + assert converted_event.author == "user" + assert converted_event.content.role == "user" + assert len(converted_event.content.parts) == 2 + assert converted_event.content.parts[0].text == "For context:" + assert ( + "[agent1] said: Agent response" in converted_event.content.parts[1].text + ) + + +def test_convert_event_with_function_call(): + """Test _convert_foreign_event with function call.""" + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + agent_event = Event( + invocation_id="test_inv", + author="agent1", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + converted_event = _convert_foreign_event(agent_event) + + assert converted_event.author == "user" + assert converted_event.content.role == "user" + assert len(converted_event.content.parts) == 2 + assert converted_event.content.parts[0].text == "For context:" + assert ( + "[agent1] called tool `test_function`" + in converted_event.content.parts[1].text + ) + assert "{'param': 'value'}" in converted_event.content.parts[1].text + + +def test_convert_event_with_function_response(): + """Test _convert_foreign_event with function response.""" + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + agent_event = Event( + invocation_id="test_inv", + author="agent1", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + converted_event = _convert_foreign_event(agent_event) + + assert converted_event.author == "user" + assert converted_event.content.role == "user" + assert len(converted_event.content.parts) == 2 + assert converted_event.content.parts[0].text == "For context:" + assert ( + "[agent1] `test_function` tool returned result:" + in converted_event.content.parts[1].text + ) + assert "{'result': 'success'}" in converted_event.content.parts[1].text + + +def test_merge_function_response_events(): + """Test _merge_function_response_events function.""" + # Create initial function response event + function_response1 = types.FunctionResponse( + id="func_123", name="test_function", response={"status": "pending"} + ) + + initial_event = Event( + invocation_id="test_inv", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response1)] + ), + ) + + # Create final function response event + function_response2 = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + final_event = Event( + invocation_id="test_inv2", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response2)] + ), + ) + + merged_event = _merge_function_response_events([initial_event, final_event]) + + assert ( + merged_event.invocation_id == "test_inv" + ) # Should keep initial event ID + assert len(merged_event.content.parts) == 1 + # The first part should be replaced with the final response + assert merged_event.content.parts[0].function_response.response == { + "result": "success" + } + + +def test_rearrange_events_for_async_function_responses(): + """Test _rearrange_events_for_async_function_responses_in_history function.""" + # Create function call event + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + call_event = Event( + invocation_id="test_inv1", + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + # Create function response event + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + response_event = Event( + invocation_id="test_inv2", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + # Test rearrangement + events = [call_event, response_event] + rearranged = _rearrange_events_for_async_function_responses_in_history(events) + + # Should have both events in correct order + assert len(rearranged) == 2 + assert rearranged[0] == call_event + assert rearranged[1] == response_event + + +def test_rearrange_events_for_latest_function_response(): + """Test _rearrange_events_for_latest_function_response function.""" + # Create function call event + function_call = types.FunctionCall( + id="func_123", name="test_function", args={"param": "value"} + ) + + call_event = Event( + invocation_id="test_inv1", + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + # Create intermediate event + intermediate_event = Event( + invocation_id="test_inv2", + author="agent", + content=types.Content( + role="model", parts=[types.Part.from_text(text="Processing...")] + ), + ) + + # Create function response event + function_response = types.FunctionResponse( + id="func_123", name="test_function", response={"result": "success"} + ) + + response_event = Event( + invocation_id="test_inv3", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + # Test with matching function call and response + events = [call_event, intermediate_event, response_event] + rearranged = _rearrange_events_for_latest_function_response(events) + + # Should remove intermediate events and merge responses + assert len(rearranged) == 2 + assert rearranged[0] == call_event diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index 2c5ef9bce..720af516d 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -17,6 +17,9 @@ from typing import Callable from google.adk.agents import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.functions import find_matching_function_call +from google.adk.sessions.session import Session from google.adk.tools import ToolContext from google.adk.tools.function_tool import FunctionTool from google.genai import types @@ -256,3 +259,136 @@ def increase_by_one(x: int) -> int: assert part.function_response.id is None assert events[0].content.parts[0].function_call.id.startswith('adk-') assert events[1].content.parts[0].function_response.id.startswith('adk-') + + +def test_find_function_call_event_no_function_response_in_last_event(): + """Test when last event has no function response.""" + events = [ + Event( + invocation_id='inv1', + author='user', + content=types.Content(role='user', parts=[types.Part(text='Hello')]), + ) + ] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_empty_session_events(): + """Test when session has no events.""" + events = [] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_function_response_but_no_matching_call(): + """Test when last event has function response but no matching call found.""" + # Create a function response + function_response = types.FunctionResponse( + id='func_123', name='test_func', response={} + ) + + events = [ + Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', + parts=[types.Part(text='Some other response')], + ), + ), + Event( + invocation_id='inv2', + author='user', + content=types.Content( + role='user', + parts=[types.Part(function_response=function_response)], + ), + ), + ] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_function_response_with_matching_call(): + """Test when last event has function response with matching function call.""" + # Create a function call + function_call = types.FunctionCall(id='func_123', name='test_func', args={}) + + # Create a function response with matching ID + function_response = types.FunctionResponse( + id='func_123', name='test_func', response={} + ) + + call_event = Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call)] + ), + ) + + response_event = Event( + invocation_id='inv2', + author='user', + content=types.Content( + role='user', parts=[types.Part(function_response=function_response)] + ), + ) + + events = [call_event, response_event] + + result = find_matching_function_call(events) + assert result == call_event + + +def test_find_function_call_event_multiple_function_responses(): + """Test when last event has multiple function responses.""" + # Create function calls + function_call1 = types.FunctionCall(id='func_123', name='test_func1', args={}) + function_call2 = types.FunctionCall(id='func_456', name='test_func2', args={}) + + # Create function responses + function_response1 = types.FunctionResponse( + id='func_123', name='test_func1', response={} + ) + function_response2 = types.FunctionResponse( + id='func_456', name='test_func2', response={} + ) + + call_event1 = Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call1)] + ), + ) + + call_event2 = Event( + invocation_id='inv2', + author='agent2', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call2)] + ), + ) + + response_event = Event( + invocation_id='inv3', + author='user', + content=types.Content( + role='user', + parts=[ + types.Part(function_response=function_response1), + types.Part(function_response=function_response2), + ], + ), + ) + + events = [call_event1, call_event2, response_event] + + # Should return the first matching function call event found + result = find_matching_function_call(events) + assert result == call_event1 # First match (func_123) diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py new file mode 100644 index 000000000..2fbf3291c --- /dev/null +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any +from unittest import mock + +from google.adk.events import Event +from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService +from google.adk.sessions import Session +from google.genai import types +import pytest + +MOCK_APP_NAME = 'test-app' +MOCK_USER_ID = 'test-user' + +MOCK_SESSION = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='333', + last_update_time=22333, + events=[ + Event( + id='444', + invocation_id='123', + author='user', + timestamp=12345, + content=types.Content(parts=[types.Part(text='test_content')]), + ), + # Empty event, should be ignored + Event( + id='555', + invocation_id='456', + author='user', + timestamp=12345, + ), + ], +) + +MOCK_SESSION_WITH_EMPTY_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='444', + last_update_time=22333, +) + + +RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$' +GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$' + + +class MockApiClient: + """Mocks the API Client.""" + + def __init__(self) -> None: + """Initializes MockClient.""" + self.async_request = mock.AsyncMock() + self.async_request.side_effect = self._mock_async_request + + async def _mock_async_request( + self, http_method: str, path: str, request_dict: dict[str, Any] + ): + """Mocks the API Client request method.""" + if http_method == 'POST': + if re.match(GENERATE_MEMORIES_REGEX, path): + return {} + elif re.match(RETRIEVE_MEMORIES_REGEX, path): + if ( + request_dict.get('scope', None) + and request_dict['scope'].get('app_name', None) == MOCK_APP_NAME + ): + return { + 'retrievedMemories': [ + { + 'memory': { + 'fact': 'test_content', + }, + 'updateTime': '2024-12-12T12:12:12.123456Z', + }, + ], + } + else: + return {'retrievedMemories': []} + else: + raise ValueError(f'Unsupported path: {path}') + else: + raise ValueError(f'Unsupported http method: {http_method}') + + +def mock_vertex_ai_memory_bank_service(): + """Creates a mock Vertex AI Memory Bank service for testing.""" + return VertexAiMemoryBankService( + project='test-project', + location='test-location', + agent_engine_id='123', + ) + + +@pytest.fixture +def mock_get_api_client(): + api_client = MockApiClient() + with mock.patch( + 'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client', + return_value=api_client, + ): + yield api_client + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_add_session_to_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_session_to_memory(MOCK_SESSION) + + mock_get_api_client.async_request.assert_awaited_once_with( + http_method='POST', + path='reasoningEngines/123/memories:generate', + request_dict={ + 'direct_contents_source': { + 'events': [ + { + 'content': { + 'parts': [ + {'text': 'test_content'}, + ], + }, + }, + ], + }, + 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + }, + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_add_empty_session_to_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS) + + mock_get_api_client.async_request.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_search_memory(mock_get_api_client): + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + mock_get_api_client.async_request.assert_awaited_once_with( + http_method='POST', + path='reasoningEngines/123/memories:retrieve', + request_dict={ + 'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + 'similarity_search_params': {'search_query': 'query'}, + }, + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'test_content' diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index f316e83ae..b9b1fb409 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -416,9 +416,26 @@ def __init__(self, acompletion_mock, completion_mock): self.completion_mock = completion_mock async def acompletion(self, model, messages, tools, **kwargs): - return await self.acompletion_mock( - model=model, messages=messages, tools=tools, **kwargs - ) + if kwargs.get("stream", False): + kwargs_copy = dict(kwargs) + kwargs_copy.pop("stream", None) + + async def stream_generator(): + stream_data = self.completion_mock( + model=model, + messages=messages, + tools=tools, + stream=True, + **kwargs_copy, + ) + for item in stream_data: + yield item + + return stream_generator() + else: + return await self.acompletion_mock( + model=model, messages=messages, tools=tools, **kwargs + ) def completion(self, model, messages, tools, stream, **kwargs): return self.completion_mock( @@ -1194,11 +1211,11 @@ async def test_generate_content_async_stream( assert responses[2].content.role == "model" assert responses[2].content.parts[0].text == "two:" assert responses[3].content.role == "model" - assert responses[3].content.parts[0].function_call.name == "test_function" - assert responses[3].content.parts[0].function_call.args == { + assert responses[3].content.parts[-1].function_call.name == "test_function" + assert responses[3].content.parts[-1].function_call.args == { "test_arg": "test_value" } - assert responses[3].content.parts[0].function_call.id == "test_tool_call_id" + assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" mock_completion.assert_called_once() _, kwargs = mock_completion.call_args @@ -1257,11 +1274,11 @@ async def test_generate_content_async_stream_with_usage_metadata( assert responses[2].content.role == "model" assert responses[2].content.parts[0].text == "two:" assert responses[3].content.role == "model" - assert responses[3].content.parts[0].function_call.name == "test_function" - assert responses[3].content.parts[0].function_call.args == { + assert responses[3].content.parts[-1].function_call.name == "test_function" + assert responses[3].content.parts[-1].function_call.args == { "test_arg": "test_value" } - assert responses[3].content.parts[0].function_call.id == "test_tool_call_id" + assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" assert responses[3].usage_metadata.prompt_token_count == 10 assert responses[3].usage_metadata.candidates_token_count == 5 @@ -1430,3 +1447,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +def test_get_completion_inputs_generation_params(): + # Test that generation_params are extracted and mapped correctly + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]), + ], + config=types.GenerateContentConfig( + temperature=0.33, + max_output_tokens=123, + top_p=0.88, + top_k=7, + stop_sequences=["foo", "bar"], + presence_penalty=0.1, + frequency_penalty=0.2, + ), + ) + from google.adk.models.lite_llm import _get_completion_inputs + + _, _, _, generation_params = _get_completion_inputs(req) + assert generation_params["temperature"] == 0.33 + assert generation_params["max_completion_tokens"] == 123 + assert generation_params["top_p"] == 0.88 + assert generation_params["top_k"] == 7 + assert generation_params["stop"] == ["foo", "bar"] + assert generation_params["presence_penalty"] == 0.1 + assert generation_params["frequency_penalty"] == 0.2 + # Should not include max_output_tokens + assert "max_output_tokens" not in generation_params + assert "stop_sequences" not in generation_params diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py new file mode 100644 index 000000000..8d5bd2418 --- /dev/null +++ b/tests/unittests/test_runners.py @@ -0,0 +1,310 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.llm_agent import LlmAgent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types + + +class MockAgent(BaseAgent): + """Mock agent for unit testing.""" + + def __init__( + self, + name: str, + parent_agent: Optional[BaseAgent] = None, + ): + super().__init__(name=name, sub_agents=[]) + # BaseAgent doesn't have disallow_transfer_to_parent field + # This is intentional as we want to test non-LLM agents + if parent_agent: + self.parent_agent = parent_agent + + async def _run_async_impl(self, invocation_context): + yield Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ), + ) + + +class MockLlmAgent(LlmAgent): + """Mock LLM agent for unit testing.""" + + def __init__( + self, + name: str, + disallow_transfer_to_parent: bool = False, + parent_agent: Optional[BaseAgent] = None, + ): + # Use a string model instead of mock + super().__init__(name=name, model="gemini-1.5-pro", sub_agents=[]) + self.disallow_transfer_to_parent = disallow_transfer_to_parent + self.parent_agent = parent_agent + + async def _run_async_impl(self, invocation_context): + yield Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + content=types.Content( + role="model", parts=[types.Part(text="Test LLM response")] + ), + ) + + +class TestRunnerFindAgentToRun: + """Tests for Runner._find_agent_to_run method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.session_service = InMemorySessionService() + self.artifact_service = InMemoryArtifactService() + + # Create test agents + self.root_agent = MockLlmAgent("root_agent") + self.sub_agent1 = MockLlmAgent("sub_agent1", parent_agent=self.root_agent) + self.sub_agent2 = MockLlmAgent("sub_agent2", parent_agent=self.root_agent) + self.non_transferable_agent = MockLlmAgent( + "non_transferable", + disallow_transfer_to_parent=True, + parent_agent=self.root_agent, + ) + + self.root_agent.sub_agents = [ + self.sub_agent1, + self.sub_agent2, + self.non_transferable_agent, + ] + + self.runner = Runner( + app_name="test_app", + agent=self.root_agent, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + def test_find_agent_to_run_with_function_response_scenario(self): + """Test finding agent when last event is function response.""" + # Create a function call from sub_agent1 + function_call = types.FunctionCall(id="func_123", name="test_func", args={}) + function_response = types.FunctionResponse( + id="func_123", name="test_func", response={} + ) + + call_event = Event( + invocation_id="inv1", + author="sub_agent1", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + response_event = Event( + invocation_id="inv2", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[call_event, response_event], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.sub_agent1 + + def test_find_agent_to_run_returns_root_agent_when_no_events(self): + """Test that root agent is returned when session has no non-user events.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="user", + content=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + ) + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.root_agent + + def test_find_agent_to_run_returns_root_agent_when_found_in_events(self): + """Test that root agent is returned when it's found in session events.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="root_agent", + content=types.Content( + role="model", parts=[types.Part(text="Root response")] + ), + ) + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.root_agent + + def test_find_agent_to_run_returns_transferable_sub_agent(self): + """Test that transferable sub agent is returned when found.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="sub_agent1", + content=types.Content( + role="model", parts=[types.Part(text="Sub agent response")] + ), + ) + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.sub_agent1 + + def test_find_agent_to_run_skips_non_transferable_agent(self): + """Test that non-transferable agent is skipped and root agent is returned.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="non_transferable", + content=types.Content( + role="model", + parts=[types.Part(text="Non-transferable response")], + ), + ) + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.root_agent + + def test_find_agent_to_run_skips_unknown_agent(self): + """Test that unknown agent is skipped and root agent is returned.""" + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[ + Event( + invocation_id="inv1", + author="unknown_agent", + content=types.Content( + role="model", + parts=[types.Part(text="Unknown agent response")], + ), + ), + Event( + invocation_id="inv2", + author="root_agent", + content=types.Content( + role="model", parts=[types.Part(text="Root response")] + ), + ), + ], + ) + + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.root_agent + + def test_find_agent_to_run_function_response_takes_precedence(self): + """Test that function response scenario takes precedence over other logic.""" + # Create a function call from sub_agent2 + function_call = types.FunctionCall(id="func_456", name="test_func", args={}) + function_response = types.FunctionResponse( + id="func_456", name="test_func", response={} + ) + + call_event = Event( + invocation_id="inv1", + author="sub_agent2", + content=types.Content( + role="model", parts=[types.Part(function_call=function_call)] + ), + ) + + # Add another event from root_agent + root_event = Event( + invocation_id="inv2", + author="root_agent", + content=types.Content( + role="model", parts=[types.Part(text="Root response")] + ), + ) + + response_event = Event( + invocation_id="inv3", + author="user", + content=types.Content( + role="user", parts=[types.Part(function_response=function_response)] + ), + ) + + session = Session( + id="test_session", + user_id="test_user", + app_name="test_app", + events=[call_event, root_event, response_event], + ) + + # Should return sub_agent2 due to function response, not root_agent + result = self.runner._find_agent_to_run(session, self.root_agent) + assert result == self.sub_agent2 + + def test_is_transferable_across_agent_tree_with_llm_agent(self): + """Test _is_transferable_across_agent_tree with LLM agent.""" + result = self.runner._is_transferable_across_agent_tree(self.sub_agent1) + assert result is True + + def test_is_transferable_across_agent_tree_with_non_transferable_agent(self): + """Test _is_transferable_across_agent_tree with non-transferable agent.""" + result = self.runner._is_transferable_across_agent_tree( + self.non_transferable_agent + ) + assert result is False + + def test_is_transferable_across_agent_tree_with_non_llm_agent(self): + """Test _is_transferable_across_agent_tree with non-LLM agent.""" + non_llm_agent = MockAgent("non_llm_agent") + # MockAgent inherits from BaseAgent, not LlmAgent, so it should return False + result = self.runner._is_transferable_across_agent_tree(non_llm_agent) + assert result is False diff --git a/tests/unittests/test_telemetry.py b/tests/unittests/test_telemetry.py index 1b8ee1b16..debdc802e 100644 --- a/tests/unittests/test_telemetry.py +++ b/tests/unittests/test_telemetry.py @@ -141,6 +141,36 @@ async def test_trace_call_llm_function_response_includes_part_from_bytes( assert llm_request_json_str.count('') == 2 +@pytest.mark.asyncio +async def test_trace_call_llm_usage_metadata(monkeypatch, mock_span_fixture): + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + agent = LlmAgent(name='test_agent') + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest( + config=types.GenerateContentConfig(system_instruction=''), + ) + llm_response = LlmResponse( + turn_complete=True, + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=100, prompt_token_count=50 + ), + ) + trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response) + + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + mock.call('gen_ai.usage.input_tokens', 50), + mock.call('gen_ai.usage.output_tokens', 100), + ] + assert mock_span_fixture.set_attribute.call_count == 9 + mock_span_fixture.set_attribute.assert_has_calls( + expected_calls, any_order=True + ) + + def test_trace_tool_call_with_scalar_response( monkeypatch, mock_span_fixture, mock_tool_fixture, mock_event_fixture ): diff --git a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py index cd37a105e..c9b542e51 100644 --- a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py +++ b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py @@ -20,6 +20,7 @@ from google.adk.auth.auth_credential import HttpCredentials from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool +from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import AuthPreparationResult from google.genai.types import FunctionDeclaration from google.genai.types import Schema from google.genai.types import Type @@ -50,7 +51,9 @@ def mock_rest_api_tool(): "required": ["user_id", "page_size", "filter", "connection_name"], } mock_tool._operation_parser = mock_parser - mock_tool.call.return_value = {"status": "success", "data": "mock_data"} + mock_tool.call = mock.AsyncMock( + return_value={"status": "success", "data": "mock_data"} + ) return mock_tool @@ -179,9 +182,6 @@ async def test_run_with_auth_async_none_token( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = mock.MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) # Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None mock_auth_credential_without_token = AuthCredential( auth_type=AuthCredentialTypes.HTTP, @@ -190,8 +190,12 @@ async def test_run_with_auth_async_none_token( credentials=HttpCredentials(token=None), # Token is None ), ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = ( - mock_auth_credential_without_token + mock_tool_auth_handler_instance.prepare_auth_credentials = mock.AsyncMock( + return_value=( + AuthPreparationResult( + state="done", auth_credential=mock_auth_credential_without_token + ) + ) ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance @@ -229,18 +233,18 @@ async def test_run_with_auth_async( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = mock.MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", - credentials=HttpCredentials(token="mocked_token"), - ), + + mock_tool_auth_handler_instance.prepare_auth_credentials = mock.AsyncMock( + return_value=AuthPreparationResult( + state="done", + auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="mocked_token"), + ), + ), + ) ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance result = await integration_tool_with_auth.run_async( diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py new file mode 100644 index 000000000..e8b373416 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -0,0 +1,129 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import re +from unittest import mock + +from google.adk.tools.bigquery.client import get_bigquery_client +from google.auth.exceptions import DefaultCredentialsError +from google.oauth2.credentials import Credentials +import pytest + + +def test_bigquery_client_project(): + """Test BigQuery client project.""" + # Trigger the BigQuery client creation + client = get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the client has the desired project set + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_explicit(): + """Test BigQuery client creation does not invoke default auth.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_with_default_auth(): + """Test BigQuery client creation invokes default auth to set the project.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate credentials + mock_creds = mock.create_autospec(Credentials, instance=True) + + # Simulate output of the default auth + mock_default_auth.return_value = (mock_creds, "test-gcp-project") + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project=None, + credentials=mock_creds, + ) + + # Verify that default auth was called once to set the client project + mock_default_auth.assert_called_once() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_with_env(): + """Test BigQuery client creation sets the project from environment variable.""" + # Let's simulate the project set in environment variables + with mock.patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True + ): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project=None, + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_user_agent(): + """Test BigQuery client user agent.""" + with mock.patch( + "google.cloud.bigquery.client.Connection", autospec=True + ) as mock_connection: + # Trigger the BigQuery client creation + get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the tracking user agent was set + client_info_arg = mock_connection.call_args[1].get("client_info") + assert client_info_arg is not None + assert re.search( + r"adk-bigquery-tool google-adk/([0-9A-Za-z._\-+/]+)", + client_info_arg.user_agent, + ) diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index 9fa152fc2..05af3aaf3 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest import mock from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig # Mock the Google OAuth and API dependencies -from google.oauth2.credentials import Credentials +import google.auth.credentials +import google.oauth2.credentials import pytest @@ -27,22 +28,46 @@ class TestBigQueryCredentials: either existing credentials or client ID/secret pairs are provided. """ - def test_valid_credentials_object(self): - """Test that providing valid Credentials object works correctly. + def test_valid_credentials_object_auth_credentials(self): + """Test that providing valid Credentials object works correctly with + google.auth.credentials.Credentials. When a user already has valid OAuth credentials, they should be able to pass them directly without needing to provide client ID/secret. """ - # Create a mock credentials object with the expected attributes - mock_creds = Mock(spec=Credentials) - mock_creds.client_id = "test_client_id" - mock_creds.client_secret = "test_client_secret" - mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"] + # Create a mock auth credentials object + # auth_creds = google.auth.credentials.Credentials() + auth_creds = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + config = BigQueryCredentialsConfig(credentials=auth_creds) + + # Verify that the credentials are properly stored and attributes are extracted + assert config.credentials == auth_creds + assert config.client_id is None + assert config.client_secret is None + assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + + def test_valid_credentials_object_oauth2_credentials(self): + """Test that providing valid Credentials object works correctly with + google.oauth2.credentials.Credentials. + + When a user already has valid OAuth credentials, they should be able + to pass them directly without needing to provide client ID/secret. + """ + # Create a mock oauth2 credentials object + oauth2_creds = google.oauth2.credentials.Credentials( + "test_token", + client_id="test_client_id", + client_secret="test_client_secret", + scopes=["https://www.googleapis.com/auth/calendar"], + ) - config = BigQueryCredentialsConfig(credentials=mock_creds) + config = BigQueryCredentialsConfig(credentials=oauth2_creds) # Verify that the credentials are properly stored and attributes are extracted - assert config.credentials == mock_creds + assert config.credentials == oauth2_creds assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" assert config.scopes == ["https://www.googleapis.com/auth/calendar"] diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py index 95d8b00d6..47d955906 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py @@ -22,9 +22,10 @@ from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager +from google.auth.credentials import Credentials as AuthCredentials from google.auth.exceptions import RefreshError # Mock the Google OAuth and API dependencies -from google.oauth2.credentials import Credentials +from google.oauth2.credentials import Credentials as OAuthCredentials import pytest @@ -64,9 +65,16 @@ def manager(self, credentials_config): """Create a credentials manager instance for testing.""" return BigQueryCredentialsManager(credentials_config) + @pytest.mark.parametrize( + ("credentials_class",), + [ + pytest.param(OAuthCredentials, id="oauth"), + pytest.param(AuthCredentials, id="auth"), + ], + ) @pytest.mark.asyncio async def test_get_valid_credentials_with_valid_existing_creds( - self, manager, mock_tool_context + self, manager, mock_tool_context, credentials_class ): """Test that valid existing credentials are returned immediately. @@ -74,7 +82,7 @@ async def test_get_valid_credentials_with_valid_existing_creds( should be needed. This is the optimal happy path scenario. """ # Create mock credentials that are already valid - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=credentials_class) mock_creds.valid = True manager.credentials_config.credentials = mock_creds @@ -85,6 +93,34 @@ async def test_get_valid_credentials_with_valid_existing_creds( mock_tool_context.get_auth_response.assert_not_called() mock_tool_context.request_credential.assert_not_called() + @pytest.mark.parametrize( + ("valid",), + [ + pytest.param(False, id="invalid"), + pytest.param(True, id="valid"), + ], + ) + @pytest.mark.asyncio + async def test_get_valid_credentials_with_existing_non_oauth_creds( + self, manager, mock_tool_context, valid + ): + """Test that existing non-oauth credentials are returned immediately. + + When credentials are of non-oauth type, no refresh or OAuth flow + is triggered irrespective of whether it is valid or not. + """ + # Create mock credentials that are already valid + mock_creds = Mock(spec=AuthCredentials) + mock_creds.valid = valid + manager.credentials_config.credentials = mock_creds + + result = await manager.get_valid_credentials(mock_tool_context) + + assert result == mock_creds + # Verify no OAuth flow was triggered + mock_tool_context.get_auth_response.assert_not_called() + mock_tool_context.request_credential.assert_not_called() + @pytest.mark.asyncio async def test_get_credentials_from_cache_when_none_in_manager( self, manager, mock_tool_context @@ -113,7 +149,7 @@ async def test_get_credentials_from_cache_when_none_in_manager( with patch( "google.oauth2.credentials.Credentials.from_authorized_user_info" ) as mock_from_json: - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) mock_creds.valid = True mock_from_json.return_value = mock_creds @@ -179,7 +215,7 @@ async def test_refresh_cached_credentials_success( mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json # Create expired cached credentials with refresh token - mock_cached_creds = Mock(spec=Credentials) + mock_cached_creds = Mock(spec=OAuthCredentials) mock_cached_creds.valid = False mock_cached_creds.expired = True mock_cached_creds.refresh_token = "valid_refresh_token" @@ -227,7 +263,7 @@ async def test_get_valid_credentials_with_refresh_success( users from having to re-authenticate for every expired token. """ # Create expired credentials with refresh token - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) mock_creds.valid = False mock_creds.expired = True mock_creds.refresh_token = "refresh_token" @@ -257,7 +293,7 @@ async def test_get_valid_credentials_with_refresh_failure( gracefully fall back to requesting a new OAuth flow. """ # Create expired credentials that fail to refresh - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) mock_creds.valid = False mock_creds.expired = True mock_creds.refresh_token = "expired_refresh_token" @@ -287,7 +323,7 @@ async def test_oauth_flow_completion_with_caching( mock_tool_context.get_auth_response.return_value = mock_auth_response # Create a mock credentials instance that will represent our created credentials - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) # Make the JSON match what a real Credentials object would produce mock_creds_json = ( '{"token": "new_access_token", "refresh_token": "new_refresh_token",' @@ -300,7 +336,7 @@ async def test_oauth_flow_completion_with_caching( # Use the full module path as it appears in the project structure with patch( - "google.adk.tools.bigquery.bigquery_credentials.Credentials", + "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: result = await manager.get_valid_credentials(mock_tool_context) @@ -361,7 +397,7 @@ async def test_cache_persistence_across_manager_instances( mock_tool_context.get_auth_response.return_value = mock_auth_response # Create the mock credentials instance that will be returned by the constructor - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) # Make sure our mock JSON matches the structure that real Credentials objects produce mock_creds_json = ( '{"token": "cached_access_token", "refresh_token":' @@ -376,7 +412,7 @@ async def test_cache_persistence_across_manager_instances( # Use the correct module path - without the 'src.' prefix with patch( - "google.adk.tools.bigquery.bigquery_credentials.Credentials", + "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: # Complete OAuth flow with first manager @@ -396,9 +432,9 @@ async def test_cache_persistence_across_manager_instances( # Mock the from_authorized_user_info method for the second manager with patch( - "google.adk.tools.bigquery.bigquery_credentials.Credentials.from_authorized_user_info" + "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials.from_authorized_user_info" ) as mock_from_json: - mock_cached_creds = Mock(spec=Credentials) + mock_cached_creds = Mock(spec=OAuthCredentials) mock_cached_creds.valid = True mock_from_json.return_value = mock_cached_creds diff --git a/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py new file mode 100644 index 000000000..14ecea558 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py @@ -0,0 +1,122 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock + +from google.adk.tools.bigquery import metadata_tool +from google.auth.exceptions import DefaultCredentialsError +from google.cloud import bigquery +from google.oauth2.credentials import Credentials +import pytest + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.list_datasets", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_list_dataset_ids(mock_default_auth, mock_list_datasets): + """Test list_dataset_ids tool invocation.""" + project = "my_project_id" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_list_datasets.return_value = [ + bigquery.DatasetReference(project, "dataset1"), + bigquery.DatasetReference(project, "dataset2"), + ] + result = metadata_tool.list_dataset_ids(project, mock_credentials) + assert result == ["dataset1", "dataset2"] + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.get_dataset", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_get_dataset_info(mock_default_auth, mock_get_dataset): + """Test get_dataset_info tool invocation.""" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_get_dataset.return_value = mock.create_autospec( + Credentials, instance=True + ) + result = metadata_tool.get_dataset_info( + "my_project_id", "my_dataset_id", mock_credentials + ) + assert result != { + "status": "ERROR", + "error_details": "Your default credentials were not found", + } + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.list_tables", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_list_table_ids(mock_default_auth, mock_list_tables): + """Test list_table_ids tool invocation.""" + project = "my_project_id" + dataset = "my_dataset_id" + dataset_ref = bigquery.DatasetReference(project, dataset) + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_list_tables.return_value = [ + bigquery.TableReference(dataset_ref, "table1"), + bigquery.TableReference(dataset_ref, "table2"), + ] + result = metadata_tool.list_table_ids(project, dataset, mock_credentials) + assert result == ["table1", "table2"] + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.get_table", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_get_table_info(mock_default_auth, mock_get_table): + """Test get_table_info tool invocation.""" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_get_table.return_value = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_table_info( + "my_project_id", "my_dataset_id", "my_table_id", mock_credentials + ) + assert result != { + "status": "ERROR", + "error_details": "Your default credentials were not found", + } + mock_default_auth.assert_not_called() diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index 35d44ef81..3cb8c3c4a 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import os import textwrap from typing import Optional from unittest import mock @@ -24,6 +25,7 @@ from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode from google.adk.tools.bigquery.query_tool import execute_sql +from google.auth.exceptions import DefaultCredentialsError from google.cloud import bigquery from google.oauth2.credentials import Credentials import pytest @@ -227,14 +229,8 @@ async def test_execute_sql_declaration_write(tool_config): @pytest.mark.parametrize( ("write_mode",), [ - pytest.param( - WriteMode.BLOCKED, - id="blocked", - ), - pytest.param( - WriteMode.ALLOWED, - id="allowed", - ), + pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.ALLOWED, id="allowed"), ], ) def test_execute_sql_select_stmt(write_mode): @@ -279,7 +275,7 @@ def test_execute_sql_select_stmt(write_mode): ], ) def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): - """Test execute_sql tool for SELECT query when writes are blocked.""" + """Test execute_sql tool for non-SELECT query when writes are blocked.""" project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) @@ -318,7 +314,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): ], ) def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): - """Test execute_sql tool for SELECT query when writes are blocked.""" + """Test execute_sql tool for non-SELECT query when writes are blocked.""" project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) @@ -342,3 +338,45 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): "status": "ERROR", "error_details": "Read-only mode only supports SELECT statements.", } + + +@pytest.mark.parametrize( + ("write_mode",), + [ + pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.ALLOWED, id="allowed"), + ], +) +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True) +@mock.patch("google.cloud.bigquery.Client.query", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_execute_sql_no_default_auth( + mock_default_auth, mock_query, mock_query_and_wait, write_mode +): + """Test execute_sql tool invocation does not involve calling default auth.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + query_result = [{"num": 123}] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=write_mode) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + mock_query.return_value = query_job + + # Simulate the result of query_and_wait API + mock_query_and_wait.return_value = query_result + + # Test the tool worked without invoking default auth + result = execute_sql(project, query, credentials, tool_config) + assert result == {"status": "SUCCESS", "rows": query_result} + mock_default_auth.assert_not_called() diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index ea9990b9f..4129dc512 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -96,9 +96,7 @@ async def test_bigquery_toolset_tools_selective(selected_tools): ], ) @pytest.mark.asyncio -async def test_bigquery_toolset_unknown_tool_raises( - selected_tools, returned_tools -): +async def test_bigquery_toolset_unknown_tool(selected_tools, returned_tools): """Test BigQuery toolset with filter. This test verifies the behavior of the BigQuery toolset when filter is diff --git a/tests/unittests/tools/mcp_tool/__init__.py b/tests/unittests/tools/mcp_tool/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/tools/mcp_tool/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py new file mode 100644 index 000000000..559e51719 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -0,0 +1,364 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +from io import StringIO +import json +import sys +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager + from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource + from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams + from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams + from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyClass: + pass + + MCPSessionManager = DummyClass + retry_on_closed_resource = lambda x: x + SseConnectionParams = DummyClass + StdioConnectionParams = DummyClass + StreamableHTTPConnectionParams = DummyClass + else: + raise e + +# Import real MCP classes +try: + from mcp import StdioServerParameters +except ImportError: + # Create a mock if MCP is not available + class StdioServerParameters: + + def __init__(self, command="test_command", args=None): + self.command = command + self.args = args or [] + + +class MockClientSession: + """Mock ClientSession for testing.""" + + def __init__(self): + self._read_stream = Mock() + self._write_stream = Mock() + self._read_stream._closed = False + self._write_stream._closed = False + self.initialize = AsyncMock() + + +class MockAsyncExitStack: + """Mock AsyncExitStack for testing.""" + + def __init__(self): + self.aclose = AsyncMock() + self.enter_async_context = AsyncMock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class TestMCPSessionManager: + """Test suite for MCPSessionManager class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_stdio_connection_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=5.0 + ) + + def test_init_with_stdio_server_parameters(self): + """Test initialization with StdioServerParameters (deprecated).""" + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.logger" + ) as mock_logger: + manager = MCPSessionManager(self.mock_stdio_params) + + # Should log deprecation warning + mock_logger.warning.assert_called_once() + assert "StdioServerParameters is not recommended" in str( + mock_logger.warning.call_args + ) + + # Should convert to StdioConnectionParams + assert isinstance(manager._connection_params, StdioConnectionParams) + assert manager._connection_params.server_params == self.mock_stdio_params + assert manager._connection_params.timeout == 5 + + def test_init_with_stdio_connection_params(self): + """Test initialization with StdioConnectionParams.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + assert manager._connection_params == self.mock_stdio_connection_params + assert manager._errlog == sys.stderr + assert manager._sessions == {} + + def test_init_with_sse_connection_params(self): + """Test initialization with SseConnectionParams.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", + headers={"Authorization": "Bearer token"}, + timeout=10.0, + ) + manager = MCPSessionManager(sse_params) + + assert manager._connection_params == sse_params + + def test_init_with_streamable_http_params(self): + """Test initialization with StreamableHTTPConnectionParams.""" + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", timeout=15.0 + ) + manager = MCPSessionManager(http_params) + + assert manager._connection_params == http_params + + def test_generate_session_key_stdio(self): + """Test session key generation for stdio connections.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # For stdio, headers should be ignored and return constant key + key1 = manager._generate_session_key({"Authorization": "Bearer token"}) + key2 = manager._generate_session_key(None) + + assert key1 == "stdio_session" + assert key2 == "stdio_session" + assert key1 == key2 + + def test_generate_session_key_sse(self): + """Test session key generation for SSE connections.""" + sse_params = SseConnectionParams(url="https://example.com/mcp") + manager = MCPSessionManager(sse_params) + + headers1 = {"Authorization": "Bearer token1"} + headers2 = {"Authorization": "Bearer token2"} + + key1 = manager._generate_session_key(headers1) + key2 = manager._generate_session_key(headers2) + key3 = manager._generate_session_key(headers1) + + # Different headers should generate different keys + assert key1 != key2 + # Same headers should generate same key + assert key1 == key3 + + # Should be deterministic hash + headers_json = json.dumps(headers1, sort_keys=True) + expected_hash = hashlib.md5(headers_json.encode()).hexdigest() + assert key1 == f"session_{expected_hash}" + + def test_merge_headers_stdio(self): + """Test header merging for stdio connections.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Stdio connections don't support headers + headers = manager._merge_headers({"Authorization": "Bearer token"}) + assert headers is None + + def test_merge_headers_sse(self): + """Test header merging for SSE connections.""" + base_headers = {"Content-Type": "application/json"} + sse_params = SseConnectionParams( + url="https://example.com/mcp", headers=base_headers + ) + manager = MCPSessionManager(sse_params) + + # With additional headers + additional = {"Authorization": "Bearer token"} + merged = manager._merge_headers(additional) + + expected = { + "Content-Type": "application/json", + "Authorization": "Bearer token", + } + assert merged == expected + + def test_is_session_disconnected(self): + """Test session disconnection detection.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock session + session = MockClientSession() + + # Not disconnected + assert not manager._is_session_disconnected(session) + + # Disconnected - read stream closed + session._read_stream._closed = True + assert manager._is_session_disconnected(session) + + @pytest.mark.asyncio + async def test_create_session_stdio_new(self): + """Test creating a new stdio session.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + mock_session = MockClientSession() + mock_exit_stack = MockAsyncExitStack() + + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.stdio_client" + ) as mock_stdio: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack" + ) as mock_exit_stack_class: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.ClientSession" + ) as mock_session_class: + + # Setup mocks + mock_exit_stack_class.return_value = mock_exit_stack + mock_stdio.return_value = AsyncMock() + mock_exit_stack.enter_async_context.side_effect = [ + ("read", "write"), # First call returns transports + mock_session, # Second call returns session + ] + mock_session_class.return_value = mock_session + + # Create session + session = await manager.create_session() + + # Verify session creation + assert session == mock_session + assert len(manager._sessions) == 1 + assert "stdio_session" in manager._sessions + + # Verify session was initialized + mock_session.initialize.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_reuse_existing(self): + """Test reusing an existing connected session.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock existing session + existing_session = MockClientSession() + existing_exit_stack = MockAsyncExitStack() + manager._sessions["stdio_session"] = (existing_session, existing_exit_stack) + + # Session is connected + existing_session._read_stream._closed = False + existing_session._write_stream._closed = False + + session = await manager.create_session() + + # Should reuse existing session + assert session == existing_session + assert len(manager._sessions) == 1 + + # Should not create new session + existing_session.initialize.assert_not_called() + + @pytest.mark.asyncio + async def test_close_success(self): + """Test successful cleanup of all sessions.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add mock sessions + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + + manager._sessions["session1"] = (session1, exit_stack1) + manager._sessions["session2"] = (session2, exit_stack2) + + await manager.close() + + # All sessions should be closed + exit_stack1.aclose.assert_called_once() + exit_stack2.aclose.assert_called_once() + assert len(manager._sessions) == 0 + + @pytest.mark.asyncio + async def test_close_with_errors(self): + """Test cleanup when some sessions fail to close.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add mock sessions + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + exit_stack1.aclose.side_effect = Exception("Close error 1") + + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + + manager._sessions["session1"] = (session1, exit_stack1) + manager._sessions["session2"] = (session2, exit_stack2) + + custom_errlog = StringIO() + manager._errlog = custom_errlog + + # Should not raise exception + await manager.close() + + # Good session should still be closed + exit_stack2.aclose.assert_called_once() + assert len(manager._sessions) == 0 + + # Error should be logged + error_output = custom_errlog.getvalue() + assert "Warning: Error during MCP session cleanup" in error_output + assert "Close error 1" in error_output + + +def test_retry_on_closed_resource_decorator(): + """Test the retry_on_closed_resource decorator.""" + + call_count = 0 + + @retry_on_closed_resource + async def mock_function(self): + nonlocal call_count + call_count += 1 + if call_count == 1: + import anyio + + raise anyio.ClosedResourceError("Resource closed") + return "success" + + @pytest.mark.asyncio + async def test_retry(): + nonlocal call_count + call_count = 0 + + mock_self = Mock() + result = await mock_function(mock_self) + + assert result == "success" + assert call_count == 2 # First call fails, second succeeds + + # Run the test + import asyncio + + asyncio.run(test_retry()) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py new file mode 100644 index 000000000..82e3f2234 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -0,0 +1,360 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Any +from typing import Dict +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager + from google.adk.tools.mcp_tool.mcp_tool import MCPTool + from google.adk.tools.tool_context import ToolContext + from google.genai.types import FunctionDeclaration +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyClass: + pass + + MCPSessionManager = DummyClass + MCPTool = DummyClass + ToolContext = DummyClass + FunctionDeclaration = DummyClass + else: + raise e + + +# Mock MCP Tool from mcp.types +class MockMCPTool: + """Mock MCP Tool for testing.""" + + def __init__(self, name="test_tool", description="Test tool description"): + self.name = name + self.description = description + self.inputSchema = { + "type": "object", + "properties": { + "param1": {"type": "string", "description": "First parameter"}, + "param2": {"type": "integer", "description": "Second parameter"}, + }, + "required": ["param1"], + } + + +class TestMCPTool: + """Test suite for MCPTool class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_mcp_tool = MockMCPTool() + self.mock_session_manager = Mock(spec=MCPSessionManager) + self.mock_session = AsyncMock() + self.mock_session_manager.create_session = AsyncMock( + return_value=self.mock_session + ) + + def test_init_basic(self): + """Test basic initialization without auth.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test tool description" + assert tool._mcp_tool == self.mock_mcp_tool + assert tool._mcp_session_manager == self.mock_session_manager + + def test_init_with_auth(self): + """Test initialization with authentication.""" + # Create real auth scheme instances instead of mocks + from fastapi.openapi.models import OAuth2 + + auth_scheme = OAuth2(flows={}) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + # The auth config is stored in the parent class _credentials_manager + assert tool._credentials_manager is not None + assert tool._credentials_manager._auth_config.auth_scheme == auth_scheme + assert ( + tool._credentials_manager._auth_config.raw_auth_credential + == auth_credential + ) + + def test_init_with_empty_description(self): + """Test initialization with empty description.""" + mock_tool = MockMCPTool(description=None) + tool = MCPTool( + mcp_tool=mock_tool, + mcp_session_manager=self.mock_session_manager, + ) + + assert tool.description == "" + + def test_get_declaration(self): + """Test function declaration generation.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + declaration = tool._get_declaration() + + assert isinstance(declaration, FunctionDeclaration) + assert declaration.name == "test_tool" + assert declaration.description == "Test tool description" + assert declaration.parameters is not None + + @pytest.mark.asyncio + async def test_run_async_impl_no_auth(self): + """Test running tool without authentication.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Mock the session response + expected_response = {"result": "success"} + self.mock_session.call_tool = AsyncMock(return_value=expected_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=None + ) + + assert result == expected_response + self.mock_session_manager.create_session.assert_called_once_with( + headers=None + ) + # Fix: call_tool uses 'arguments' parameter, not positional args + self.mock_session.call_tool.assert_called_once_with( + "test_tool", arguments=args + ) + + @pytest.mark.asyncio + async def test_run_async_impl_with_oauth2(self): + """Test running tool with OAuth2 authentication.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create OAuth2 credential + oauth2_auth = OAuth2Auth(access_token="test_access_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + # Mock the session response + expected_response = {"result": "success"} + self.mock_session.call_tool = AsyncMock(return_value=expected_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + assert result == expected_response + # Check that headers were passed correctly + self.mock_session_manager.create_session.assert_called_once() + call_args = self.mock_session_manager.create_session.call_args + headers = call_args[1]["headers"] + assert headers == {"Authorization": "Bearer test_access_token"} + + @pytest.mark.asyncio + async def test_get_headers_oauth2(self): + """Test header generation for OAuth2 credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + oauth2_auth = OAuth2Auth(access_token="test_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer test_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_bearer(self): + """Test header generation for HTTP Bearer credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="bearer", credentials=HttpCredentials(token="bearer_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer bearer_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_basic(self): + """Test header generation for HTTP Basic credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="basic", + credentials=HttpCredentials(username="user", password="pass"), + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should create Basic auth header with base64 encoded credentials + import base64 + + expected_encoded = base64.b64encode(b"user:pass").decode() + assert headers == {"Authorization": f"Basic {expected_encoded}"} + + @pytest.mark.asyncio + async def test_get_headers_api_key(self): + """Test header generation for API Key credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"X-API-Key": "my_api_key"} + + @pytest.mark.asyncio + async def test_get_headers_no_credential(self): + """Test header generation with no credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, None) + + assert headers is None + + @pytest.mark.asyncio + async def test_get_headers_service_account(self): + """Test header generation for service account credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create service account credential + service_account = ServiceAccount(scopes=["test"]) + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=service_account, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should return None as service account credentials are not supported for direct header generation + assert headers is None + + @pytest.mark.asyncio + async def test_run_async_impl_retry_decorator(self): + """Test that the retry decorator is applied correctly.""" + # This is more of an integration test to ensure the decorator is present + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Check that the method has the retry decorator + assert hasattr(tool._run_async_impl, "__wrapped__") + + @pytest.mark.asyncio + async def test_get_headers_http_custom_scheme(self): + """Test header generation for custom HTTP scheme.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="custom", credentials=HttpCredentials(token="custom_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "custom custom_token"} + + def test_init_validation(self): + """Test that initialization validates required parameters.""" + # This test ensures that the MCPTool properly handles its dependencies + with pytest.raises(TypeError): + MCPTool() # Missing required parameters + + with pytest.raises(TypeError): + MCPTool(mcp_tool=self.mock_mcp_tool) # Missing session manager diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py new file mode 100644 index 000000000..d5e6ae243 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -0,0 +1,286 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import StringIO +import sys +import unittest +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.auth.auth_credential import AuthCredential +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager + from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams + from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams + from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams + from google.adk.tools.mcp_tool.mcp_tool import MCPTool + from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset + from mcp import StdioServerParameters +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyClass: + pass + + class StdioServerParameters: + + def __init__(self, command="test_command", args=None): + self.command = command + self.args = args or [] + + MCPSessionManager = DummyClass + SseConnectionParams = DummyClass + StdioConnectionParams = DummyClass + StreamableHTTPConnectionParams = DummyClass + MCPTool = DummyClass + MCPToolset = DummyClass + else: + raise e + + +class MockMCPTool: + """Mock MCP Tool for testing.""" + + def __init__(self, name, description="Test tool description"): + self.name = name + self.description = description + self.inputSchema = { + "type": "object", + "properties": {"param": {"type": "string"}}, + } + + +class MockListToolsResult: + """Mock ListToolsResult for testing.""" + + def __init__(self, tools): + self.tools = tools + + +class TestMCPToolset: + """Test suite for MCPToolset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_session_manager = Mock(spec=MCPSessionManager) + self.mock_session = AsyncMock() + self.mock_session_manager.create_session = AsyncMock( + return_value=self.mock_session + ) + + def test_init_basic(self): + """Test basic initialization with StdioServerParameters.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + + # Note: StdioServerParameters gets converted to StdioConnectionParams internally + assert toolset._errlog == sys.stderr + assert toolset._auth_scheme is None + assert toolset._auth_credential is None + + def test_init_with_stdio_connection_params(self): + """Test initialization with StdioConnectionParams.""" + stdio_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=10.0 + ) + toolset = MCPToolset(connection_params=stdio_params) + + assert toolset._connection_params == stdio_params + + def test_init_with_sse_connection_params(self): + """Test initialization with SseConnectionParams.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", headers={"Authorization": "Bearer token"} + ) + toolset = MCPToolset(connection_params=sse_params) + + assert toolset._connection_params == sse_params + + def test_init_with_streamable_http_params(self): + """Test initialization with StreamableHTTPConnectionParams.""" + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", + headers={"Content-Type": "application/json"}, + ) + toolset = MCPToolset(connection_params=http_params) + + assert toolset._connection_params == http_params + + def test_init_with_tool_filter_list(self): + """Test initialization with tool filter as list.""" + tool_filter = ["tool1", "tool2"] + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=tool_filter + ) + + # The tool filter is stored in the parent BaseToolset class + # We can verify it by checking the filtering behavior in get_tools + assert toolset._is_tool_selected is not None + + def test_init_with_auth(self): + """Test initialization with authentication.""" + # Create real auth scheme instances + from fastapi.openapi.models import OAuth2 + + auth_scheme = OAuth2(flows={}) + from google.adk.auth.auth_credential import OAuth2Auth + + auth_credential = AuthCredential( + auth_type="oauth2", + oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + ) + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + assert toolset._auth_scheme == auth_scheme + assert toolset._auth_credential == auth_credential + + def test_init_missing_connection_params(self): + """Test initialization with missing connection params raises error.""" + with pytest.raises(ValueError, match="Missing connection params"): + MCPToolset(connection_params=None) + + @pytest.mark.asyncio + async def test_get_tools_basic(self): + """Test getting tools without filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("tool1"), + MockMCPTool("tool2"), + MockMCPTool("tool3"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 3 + for tool in tools: + assert isinstance(tool, MCPTool) + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + assert tools[2].name == "tool3" + + @pytest.mark.asyncio + async def test_get_tools_with_list_filter(self): + """Test getting tools with list-based filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("tool1"), + MockMCPTool("tool2"), + MockMCPTool("tool3"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + tool_filter = ["tool1", "tool3"] + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=tool_filter + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool3" + + @pytest.mark.asyncio + async def test_get_tools_with_function_filter(self): + """Test getting tools with function-based filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("read_file"), + MockMCPTool("write_file"), + MockMCPTool("list_directory"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + def file_tools_filter(tool, context): + """Filter for file-related tools only.""" + return "file" in tool.name + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=file_tools_filter + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert tools[0].name == "read_file" + assert tools[1].name == "write_file" + + @pytest.mark.asyncio + async def test_close_success(self): + """Test successful cleanup.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + await toolset.close() + + self.mock_session_manager.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_exception(self): + """Test cleanup when session manager raises exception.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + # Mock close to raise an exception + self.mock_session_manager.close = AsyncMock( + side_effect=Exception("Cleanup error") + ) + + custom_errlog = StringIO() + toolset._errlog = custom_errlog + + # Should not raise exception + await toolset.close() + + # Should log the error + error_output = custom_errlog.getvalue() + assert "Warning: Error during MCPToolset cleanup" in error_output + assert "Cleanup error" in error_output + + @pytest.mark.asyncio + async def test_get_tools_retry_decorator(self): + """Test that get_tools has retry decorator applied.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + + # Check that the method has the retry decorator + assert hasattr(toolset.get_tools, "__wrapped__") diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index 303dda69d..c4cbea7b9 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -14,6 +14,7 @@ import json +from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -194,7 +195,8 @@ def test_get_declaration( @patch( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request" ) - def test_call_success( + @pytest.mark.asyncio + async def test_call_success( self, mock_request, mock_tool_context, @@ -217,7 +219,7 @@ def test_call_success( ) # Call the method - result = tool.call(args={}, tool_context=mock_tool_context) + result = await tool.call(args={}, tool_context=mock_tool_context) # Check the result assert result == {"result": "success"} @@ -225,7 +227,8 @@ def test_call_success( @patch( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request" ) - def test_call_auth_pending( + @pytest.mark.asyncio + async def test_call_auth_pending( self, mock_request, sample_endpoint, @@ -246,12 +249,14 @@ def test_call_auth_pending( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "pending" + mock_prepare_result = MagicMock() + mock_prepare_result.state = "pending" + mock_tool_auth_handler_instance.prepare_auth_credentials = AsyncMock( + return_value=mock_prepare_result ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance - response = tool.call(args={}, tool_context=None) + response = await tool.call(args={}, tool_context=None) assert response == { "pending": True, "message": "Needs your authorization to access your data.", diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py index 8db151fc8..e405ce5b8 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py @@ -116,7 +116,8 @@ def openid_connect_credential(): return credential -def test_openid_connect_no_auth_response( +@pytest.mark.asyncio +async def test_openid_connect_no_auth_response( openid_connect_scheme, openid_connect_credential ): # Setup Mock exchanger @@ -132,12 +133,13 @@ def test_openid_connect_no_auth_response( credential_exchanger=mock_exchanger, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'pending' assert result.auth_credential == openid_connect_credential -def test_openid_connect_with_auth_response( +@pytest.mark.asyncio +async def test_openid_connect_with_auth_response( openid_connect_scheme, openid_connect_credential, monkeypatch ): mock_exchanger = MockOpenIdConnectCredentialExchanger( @@ -166,7 +168,7 @@ def test_openid_connect_with_auth_response( credential_exchanger=mock_exchanger, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'done' assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP assert 'test_access_token' in result.auth_credential.http.credentials.token @@ -178,7 +180,8 @@ def test_openid_connect_with_auth_response( mock_auth_handler.get_auth_response.assert_called_once() -def test_openid_connect_existing_token( +@pytest.mark.asyncio +async def test_openid_connect_existing_token( openid_connect_scheme, openid_connect_credential ): _, existing_credential = token_to_scheme_credential( @@ -198,16 +201,17 @@ def test_openid_connect_existing_token( openid_connect_credential, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'done' assert result.auth_credential == existing_credential @patch( - 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialFetcher' + 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialRefresher' ) -def test_openid_connect_existing_oauth2_token_refresh( - mock_oauth2_fetcher, openid_connect_scheme, openid_connect_credential +@pytest.mark.asyncio +async def test_openid_connect_existing_oauth2_token_refresh( + mock_oauth2_refresher, openid_connect_scheme, openid_connect_credential ): """Test that OAuth2 tokens are refreshed when existing credentials are found.""" # Create existing OAuth2 credential @@ -232,10 +236,13 @@ def test_openid_connect_existing_oauth2_token_refresh( ), ) - # Setup mock OAuth2CredentialFetcher - mock_fetcher_instance = MagicMock() - mock_fetcher_instance.refresh.return_value = refreshed_credential - mock_oauth2_fetcher.return_value = mock_fetcher_instance + # Setup mock OAuth2CredentialRefresher + from unittest.mock import AsyncMock + + mock_refresher_instance = MagicMock() + mock_refresher_instance.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher_instance.refresh = AsyncMock(return_value=refreshed_credential) + mock_oauth2_refresher.return_value = mock_refresher_instance tool_context = create_mock_tool_context() credential_store = ToolContextCredentialStore(tool_context=tool_context) @@ -253,13 +260,17 @@ def test_openid_connect_existing_oauth2_token_refresh( credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() + + # Verify OAuth2CredentialRefresher was called for refresh + mock_oauth2_refresher.assert_called_once() - # Verify OAuth2CredentialFetcher was called for refresh - mock_oauth2_fetcher.assert_called_once_with( - openid_connect_scheme, existing_credential + mock_refresher_instance.is_refresh_needed.assert_called_once_with( + existing_credential + ) + mock_refresher_instance.refresh.assert_called_once_with( + existing_credential, openid_connect_scheme ) - mock_fetcher_instance.refresh.assert_called_once() assert result.state == 'done' # The result should contain the refreshed credential after exchange diff --git a/tests/unittests/tools/test_authenticated_function_tool.py b/tests/unittests/tools/test_authenticated_function_tool.py new file mode 100644 index 000000000..88454032a --- /dev/null +++ b/tests/unittests/tools/test_authenticated_function_tool.py @@ -0,0 +1,541 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.authenticated_function_tool import AuthenticatedFunctionTool +from google.adk.tools.tool_context import ToolContext +import pytest + +# Test functions for different scenarios + + +def sync_function_no_credential(arg1: str, arg2: int) -> str: + """Test sync function without credential parameter.""" + return f"sync_result_{arg1}_{arg2}" + + +async def async_function_no_credential(arg1: str, arg2: int) -> str: + """Test async function without credential parameter.""" + return f"async_result_{arg1}_{arg2}" + + +def sync_function_with_credential(arg1: str, credential: AuthCredential) -> str: + """Test sync function with credential parameter.""" + return f"sync_cred_result_{arg1}_{credential.auth_type.value}" + + +async def async_function_with_credential( + arg1: str, credential: AuthCredential +) -> str: + """Test async function with credential parameter.""" + return f"async_cred_result_{arg1}_{credential.auth_type.value}" + + +def sync_function_with_tool_context( + arg1: str, tool_context: ToolContext +) -> str: + """Test sync function with tool_context parameter.""" + return f"sync_context_result_{arg1}" + + +async def async_function_with_both( + arg1: str, tool_context: ToolContext, credential: AuthCredential +) -> str: + """Test async function with both tool_context and credential parameters.""" + return f"async_both_result_{arg1}_{credential.auth_type.value}" + + +def function_with_optional_args( + arg1: str, arg2: str = "default", credential: AuthCredential = None +) -> str: + """Test function with optional arguments.""" + cred_type = credential.auth_type.value if credential else "none" + return f"optional_result_{arg1}_{arg2}_{cred_type}" + + +class MockCallable: + """Test callable class for testing.""" + + def __init__(self): + self.__name__ = "MockCallable" + self.__doc__ = "Test callable documentation" + + def __call__(self, arg1: str, credential: AuthCredential) -> str: + return f"callable_result_{arg1}_{credential.auth_type.value}" + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + # Create a mock auth_type that returns the expected value + mock_auth_type = Mock() + mock_auth_type.value = "oauth2" + credential.auth_type = mock_auth_type + return credential + + +class TestAuthenticatedFunctionTool: + """Test suite for AuthenticatedFunctionTool.""" + + def test_init_with_sync_function(self): + """Test initialization with synchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, + auth_config=auth_config, + response_for_auth_required="Please authenticate", + ) + + assert tool.name == "sync_function_no_credential" + assert ( + tool.description == "Test sync function without credential parameter." + ) + assert tool.func == sync_function_no_credential + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == "Please authenticate" + assert "credential" in tool._ignore_params + + def test_init_with_async_function(self): + """Test initialization with asynchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=async_function_no_credential, auth_config=auth_config + ) + + assert tool.name == "async_function_no_credential" + assert ( + tool.description == "Test async function without credential parameter." + ) + assert tool.func == async_function_no_credential + assert tool._response_for_auth_required is None + + def test_init_with_callable(self): + """Test initialization with callable object.""" + auth_config = _create_mock_auth_config() + test_callable = MockCallable() + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + + assert tool.name == "MockCallable" + assert tool.description == "Test callable documentation" + assert tool.func == test_callable + + def test_init_no_auth_config(self): + """Test initialization without auth_config.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + + assert tool._credentials_manager is None + + @pytest.mark.asyncio + async def test_run_async_sync_function_no_credential_manager(self): + """Test run_async with sync function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_async_function_no_credential_manager(self): + """Test run_async with async function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=async_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "async_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"sync_cred_result_test_{credential.auth_type.value}" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_async_function_with_credential(self): + """Test run_async with async function that expects credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_cred_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, + auth_config=auth_config, + response_for_auth_required="Custom auth required", + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Custom auth required" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_default_message(self): + """Test run_async when no credential is available with default message.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + + @pytest.mark.asyncio + async def test_run_async_function_without_credential_param(self): + """Test run_async with function that doesn't have credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Credential should not be passed to function since it doesn't have the parameter + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_function_with_tool_context(self): + """Test run_async with function that has tool_context parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_tool_context, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_context_result_test" + + @pytest.mark.asyncio + async def test_run_async_function_with_both_params(self): + """Test run_async with function that has both tool_context and credential parameters.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_both, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_both_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_function_with_optional_credential(self): + """Test run_async with function that has optional credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=function_with_optional_args, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert ( + result == f"optional_result_test_default_{credential.auth_type.value}" + ) + + @pytest.mark.asyncio + async def test_run_async_callable_object(self): + """Test run_async with callable object.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + test_callable = MockCallable() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"callable_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_propagates_function_exception(self): + """Test that run_async propagates exceptions from the wrapped function.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + def failing_function(arg1: str, credential: AuthCredential) -> str: + raise ValueError("Function failed") + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=failing_function, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(ValueError, match="Function failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_missing_required_args(self): + """Test run_async with missing required arguments.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} # Missing arg2 + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Should return error dict indicating missing parameters + assert isinstance(result, dict) + assert "error" in result + assert "arg2" in result["error"] + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_credential_in_ignore_params(self): + """Test that 'credential' is added to ignore_params during initialization.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + assert "credential" in tool._ignore_params + + @pytest.mark.asyncio + async def test_run_async_with_none_credential(self): + """Test run_async when credential is None but function expects it.""" + tool = AuthenticatedFunctionTool(func=function_with_optional_args) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "optional_result_test_default_none" + + def test_signature_inspection(self): + """Test that the tool correctly inspects function signatures.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + signature = inspect.signature(tool.func) + assert "credential" in signature.parameters + assert "arg1" in signature.parameters + + @pytest.mark.asyncio + async def test_args_to_call_modification(self): + """Test that args_to_call is properly modified with credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + # Create a spy function to check what arguments are passed + original_args = {} + + def spy_function(arg1: str, credential: AuthCredential) -> str: + nonlocal original_args + original_args = {"arg1": arg1, "credential": credential} + return "spy_result" + + tool = AuthenticatedFunctionTool(func=spy_function, auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "spy_result" + assert original_args is not None + assert original_args["arg1"] == "test" + assert original_args["credential"] == credential diff --git a/tests/unittests/tools/test_base_authenticated_tool.py b/tests/unittests/tools/test_base_authenticated_tool.py new file mode 100644 index 000000000..55454224d --- /dev/null +++ b/tests/unittests/tools/test_base_authenticated_tool.py @@ -0,0 +1,343 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.base_authenticated_tool import BaseAuthenticatedTool +from google.adk.tools.tool_context import ToolContext +import pytest + + +class _TestAuthenticatedTool(BaseAuthenticatedTool): + """Test implementation of BaseAuthenticatedTool for testing purposes.""" + + def __init__( + self, + name="test_auth_tool", + description="Test authenticated tool", + auth_config=None, + unauthenticated_response=None, + ): + super().__init__( + name=name, + description=description, + auth_config=auth_config, + response_for_auth_required=unauthenticated_response, + ) + self.run_impl_called = False + self.run_impl_result = "test_result" + + async def _run_async_impl(self, *, args, tool_context, credential): + """Test implementation of the abstract method.""" + self.run_impl_called = True + self.last_args = args + self.last_tool_context = tool_context + self.last_credential = credential + return self.run_impl_result + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + credential.auth_type = AuthCredentialTypes.OAUTH2 + return credential + + +class TestBaseAuthenticatedTool: + """Test suite for BaseAuthenticatedTool.""" + + def test_init_with_auth_config(self): + """Test initialization with auth_config.""" + auth_config = _create_mock_auth_config() + unauthenticated_response = {"error": "Not authenticated"} + + tool = _TestAuthenticatedTool( + name="test_tool", + description="Test description", + auth_config=auth_config, + unauthenticated_response=unauthenticated_response, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test description" + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == unauthenticated_response + + def test_init_with_no_auth_config(self): + """Test initialization without auth_config.""" + tool = _TestAuthenticatedTool() + + assert tool.name == "test_auth_tool" + assert tool.description == "Test authenticated tool" + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._credentials_manager is None + + def test_init_with_default_unauthenticated_response(self): + """Test initialization with default unauthenticated response.""" + auth_config = _create_mock_auth_config() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._response_for_auth_required is None + + @pytest.mark.asyncio + async def test_run_async_no_credentials_manager(self): + """Test run_async when no credentials manager is configured.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential is None + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential == credential + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_custom_response(self): + """Test run_async when no credential is available with custom response.""" + auth_config = _create_mock_auth_config() + custom_response = { + "status": "authentication_required", + "message": "Please login", + } + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_string_response(self): + """Test run_async when no credential is available with string response.""" + auth_config = _create_mock_auth_config() + custom_response = "Custom authentication required message" + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + + @pytest.mark.asyncio + async def test_run_async_propagates_impl_exception(self): + """Test that run_async propagates exceptions from _run_async_impl.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + # Make the implementation raise an exception + async def failing_impl(*, args, tool_context, credential): + raise ValueError("Implementation failed") + + tool._run_async_impl = failing_impl + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(ValueError, match="Implementation failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_with_different_args_types(self): + """Test run_async with different argument types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + + # Test with empty args + result = await tool.run_async(args={}, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == {} + + # Test with complex args + complex_args = { + "string_param": "test", + "number_param": 42, + "list_param": [1, 2, 3], + "dict_param": {"nested": "value"}, + } + result = await tool.run_async(args=complex_args, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == complex_args + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_abstract_nature(self): + """Test that BaseAuthenticatedTool cannot be instantiated directly.""" + with pytest.raises(TypeError): + # This should fail because _run_async_impl is abstract + BaseAuthenticatedTool(name="test", description="test") + + @pytest.mark.asyncio + async def test_run_async_return_values(self): + """Test run_async with different return value types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {} + + # Test with None return + tool.run_impl_result = None + result = await tool.run_async(args=args, tool_context=tool_context) + assert result is None + + # Test with dict return + tool.run_impl_result = {"key": "value"} + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == {"key": "value"} + + # Test with list return + tool.run_impl_result = [1, 2, 3] + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == [1, 2, 3]