diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index a504fde0d..565ee1dca 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -43,11 +43,19 @@ jobs: run: | uv venv .venv source .venv/bin/activate - uv sync --extra test --extra eval + uv sync --extra test --extra eval --extra a2a - name: Run unit tests with pytest run: | source .venv/bin/activate - pytest tests/unittests \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + if [[ "${{ matrix.python-version }}" == "3.9" ]]; then + pytest tests/unittests \ + --ignore=tests/unittests/a2a \ + --ignore=tests/unittests/tools/mcp_tool \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + else + pytest tests/unittests \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py + fi \ No newline at end of file diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml new file mode 100644 index 000000000..2e258857f --- /dev/null +++ b/.github/workflows/triage.yml @@ -0,0 +1,43 @@ +name: ADK Issue Triaging Agent + +on: + issues: + types: [opened, reopened] + schedule: + - cron: '0 */6 * * *' # every 6h + +jobs: + agent-triage-issues: + runs-on: ubuntu-latest + permissions: + issues: write + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install requests google-adk + + - name: Run Triaging Script + env: + GITHUB_TOKEN: ${{ secrets.ADK_TRIAGE_AGENT }} + GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} + GOOGLE_GENAI_USE_VERTEXAI: 0 + OWNER: 'google' + REPO: 'adk-python' + INTERACTIVE: 0 + EVENT_NAME: ${{ github.event_name }} # 'issues', 'schedule', etc. + ISSUE_NUMBER: ${{ github.event.issue.number }} + ISSUE_TITLE: ${{ github.event.issue.title }} + ISSUE_BODY: ${{ github.event.issue.body }} + ISSUE_COUNT_TO_PROCESS: '3' # Process 3 issues at a time on schedule + run: python contributing/samples/adk_triaging_agent/main.py diff --git a/contributing/samples/adk_triaging_agent/README.md b/contributing/samples/adk_triaging_agent/README.md new file mode 100644 index 000000000..be4071b61 --- /dev/null +++ b/contributing/samples/adk_triaging_agent/README.md @@ -0,0 +1,67 @@ +# ADK Issue Triaging Assistant + +The ADK Issue Triaging Assistant is a Python-based agent designed to help manage and triage GitHub issues for the `google/adk-python` repository. It uses a large language model to analyze new and unlabelled issues, recommend appropriate labels based on a predefined set of rules, and apply them. + +This agent can be operated in two distinct modes: an interactive mode for local use or as a fully automated GitHub Actions workflow. + +--- + +## Interactive Mode + +This mode allows you to run the agent locally to review its recommendations in real-time before any changes are made to your repository's issues. + +### Features +* **Web Interface**: The agent's interactive mode can be rendered in a web browser using the ADK's `adk web` command. +* **User Approval**: In interactive mode, the agent is instructed to ask for your confirmation before applying a label to a GitHub issue. + +### Running in Interactive Mode +To run the agent in interactive mode, first set the required environment variables. Then, execute the following command in your terminal: + +```bash +adk web +``` +This will start a local server and provide a URL to access the agent's web interface in your browser. + +--- + +## GitHub Workflow Mode + +For automated, hands-off issue triaging, the agent can be integrated directly into your repository's CI/CD pipeline using a GitHub Actions workflow. + +### Workflow Triggers +The GitHub workflow is configured to run on specific triggers: + +1. **Issue Events**: The workflow executes automatically whenever a new issue is `opened` or an existing one is `reopened`. + +2. **Scheduled Runs**: The workflow also runs on a recurring schedule (every 6 hours) to process any unlabelled issues that may have been missed. + +### Automated Labeling +When running as part of the GitHub workflow, the agent operates non-interactively. It identifies the best label and applies it directly without requiring user approval. This behavior is configured by setting the `INTERACTIVE` environment variable to `0` in the workflow file. + +### Workflow Configuration +The workflow is defined in a YAML file (`.github/workflows/triage.yml`). This file contains the steps to check out the code, set up the Python environment, install dependencies, and run the triaging script with the necessary environment variables and secrets. + +--- + +## Setup and Configuration + +Whether running in interactive or workflow mode, the agent requires the following setup. + +### Dependencies +The agent requires the following Python libraries. + +```bash +pip install --upgrade pip +pip install google-adk requests +``` + +### Environment Variables +The following environment variables are required for the agent to connect to the necessary services. + +* `GITHUB_TOKEN`: **(Required)** A GitHub Personal Access Token with `issues:write` permissions. Needed for both interactive and workflow modes. +* `GOOGLE_API_KEY`: **(Required)** Your API key for the Gemini API. Needed for both interactive and workflow modes. +* `OWNER`: The GitHub organization or username that owns the repository (e.g., `google`). Needed for both modes. +* `REPO`: The name of the GitHub repository (e.g., `adk-python`). Needed for both modes. +* `INTERACTIVE`: Controls the agent's interaction mode. For the automated workflow, this is set to `0`. For interactive mode, it should be set to `1` or left unset. + +For local execution in interactive mode, you can place these variables in a `.env` file in the project's root directory. For the GitHub workflow, they should be configured as repository secrets. \ No newline at end of file diff --git a/contributing/samples/adk_triaging_agent/__init__.py b/contributing/samples/adk_triaging_agent/__init__.py new file mode 100755 index 000000000..c48963cdc --- /dev/null +++ b/contributing/samples/adk_triaging_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/adk_triaging_agent/agent.py b/contributing/samples/adk_triaging_agent/agent.py index 2720e5b46..ecf574572 100644 --- a/contributing/samples/adk_triaging_agent/agent.py +++ b/contributing/samples/adk_triaging_agent/agent.py @@ -13,61 +13,73 @@ # limitations under the License. import os -import random -import time from google.adk import Agent -from google.adk.tools.tool_context import ToolContext -from google.genai import types import requests -# Read the PAT from the environment variable -GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") # Ensure you've set this in your shell +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") if not GITHUB_TOKEN: raise ValueError("GITHUB_TOKEN environment variable not set") -# Repository information -OWNER = "google" -REPO = "adk-python" +OWNER = os.getenv("OWNER", "google") +REPO = os.getenv("REPO", "adk-python") +BOT_LABEL = os.getenv("BOT_LABEL", "bot_triaged") -# Base URL for the GitHub API BASE_URL = "https://api.github.com" -# Headers including the Authorization header headers = { "Authorization": f"token {GITHUB_TOKEN}", "Accept": "application/vnd.github.v3+json", } +ALLOWED_LABELS = [ + "documentation", + "services", + "question", + "tools", + "eval", + "live", + "models", + "tracing", + "core", + "web", +] -def list_issues(per_page: int): + +def is_interactive(): + return os.environ.get("INTERACTIVE", "1").lower() in ["true", "1"] + + +def list_issues(issue_count: int): """ Generator to list all issues for the repository by handling pagination. Args: - per_page: number of pages to return per page. + issue_count: number of issues to return """ - state = "open" - # only process the 1st page for testing for now - page = 1 - results = [] - url = ( # :contentReference[oaicite:16]{index=16} - f"{BASE_URL}/repos/{OWNER}/{REPO}/issues" - ) - # Warning: let's only handle max 10 issues at a time to avoid bad results - params = {"state": state, "per_page": per_page, "page": page} - response = requests.get(url, headers=headers, params=params) - response.raise_for_status() # :contentReference[oaicite:17]{index=17} - issues = response.json() + query = f"repo:{OWNER}/{REPO} is:open is:issue no:label" + + unlabelled_issues = [] + url = f"{BASE_URL}/search/issues" + + params = { + "q": query, + "sort": "created", + "order": "desc", + "per_page": issue_count, + "page": 1, + } + response = requests.get(url, headers=headers, params=params, timeout=60) + response.raise_for_status() + json_response = response.json() + issues = json_response.get("items", None) if not issues: return [] for issue in issues: - # Skip pull requests (issues API returns PRs as well) - if "pull_request" in issue: - continue - results.append(issue) - return results + if not issue.get("labels", None) or len(issue["labels"]) == 0: + unlabelled_issues.append(issue) + return unlabelled_issues def add_label_to_issue(issue_number: str, label: str): @@ -78,41 +90,56 @@ def add_label_to_issue(issue_number: str, label: str): issue_number: issue number of the Github issue, in string foramt. label: label to assign """ + print(f"Attempting to add label '{label}' to issue #{issue_number}") + if label not in ALLOWED_LABELS: + error_message = ( + f"Error: Label '{label}' is not an allowed label. Will not apply." + ) + print(error_message) + return {"status": "error", "message": error_message, "applied_label": None} + url = f"{BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}/labels" - payload = [label] - response = requests.post(url, headers=headers, json=payload) + payload = [label, BOT_LABEL] + response = requests.post(url, headers=headers, json=payload, timeout=60) response.raise_for_status() return response.json() +approval_instruction = ( + "Only label them when the user approves the labeling!" + if is_interactive() + else ( + "Do not ask for user approval for labeling! If you can't find a" + " appropriate labels for the issue, do not label it." + ) +) + root_agent = Agent( model="gemini-2.5-pro-preview-05-06", name="adk_triaging_assistant", description="Triage ADK issues.", - instruction=""" - You are a Github adk-python repo triaging bot. You will help get issues, and label them. + instruction=f""" + You are a Github adk-python repo triaging bot. You will help get issues, and recommend a label. + IMPORTANT: {approval_instruction} Here are the rules for labeling: - If the user is asking about documentation-related questions, label it with "documentation". - If it's about session, memory services, label it with "services" - - If it's about UI/web, label it with "question" + - If it's about UI/web, label it with "web" + - If the user is asking about a question, label it with "question" - If it's related to tools, label it with "tools" - If it's about agent evalaution, then label it with "eval". - If it's about streaming/live, label it with "live". - If it's about model support(non-Gemini, like Litellm, Ollama, OpenAI models), label it with "models". - If it's about tracing, label it with "tracing". - If it's agent orchestration, agent definition, label it with "core". - - If you can't find a appropriate labels for the issue, return the issues to user to decide. + - If you can't find a appropriate labels for the issue, follow the previous instruction that starts with "IMPORTANT:". + + Present the followings in an easy to read format highlighting issue number and your label. + - the issue summary in a few sentence + - your label recommendation and justification """, tools=[ list_issues, add_label_to_issue, ], - generate_content_config=types.GenerateContentConfig( - safety_settings=[ - types.SafetySetting( # avoid false alarm about rolling dice. - category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=types.HarmBlockThreshold.OFF, - ), - ] - ), ) diff --git a/contributing/samples/adk_triaging_agent/main.py b/contributing/samples/adk_triaging_agent/main.py new file mode 100644 index 000000000..a749b26fc --- /dev/null +++ b/contributing/samples/adk_triaging_agent/main.py @@ -0,0 +1,164 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import time + +import agent +from dotenv import load_dotenv +from google.adk.agents.run_config import RunConfig +from google.adk.runners import InMemoryRunner +from google.adk.sessions import Session +from google.genai import types +import requests + +load_dotenv(override=True) + +OWNER = os.getenv("OWNER", "google") +REPO = os.getenv("REPO", "adk-python") +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +BASE_URL = "https://api.github.com" +headers = { + "Authorization": f"token {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +} + +if not GITHUB_TOKEN: + print( + "Warning: GITHUB_TOKEN environment variable not set. API calls might" + " fail." + ) + + +async def fetch_specific_issue_details(issue_number: int): + """Fetches details for a single issue if it's unlabelled.""" + if not GITHUB_TOKEN: + print("Cannot fetch issue details: GITHUB_TOKEN is not set.") + return None + + url = f"{BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}" + print(f"Fetching details for specific issue: {url}") + try: + response = requests.get(url, headers=headers, timeout=60) + response.raise_for_status() + issue_data = response.json() + if not issue_data.get("labels") or len(issue_data["labels"]) == 0: + print(f"Issue #{issue_number} is unlabelled. Proceeding.") + return { + "number": issue_data["number"], + "title": issue_data["title"], + "body": issue_data.get("body", ""), + } + else: + print(f"Issue #{issue_number} is already labelled. Skipping.") + return None + except requests.exceptions.RequestException as e: + print(f"Error fetching issue #{issue_number}: {e}") + if hasattr(e, "response") and e.response is not None: + print(f"Response content: {e.response.text}") + return None + + +async def main(): + app_name = "triage_app" + user_id_1 = "triage_user" + runner = InMemoryRunner( + agent=agent.root_agent, + app_name=app_name, + ) + session_11 = await runner.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_agent_prompt(session: Session, prompt_text: str): + content = types.Content( + role="user", parts=[types.Part.from_text(text=prompt_text)] + ) + print(f"\n>>>> Agent Prompt: {prompt_text}") + final_agent_response_parts = [] + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=False), + ): + if event.content.parts and event.content.parts[0].text: + print(f"** {event.author} (ADK): {event.content.parts[0].text}") + if event.author == agent.root_agent.name: + final_agent_response_parts.append(event.content.parts[0].text) + print(f"<<<< Agent Final Output: {''.join(final_agent_response_parts)}\n") + + event_name = os.getenv("EVENT_NAME") + issue_number_str = os.getenv("ISSUE_NUMBER") + + if event_name == "issues" and issue_number_str: + print(f"EVENT: Processing specific issue due to '{event_name}' event.") + try: + issue_number = int(issue_number_str) + specific_issue = await fetch_specific_issue_details(issue_number) + + if specific_issue: + prompt = ( + f"A new GitHub issue #{specific_issue['number']} has been opened or" + f" reopened. Title: \"{specific_issue['title']}\"\nBody:" + f" \"{specific_issue['body']}\"\n\nBased on the rules, recommend an" + " appropriate label and its justification." + " Then, use the 'add_label_to_issue' tool to apply the label " + "directly to this issue." + f" The issue number is {specific_issue['number']}." + ) + await run_agent_prompt(session_11, prompt) + else: + print( + f"No unlabelled issue details found for #{issue_number} or an error" + " occurred. Skipping agent interaction." + ) + + except ValueError: + print(f"Error: Invalid ISSUE_NUMBER received: {issue_number_str}") + + else: + print(f"EVENT: Processing batch of issues (event: {event_name}).") + issue_count_str = os.getenv("ISSUE_COUNT_TO_PROCESS", "3") + try: + num_issues_to_process = int(issue_count_str) + except ValueError: + print(f"Warning: Invalid ISSUE_COUNT_TO_PROCESS. Defaulting to 3.") + num_issues_to_process = 3 + + prompt = ( + f"List the first {num_issues_to_process} unlabelled open issues from" + f" the {OWNER}/{REPO} repository. For each issue, provide a summary," + " recommend a label with justification, and then use the" + " 'add_label_to_issue' tool to apply the recommended label directly." + ) + await run_agent_prompt(session_11, prompt) + + +if __name__ == "__main__": + start_time = time.time() + print( + "Script start time:", + time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(start_time)), + ) + print("------------------------------------") + asyncio.run(main()) + end_time = time.time() + print("------------------------------------") + print( + "Script end time:", + time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(end_time)), + ) + print("Total script execution time:", f"{end_time - start_time:.2f} seconds") diff --git a/contributing/samples/hello_world_ma/agent.py b/contributing/samples/hello_world_ma/agent.py index 98e79a397..a6bf78a9e 100755 --- a/contributing/samples/hello_world_ma/agent.py +++ b/contributing/samples/hello_world_ma/agent.py @@ -15,6 +15,7 @@ import random from google.adk.agents import Agent +from google.adk.examples.example import Example from google.adk.tools.example_tool import ExampleTool from google.genai import types @@ -66,43 +67,47 @@ def check_prime(nums: list[int]) -> str: ) -example_tool = ExampleTool([ - { - "input": { - "role": "user", - "parts": [{"text": "Roll a 6-sided die."}], - }, - "output": [ - {"role": "model", "parts": [{"text": "I rolled a 4 for you."}]} - ], - }, - { - "input": { - "role": "user", - "parts": [{"text": "Is 7 a prime number?"}], - }, - "output": [{ - "role": "model", - "parts": [{"text": "Yes, 7 is a prime number."}], - }], - }, - { - "input": { - "role": "user", - "parts": [{"text": "Roll a 10-sided die and check if it's prime."}], - }, - "output": [ - { - "role": "model", - "parts": [{"text": "I rolled an 8 for you."}], - }, - { - "role": "model", - "parts": [{"text": "8 is not a prime number."}], - }, - ], - }, -]) +example_tool = ExampleTool( + examples=[ + Example( + input=types.UserContent( + parts=[types.Part(text="Roll a 6-sided die.")] + ), + output=[ + types.ModelContent( + parts=[types.Part(text="I rolled a 4 for you.")] + ) + ], + ), + Example( + input=types.UserContent( + parts=[types.Part(text="Is 7 a prime number?")] + ), + output=[ + types.ModelContent( + parts=[types.Part(text="Yes, 7 is a prime number.")] + ) + ], + ), + Example( + input=types.UserContent( + parts=[ + types.Part( + text="Roll a 10-sided die and check if it's prime." + ) + ] + ), + output=[ + types.ModelContent( + parts=[types.Part(text="I rolled an 8 for you.")] + ), + types.ModelContent( + parts=[types.Part(text="8 is not a prime number.")] + ), + ], + ), + ] +) prime_agent = Agent( name="prime_agent", diff --git a/contributing/samples/live_bidi_streaming_agent/__init__.py b/contributing/samples/live_bidi_streaming_agent/__init__.py new file mode 100755 index 000000000..c48963cdc --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/live_bidi_streaming_agent/agent.py b/contributing/samples/live_bidi_streaming_agent/agent.py new file mode 100755 index 000000000..2896bd70f --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/agent.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk import Agent +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +def roll_die(sides: int, tool_context: ToolContext) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if not 'rolls' in tool_context.state: + tool_context.state['rolls'] = [] + + tool_context.state['rolls'] = tool_context.state['rolls'] + [result] + return result + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + 'No prime numbers found.' + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model='gemini-2.0-flash-live-preview-04-09', # for Vertex project + # model='gemini-2.0-flash-live-001', # for AI studio key + name='hello_world_agent', + description=( + 'hello world agent that can roll a dice of 8 sides and check prime' + ' numbers.' + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) diff --git a/contributing/samples/live_bidi_streaming_agent/readme.md b/contributing/samples/live_bidi_streaming_agent/readme.md new file mode 100644 index 000000000..6a9258f3e --- /dev/null +++ b/contributing/samples/live_bidi_streaming_agent/readme.md @@ -0,0 +1,37 @@ +# Simplistic Live (Bidi-Streaming) Agent +This project provides a basic example of a live, bidirectional streaming agent +designed for testing and experimentation. + +You can see full documentation [here](https://google.github.io/adk-docs/streaming/). + +## Getting Started + +Follow these steps to get the agent up and running: + +1. **Start the ADK Web Server** + Open your terminal, navigate to the root directory that contains the + `live_bidi_streaming_agent` folder, and execute the following command: + ```bash + adk web + ``` + +2. **Access the ADK Web UI** + Once the server is running, open your web browser and navigate to the URL + provided in the terminal (it will typically be `http://localhost:8000`). + +3. **Select the Agent** + In the top-left corner of the ADK Web UI, use the dropdown menu to select + this agent. + +4. **Start Streaming** + Click on either the **Audio** or **Video** icon located near the chat input + box to begin the streaming session. + +5. **Interact with the Agent** + You can now begin talking to the agent, and it will respond in real-time. + +## Usage Notes + +* You only need to click the **Audio** or **Video** button once to initiate the + stream. The current version does not support stopping and restarting the stream + by clicking the button again during a session. diff --git a/contributing/samples/mcp_streamablehttp_agent/README.md b/contributing/samples/mcp_streamablehttp_agent/README.md index 1c211dd71..547a0788d 100644 --- a/contributing/samples/mcp_streamablehttp_agent/README.md +++ b/contributing/samples/mcp_streamablehttp_agent/README.md @@ -1,8 +1,7 @@ -This agent connects to a local MCP server via sse. +This agent connects to a local MCP server via Streamable HTTP. To run this agent, start the local MCP server first by : ```bash uv run filesystem_server.py ``` - diff --git a/contributing/samples/mcp_streamablehttp_agent/agent.py b/contributing/samples/mcp_streamablehttp_agent/agent.py index 61d59e051..f165c4c1b 100644 --- a/contributing/samples/mcp_streamablehttp_agent/agent.py +++ b/contributing/samples/mcp_streamablehttp_agent/agent.py @@ -18,7 +18,6 @@ from google.adk.agents.llm_agent import LlmAgent from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPServerParams from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset -from google.adk.tools.mcp_tool.mcp_toolset import SseServerParams _allowed_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/contributing/samples/quickstart/agent.py b/contributing/samples/quickstart/agent.py index fdd6b7f9d..b251069ad 100644 --- a/contributing/samples/quickstart/agent.py +++ b/contributing/samples/quickstart/agent.py @@ -29,7 +29,7 @@ def get_weather(city: str) -> dict: "status": "success", "report": ( "The weather in New York is sunny with a temperature of 25 degrees" - " Celsius (41 degrees Fahrenheit)." + " Celsius (77 degrees Fahrenheit)." ), } else: diff --git a/pyproject.toml b/pyproject.toml index 158025c7e..8ece4db81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,27 +25,32 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start - "authlib>=1.5.1", # For RestAPI Tool - "click>=8.1.8", # For CLI tools - "fastapi>=0.115.0", # FastAPI framework - "google-api-python-client>=2.157.0", # Google API client discovery - "google-cloud-aiplatform[agent_engines]>=1.95.1", # For VertexAI integrations, e.g. example store. - "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool - "google-cloud-speech>=2.30.0", # For Audio Transcription - "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service - "google-genai>=1.17.0", # Google GenAI SDK - "graphviz>=0.20.2", # Graphviz for graph rendering - "mcp>=1.8.0;python_version>='3.10'", # For MCP Toolset - "opentelemetry-api>=1.31.0", # OpenTelemetry + "anyio>=4.9.0;python_version>='3.10'", # For MCP Session Manager + "authlib>=1.5.1", # For RestAPI Tool + "click>=8.1.8", # For CLI tools + "fastapi>=0.115.0", # FastAPI framework + "google-api-python-client>=2.157.0", # Google API client discovery + "google-cloud-aiplatform[agent_engines]>=1.95.1", # For VertexAI integrations, e.g. example store. + "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool + "google-cloud-speech>=2.30.0", # For Audio Transcription + "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service + "google-genai>=1.17.0", # Google GenAI SDK + "graphviz>=0.20.2", # Graphviz for graph rendering + "mcp>=1.8.0;python_version>='3.10'", # For MCP Toolset + "opentelemetry-api>=1.31.0", # OpenTelemetry "opentelemetry-exporter-gcp-trace>=1.9.0", "opentelemetry-sdk>=1.31.0", - "pydantic>=2.0, <3.0.0", # For data validation/models - "python-dotenv>=1.0.0", # To manage environment variables - "PyYAML>=6.0.2", # For APIHubToolset. - "sqlalchemy>=2.0", # SQL database ORM - "tzlocal>=5.3", # Time zone utilities + "pydantic>=2.0, <3.0.0", # For data validation/models + "python-dateutil>=2.9.0.post0", # For Vertext AI Session Service + "python-dotenv>=1.0.0", # To manage environment variables + "PyYAML>=6.0.2", # For APIHubToolset. + "requests>=2.32.4", + "sqlalchemy>=2.0", # SQL database ORM + "starlette>=0.46.2", # For FastAPI CLI "typing-extensions>=4.5, <5", - "uvicorn>=0.34.0", # ASGI server for FastAPI + "tzlocal>=5.3", # Time zone utilities + "uvicorn>=0.34.0", # ASGI server for FastAPI + "websockets>=15.0.1", # For BaseLlmFlow # go/keep-sorted end ] dynamic = ["version"] @@ -71,6 +76,12 @@ dev = [ # go/keep-sorted end ] +a2a = [ + # go/keep-sorted start + "a2a-sdk>=0.2.7;python_version>='3.10'" + # go/keep-sorted end +] + eval = [ # go/keep-sorted start "google-cloud-aiplatform[evaluation]>=1.87.0", @@ -84,7 +95,7 @@ test = [ "anthropic>=0.43.0", # For anthropic model tests "langchain-community>=0.3.17", "langgraph>=0.2.60", # For LangGraphAgent - "litellm>=1.71.2", # For LiteLLM tests + "litellm>=1.71.2", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "pytest-asyncio>=0.25.0", @@ -140,24 +151,30 @@ pyink-annotation-pragmas = [ requires = ["flit_core >=3.8,<4"] build-backend = "flit_core.buildapi" + [tool.flit.sdist] include = ['src/**/*', 'README.md', 'pyproject.toml', 'LICENSE'] exclude = ['src/**/*.sh'] + [tool.flit.module] name = "google.adk" include = ["py.typed"] + [tool.isort] profile = "google" single_line_exclusions = [] +line_length = 200 # Prevent line wrap flickering. known_third_party = ["google.adk"] + [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" asyncio_mode = "auto" + [tool.mypy] python_version = "3.9" exclude = "tests/" diff --git a/src/google/adk/a2a/__init__.py b/src/google/adk/a2a/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/a2a/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/a2a/converters/__init__.py b/src/google/adk/a2a/converters/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/a2a/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py new file mode 100644 index 000000000..2d94abd7c --- /dev/null +++ b/src/google/adk/a2a/converters/part_converter.py @@ -0,0 +1,177 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +module containing utilities for conversion betwen A2A Part and Google GenAI Part +""" + +from __future__ import annotations + +import json +import logging +import sys +from typing import Optional + +try: + from a2a import types as a2a_types +except ImportError as e: + if sys.version_info < (3, 10): + raise ImportError( + 'A2A Tool requires Python 3.10 or above. Please upgrade your Python' + ' version.' + ) from e + else: + raise e + +from google.genai import types as genai_types + +from ...utils.feature_decorator import working_in_progress + +logger = logging.getLogger('google_adk.' + __name__) + +A2A_DATA_PART_METADATA_TYPE_KEY = 'type' +A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = 'function_call' +A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response' + + +@working_in_progress +def convert_a2a_part_to_genai_part( + a2a_part: a2a_types.Part, +) -> Optional[genai_types.Part]: + """Convert an A2A Part to a Google GenAI Part.""" + part = a2a_part.root + if isinstance(part, a2a_types.TextPart): + return genai_types.Part(text=part.text) + + if isinstance(part, a2a_types.FilePart): + if isinstance(part.file, a2a_types.FileWithUri): + return genai_types.Part( + file_data=genai_types.FileData( + file_uri=part.file.uri, mime_type=part.file.mimeType + ) + ) + + elif isinstance(part.file, a2a_types.FileWithBytes): + return genai_types.Part( + inline_data=genai_types.Blob( + data=part.file.bytes.encode('utf-8'), mime_type=part.file.mimeType + ) + ) + else: + logger.warning( + 'Cannot convert unsupported file type: %s for A2A part: %s', + type(part.file), + a2a_part, + ) + return None + + if isinstance(part, a2a_types.DataPart): + # Conver the Data Part to funcall and function reponse. + # This is mainly for converting human in the loop and auth request and + # response. + # TODO once A2A defined how to suervice such information, migrate below + # logic accordinlgy + if part.metadata and A2A_DATA_PART_METADATA_TYPE_KEY in part.metadata: + if ( + part.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ): + return genai_types.Part( + function_call=genai_types.FunctionCall.model_validate( + part.data, by_alias=True + ) + ) + if ( + part.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ): + return genai_types.Part( + function_response=genai_types.FunctionResponse.model_validate( + part.data, by_alias=True + ) + ) + return genai_types.Part(text=json.dumps(part.data)) + + logger.warning( + 'Cannot convert unsupported part type: %s for A2A part: %s', + type(part), + a2a_part, + ) + return None + + +@working_in_progress +def convert_genai_part_to_a2a_part( + part: genai_types.Part, +) -> Optional[a2a_types.Part]: + """Convert a Google GenAI Part to an A2A Part.""" + if part.text: + return a2a_types.TextPart(text=part.text) + + if part.file_data: + return a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=part.file_data.file_uri, + mimeType=part.file_data.mime_type, + ) + ) + + if part.inline_data: + return a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=part.inline_data.data, + mimeType=part.inline_data.mime_type, + ) + ) + ) + + # Conver the funcall and function reponse to A2A DataPart. + # This is mainly for converting human in the loop and auth request and + # response. + # TODO once A2A defined how to suervice such information, migrate below + # logic accordinlgy + if part.function_call: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.function_call.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + }, + ) + ) + + if part.function_response: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.function_response.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + }, + ) + ) + + logger.warning( + 'Cannot convert unsupported part for Google GenAI part: %s', + part, + ) + return None diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index f70371535..765f22a2c 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -22,6 +22,7 @@ from pydantic import ConfigDict from ..artifacts.base_artifact_service import BaseArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService from ..memory.base_memory_service import BaseMemoryService from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session @@ -115,6 +116,7 @@ class InvocationContext(BaseModel): artifact_service: Optional[BaseArtifactService] = None session_service: BaseSessionService memory_service: Optional[BaseMemoryService] = None + credential_service: Optional[BaseCredentialService] = None invocation_id: str """The id of this invocation context. Readonly.""" diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index 566fe8606..c9a50a0ae 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from enum import Enum import logging import sys @@ -68,6 +70,15 @@ class RunConfig(BaseModel): input_audio_transcription: Optional[types.AudioTranscriptionConfig] = None """Input transcription for live agents with audio input from user.""" + realtime_input_config: Optional[types.RealtimeInputConfig] = None + """Realtime input config for live agents with audio input from user.""" + + enable_affective_dialog: Optional[bool] = None + """If enabled, the model will detect emotions and adapt its responses accordingly.""" + + proactivity: Optional[types.ProactivityConfig] = None + """Configures the proactivity of the model. This allows the model to respond proactively to the input and to ignore irrelevant input.""" + max_llm_calls: int = 500 """ A limit on the total number of llm calls for a given run. diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index db6fa9767..34d04dde9 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from enum import Enum from typing import Any from typing import Dict @@ -75,6 +77,8 @@ class OAuth2Auth(BaseModelWithConfig): auth_code: Optional[str] = None access_token: Optional[str] = None refresh_token: Optional[str] = None + expires_at: Optional[int] = None + expires_in: Optional[int] = None class ServiceAccountCredential(BaseModelWithConfig): diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 1607dcadf..473f31413 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -16,16 +16,13 @@ from typing import TYPE_CHECKING -from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import SecurityBase from .auth_credential import AuthCredential -from .auth_credential import AuthCredentialTypes -from .auth_credential import OAuth2Auth from .auth_schemes import AuthSchemeType -from .auth_schemes import OAuthGrantType from .auth_schemes import OpenIdConnectWithConfig from .auth_tool import AuthConfig +from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger if TYPE_CHECKING: from ..sessions.state import State @@ -33,86 +30,31 @@ try: from authlib.integrations.requests_client import OAuth2Session - SUPPORT_TOKEN_EXCHANGE = True + AUTHLIB_AVIALABLE = True except ImportError: - SUPPORT_TOKEN_EXCHANGE = False + AUTHLIB_AVIALABLE = False class AuthHandler: + """A handler that handles the auth flow in Agent Development Kit to help + orchestrate the credential request and response flow (e.g. OAuth flow) + This class should only be used by Agent Development Kit. + """ def __init__(self, auth_config: AuthConfig): self.auth_config = auth_config - def exchange_auth_token( + async def exchange_auth_token( self, ) -> AuthCredential: - """Generates an auth token from the authorization response. - - Returns: - An AuthCredential object containing the access token. - - Raises: - ValueError: If the token endpoint is not configured in the auth - scheme. - AuthCredentialMissingError: If the access token cannot be retrieved - from the token endpoint. - """ - auth_scheme = self.auth_config.auth_scheme - auth_credential = self.auth_config.exchanged_auth_credential - if not SUPPORT_TOKEN_EXCHANGE: - return auth_credential - if isinstance(auth_scheme, OpenIdConnectWithConfig): - if not hasattr(auth_scheme, "token_endpoint"): - return self.auth_config.exchanged_auth_credential - token_endpoint = auth_scheme.token_endpoint - scopes = auth_scheme.scopes - elif isinstance(auth_scheme, OAuth2): - if ( - not auth_scheme.flows.authorizationCode - or not auth_scheme.flows.authorizationCode.tokenUrl - ): - return self.auth_config.exchanged_auth_credential - token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl - scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) - else: - return self.auth_config.exchanged_auth_credential - - if ( - not auth_credential - or not auth_credential.oauth2 - or not auth_credential.oauth2.client_id - or not auth_credential.oauth2.client_secret - or auth_credential.oauth2.access_token - or auth_credential.oauth2.refresh_token - ): - return self.auth_config.exchanged_auth_credential - - client = OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, - scope=" ".join(scopes), - redirect_uri=auth_credential.oauth2.redirect_uri, - state=auth_credential.oauth2.state, - ) - tokens = client.fetch_token( - token_endpoint, - authorization_response=auth_credential.oauth2.auth_response_uri, - code=auth_credential.oauth2.auth_code, - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - ) - - updated_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - access_token=tokens.get("access_token"), - refresh_token=tokens.get("refresh_token"), - ), + exchanger = OAuth2CredentialExchanger() + return await exchanger.exchange( + self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme ) - return updated_credential - def parse_and_store_auth_response(self, state: State) -> None: + async def parse_and_store_auth_response(self, state: State) -> None: - credential_key = "temp:" + self.auth_config.get_credential_key() + credential_key = "temp:" + self.auth_config.credential_key state[credential_key] = self.auth_config.exchanged_auth_credential if not isinstance( @@ -123,14 +65,14 @@ def parse_and_store_auth_response(self, state: State) -> None: ): return - state[credential_key] = self.exchange_auth_token() + state[credential_key] = await self.exchange_auth_token() def _validate(self) -> None: if not self.auth_scheme: raise ValueError("auth_scheme is empty.") def get_auth_response(self, state: State) -> AuthCredential: - credential_key = "temp:" + self.auth_config.get_credential_key() + credential_key = "temp:" + self.auth_config.credential_key return state.get(credential_key, None) def generate_auth_request(self) -> AuthConfig: @@ -204,6 +146,13 @@ def generate_auth_uri( ValueError: If the authorization endpoint is not configured in the auth scheme. """ + if not AUTHLIB_AVIALABLE: + return ( + self.auth_config.raw_auth_credential.model_copy(deep=True) + if self.auth_config.raw_auth_credential + else None + ) + auth_scheme = self.auth_config.auth_scheme auth_credential = self.auth_config.raw_auth_credential diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 0c964ed96..b06774973 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -67,9 +67,9 @@ async def run_async( # function call request_euc_function_call_ids.add(function_call_response.id) auth_config = AuthConfig.model_validate(function_call_response.response) - AuthHandler(auth_config=auth_config).parse_and_store_auth_response( - state=invocation_context.session.state - ) + await AuthHandler( + auth_config=auth_config + ).parse_and_store_auth_response(state=invocation_context.session.state) break if not request_euc_function_call_ids: diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py index a1a19ab87..53c571d42 100644 --- a/src/google/adk/auth/auth_tool.py +++ b/src/google/adk/auth/auth_tool.py @@ -14,6 +14,10 @@ from __future__ import annotations +from typing import Optional + +from typing_extensions import deprecated + from .auth_credential import AuthCredential from .auth_credential import BaseModelWithConfig from .auth_schemes import AuthScheme @@ -45,11 +49,23 @@ class AuthConfig(BaseModelWithConfig): this field to guide the user through the OAuth2 flow and fill auth response in this field""" + credential_key: Optional[str] = None + """A user specified key used to load and save this credential in a credential + service. + """ + + def __init__(self, **data): + super().__init__(**data) + if self.credential_key: + return + self.credential_key = self.get_credential_key() + + @deprecated("This method is deprecated. Use credential_key instead.") def get_credential_key(self): - """Generates a hash key based on auth_scheme and raw_auth_credential. This - hash key can be used to store / retrieve exchanged_auth_credential in a - credentials store. + """Builds a hash key based on auth_scheme and raw_auth_credential used to + save / load this credential to / from a credentials service. """ + auth_scheme = self.auth_scheme if auth_scheme.model_extra: @@ -62,7 +78,7 @@ def get_credential_key(self): ) auth_credential = self.raw_auth_credential - if auth_credential.model_extra: + if auth_credential and auth_credential.model_extra: auth_credential = auth_credential.model_copy(deep=True) auth_credential.model_extra.clear() credential_name = ( diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py new file mode 100644 index 000000000..0dbf006ab --- /dev/null +++ b/src/google/adk/auth/credential_manager.py @@ -0,0 +1,261 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from ..tools.tool_context import ToolContext +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_credential import AuthCredentialTypes +from .auth_schemes import AuthSchemeType +from .auth_tool import AuthConfig +from .exchanger.base_credential_exchanger import BaseCredentialExchanger +from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry +from .refresher.base_credential_refresher import BaseCredentialRefresher +from .refresher.credential_refresher_registry import CredentialRefresherRegistry + + +@experimental +class CredentialManager: + """Manages authentication credentials through a structured workflow. + + The CredentialManager orchestrates the complete lifecycle of authentication + credentials, from initial loading to final preparation for use. It provides + a centralized interface for handling various credential types and authentication + schemes while maintaining proper credential hygiene (refresh, exchange, caching). + + This class is only for use by Agent Development Kit. + + Args: + auth_config: Configuration containing authentication scheme and credentials + + Example: + ```python + auth_config = AuthConfig( + auth_scheme=oauth2_scheme, + raw_auth_credential=service_account_credential + ) + manager = CredentialManager(auth_config) + + # Register custom exchanger if needed + manager.register_credential_exchanger( + AuthCredentialTypes.CUSTOM_TYPE, + CustomCredentialExchanger() + ) + + # Register custom refresher if needed + manager.register_credential_refresher( + AuthCredentialTypes.CUSTOM_TYPE, + CustomCredentialRefresher() + ) + + # Load and prepare credential + credential = await manager.load_auth_credential(tool_context) + ``` + """ + + def __init__( + self, + auth_config: AuthConfig, + ): + self._auth_config = auth_config + self._exchanger_registry = CredentialExchangerRegistry() + self._refresher_registry = CredentialRefresherRegistry() + + # Register default exchangers and refreshers + # TODO: support service account credential exchanger + from .refresher.oauth2_credential_refresher import OAuth2CredentialRefresher + + oauth2_refresher = OAuth2CredentialRefresher() + self._refresher_registry.register( + AuthCredentialTypes.OAUTH2, oauth2_refresher + ) + self._refresher_registry.register( + AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher + ) + + def register_credential_exchanger( + self, + credential_type: AuthCredentialTypes, + exchanger_instance: BaseCredentialExchanger, + ) -> None: + """Register a credential exchanger for a credential type. + + Args: + credential_type: The credential type to register for. + exchanger_instance: The exchanger instance to register. + """ + self._exchanger_registry.register(credential_type, exchanger_instance) + + async def request_credential(self, tool_context: ToolContext) -> None: + tool_context.request_credential(self._auth_config) + + async def get_auth_credential( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load and prepare authentication credential through a structured workflow.""" + + # Step 1: Validate credential configuration + await self._validate_credential() + + # Step 2: Check if credential is already ready (no processing needed) + if self._is_credential_ready(): + return self._auth_config.raw_auth_credential + + # Step 3: Try to load existing processed credential + credential = await self._load_existing_credential(tool_context) + + # Step 4: If no existing credential, load from auth response + # TODO instead of load from auth response, we can store auth response in + # credential service. + was_from_auth_response = False + if not credential: + credential = await self._load_from_auth_response(tool_context) + was_from_auth_response = True + + # Step 5: If still no credential available, return None + if not credential: + return None + + # Step 6: Exchange credential if needed (e.g., service account to access token) + credential, was_exchanged = await self._exchange_credential(credential) + + # Step 7: Refresh credential if expired + if not was_exchanged: + credential, was_refreshed = await self._refresh_credential(credential) + + # Step 8: Save credential if it was modified + if was_from_auth_response or was_exchanged or was_refreshed: + await self._save_credential(tool_context, credential) + + return credential + + async def _load_existing_credential( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load existing credential from credential service or cached exchanged credential.""" + + # Try loading from credential service first + credential = await self._load_from_credential_service(tool_context) + if credential: + return credential + + # Check if we have a cached exchanged credential + if self._auth_config.exchanged_auth_credential: + return self._auth_config.exchanged_auth_credential + + return None + + async def _load_from_credential_service( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load credential from credential service if available.""" + credential_service = tool_context._invocation_context.credential_service + if credential_service: + # Note: This should be made async in a future refactor + # For now, assuming synchronous operation + return await credential_service.load_credential( + self._auth_config, tool_context + ) + return None + + async def _load_from_auth_response( + self, tool_context: ToolContext + ) -> Optional[AuthCredential]: + """Load credential from auth response in tool context.""" + return tool_context.get_auth_response(self._auth_config) + + async def _exchange_credential( + self, credential: AuthCredential + ) -> tuple[AuthCredential, bool]: + """Exchange credential if needed and return the credential and whether it was exchanged.""" + exchanger = self._exchanger_registry.get_exchanger(credential.auth_type) + if not exchanger: + return credential, False + + exchanged_credential = await exchanger.exchange( + credential, self._auth_config.auth_scheme + ) + return exchanged_credential, True + + async def _refresh_credential( + self, credential: AuthCredential + ) -> tuple[AuthCredential, bool]: + """Refresh credential if expired and return the credential and whether it was refreshed.""" + refresher = self._refresher_registry.get_refresher(credential.auth_type) + if not refresher: + return credential, False + + if await refresher.is_refresh_needed( + credential, self._auth_config.auth_scheme + ): + refreshed_credential = await refresher.refresh( + credential, self._auth_config.auth_scheme + ) + return refreshed_credential, True + + return credential, False + + def _is_credential_ready(self) -> bool: + """Check if credential is ready to use without further processing.""" + raw_credential = self._auth_config.raw_auth_credential + if not raw_credential: + return False + + # Simple credentials that don't need exchange or refresh + return raw_credential.auth_type in ( + AuthCredentialTypes.API_KEY, + AuthCredentialTypes.HTTP, + # Add other simple auth types as needed + ) + + async def _validate_credential(self) -> None: + """Validate credential configuration and raise errors if invalid.""" + if not self._auth_config.raw_auth_credential: + if self._auth_config.auth_scheme.type_ in ( + AuthSchemeType.oauth2, + AuthSchemeType.openIdConnect, + ): + raise ValueError( + "raw_auth_credential is required for auth_scheme type " + f"{self._auth_config.auth_scheme.type_}" + ) + + raw_credential = self._auth_config.raw_auth_credential + if raw_credential: + if ( + raw_credential.auth_type + in ( + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + ) + and not raw_credential.oauth2 + ): + raise ValueError( + "auth_config.raw_credential.oauth2 required for credential type " + f"{raw_credential.auth_type}" + ) + # Additional validation can be added here + + async def _save_credential( + self, tool_context: ToolContext, credential: AuthCredential + ) -> None: + """Save credential to credential service if available.""" + credential_service = tool_context._invocation_context.credential_service + if credential_service: + # Update the exchanged credential in config + self._auth_config.exchanged_auth_credential = credential + await credential_service.save_credential(self._auth_config, tool_context) diff --git a/src/google/adk/auth/credential_service/__init__.py b/src/google/adk/auth/credential_service/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/auth/credential_service/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/auth/credential_service/base_credential_service.py b/src/google/adk/auth/credential_service/base_credential_service.py new file mode 100644 index 000000000..fc6cd500d --- /dev/null +++ b/src/google/adk/auth/credential_service/base_credential_service.py @@ -0,0 +1,75 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +from typing import Optional + +from ...tools.tool_context import ToolContext +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_tool import AuthConfig + + +@experimental +class BaseCredentialService(ABC): + """Abstract class for Service that loads / saves tool credentials from / to + the backend credential store.""" + + @abstractmethod + async def load_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> Optional[AuthCredential]: + """ + Loads the credential by auth config and current tool context from the + backend credential store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to load the credential. + + tool_context: The context of the current invocation when the tool is + trying to load the credential. + + Returns: + Optional[AuthCredential]: the credential saved in the store. + + """ + + @abstractmethod + async def save_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> None: + """ + Saves the exchanged_auth_credential in auth config to the backend credential + store. + + Args: + auth_config: The auth config which contains the auth scheme and auth + credential information. auth_config.get_credential_key will be used to + build the key to save the credential. + + tool_context: The context of the current invocation when the tool is + trying to save the credential. + + Returns: + None + """ diff --git a/src/google/adk/auth/credential_service/in_memory_credential_service.py b/src/google/adk/auth/credential_service/in_memory_credential_service.py new file mode 100644 index 000000000..f6f51b35a --- /dev/null +++ b/src/google/adk/auth/credential_service/in_memory_credential_service.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from typing_extensions import override + +from ...tools.tool_context import ToolContext +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_tool import AuthConfig +from .base_credential_service import BaseCredentialService + + +@experimental +class InMemoryCredentialService(BaseCredentialService): + """Class for in memory implementation of credential service(Experimental)""" + + def __init__(self): + super().__init__() + self._credentials = {} + + @override + async def load_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> Optional[AuthCredential]: + credential_bucket = self._get_bucket_for_current_context(tool_context) + return credential_bucket.get(auth_config.credential_key) + + @override + async def save_credential( + self, + auth_config: AuthConfig, + tool_context: ToolContext, + ) -> None: + credential_bucket = self._get_bucket_for_current_context(tool_context) + credential_bucket[auth_config.credential_key] = ( + auth_config.exchanged_auth_credential + ) + + def _get_bucket_for_current_context(self, tool_context: ToolContext) -> str: + app_name = tool_context._invocation_context.app_name + user_id = tool_context._invocation_context.user_id + + if app_name not in self._credentials: + self._credentials[app_name] = {} + if user_id not in self._credentials[app_name]: + self._credentials[app_name][user_id] = {} + return self._credentials[app_name][user_id] diff --git a/src/google/adk/auth/exchanger/__init__.py b/src/google/adk/auth/exchanger/__init__.py new file mode 100644 index 000000000..3b0fbb246 --- /dev/null +++ b/src/google/adk/auth/exchanger/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential exchanger module.""" + +from .base_credential_exchanger import BaseCredentialExchanger + +__all__ = [ + "BaseCredentialExchanger", +] diff --git a/src/google/adk/auth/exchanger/base_credential_exchanger.py b/src/google/adk/auth/exchanger/base_credential_exchanger.py new file mode 100644 index 000000000..b09adb80a --- /dev/null +++ b/src/google/adk/auth/exchanger/base_credential_exchanger.py @@ -0,0 +1,57 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base credential exchanger interface.""" + +from __future__ import annotations + +import abc +from typing import Optional + +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredential +from ..auth_schemes import AuthScheme + + +class CredentialExchangError(Exception): + """Base exception for credential exchange errors.""" + + +@experimental +class BaseCredentialExchanger(abc.ABC): + """Base interface for credential exchangers. + + Credential exchangers are responsible for exchanging credentials from + one format or scheme to another. + """ + + @abc.abstractmethod + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Exchange credential if needed. + + Args: + auth_credential: The credential to exchange. + auth_scheme: The authentication scheme (optional, some exchangers don't need it). + + Returns: + The exchanged credential. + + Raises: + CredentialExchangError: If credential exchange fails. + """ + pass diff --git a/src/google/adk/auth/exchanger/credential_exchanger_registry.py b/src/google/adk/auth/exchanger/credential_exchanger_registry.py new file mode 100644 index 000000000..5af7f3c1a --- /dev/null +++ b/src/google/adk/auth/exchanger/credential_exchanger_registry.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential exchanger registry.""" + +from __future__ import annotations + +from typing import Dict +from typing import Optional + +from ...utils.feature_decorator import experimental +from ..auth_credential import AuthCredentialTypes +from .base_credential_exchanger import BaseCredentialExchanger + + +@experimental +class CredentialExchangerRegistry: + """Registry for credential exchanger instances.""" + + def __init__(self): + self._exchangers: Dict[AuthCredentialTypes, BaseCredentialExchanger] = {} + + def register( + self, + credential_type: AuthCredentialTypes, + exchanger_instance: BaseCredentialExchanger, + ) -> None: + """Register an exchanger instance for a credential type. + + Args: + credential_type: The credential type to register for. + exchanger_instance: The exchanger instance to register. + """ + self._exchangers[credential_type] = exchanger_instance + + def get_exchanger( + self, credential_type: AuthCredentialTypes + ) -> Optional[BaseCredentialExchanger]: + """Get the exchanger instance for a credential type. + + Args: + credential_type: The credential type to get exchanger for. + + Returns: + The exchanger instance if registered, None otherwise. + """ + return self._exchangers.get(credential_type) diff --git a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py new file mode 100644 index 000000000..768457e1a --- /dev/null +++ b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 credential exchanger implementation.""" + +from __future__ import annotations + +import logging +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import OAuthGrantType +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +from google.adk.utils.feature_decorator import experimental +from typing_extensions import override + +from .base_credential_exchanger import BaseCredentialExchanger +from .base_credential_exchanger import CredentialExchangError + +try: + from authlib.integrations.requests_client import OAuth2Session + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class OAuth2CredentialExchanger(BaseCredentialExchanger): + """Exchanges OAuth2 credentials from authorization responses.""" + + @override + async def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Exchange OAuth2 credential from authorization response. + if credential exchange failed, the original credential will be returned. + + Args: + auth_credential: The OAuth2 credential to exchange. + auth_scheme: The OAuth2 authentication scheme. + + Returns: + The exchanged credential with access token. + + Raises: + CredentialExchangError: If auth_scheme is missing. + """ + if not auth_scheme: + raise CredentialExchangError( + "auth_scheme is required for OAuth2 credential exchange" + ) + + if not AUTHLIB_AVIALABLE: + # If authlib is not available, we cannot exchange the credential. + # We return the original credential without exchange. + # The client using this tool can decide to exchange the credential + # themselves using other lib. + logger.warning( + "authlib is not available, skipping OAuth2 credential exchange." + ) + return auth_credential + + if auth_credential.oauth2 and auth_credential.oauth2.access_token: + return auth_credential + + client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential) + if not client: + logger.warning("Could not create OAuth2 session for token exchange") + return auth_credential + + try: + tokens = client.fetch_token( + token_endpoint, + authorization_response=auth_credential.oauth2.auth_response_uri, + code=auth_credential.oauth2.auth_code, + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + ) + update_credential_with_tokens(auth_credential, tokens) + logger.debug("Successfully exchanged OAuth2 tokens") + except Exception as e: + # TODO reconsider whether we should raise errors in this case + logger.error("Failed to exchange OAuth2 tokens: %s", e) + # Return original credential on failure + return auth_credential + + return auth_credential diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py new file mode 100644 index 000000000..51ed4d29f --- /dev/null +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional +from typing import Tuple + +from fastapi.openapi.models import OAuth2 + +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_schemes import AuthScheme +from .auth_schemes import OpenIdConnectWithConfig + +try: + from authlib.integrations.requests_client import OAuth2Session + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +def create_oauth2_session( + auth_scheme: AuthScheme, + auth_credential: AuthCredential, +) -> Tuple[Optional[OAuth2Session], Optional[str]]: + """Create an OAuth2 session for token operations. + + Args: + auth_scheme: The authentication scheme configuration. + auth_credential: The authentication credential. + + Returns: + Tuple of (OAuth2Session, token_endpoint) or (None, None) if cannot create session. + """ + if isinstance(auth_scheme, OpenIdConnectWithConfig): + if not hasattr(auth_scheme, "token_endpoint"): + return None, None + token_endpoint = auth_scheme.token_endpoint + scopes = auth_scheme.scopes + elif isinstance(auth_scheme, OAuth2): + if ( + not auth_scheme.flows.authorizationCode + or not auth_scheme.flows.authorizationCode.tokenUrl + ): + return None, None + token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl + scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) + else: + return None, None + + if ( + not auth_credential + or not auth_credential.oauth2 + or not auth_credential.oauth2.client_id + or not auth_credential.oauth2.client_secret + ): + return None, None + + return ( + OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + scope=" ".join(scopes), + redirect_uri=auth_credential.oauth2.redirect_uri, + state=auth_credential.oauth2.state, + ), + token_endpoint, + ) + + +@experimental +def update_credential_with_tokens( + auth_credential: AuthCredential, tokens: OAuth2Token +) -> None: + """Update the credential with new tokens. + + Args: + auth_credential: The authentication credential to update. + tokens: The OAuth2Token object containing new token information. + """ + auth_credential.oauth2.access_token = tokens.get("access_token") + auth_credential.oauth2.refresh_token = tokens.get("refresh_token") + auth_credential.oauth2.expires_at = ( + int(tokens.get("expires_at")) if tokens.get("expires_at") else None + ) + auth_credential.oauth2.expires_in = ( + int(tokens.get("expires_in")) if tokens.get("expires_in") else None + ) diff --git a/src/google/adk/auth/refresher/__init__.py b/src/google/adk/auth/refresher/__init__.py new file mode 100644 index 000000000..27d7245dc --- /dev/null +++ b/src/google/adk/auth/refresher/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential refresher module.""" + +from .base_credential_refresher import BaseCredentialRefresher + +__all__ = [ + "BaseCredentialRefresher", +] diff --git a/src/google/adk/auth/refresher/base_credential_refresher.py b/src/google/adk/auth/refresher/base_credential_refresher.py new file mode 100644 index 000000000..230b07d09 --- /dev/null +++ b/src/google/adk/auth/refresher/base_credential_refresher.py @@ -0,0 +1,74 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base credential refresher interface.""" + +from __future__ import annotations + +import abc +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.utils.feature_decorator import experimental + + +class CredentialRefresherError(Exception): + """Base exception for credential refresh errors.""" + + +@experimental +class BaseCredentialRefresher(abc.ABC): + """Base interface for credential refreshers. + + Credential refreshers are responsible for checking if a credential is expired + or needs to be refreshed, and for refreshing it if necessary. + """ + + @abc.abstractmethod + async def is_refresh_needed( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> bool: + """Checks if a credential needs to be refreshed. + + Args: + auth_credential: The credential to check. + auth_scheme: The authentication scheme (optional, some refreshers don't need it). + + Returns: + True if the credential needs to be refreshed, False otherwise. + """ + pass + + @abc.abstractmethod + async def refresh( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Refreshes a credential if needed. + + Args: + auth_credential: The credential to refresh. + auth_scheme: The authentication scheme (optional, some refreshers don't need it). + + Returns: + The refreshed credential. + + Raises: + CredentialRefresherError: If credential refresh fails. + """ + pass diff --git a/src/google/adk/auth/refresher/credential_refresher_registry.py b/src/google/adk/auth/refresher/credential_refresher_registry.py new file mode 100644 index 000000000..90975d66d --- /dev/null +++ b/src/google/adk/auth/refresher/credential_refresher_registry.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credential refresher registry.""" + +from __future__ import annotations + +from typing import Dict +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.utils.feature_decorator import experimental + +from .base_credential_refresher import BaseCredentialRefresher + + +@experimental +class CredentialRefresherRegistry: + """Registry for credential refresher instances.""" + + def __init__(self): + self._refreshers: Dict[AuthCredentialTypes, BaseCredentialRefresher] = {} + + def register( + self, + credential_type: AuthCredentialTypes, + refresher_instance: BaseCredentialRefresher, + ) -> None: + """Register a refresher instance for a credential type. + + Args: + credential_type: The credential type to register for. + refresher_instance: The refresher instance to register. + """ + self._refreshers[credential_type] = refresher_instance + + def get_refresher( + self, credential_type: AuthCredentialTypes + ) -> Optional[BaseCredentialRefresher]: + """Get the refresher instance for a credential type. + + Args: + credential_type: The credential type to get refresher for. + + Returns: + The refresher instance if registered, None otherwise. + """ + return self._refreshers.get(credential_type) diff --git a/src/google/adk/auth/refresher/oauth2_credential_refresher.py b/src/google/adk/auth/refresher/oauth2_credential_refresher.py new file mode 100644 index 000000000..4c19520ce --- /dev/null +++ b/src/google/adk/auth/refresher/oauth2_credential_refresher.py @@ -0,0 +1,126 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 credential refresher implementation.""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +from google.adk.utils.feature_decorator import experimental +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from typing_extensions import override + +from .base_credential_refresher import BaseCredentialRefresher + +try: + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class OAuth2CredentialRefresher(BaseCredentialRefresher): + """Refreshes OAuth2 credentials including Google OAuth2 JSON credentials.""" + + @override + async def is_refresh_needed( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> bool: + """Check if the OAuth2 credential needs to be refreshed. + + Args: + auth_credential: The OAuth2 credential to check. + auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON). + + Returns: + True if the credential needs to be refreshed, False otherwise. + """ + + # Handle regular OAuth2 credentials + if auth_credential.oauth2: + if not AUTHLIB_AVIALABLE: + return False + + return OAuth2Token({ + "expires_at": auth_credential.oauth2.expires_at, + "expires_in": auth_credential.oauth2.expires_in, + }).is_expired() + + return False + + @override + async def refresh( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Refresh the OAuth2 credential. + If refresh failed, return the original credential. + + Args: + auth_credential: The OAuth2 credential to refresh. + auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON). + + Returns: + The refreshed credential. + + """ + + # Handle regular OAuth2 credentials + if auth_credential.oauth2 and auth_scheme: + if not AUTHLIB_AVIALABLE: + return auth_credential + + if not auth_credential.oauth2: + return auth_credential + + if OAuth2Token({ + "expires_at": auth_credential.oauth2.expires_at, + "expires_in": auth_credential.oauth2.expires_in, + }).is_expired(): + client, token_endpoint = create_oauth2_session( + auth_scheme, auth_credential + ) + if not client: + logger.warning("Could not create OAuth2 session for token refresh") + return auth_credential + + try: + tokens = client.refresh_token( + url=token_endpoint, + refresh_token=auth_credential.oauth2.refresh_token, + ) + update_credential_with_tokens(auth_credential, tokens) + logger.debug("Successfully refreshed OAuth2 tokens") + except Exception as e: + # TODO reconsider whether we should raise error when refresh failed. + logger.error("Failed to refresh OAuth2 tokens: %s", e) + # Return original credential on failure + return auth_credential + + return auth_credential diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index aceb3fcce..79d0bfe65 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -24,6 +24,8 @@ from ..agents.llm_agent import LlmAgent from ..artifacts import BaseArtifactService from ..artifacts import InMemoryArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService +from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService @@ -43,6 +45,7 @@ async def run_input_file( root_agent: LlmAgent, artifact_service: BaseArtifactService, session_service: BaseSessionService, + credential_service: BaseCredentialService, input_path: str, ) -> Session: runner = Runner( @@ -50,6 +53,7 @@ async def run_input_file( agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, ) with open(input_path, 'r', encoding='utf-8') as f: input_file = InputFile.model_validate_json(f.read()) @@ -75,12 +79,14 @@ async def run_interactively( artifact_service: BaseArtifactService, session: Session, session_service: BaseSessionService, + credential_service: BaseCredentialService, ) -> None: runner = Runner( app_name=session.app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, ) while True: query = input('[user]: ') @@ -125,6 +131,7 @@ async def run_cli( artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() + credential_service = InMemoryCredentialService() user_id = 'test_user' session = await session_service.create_session( @@ -141,6 +148,7 @@ async def run_cli( root_agent=root_agent, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, input_path=input_file, ) elif saved_session_file: @@ -163,6 +171,7 @@ async def run_cli( artifact_service, session, session_service, + credential_service, ) else: click.echo(f'Running agent {root_agent.name}, type exit to exit.') @@ -171,6 +180,7 @@ async def run_cli( artifact_service, session, session_service, + credential_service, ) if save_session: diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4ef8ae6c2..46e008655 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -57,6 +57,7 @@ from ..agents.run_config import StreamingMode from ..artifacts.gcs_artifact_service import GcsArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService +from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..errors.not_found_error import NotFoundError from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import SessionInput @@ -276,7 +277,6 @@ async def internal_lifespan(app: FastAPI): memory_service = InMemoryMemoryService() # Build the Session service - agent_engine_id = "" if session_service_uri: if session_service_uri.startswith("agentengine://"): # Create vertex session service @@ -285,8 +285,9 @@ async def internal_lifespan(app: FastAPI): raise click.ClickException("Agent engine id can not be empty.") envs.load_dotenv_for_agent("", agents_dir) session_service = VertexAiSessionService( - os.environ["GOOGLE_CLOUD_PROJECT"], - os.environ["GOOGLE_CLOUD_LOCATION"], + project=os.environ["GOOGLE_CLOUD_PROJECT"], + location=os.environ["GOOGLE_CLOUD_LOCATION"], + agent_engine_id=agent_engine_id, ) else: session_service = DatabaseSessionService(db_url=session_service_uri) @@ -305,6 +306,9 @@ async def internal_lifespan(app: FastAPI): else: artifact_service = InMemoryArtifactService() + # Build the Credential service + credential_service = InMemoryCredentialService() + # initialize Agent Loader agent_loader = AgentLoader(agents_dir) @@ -357,8 +361,6 @@ def get_session_trace(session_id: str) -> Any: async def get_session( app_name: str, user_id: str, session_id: str ) -> Session: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -371,8 +373,6 @@ async def get_session( response_model_exclude_none=True, ) async def list_sessions(app_name: str, user_id: str) -> list[Session]: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name list_sessions_response = await session_service.list_sessions( app_name=app_name, user_id=user_id ) @@ -393,8 +393,6 @@ async def create_session_with_id( session_id: str, state: Optional[dict[str, Any]] = None, ) -> Session: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name if ( await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id @@ -418,14 +416,19 @@ async def create_session( app_name: str, user_id: str, state: Optional[dict[str, Any]] = None, + events: Optional[list[Event]] = None, ) -> Session: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name logger.info("New session created") - return await session_service.create_session( + session = await session_service.create_session( app_name=app_name, user_id=user_id, state=state ) + if events: + for event in events: + await session_service.append_event(session=session, event=event) + + return session + def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str: return os.path.join( agents_dir, @@ -660,8 +663,6 @@ def list_eval_results(app_name: str) -> list[str]: @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") async def delete_session(app_name: str, user_id: str, session_id: str): - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name await session_service.delete_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -677,7 +678,6 @@ async def load_artifact( artifact_name: str, version: Optional[int] = Query(None), ) -> Optional[types.Part]: - app_name = agent_engine_id if agent_engine_id else app_name artifact = await artifact_service.load_artifact( app_name=app_name, user_id=user_id, @@ -700,7 +700,6 @@ async def load_artifact_version( artifact_name: str, version_id: int, ) -> Optional[types.Part]: - app_name = agent_engine_id if agent_engine_id else app_name artifact = await artifact_service.load_artifact( app_name=app_name, user_id=user_id, @@ -719,7 +718,6 @@ async def load_artifact_version( async def list_artifact_names( app_name: str, user_id: str, session_id: str ) -> list[str]: - app_name = agent_engine_id if agent_engine_id else app_name return await artifact_service.list_artifact_keys( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -731,7 +729,6 @@ async def list_artifact_names( async def list_artifact_versions( app_name: str, user_id: str, session_id: str, artifact_name: str ) -> list[int]: - app_name = agent_engine_id if agent_engine_id else app_name return await artifact_service.list_versions( app_name=app_name, user_id=user_id, @@ -745,7 +742,6 @@ async def list_artifact_versions( async def delete_artifact( app_name: str, user_id: str, session_id: str, artifact_name: str ): - app_name = agent_engine_id if agent_engine_id else app_name await artifact_service.delete_artifact( app_name=app_name, user_id=user_id, @@ -755,10 +751,8 @@ async def delete_artifact( @app.post("/run", response_model_exclude_none=True) async def agent_run(req: AgentRunRequest) -> list[Event]: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else req.app_name session = await session_service.get_session( - app_name=app_name, user_id=req.user_id, session_id=req.session_id + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") @@ -776,11 +770,9 @@ async def agent_run(req: AgentRunRequest) -> list[Event]: @app.post("/run_sse") async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else req.app_name # SSE endpoint session = await session_service.get_session( - app_name=app_name, user_id=req.user_id, session_id=req.session_id + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") @@ -818,8 +810,6 @@ async def event_generator(): async def get_event_graph( app_name: str, user_id: str, session_id: str, event_id: str ): - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -875,8 +865,6 @@ async def agent_live_run( ) -> None: await websocket.accept() - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -940,11 +928,12 @@ async def _get_runner_async(app_name: str) -> Runner: return runner_dict[app_name] root_agent = agent_loader.load_agent(app_name) runner = Runner( - app_name=agent_engine_id if agent_engine_id else app_name, + app_name=app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, memory_service=memory_service, + credential_service=credential_service, ) runner_dict[app_name] = runner return runner diff --git a/src/google/adk/evaluation/_eval_set_results_manager_utils.py b/src/google/adk/evaluation/_eval_set_results_manager_utils.py new file mode 100644 index 000000000..8505e68d1 --- /dev/null +++ b/src/google/adk/evaluation/_eval_set_results_manager_utils.py @@ -0,0 +1,44 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time + +from .eval_result import EvalCaseResult +from .eval_result import EvalSetResult + + +def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str: + """Sanitizes the eval set result name.""" + return eval_set_result_name.replace("/", "_") + + +def create_eval_set_result( + app_name: str, + eval_set_id: str, + eval_case_results: list[EvalCaseResult], +) -> EvalSetResult: + """Creates a new EvalSetResult given eval_case_results.""" + timestamp = time.time() + eval_set_result_id = f"{app_name}_{eval_set_id}_{timestamp}" + eval_set_result_name = _sanitize_eval_set_result_name(eval_set_result_id) + eval_set_result = EvalSetResult( + eval_set_result_id=eval_set_result_id, + eval_set_result_name=eval_set_result_name, + eval_set_id=eval_set_id, + eval_case_results=eval_case_results, + creation_timestamp=timestamp, + ) + return eval_set_result diff --git a/src/google/adk/evaluation/_eval_sets_manager_utils.py b/src/google/adk/evaluation/_eval_sets_manager_utils.py new file mode 100644 index 000000000..b7e12dd37 --- /dev/null +++ b/src/google/adk/evaluation/_eval_sets_manager_utils.py @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional + +from ..errors.not_found_error import NotFoundError +from .eval_case import EvalCase +from .eval_set import EvalSet +from .eval_sets_manager import EvalSetsManager + +logger = logging.getLogger("google_adk." + __name__) + + +def get_eval_set_from_app_and_id( + eval_sets_manager: EvalSetsManager, app_name: str, eval_set_id: str +) -> EvalSet: + """Returns an EvalSet if found, otherwise raises NotFoundError.""" + eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id) + if not eval_set: + raise NotFoundError(f"Eval set `{eval_set_id}` not found.") + return eval_set + + +def get_eval_case_from_eval_set( + eval_set: EvalSet, eval_case_id: str +) -> Optional[EvalCase]: + """Returns an EvalCase if found, otherwise None.""" + eval_case_to_find = None + + # Look up the eval case by eval_case_id + for eval_case in eval_set.eval_cases: + if eval_case.eval_id == eval_case_id: + eval_case_to_find = eval_case + break + + return eval_case_to_find + + +def add_eval_case_to_eval_set( + eval_set: EvalSet, eval_case: EvalCase +) -> EvalSet: + """Adds an eval case to an eval set and returns the updated eval set.""" + eval_case_id = eval_case.eval_id + + if [x for x in eval_set.eval_cases if x.eval_id == eval_case_id]: + raise ValueError( + f"Eval id `{eval_case_id}` already exists in `{eval_set.eval_set_id}`" + " eval set.", + ) + + eval_set.eval_cases.append(eval_case) + return eval_set + + +def update_eval_case_in_eval_set( + eval_set: EvalSet, updated_eval_case: EvalCase +) -> EvalSet: + """Updates an eval case in an eval set and returns the updated eval set.""" + # Find the eval case to be updated. + eval_case_id = updated_eval_case.eval_id + eval_case_to_update = get_eval_case_from_eval_set(eval_set, eval_case_id) + + if not eval_case_to_update: + raise NotFoundError( + f"Eval case `{eval_case_id}` not found in eval set" + f" `{eval_set.eval_set_id}`." + ) + + # Remove the existing eval case and add the updated eval case. + eval_set.eval_cases.remove(eval_case_to_update) + eval_set.eval_cases.append(updated_eval_case) + return eval_set + + +def delete_eval_case_from_eval_set( + eval_set: EvalSet, eval_case_id: str +) -> EvalSet: + """Deletes an eval case from an eval set and returns the updated eval set.""" + # Find the eval case to be deleted. + eval_case_to_delete = get_eval_case_from_eval_set(eval_set, eval_case_id) + + if not eval_case_to_delete: + raise NotFoundError( + f"Eval case `{eval_case_id}` not found in eval set" + f" `{eval_set.eval_set_id}`." + ) + + # Remove the existing eval case. + logger.info( + "EvalCase`%s` was found in the eval set. It will be removed permanently.", + eval_case_id, + ) + eval_set.eval_cases.remove(eval_case_to_delete) + return eval_set diff --git a/src/google/adk/evaluation/eval_result.py b/src/google/adk/evaluation/eval_result.py index 8f87a14b4..96e8d3c98 100644 --- a/src/google/adk/evaluation/eval_result.py +++ b/src/google/adk/evaluation/eval_result.py @@ -36,8 +36,9 @@ class EvalCaseResult(BaseModel): populate_by_name=True, ) - eval_set_file: str = Field( + eval_set_file: Optional[str] = Field( deprecated=True, + default=None, description="This field is deprecated, use eval_set_id instead.", ) eval_set_id: str = "" @@ -49,11 +50,15 @@ class EvalCaseResult(BaseModel): final_eval_status: EvalStatus """Final eval status for this eval case.""" - eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( - deprecated=True, - description=( - "This field is deprecated, use overall_eval_metric_results instead." - ), + eval_metric_results: Optional[list[tuple[EvalMetric, EvalMetricResult]]] = ( + Field( + deprecated=True, + default=None, + description=( + "This field is deprecated, use overall_eval_metric_results" + " instead." + ), + ) ) overall_eval_metric_results: list[EvalMetricResult] @@ -80,7 +85,7 @@ class EvalSetResult(BaseModel): populate_by_name=True, ) eval_set_result_id: str - eval_set_result_name: str + eval_set_result_name: Optional[str] = None eval_set_id: str eval_case_results: list[EvalCaseResult] = Field(default_factory=list) creation_timestamp: float = 0.0 diff --git a/src/google/adk/evaluation/eval_set_results_manager.py b/src/google/adk/evaluation/eval_set_results_manager.py index 5a300ed14..588e823ba 100644 --- a/src/google/adk/evaluation/eval_set_results_manager.py +++ b/src/google/adk/evaluation/eval_set_results_manager.py @@ -16,6 +16,7 @@ from abc import ABC from abc import abstractmethod +from typing import Optional from .eval_result import EvalCaseResult from .eval_result import EvalSetResult @@ -38,7 +39,11 @@ def save_eval_set_result( def get_eval_set_result( self, app_name: str, eval_set_result_id: str ) -> EvalSetResult: - """Returns an EvalSetResult identified by app_name and eval_set_result_id.""" + """Returns the EvalSetResult from app_name and eval_set_result_id. + + Raises: + NotFoundError: If the EvalSetResult is not found. + """ raise NotImplementedError() @abstractmethod diff --git a/src/google/adk/evaluation/gcs_eval_set_results_manager.py b/src/google/adk/evaluation/gcs_eval_set_results_manager.py new file mode 100644 index 000000000..860d932ff --- /dev/null +++ b/src/google/adk/evaluation/gcs_eval_set_results_manager.py @@ -0,0 +1,121 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging + +from google.cloud import exceptions as cloud_exceptions +from google.cloud import storage +from typing_extensions import override + +from ..errors.not_found_error import NotFoundError +from ._eval_set_results_manager_utils import create_eval_set_result +from .eval_result import EvalCaseResult +from .eval_result import EvalSetResult +from .eval_set_results_manager import EvalSetResultsManager + +logger = logging.getLogger("google_adk." + __name__) + +_EVAL_HISTORY_DIR = "evals/eval_history" +_EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json" + + +class GcsEvalSetResultsManager(EvalSetResultsManager): + """An EvalSetResultsManager that stores eval results in a GCS bucket.""" + + def __init__(self, bucket_name: str, **kwargs): + """Initializes the GcsEvalSetsManager. + + Args: + bucket_name: The name of the bucket to use. + **kwargs: Keyword arguments to pass to the Google Cloud Storage client. + """ + self.bucket_name = bucket_name + self.storage_client = storage.Client(**kwargs) + self.bucket = self.storage_client.bucket(self.bucket_name) + # Check if the bucket exists. + if not self.bucket.exists(): + raise ValueError( + f"Bucket `{self.bucket_name}` does not exist. Please create it before" + " using the GcsEvalSetsManager." + ) + + def _get_eval_history_dir(self, app_name: str) -> str: + return f"{app_name}/{_EVAL_HISTORY_DIR}" + + def _get_eval_set_result_blob_name( + self, app_name: str, eval_set_result_id: str + ) -> str: + eval_history_dir = self._get_eval_history_dir(app_name) + return f"{eval_history_dir}/{eval_set_result_id}{_EVAL_SET_RESULT_FILE_EXTENSION}" + + def _write_eval_set_result( + self, blob_name: str, eval_set_result: EvalSetResult + ): + """Writes an EvalSetResult to GCS.""" + blob = self.bucket.blob(blob_name) + blob.upload_from_string( + eval_set_result.model_dump_json(indent=2), + content_type="application/json", + ) + + @override + def save_eval_set_result( + self, + app_name: str, + eval_set_id: str, + eval_case_results: list[EvalCaseResult], + ) -> None: + """Creates and saves a new EvalSetResult given eval_case_results.""" + eval_set_result = create_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + + eval_set_result_blob_name = self._get_eval_set_result_blob_name( + app_name, eval_set_result.eval_set_result_id + ) + logger.info("Writing eval result to blob: %s", eval_set_result_blob_name) + self._write_eval_set_result(eval_set_result_blob_name, eval_set_result) + + @override + def get_eval_set_result( + self, app_name: str, eval_set_result_id: str + ) -> EvalSetResult: + """Returns an EvalSetResult from app_name and eval_set_result_id.""" + eval_set_result_blob_name = self._get_eval_set_result_blob_name( + app_name, eval_set_result_id + ) + blob = self.bucket.blob(eval_set_result_blob_name) + if not blob.exists(): + raise NotFoundError(f"Eval set result `{eval_set_result_id}` not found.") + eval_set_result_data = blob.download_as_text() + return EvalSetResult.model_validate_json(eval_set_result_data) + + @override + def list_eval_set_results(self, app_name: str) -> list[str]: + """Returns the eval result ids that belong to the given app_name.""" + eval_history_dir = self._get_eval_history_dir(app_name) + eval_set_results = [] + try: + for blob in self.bucket.list_blobs(prefix=eval_history_dir): + eval_set_result_id = blob.name.split("/")[-1].removesuffix( + _EVAL_SET_RESULT_FILE_EXTENSION + ) + eval_set_results.append(eval_set_result_id) + return sorted(eval_set_results) + except cloud_exceptions.NotFound as e: + raise ValueError( + f"App `{app_name}` not found in GCS bucket `{self.bucket_name}`." + ) from e diff --git a/src/google/adk/evaluation/gcs_eval_sets_manager.py b/src/google/adk/evaluation/gcs_eval_sets_manager.py new file mode 100644 index 000000000..fe5d8c9b5 --- /dev/null +++ b/src/google/adk/evaluation/gcs_eval_sets_manager.py @@ -0,0 +1,196 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import re +import time +from typing import Optional + +from google.cloud import exceptions as cloud_exceptions +from google.cloud import storage +from typing_extensions import override + +from ._eval_sets_manager_utils import add_eval_case_to_eval_set +from ._eval_sets_manager_utils import delete_eval_case_from_eval_set +from ._eval_sets_manager_utils import get_eval_case_from_eval_set +from ._eval_sets_manager_utils import get_eval_set_from_app_and_id +from ._eval_sets_manager_utils import update_eval_case_in_eval_set +from .eval_case import EvalCase +from .eval_set import EvalSet +from .eval_sets_manager import EvalSetsManager + +logger = logging.getLogger("google_adk." + __name__) + +_EVAL_SETS_DIR = "evals/eval_sets" +_EVAL_SET_FILE_EXTENSION = ".evalset.json" + + +class GcsEvalSetsManager(EvalSetsManager): + """An EvalSetsManager that stores eval sets in a GCS bucket.""" + + def __init__(self, bucket_name: str, **kwargs): + """Initializes the GcsEvalSetsManager. + + Args: + bucket_name: The name of the bucket to use. + **kwargs: Keyword arguments to pass to the Google Cloud Storage client. + """ + self.bucket_name = bucket_name + self.storage_client = storage.Client(**kwargs) + self.bucket = self.storage_client.bucket(self.bucket_name) + # Check if the bucket exists. + if not self.bucket.exists(): + raise ValueError( + f"Bucket `{self.bucket_name}` does not exist. Please create it " + "before using the GcsEvalSetsManager." + ) + + def _get_eval_sets_dir(self, app_name: str) -> str: + return f"{app_name}/{_EVAL_SETS_DIR}" + + def _get_eval_set_blob_name(self, app_name: str, eval_set_id: str) -> str: + eval_sets_dir = self._get_eval_sets_dir(app_name) + return f"{eval_sets_dir}/{eval_set_id}{_EVAL_SET_FILE_EXTENSION}" + + def _validate_id(self, id_name: str, id_value: str): + pattern = r"^[a-zA-Z0-9_]+$" + if not bool(re.fullmatch(pattern, id_value)): + raise ValueError( + f"Invalid {id_name}. {id_name} should have the `{pattern}` format", + ) + + def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet): + """Writes an EvalSet to GCS.""" + blob = self.bucket.blob(blob_name) + blob.upload_from_string( + eval_set.model_dump_json(indent=2), + content_type="application/json", + ) + + def _save_eval_set(self, app_name: str, eval_set_id: str, eval_set: EvalSet): + eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id) + self._write_eval_set_to_blob(eval_set_blob_name, eval_set) + + @override + def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]: + """Returns an EvalSet identified by an app_name and eval_set_id.""" + eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id) + blob = self.bucket.blob(eval_set_blob_name) + if not blob.exists(): + return None + eval_set_data = blob.download_as_text() + return EvalSet.model_validate_json(eval_set_data) + + @override + def create_eval_set(self, app_name: str, eval_set_id: str): + """Creates an empty EvalSet and saves it to GCS.""" + self._validate_id(id_name="Eval Set Id", id_value=eval_set_id) + new_eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id) + if self.bucket.blob(new_eval_set_blob_name).exists(): + raise ValueError( + f"Eval set `{eval_set_id}` already exists for app `{app_name}`." + ) + logger.info("Creating eval set blob: `%s`", new_eval_set_blob_name) + new_eval_set = EvalSet( + eval_set_id=eval_set_id, + name=eval_set_id, + eval_cases=[], + creation_timestamp=time.time(), + ) + self._write_eval_set_to_blob(new_eval_set_blob_name, new_eval_set) + + @override + def list_eval_sets(self, app_name: str) -> list[str]: + """Returns a list of EvalSet ids that belong to the given app_name.""" + eval_sets_dir = self._get_eval_sets_dir(app_name) + eval_sets = [] + try: + for blob in self.bucket.list_blobs(prefix=eval_sets_dir): + if not blob.name.endswith(_EVAL_SET_FILE_EXTENSION): + continue + eval_set_id = blob.name.split("/")[-1].removesuffix( + _EVAL_SET_FILE_EXTENSION + ) + eval_sets.append(eval_set_id) + return sorted(eval_sets) + except cloud_exceptions.NotFound as e: + raise ValueError( + f"App `{app_name}` not found in GCS bucket `{self.bucket_name}`." + ) from e + + @override + def get_eval_case( + self, app_name: str, eval_set_id: str, eval_case_id: str + ) -> Optional[EvalCase]: + """Returns an EvalCase identified by an app_name, eval_set_id and eval_case_id.""" + eval_set = self.get_eval_set(app_name, eval_set_id) + if not eval_set: + return None + return get_eval_case_from_eval_set(eval_set, eval_case_id) + + @override + def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: EvalCase): + """Adds the given EvalCase to an existing EvalSet. + + Args: + app_name: The name of the app. + eval_set_id: The id of the eval set containing the eval case to update. + eval_case: The EvalCase to add. + + Raises: + NotFoundError: If the eval set is not found. + ValueError: If the eval case already exists in the eval set. + """ + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = add_eval_case_to_eval_set(eval_set, eval_case) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) + + @override + def update_eval_case( + self, app_name: str, eval_set_id: str, updated_eval_case: EvalCase + ): + """Updates an existing EvalCase. + + Args: + app_name: The name of the app. + eval_set_id: The id of the eval set containing the eval case to update. + updated_eval_case: The updated EvalCase. Overwrites the existing EvalCase + using the eval_id field. + + Raises: + NotFoundError: If the eval set or the eval case is not found. + """ + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = update_eval_case_in_eval_set(eval_set, updated_eval_case) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) + + @override + def delete_eval_case( + self, app_name: str, eval_set_id: str, eval_case_id: str + ): + """Deletes the EvalCase with the given eval_case_id from the given EvalSet. + + Args: + app_name: The name of the app. + eval_set_id: The id of the eval set containing the eval case to delete. + eval_case_id: The id of the eval case to delete. + + Raises: + NotFoundError: If the eval set or the eval case to delete is not found. + """ + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = delete_eval_case_from_eval_set(eval_set, eval_case_id) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) diff --git a/src/google/adk/evaluation/local_eval_set_results_manager.py b/src/google/adk/evaluation/local_eval_set_results_manager.py index 598af7f96..3a66f888c 100644 --- a/src/google/adk/evaluation/local_eval_set_results_manager.py +++ b/src/google/adk/evaluation/local_eval_set_results_manager.py @@ -17,10 +17,11 @@ import json import logging import os -import time from typing_extensions import override +from ..errors.not_found_error import NotFoundError +from ._eval_set_results_manager_utils import create_eval_set_result from .eval_result import EvalCaseResult from .eval_result import EvalSetResult from .eval_set_results_manager import EvalSetResultsManager @@ -31,10 +32,6 @@ _EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json" -def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str: - return eval_set_result_name.replace("/", "_") - - class LocalEvalSetResultsManager(EvalSetResultsManager): """An EvalSetResult manager that stores eval set results locally on disk.""" @@ -49,15 +46,8 @@ def save_eval_set_result( eval_case_results: list[EvalCaseResult], ) -> None: """Creates and saves a new EvalSetResult given eval_case_results.""" - timestamp = time.time() - eval_set_result_id = app_name + "_" + eval_set_id + "_" + str(timestamp) - eval_set_result_name = _sanitize_eval_set_result_name(eval_set_result_id) - eval_set_result = EvalSetResult( - eval_set_result_id=eval_set_result_id, - eval_set_result_name=eval_set_result_name, - eval_set_id=eval_set_id, - eval_case_results=eval_case_results, - creation_timestamp=timestamp, + eval_set_result = create_eval_set_result( + app_name, eval_set_id, eval_case_results ) # Write eval result file, with eval_set_result_name. app_eval_history_dir = self._get_eval_history_dir(app_name) @@ -67,7 +57,7 @@ def save_eval_set_result( eval_set_result_json = eval_set_result.model_dump_json() eval_set_result_file_path = os.path.join( app_eval_history_dir, - eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION, + eval_set_result.eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION, ) logger.info("Writing eval result to file: %s", eval_set_result_file_path) with open(eval_set_result_file_path, "w") as f: @@ -87,9 +77,7 @@ def get_eval_set_result( + _EVAL_SET_RESULT_FILE_EXTENSION ) if not os.path.exists(maybe_eval_result_file_path): - raise ValueError( - f"Eval set result `{eval_set_result_id}` does not exist." - ) + raise NotFoundError(f"Eval set result `{eval_set_result_id}` not found.") with open(maybe_eval_result_file_path, "r") as file: eval_result_data = json.load(file) return EvalSetResult.model_validate_json(eval_result_data) diff --git a/src/google/adk/evaluation/local_eval_sets_manager.py b/src/google/adk/evaluation/local_eval_sets_manager.py index e01ecd357..0e93b9201 100644 --- a/src/google/adk/evaluation/local_eval_sets_manager.py +++ b/src/google/adk/evaluation/local_eval_sets_manager.py @@ -27,7 +27,11 @@ from pydantic import ValidationError from typing_extensions import override -from ..errors.not_found_error import NotFoundError +from ._eval_sets_manager_utils import add_eval_case_to_eval_set +from ._eval_sets_manager_utils import delete_eval_case_from_eval_set +from ._eval_sets_manager_utils import get_eval_case_from_eval_set +from ._eval_sets_manager_utils import get_eval_set_from_app_and_id +from ._eval_sets_manager_utils import update_eval_case_in_eval_set from .eval_case import EvalCase from .eval_case import IntermediateData from .eval_case import Invocation @@ -218,7 +222,7 @@ def create_eval_set(self, app_name: str, eval_set_id: str): eval_cases=[], creation_timestamp=time.time(), ) - self._write_eval_set(new_eval_set_path, new_eval_set) + self._write_eval_set_to_path(new_eval_set_path, new_eval_set) @override def list_eval_sets(self, app_name: str) -> list[str]: @@ -233,51 +237,27 @@ def list_eval_sets(self, app_name: str) -> list[str]: return sorted(eval_sets) - @override - def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: EvalCase): - """Adds the given EvalCase to an existing EvalSet identified by app_name and eval_set_id. - - Raises: - NotFoundError: If the eval set is not found. - """ - eval_case_id = eval_case.eval_id - self._validate_id(id_name="Eval Case Id", id_value=eval_case_id) - - eval_set = self.get_eval_set(app_name, eval_set_id) - - if not eval_set: - raise NotFoundError(f"Eval set `{eval_set_id}` not found.") - - if [x for x in eval_set.eval_cases if x.eval_id == eval_case_id]: - raise ValueError( - f"Eval id `{eval_case_id}` already exists in `{eval_set_id}`" - " eval set.", - ) - - eval_set.eval_cases.append(eval_case) - - eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) - self._write_eval_set(eval_set_file_path, eval_set) - @override def get_eval_case( self, app_name: str, eval_set_id: str, eval_case_id: str ) -> Optional[EvalCase]: """Returns an EvalCase if found, otherwise None.""" eval_set = self.get_eval_set(app_name, eval_set_id) - if not eval_set: return None + return get_eval_case_from_eval_set(eval_set, eval_case_id) - eval_case_to_find = None + @override + def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: EvalCase): + """Adds the given EvalCase to an existing EvalSet identified by app_name and eval_set_id. - # Look up the eval case by eval_case_id - for eval_case in eval_set.eval_cases: - if eval_case.eval_id == eval_case_id: - eval_case_to_find = eval_case - break + Raises: + NotFoundError: If the eval set is not found. + """ + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = add_eval_case_to_eval_set(eval_set, eval_case) - return eval_case_to_find + self._save_eval_set(app_name, eval_set_id, updated_eval_set) @override def update_eval_case( @@ -288,28 +268,9 @@ def update_eval_case( Raises: NotFoundError: If the eval set or the eval case is not found. """ - eval_case_id = updated_eval_case.eval_id - - # Find the eval case to be updated. - eval_case_to_update = self.get_eval_case( - app_name, eval_set_id, eval_case_id - ) - - if eval_case_to_update: - # Remove the eval case from the existing eval set. - eval_set = self.get_eval_set(app_name, eval_set_id) - eval_set.eval_cases.remove(eval_case_to_update) - - # Add the updated eval case to the existing eval set. - eval_set.eval_cases.append(updated_eval_case) - - # Persit the eval set. - eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) - self._write_eval_set(eval_set_file_path, eval_set) - else: - raise NotFoundError( - f"Eval Set `{eval_set_id}` or Eval id `{eval_case_id}` not found.", - ) + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = update_eval_case_in_eval_set(eval_set, updated_eval_case) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) @override def delete_eval_case( @@ -320,25 +281,9 @@ def delete_eval_case( Raises: NotFoundError: If the eval set or the eval case to delete is not found. """ - # Find the eval case that needs to be deleted. - eval_case_to_remove = self.get_eval_case( - app_name, eval_set_id, eval_case_id - ) - - if eval_case_to_remove: - logger.info( - "EvalCase`%s` was found in the eval set. It will be removed" - " permanently.", - eval_case_id, - ) - eval_set = self.get_eval_set(app_name, eval_set_id) - eval_set.eval_cases.remove(eval_case_to_remove) - eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) - self._write_eval_set(eval_set_file_path, eval_set) - else: - raise NotFoundError( - f"Eval Set `{eval_set_id}` or Eval id `{eval_case_id}` not found.", - ) + eval_set = get_eval_set_from_app_and_id(self, app_name, eval_set_id) + updated_eval_set = delete_eval_case_from_eval_set(eval_set, eval_case_id) + self._save_eval_set(app_name, eval_set_id, updated_eval_set) def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: return os.path.join( @@ -354,6 +299,10 @@ def _validate_id(self, id_name: str, id_value: str): f"Invalid {id_name}. {id_name} should have the `{pattern}` format", ) - def _write_eval_set(self, eval_set_path: str, eval_set: EvalSet): + def _write_eval_set_to_path(self, eval_set_path: str, eval_set: EvalSet): with open(eval_set_path, "w") as f: f.write(eval_set.model_dump_json(indent=2)) + + def _save_eval_set(self, app_name: str, eval_set_id: str, eval_set: EvalSet): + eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) + self._write_eval_set_to_path(eval_set_file_path, eval_set) diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index d48c8cd20..ee5c83da1 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -65,6 +65,15 @@ async def run_async( llm_request.live_connect_config.input_audio_transcription = ( invocation_context.run_config.input_audio_transcription ) + llm_request.live_connect_config.realtime_input_config = ( + invocation_context.run_config.realtime_input_config + ) + llm_request.live_connect_config.enable_affective_dialog = ( + invocation_context.run_config.enable_affective_dialog + ) + llm_request.live_connect_config.proactivity = ( + invocation_context.run_config.proactivity + ) # TODO: handle tool append here, instead of in BaseTool.process_llm_request. diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 2541ac664..2772550c2 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -288,8 +288,7 @@ async def handle_function_calls_live( trace_tool_call( tool=tool, args=function_args, - response_event_id=function_response_event.id, - function_response=function_response, + function_response_event=function_response_event, ) function_response_events.append(function_response_event) diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 96b95ac5a..a3a0e0962 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -201,7 +201,7 @@ def function_declaration_to_tool_param( class Claude(BaseLlm): - """ "Integration with Claude models served from Vertex AI. + """Integration with Claude models served from Vertex AI. Attributes: model: The name of the Claude model. diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 400aab3d6..36aea3b20 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging from typing import AsyncGenerator diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index fd1d8e582..bff2b675c 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -95,6 +95,13 @@ async def generate_content_async( ) logger.info(_build_request_log(llm_request)) + # add tracking headers to custom headers given it will override the headers + # set in the api client constructor + if llm_request.config and llm_request.config.http_options: + if not llm_request.config.http_options.headers: + llm_request.config.http_options.headers = {} + llm_request.config.http_options.headers.update(self._tracking_headers) + if stream: responses = await self.api_client.aio.models.generate_content_stream( model=llm_request.model, @@ -201,24 +208,21 @@ def _tracking_headers(self) -> dict[str, str]: return tracking_headers @cached_property - def _live_api_client(self) -> Client: + def _live_api_version(self) -> str: if self._api_backend == GoogleLLMVariant.VERTEX_AI: # use beta version for vertex api - api_version = 'v1beta1' - # use default api version for vertex - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers, api_version=api_version - ) - ) + return 'v1beta1' else: # use v1alpha for using API KEY from Google AI Studio - api_version = 'v1alpha' - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers, api_version=api_version - ) - ) + return 'v1alpha' + + @cached_property + def _live_api_client(self) -> Client: + return Client( + http_options=types.HttpOptions( + headers=self._tracking_headers, api_version=self._live_api_version + ) + ) @contextlib.asynccontextmanager async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: @@ -230,6 +234,21 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: Yields: BaseLlmConnection, the connection to the Gemini model. """ + # add tracking headers to custom headers and set api_version given + # the customized http options will override the one set in the api client + # constructor + if ( + llm_request.live_connect_config + and llm_request.live_connect_config.http_options + ): + if not llm_request.live_connect_config.http_options.headers: + llm_request.live_connect_config.http_options.headers = {} + llm_request.live_connect_config.http_options.headers.update( + self._tracking_headers + ) + llm_request.live_connect_config.http_options.api_version = ( + self._live_api_version + ) llm_request.live_connect_config.system_instruction = types.Content( role='system', diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index ed54faecf..dce5ed7c4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -739,11 +739,12 @@ async def generate_content_async( _message_to_generate_content_response( ChatCompletionAssistantMessage( role="assistant", - content="", + content=text, tool_calls=tool_calls, ) ) ) + text = "" function_calls.clear() elif finish_reason == "stop" and text: aggregated_llm_response = _message_to_generate_content_response( diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index c4fcdfb9e..01412a2b3 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -17,7 +17,6 @@ import asyncio import logging import queue -import threading from typing import AsyncGenerator from typing import Generator from typing import Optional @@ -34,6 +33,7 @@ from .agents.run_config import RunConfig from .artifacts.base_artifact_service import BaseArtifactService from .artifacts.in_memory_artifact_service import InMemoryArtifactService +from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .events.event import Event from .memory.base_memory_service import BaseMemoryService @@ -73,6 +73,8 @@ class Runner: """The session service for the runner.""" memory_service: Optional[BaseMemoryService] = None """The memory service for the runner.""" + credential_service: Optional[BaseCredentialService] = None + """The credential service for the runner.""" def __init__( self, @@ -82,6 +84,7 @@ def __init__( artifact_service: Optional[BaseArtifactService] = None, session_service: BaseSessionService, memory_service: Optional[BaseMemoryService] = None, + credential_service: Optional[BaseCredentialService] = None, ): """Initializes the Runner. @@ -97,6 +100,7 @@ def __init__( self.artifact_service = artifact_service self.session_service = session_service self.memory_service = memory_service + self.credential_service = credential_service def run( self, @@ -418,6 +422,7 @@ def _new_invocation_context( artifact_service=self.artifact_service, session_service=self.session_service, memory_service=self.memory_service, + credential_service=self.credential_service, invocation_id=invocation_id, agent=self.agent, session=session, diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 5d6bed2e6..258dcd933 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -22,7 +22,6 @@ import urllib.parse from dateutil import parser -from google.genai import types from typing_extensions import override from google import genai @@ -40,15 +39,27 @@ class VertexAiSessionService(BaseSessionService): - """Connects to the managed Vertex AI Session Service.""" + """Connects to the Vertex AI Agent Engine Session Service using GenAI API client. + + https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/sessions/overview + """ def __init__( self, - project: str = None, - location: str = None, + project: Optional[str] = None, + location: Optional[str] = None, + agent_engine_id: Optional[str] = None, ): - self.project = project - self.location = location + """Initializes the VertexAiSessionService. + + Args: + project: The project id of the project to use. + location: The location of the project to use. + agent_engine_id: The resource ID of the agent engine to use. + """ + self._project = project + self._location = location + self._agent_engine_id = agent_engine_id @override async def create_session( @@ -64,14 +75,13 @@ async def create_session( 'User-provided Session id is not supported for' ' VertexAISessionService.' ) - - reasoning_engine_id = _parse_reasoning_engine_id(app_name) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() session_json_dict = {'user_id': user_id} if state: session_json_dict['session_state'] = state - api_client = _get_api_client(self.project, self.location) api_response = await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions', @@ -130,10 +140,10 @@ async def get_session( session_id: str, config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: - reasoning_engine_id = _parse_reasoning_engine_id(app_name) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() # Get session resource - api_client = _get_api_client(self.project, self.location) get_session_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', @@ -203,14 +213,14 @@ async def get_session( async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: - reasoning_engine_id = _parse_reasoning_engine_id(app_name) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() path = f'reasoningEngines/{reasoning_engine_id}/sessions' if user_id: parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='') path = path + f'?filter=user_id={parsed_user_id}' - api_client = _get_api_client(self.project, self.location) api_response = await api_client.async_request( http_method='GET', path=path, @@ -236,8 +246,9 @@ async def list_sessions( async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: - reasoning_engine_id = _parse_reasoning_engine_id(app_name) - api_client = _get_api_client(self.project, self.location) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() + try: await api_client.async_request( http_method='DELETE', @@ -253,8 +264,8 @@ async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. await super().append_event(session=session, event=event) - reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) - api_client = _get_api_client(self.project, self.location) + reasoning_engine_id = self._get_reasoning_engine_id(session.app_name) + api_client = self._get_api_client() await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', @@ -262,15 +273,34 @@ async def append_event(self, session: Session, event: Event) -> Event: ) return event + def _get_reasoning_engine_id(self, app_name: str): + if self._agent_engine_id: + return self._agent_engine_id -def _get_api_client(project: str, location: str): - """Instantiates an API client for the given project and location. + if app_name.isdigit(): + return app_name - It needs to be instantiated inside each request so that the event loop - management. - """ - client = genai.Client(vertexai=True, project=project, location=location) - return client._api_client + pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' + match = re.fullmatch(pattern, app_name) + + if not bool(match): + raise ValueError( + f'App name {app_name} is not valid. It should either be the full' + ' ReasoningEngine resource name, or the reasoning engine id.' + ) + + return match.groups()[-1] + + def _get_api_client(self): + """Instantiates an API client for the given project and location. + + It needs to be instantiated inside each request so that the event loop + management can be properly propagated. + """ + client = genai.Client( + vertexai=True, project=self._project, location=self._location + ) + return client._api_client def _convert_event_to_json(event: Event) -> Dict[str, Any]: @@ -366,19 +396,3 @@ def _from_api_event(api_event: Dict[str, Any]) -> Event: ) return event - - -def _parse_reasoning_engine_id(app_name: str): - if app_name.isdigit(): - return app_name - - pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' - match = re.fullmatch(pattern, app_name) - - if not bool(match): - raise ValueError( - f'App name {app_name} is not valid. It should either be the full' - ' ReasoningEngine resource name, or the reasoning engine id.' - ) - - return match.groups()[-1] diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 9c4de3660..d1137b58e 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -97,17 +97,6 @@ async def run_async( if isinstance(self.agent, LlmAgent) and self.agent.input_schema: input_value = self.agent.input_schema.model_validate(args) - else: - input_value = args['request'] - - if isinstance(self.agent, LlmAgent) and self.agent.input_schema: - if isinstance(input_value, dict): - input_value = self.agent.input_schema.model_validate(input_value) - if not isinstance(input_value, self.agent.input_schema): - raise ValueError( - f'Input value {input_value} is not of type' - f' `{self.agent.input_schema}`.' - ) content = types.Content( role='user', parts=[ @@ -119,7 +108,7 @@ async def run_async( else: content = types.Content( role='user', - parts=[types.Part.from_text(text=input_value)], + parts=[types.Part.from_text(text=args['request'])], ) runner = Runner( app_name=self.agent.name, @@ -145,15 +134,11 @@ async def run_async( if not last_event or not last_event.content or not last_event.content.parts: return '' + merged_text = '\n'.join(p.text for p in last_event.content.parts if p.text) if isinstance(self.agent, LlmAgent) and self.agent.output_schema: - merged_text = '\n'.join( - [p.text for p in last_event.content.parts if p.text] - ) tool_result = self.agent.output_schema.model_validate_json( merged_text ).model_dump(exclude_none=True) else: - tool_result = '\n'.join( - [p.text for p in last_event.content.parts if p.text] - ) + tool_result = merged_text return tool_result diff --git a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py index 4e5be5959..5a50a7f0c 100644 --- a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py +++ b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py @@ -150,7 +150,7 @@ async def run_async( tool_auth_handler = ToolAuthHandler.from_tool_context( tool_context, self._auth_scheme, self._auth_credential ) - auth_result = tool_auth_handler.prepare_auth_credentials() + auth_result = await tool_auth_handler.prepare_auth_credentials() if auth_result.state == 'pending': return { @@ -178,7 +178,7 @@ async def run_async( args['operation'] = self._operation args['action'] = self._action logger.info('Running tool: %s with args: %s', self.name, args) - return self._rest_api_tool.call(args=args, tool_context=tool_context) + return await self._rest_api_tool.call(args=args, tool_context=tool_context) def __str__(self): return ( diff --git a/src/google/adk/tools/authenticated_function_tool.py b/src/google/adk/tools/authenticated_function_tool.py new file mode 100644 index 000000000..67cc5885f --- /dev/null +++ b/src/google/adk/tools/authenticated_function_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import logging +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .function_tool import FunctionTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class AuthenticatedFunctionTool(FunctionTool): + """A FunctionTool that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + func: Callable[..., Any], + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """Initializes the AuthenticatedFunctionTool. + + Args: + func: The function to be called. + auth_config: The authentication configuration. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__(func=func) + self._ignore_params.append("credential") + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + args_to_call = args.copy() + signature = inspect.signature(self.func) + if "credential" in signature.parameters: + args_to_call["credential"] = credential + return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/base_authenticated_tool.py b/src/google/adk/tools/base_authenticated_tool.py new file mode 100644 index 000000000..4858e4953 --- /dev/null +++ b/src/google/adk/tools/base_authenticated_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +import logging +from typing import Any +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .base_tool import BaseTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class BaseAuthenticatedTool(BaseTool): + """A base tool class that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + name, + description, + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """ + Args: + name: The name of the tool. + description: The description of the tool. + auth_config: The auth configuration of the tool. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__( + name=name, + description=description, + ) + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, + tool_context=tool_context, + credential=credential, + ) + + @abstractmethod + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + pass diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index 27c3a786c..0a99136c4 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -31,12 +31,14 @@ from ...auth.auth_credential import AuthCredentialTypes from ...auth.auth_credential import OAuth2Auth from ...auth.auth_tool import AuthConfig +from ...utils.feature_decorator import experimental from ..tool_context import ToolContext BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] +@experimental class BigQueryCredentialsConfig(BaseModel): """Configuration for Google API tools. (Experimental)""" diff --git a/src/google/adk/tools/bigquery/bigquery_tool.py b/src/google/adk/tools/bigquery/bigquery_tool.py index 553012208..182734188 100644 --- a/src/google/adk/tools/bigquery/bigquery_tool.py +++ b/src/google/adk/tools/bigquery/bigquery_tool.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import inspect from typing import Any @@ -21,6 +22,7 @@ from google.oauth2.credentials import Credentials from typing_extensions import override +from ...utils.feature_decorator import experimental from ..function_tool import FunctionTool from ..tool_context import ToolContext from .bigquery_credentials import BigQueryCredentialsConfig @@ -28,6 +30,7 @@ from .config import BigQueryToolConfig +@experimental class BigQueryTool(FunctionTool): """GoogleApiTool class for tools that call Google APIs. diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 5543d103f..313cf4990 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -26,11 +26,13 @@ from ...tools.base_tool import BaseTool from ...tools.base_toolset import BaseToolset from ...tools.base_toolset import ToolPredicate +from ...utils.feature_decorator import experimental from .bigquery_credentials import BigQueryCredentialsConfig from .bigquery_tool import BigQueryTool from .config import BigQueryToolConfig +@experimental class BigQueryToolset(BaseToolset): """BigQuery Toolset contains tools for interacting with BigQuery data and metadata.""" diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index d72761b2d..23f1befc5 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -18,16 +18,20 @@ from google.cloud import bigquery from google.oauth2.credentials import Credentials -USER_AGENT = "adk-bigquery-tool" +from ... import version +USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" -def get_bigquery_client(*, credentials: Credentials) -> bigquery.Client: + +def get_bigquery_client( + *, project: str, credentials: Credentials +) -> bigquery.Client: """Get a BigQuery client.""" client_info = google.api_core.client_info.ClientInfo(user_agent=USER_AGENT) bigquery_client = bigquery.Client( - credentials=credentials, client_info=client_info + project=project, credentials=credentials, client_info=client_info ) return bigquery_client diff --git a/src/google/adk/tools/bigquery/metadata_tool.py b/src/google/adk/tools/bigquery/metadata_tool.py index 6e279d59e..4f5400611 100644 --- a/src/google/adk/tools/bigquery/metadata_tool.py +++ b/src/google/adk/tools/bigquery/metadata_tool.py @@ -42,7 +42,9 @@ def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]: 'bbc_news'] """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) datasets = [] for dataset in bq_client.list_datasets(project_id): @@ -106,7 +108,9 @@ def get_dataset_info( } """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) dataset = bq_client.get_dataset( bigquery.DatasetReference(project_id, dataset_id) ) @@ -137,7 +141,9 @@ def list_table_ids( 'local_data_for_better_health_county_data'] """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) tables = [] for table in bq_client.list_tables( @@ -251,7 +257,9 @@ def get_table_info( } """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) return bq_client.get_table( bigquery.TableReference( bigquery.DatasetReference(project_id, dataset_id), table_id diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 80b56aad3..d3a94fda7 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -72,7 +72,9 @@ def execute_sql( """ try: - bq_client = client.get_bigquery_client(credentials=credentials) + bq_client = client.get_bigquery_client( + project=project_id, credentials=credentials + ) if not config or config.write_mode == WriteMode.BLOCKED: query_job = bq_client.query( query, diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 3a07a6fe1..90b39e6cb 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -14,12 +14,16 @@ from __future__ import annotations +import asyncio from contextlib import AsyncExitStack from datetime import timedelta import functools +import hashlib +import json import logging import sys from typing import Any +from typing import Dict from typing import Optional from typing import TextIO from typing import Union @@ -34,7 +38,6 @@ from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client except ImportError as e: - import sys if sys.version_info < (3, 10): raise ImportError( @@ -105,69 +108,39 @@ class StreamableHTTPConnectionParams(BaseModel): terminate_on_close: bool = True -def retry_on_closed_resource(async_reinit_func_name: str): - """Decorator to automatically reinitialize session and retry action. +def retry_on_closed_resource(func): + """Decorator to automatically retry action when MCP session is closed. - When MCP session was closed, the decorator will automatically recreate the - session and retry the action with the same parameters. - - Note: - 1. async_reinit_func_name is the name of the class member function that - reinitializes the MCP session. - 2. Both the decorated function and the async_reinit_func_name must be async - functions. - - Usage: - class MCPTool: - ... - async def create_session(self): - self.session = ... - - @retry_on_closed_resource('create_session') - async def use_session(self): - await self.session.call_tool() + When MCP session was closed, the decorator will automatically retry the + action once. The create_session method will handle creating a new session + if the old one was disconnected. Args: - async_reinit_func_name: The name of the async function to recreate session. + func: The function to decorate. Returns: The decorated function. """ - def decorator(func): - @functools.wraps(func) # Preserves original function metadata - async def wrapper(self, *args, **kwargs): - try: - return await func(self, *args, **kwargs) - except anyio.ClosedResourceError as close_err: - try: - if hasattr(self, async_reinit_func_name) and callable( - getattr(self, async_reinit_func_name) - ): - async_init_fn = getattr(self, async_reinit_func_name) - await async_init_fn() - else: - raise ValueError( - f'Function {async_reinit_func_name} does not exist in decorated' - ' class. Please check the function name in' - ' retry_on_closed_resource decorator.' - ) from close_err - except Exception as reinit_err: - raise RuntimeError( - f'Error reinitializing: {reinit_err}' - ) from reinit_err - return await func(self, *args, **kwargs) - - return wrapper - - return decorator + @functools.wraps(func) # Preserves original function metadata + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except anyio.ClosedResourceError: + # Simply retry the function - create_session will handle + # detecting and replacing disconnected sessions + logger.info('Retrying %s due to closed resource', func.__name__) + return await func(self, *args, **kwargs) + + return wrapper class MCPSessionManager: """Manages MCP client sessions. This class provides methods for creating and initializing MCP client sessions, - handling different connection parameters (Stdio and SSE). + handling different connection parameters (Stdio and SSE) and supporting + session pooling based on authentication headers. """ def __init__( @@ -204,93 +177,203 @@ def __init__( else: self._connection_params = connection_params self._errlog = errlog - # Each session manager maintains its own exit stack for proper cleanup - self._exit_stack: Optional[AsyncExitStack] = None - self._session: Optional[ClientSession] = None - async def create_session(self) -> ClientSession: + # Session pool: maps session keys to (session, exit_stack) tuples + self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} + + # Lock to prevent race conditions in session creation + self._session_lock = asyncio.Lock() + + def _generate_session_key( + self, merged_headers: Optional[Dict[str, str]] = None + ) -> str: + """Generates a session key based on connection params and merged headers. + + For StdioConnectionParams, returns a constant key since headers are not + supported. For SSE and StreamableHTTP connections, generates a key based + on the provided merged headers. + + Args: + merged_headers: Already merged headers (base + additional). + + Returns: + A unique session key string. + """ + if isinstance(self._connection_params, StdioConnectionParams): + # For stdio connections, headers are not supported, so use constant key + return 'stdio_session' + + # For SSE and StreamableHTTP connections, use merged headers + if merged_headers: + headers_json = json.dumps(merged_headers, sort_keys=True) + headers_hash = hashlib.md5(headers_json.encode()).hexdigest() + return f'session_{headers_hash}' + else: + return 'session_no_headers' + + def _merge_headers( + self, additional_headers: Optional[Dict[str, str]] = None + ) -> Optional[Dict[str, str]]: + """Merges base connection headers with additional headers. + + Args: + additional_headers: Optional headers to merge with connection headers. + + Returns: + Merged headers dictionary, or None if no headers are provided. + """ + if isinstance(self._connection_params, StdioConnectionParams) or isinstance( + self._connection_params, StdioServerParameters + ): + # Stdio connections don't support headers + return None + + base_headers = {} + if ( + hasattr(self._connection_params, 'headers') + and self._connection_params.headers + ): + base_headers = self._connection_params.headers.copy() + + if additional_headers: + base_headers.update(additional_headers) + + return base_headers + + def _is_session_disconnected(self, session: ClientSession) -> bool: + """Checks if a session is disconnected or closed. + + Args: + session: The ClientSession to check. + + Returns: + True if the session is disconnected, False otherwise. + """ + return session._read_stream._closed or session._write_stream._closed + + async def create_session( + self, headers: Optional[Dict[str, str]] = None + ) -> ClientSession: """Creates and initializes an MCP client session. + This method will check if an existing session for the given headers + is still connected. If it's disconnected, it will be cleaned up and + a new session will be created. + + Args: + headers: Optional headers to include in the session. These will be + merged with any existing connection headers. Only applicable + for SSE and StreamableHTTP connections. + Returns: ClientSession: The initialized MCP client session. """ - if self._session is not None: - return self._session + # Merge headers once at the beginning + merged_headers = self._merge_headers(headers) + + # Generate session key using merged headers + session_key = self._generate_session_key(merged_headers) + + # Use async lock to prevent race conditions + async with self._session_lock: + # Check if we have an existing session + if session_key in self._sessions: + session, exit_stack = self._sessions[session_key] + + # Check if the existing session is still connected + if not self._is_session_disconnected(session): + # Session is still good, return it + return session + else: + # Session is disconnected, clean it up + logger.info('Cleaning up disconnected session: %s', session_key) + try: + await exit_stack.aclose() + except Exception as e: + logger.warning('Error during disconnected session cleanup: %s', e) + finally: + del self._sessions[session_key] + + # Create a new session (either first time or replacing disconnected one) + exit_stack = AsyncExitStack() - # Create a new exit stack for this session - self._exit_stack = AsyncExitStack() - - try: - if isinstance(self._connection_params, StdioConnectionParams): - client = stdio_client( - server=self._connection_params.server_params, - errlog=self._errlog, - ) - elif isinstance(self._connection_params, SseConnectionParams): - client = sse_client( - url=self._connection_params.url, - headers=self._connection_params.headers, - timeout=self._connection_params.timeout, - sse_read_timeout=self._connection_params.sse_read_timeout, - ) - elif isinstance(self._connection_params, StreamableHTTPConnectionParams): - client = streamablehttp_client( - url=self._connection_params.url, - headers=self._connection_params.headers, - timeout=timedelta(seconds=self._connection_params.timeout), - sse_read_timeout=timedelta( - seconds=self._connection_params.sse_read_timeout - ), - terminate_on_close=self._connection_params.terminate_on_close, - ) - else: - raise ValueError( - 'Unable to initialize connection. Connection should be' - ' StdioServerParameters or SseServerParams, but got' - f' {self._connection_params}' - ) - - transports = await self._exit_stack.enter_async_context(client) - # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams - # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. - if isinstance(self._connection_params, StdioConnectionParams): - session = await self._exit_stack.enter_async_context( - ClientSession( - *transports[:2], - read_timeout_seconds=timedelta( - seconds=self._connection_params.timeout - ), - ) - ) - else: - session = await self._exit_stack.enter_async_context( - ClientSession(*transports[:2]) - ) - await session.initialize() - - self._session = session - return session - - except Exception: - # If session creation fails, clean up the exit stack - if self._exit_stack: - await self._exit_stack.aclose() - self._exit_stack = None - raise + try: + if isinstance(self._connection_params, StdioConnectionParams): + client = stdio_client( + server=self._connection_params.server_params, + errlog=self._errlog, + ) + elif isinstance(self._connection_params, SseConnectionParams): + client = sse_client( + url=self._connection_params.url, + headers=merged_headers, + timeout=self._connection_params.timeout, + sse_read_timeout=self._connection_params.sse_read_timeout, + ) + elif isinstance( + self._connection_params, StreamableHTTPConnectionParams + ): + client = streamablehttp_client( + url=self._connection_params.url, + headers=merged_headers, + timeout=timedelta(seconds=self._connection_params.timeout), + sse_read_timeout=timedelta( + seconds=self._connection_params.sse_read_timeout + ), + terminate_on_close=self._connection_params.terminate_on_close, + ) + else: + raise ValueError( + 'Unable to initialize connection. Connection should be' + ' StdioServerParameters or SseServerParams, but got' + f' {self._connection_params}' + ) + + transports = await exit_stack.enter_async_context(client) + # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams + # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. + if isinstance(self._connection_params, StdioConnectionParams): + session = await exit_stack.enter_async_context( + ClientSession( + *transports[:2], + read_timeout_seconds=timedelta( + seconds=self._connection_params.timeout + ), + ) + ) + else: + session = await exit_stack.enter_async_context( + ClientSession(*transports[:2]) + ) + await session.initialize() + + # Store session and exit stack in the pool + self._sessions[session_key] = (session, exit_stack) + logger.debug('Created new session: %s', session_key) + return session + + except Exception: + # If session creation fails, clean up the exit stack + if exit_stack: + await exit_stack.aclose() + raise async def close(self): - """Closes the session and cleans up resources.""" - if self._exit_stack: - try: - await self._exit_stack.aclose() - except Exception as e: - # Log the error but don't re-raise to avoid blocking shutdown - print( - f'Warning: Error during MCP session cleanup: {e}', file=self._errlog - ) - finally: - self._exit_stack = None - self._session = None + """Closes all sessions and cleans up resources.""" + async with self._session_lock: + for session_key in list(self._sessions.keys()): + _, exit_stack = self._sessions[session_key] + try: + await exit_stack.aclose() + except Exception as e: + # Log the error but don't re-raise to avoid blocking shutdown + print( + 'Warning: Error during MCP session cleanup for' + f' {session_key}: {e}', + file=self._errlog, + ) + finally: + del self._sessions[session_key] SseServerParams = SseConnectionParams diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 463202b18..310fc48f1 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -14,10 +14,13 @@ from __future__ import annotations +import base64 +import json import logging from typing import Optional from google.genai.types import FunctionDeclaration +from google.oauth2.credentials import Credentials from typing_extensions import override from .._gemini_schema_util import _to_gemini_schema @@ -42,14 +45,16 @@ from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme -from ..base_tool import BaseTool +from ...auth.auth_tool import AuthConfig +from ..base_authenticated_tool import BaseAuthenticatedTool +# import from ..tool_context import ToolContext logger = logging.getLogger("google_adk." + __name__) -class MCPTool(BaseTool): - """Turns a MCP Tool into a Vertex Agent Framework Tool. +class MCPTool(BaseAuthenticatedTool): + """Turns an MCP Tool into an ADK Tool. Internally, the tool initializes from a MCP Tool, and uses the MCP Session to call the tool. @@ -63,9 +68,9 @@ def __init__( auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, ): - """Initializes a MCPTool. + """Initializes an MCPTool. - This tool wraps a MCP Tool interface and uses a session manager to + This tool wraps an MCP Tool interface and uses a session manager to communicate with the MCP server. Args: @@ -77,19 +82,17 @@ def __init__( Raises: ValueError: If mcp_tool or mcp_session_manager is None. """ - if mcp_tool is None: - raise ValueError("mcp_tool cannot be None") - if mcp_session_manager is None: - raise ValueError("mcp_session_manager cannot be None") super().__init__( name=mcp_tool.name, description=mcp_tool.description if mcp_tool.description else "", + auth_config=AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential + ) + if auth_scheme + else None, ) self._mcp_tool = mcp_tool self._mcp_session_manager = mcp_session_manager - # TODO(cheliu): Support passing auth to MCP Server. - self._auth_scheme = auth_scheme - self._auth_credential = auth_credential @override def _get_declaration(self) -> FunctionDeclaration: @@ -105,26 +108,74 @@ def _get_declaration(self) -> FunctionDeclaration: ) return function_decl - @retry_on_closed_resource("_reinitialize_session") - async def run_async(self, *, args, tool_context: ToolContext): + @retry_on_closed_resource + @override + async def _run_async_impl( + self, *, args, tool_context: ToolContext, credential: AuthCredential + ): """Runs the tool asynchronously. Args: args: The arguments as a dict to pass to the tool. - tool_context: The tool context from upper level ADK agent. + tool_context: The tool context of the current invocation. Returns: Any: The response from the tool. """ + # Extract headers from credential for session pooling + headers = await self._get_headers(tool_context, credential) + # Get the session from the session manager - session = await self._mcp_session_manager.create_session() + session = await self._mcp_session_manager.create_session(headers=headers) - # TODO(cheliu): Support passing tool context to MCP Server. response = await session.call_tool(self.name, arguments=args) return response - async def _reinitialize_session(self): - """Reinitializes the session when connection is lost.""" - # Close the old session and create a new one - await self._mcp_session_manager.close() - await self._mcp_session_manager.create_session() + async def _get_headers( + self, tool_context: ToolContext, credential: AuthCredential + ) -> Optional[dict[str, str]]: + headers = None + if credential: + if credential.oauth2: + headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} + elif credential.http: + # Handle HTTP authentication schemes + if ( + credential.http.scheme.lower() == "bearer" + and credential.http.credentials.token + ): + headers = { + "Authorization": f"Bearer {credential.http.credentials.token}" + } + elif credential.http.scheme.lower() == "basic": + # Handle basic auth + if ( + credential.http.credentials.username + and credential.http.credentials.password + ): + + credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}" + encoded_credentials = base64.b64encode( + credentials.encode() + ).decode() + headers = {"Authorization": f"Basic {encoded_credentials}"} + elif credential.http.credentials.token: + # Handle other HTTP schemes with token + headers = { + "Authorization": ( + f"{credential.http.scheme} {credential.http.credentials.token}" + ) + } + elif credential.api_key: + # For API keys, we'll add them as headers since MCP typically uses header-based auth + # The specific header name would depend on the API, using a common default + # TODO Allow user to specify the header name for API keys. + headers = {"X-API-Key": credential.api_key} + elif credential.service_account: + # Service accounts should be exchanged for access tokens before reaching this point + logger.warning( + "Service account credentials should be exchanged before MCP" + " session creation" + ) + + return headers diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 8076752b4..c01b0cec2 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -22,16 +22,16 @@ from typing import Union from ...agents.readonly_context import ReadonlyContext +from ...auth.auth_credential import AuthCredential +from ...auth.auth_schemes import AuthScheme from ..base_tool import BaseTool from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_closed_resource from .mcp_session_manager import SseConnectionParams -from .mcp_session_manager import SseServerParams from .mcp_session_manager import StdioConnectionParams from .mcp_session_manager import StreamableHTTPConnectionParams -from .mcp_session_manager import StreamableHTTPServerParams # Attempt to import MCP Tool from the MCP library, and hints user to upgrade # their Python version to 3.10 if it fails. @@ -96,6 +96,8 @@ def __init__( ], tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, errlog: TextIO = sys.stderr, + auth_scheme: Optional[AuthScheme] = None, + auth_credential: Optional[AuthCredential] = None, ): """Initializes the MCPToolset. @@ -112,6 +114,8 @@ def __init__( list of tool names to include - A ToolPredicate function for custom filtering logic errlog: TextIO stream for error logging. + auth_scheme: The auth scheme of the tool for tool calling + auth_credential: The auth credential of the tool for tool calling """ super().__init__(tool_filter=tool_filter) @@ -126,10 +130,10 @@ def __init__( connection_params=self._connection_params, errlog=self._errlog, ) + self._auth_scheme = auth_scheme + self._auth_credential = auth_credential - self._session = None - - @retry_on_closed_resource("_reinitialize_session") + @retry_on_closed_resource async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None, @@ -144,11 +148,10 @@ async def get_tools( List[BaseTool]: A list of tools available under the specified context. """ # Get session from session manager - if not self._session: - self._session = await self._mcp_session_manager.create_session() + session = await self._mcp_session_manager.create_session() # Fetch available tools from the MCP server - tools_response: ListToolsResult = await self._session.list_tools() + tools_response: ListToolsResult = await session.list_tools() # Apply filtering based on context and tool_filter tools = [] @@ -156,20 +159,14 @@ async def get_tools( mcp_tool = MCPTool( mcp_tool=tool, mcp_session_manager=self._mcp_session_manager, + auth_scheme=self._auth_scheme, + auth_credential=self._auth_credential, ) if self._is_tool_selected(mcp_tool, readonly_context): tools.append(mcp_tool) return tools - async def _reinitialize_session(self): - """Reinitializes the session when connection is lost.""" - # Close the old session and clear cache - await self._mcp_session_manager.close() - self._session = await self._mcp_session_manager.create_session() - - # Tools will be reloaded on next get_tools call - async def close(self) -> None: """Performs cleanup and releases resources held by the toolset. @@ -182,7 +179,3 @@ async def close(self) -> None: except Exception as e: # Log the error but don't re-raise to avoid blocking shutdown print(f"Warning: Error during MCPToolset cleanup: {e}", file=self._errlog) - finally: - # Clear cached tools - self._tools_cache = None - self._tools_loaded = False diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 1e451fe0f..dee103932 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -345,9 +345,9 @@ def _prepare_request_params( async def run_async( self, *, args: dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: - return self.call(args=args, tool_context=tool_context) + return await self.call(args=args, tool_context=tool_context) - def call( + async def call( self, *, args: dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: """Executes the REST API call. @@ -364,7 +364,7 @@ def call( tool_auth_handler = ToolAuthHandler.from_tool_context( tool_context, self.auth_scheme, self.auth_credential ) - auth_result = tool_auth_handler.prepare_auth_credentials() + auth_result = await tool_auth_handler.prepare_auth_credentials() auth_state, auth_scheme, auth_credential = ( auth_result.state, auth_result.auth_scheme, diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index 71e760e30..74166b00e 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import logging from typing import Literal from typing import Optional -from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from ....auth.auth_credential import AuthCredential @@ -25,6 +25,7 @@ from ....auth.auth_schemes import AuthScheme from ....auth.auth_schemes import AuthSchemeType from ....auth.auth_tool import AuthConfig +from ....auth.refresher.oauth2_credential_refresher import OAuth2CredentialRefresher from ...tool_context import ToolContext from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError @@ -95,10 +96,9 @@ def store_credential( auth_credential: Optional[AuthCredential], ): if self.tool_context: - serializable_credential = jsonable_encoder( - auth_credential, exclude_none=True + self.tool_context.state[key] = auth_credential.model_dump( + exclude_none=True ) - self.tool_context.state[key] = serializable_credential def remove_credential(self, key: str): del self.tool_context.state[key] @@ -146,20 +146,22 @@ def from_tool_context( credential_store, ) - def _handle_existing_credential( + async def _get_existing_credential( self, - ) -> Optional[AuthPreparationResult]: + ) -> Optional[AuthCredential]: """Checks for and returns an existing, exchanged credential.""" if self.credential_store: existing_credential = self.credential_store.get_credential( self.auth_scheme, self.auth_credential ) if existing_credential: - return AuthPreparationResult( - state="done", - auth_scheme=self.auth_scheme, - auth_credential=existing_credential, - ) + if existing_credential.oauth2: + refresher = OAuth2CredentialRefresher() + if await refresher.is_refresh_needed(existing_credential): + existing_credential = await refresher.refresh( + existing_credential, self.auth_scheme + ) + return existing_credential return None def _exchange_credential( @@ -223,7 +225,17 @@ def _get_auth_response(self) -> AuthCredential: ) ) - def prepare_auth_credentials( + def _external_exchange_required(self, credential) -> bool: + return ( + credential.auth_type + in ( + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + ) + and not credential.oauth2.access_token + ) + + async def prepare_auth_credentials( self, ) -> AuthPreparationResult: """Prepares authentication credentials, handling exchange and user interaction.""" @@ -233,31 +245,41 @@ def prepare_auth_credentials( return AuthPreparationResult(state="done") # Check for existing credential. - existing_result = self._handle_existing_credential() - if existing_result: - return existing_result + existing_credential = await self._get_existing_credential() + credential = existing_credential or self.auth_credential # fetch credential from adk framework # Some auth scheme like OAuth2 AuthCode & OpenIDConnect may require # multi-step exchange: # client_id , client_secret -> auth_uri -> auth_code -> access_token - # -> bearer token # adk framework supports exchange access_token already - fetched_credential = self._get_auth_response() or self.auth_credential - - exchanged_credential = self._exchange_credential(fetched_credential) + # for other credential, adk can also get back the credential directly + if not credential or self._external_exchange_required(credential): + credential = self._get_auth_response() + # store fetched credential + if credential: + self._store_credential(credential) + else: + self._request_credential() + return AuthPreparationResult( + state="pending", + auth_scheme=self.auth_scheme, + auth_credential=self.auth_credential, + ) - if exchanged_credential: - self._store_credential(exchanged_credential) - return AuthPreparationResult( - state="done", - auth_scheme=self.auth_scheme, - auth_credential=exchanged_credential, - ) - else: - self._request_credential() - return AuthPreparationResult( - state="pending", - auth_scheme=self.auth_scheme, - auth_credential=self.auth_credential, - ) + # here exchangers are doing two different thing: + # for service account the exchanger is doing actualy token exchange + # while for oauth2 it's actually doing the credentail conversion + # from OAuth2 credential to HTTP credentails for setting credential in + # http header + # TODO cleanup the logic: + # 1. service account token exchanger should happen before we store them in + # the token store + # 2. blow line should only do credential conversion + + exchanged_credential = self._exchange_credential(credential) + return AuthPreparationResult( + state="done", + auth_scheme=self.auth_scheme, + auth_credential=exchanged_credential, + ) diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index c370e2a72..b00cd0329 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -39,6 +39,9 @@ def __init__( self, *, data_store_id: Optional[str] = None, + data_store_specs: Optional[ + list[types.VertexAISearchDataStoreSpec] + ] = None, search_engine_id: Optional[str] = None, filter: Optional[str] = None, max_results: Optional[int] = None, @@ -49,6 +52,8 @@ def __init__( data_store_id: The Vertex AI search data store resource ID in the format of "projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}". + data_store_specs: Specifications that define the specific DataStores to be + searched. It should only be set if engine is used. search_engine_id: The Vertex AI search engine resource ID in the format of "projects/{project}/locations/{location}/collections/{collection}/engines/{engine}". @@ -64,7 +69,12 @@ def __init__( raise ValueError( 'Either data_store_id or search_engine_id must be specified.' ) + if data_store_specs is not None and search_engine_id is None: + raise ValueError( + 'search_engine_id must be specified if data_store_specs is specified.' + ) self.data_store_id = data_store_id + self.data_store_specs = data_store_specs self.search_engine_id = search_engine_id self.filter = filter self.max_results = max_results @@ -76,8 +86,8 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if llm_request.model and llm_request.model.startswith('gemini-'): - if llm_request.model.startswith('gemini-1') and llm_request.config.tools: + if llm_request.model and 'gemini-' in llm_request.model: + if 'gemini-1' in llm_request.model and llm_request.config.tools: raise ValueError( 'Vertex AI search tool can not be used with other tools in Gemini' ' 1.x.' @@ -89,6 +99,7 @@ async def process_llm_request( retrieval=types.Retrieval( vertex_ai_search=types.VertexAISearch( datastore=self.data_store_id, + data_store_specs=self.data_store_specs, engine=self.search_engine_id, filter=self.filter, max_results=self.max_results, diff --git a/src/google/adk/utils/feature_decorator.py b/src/google/adk/utils/feature_decorator.py index 301637319..d597063ae 100644 --- a/src/google/adk/utils/feature_decorator.py +++ b/src/google/adk/utils/feature_decorator.py @@ -15,61 +15,126 @@ from __future__ import annotations import functools +import os from typing import Callable from typing import cast +from typing import Optional from typing import TypeVar from typing import Union import warnings +from dotenv import load_dotenv + T = TypeVar("T", bound=Union[Callable, type]) def _make_feature_decorator( - *, label: str, default_message: str -) -> Callable[[str], Callable[[T], T]]: - def decorator_factory(message: str = default_message) -> Callable[[T], T]: - def decorator(obj: T) -> T: - obj_name = getattr(obj, "__name__", type(obj).__name__) - warn_msg = f"[{label.upper()}] {obj_name}: {message}" + *, + label: str, + default_message: str, + block_usage: bool = False, + bypass_env_var: Optional[str] = None, +) -> Callable: + def decorator_factory(message_or_obj=None): + # Case 1: Used as @decorator without parentheses + # message_or_obj is the decorated class/function + if message_or_obj is not None and ( + isinstance(message_or_obj, type) or callable(message_or_obj) + ): + return _create_decorator( + default_message, label, block_usage, bypass_env_var + )(message_or_obj) + + # Case 2: Used as @decorator() with or without message + # message_or_obj is either None or a string message + message = ( + message_or_obj if isinstance(message_or_obj, str) else default_message + ) + return _create_decorator(message, label, block_usage, bypass_env_var) - if isinstance(obj, type): # decorating a class - orig_init = obj.__init__ + return decorator_factory - @functools.wraps(orig_init) - def new_init(self, *args, **kwargs): - warnings.warn(warn_msg, category=UserWarning, stacklevel=2) - return orig_init(self, *args, **kwargs) - obj.__init__ = new_init # type: ignore[attr-defined] - return cast(T, obj) +def _create_decorator( + message: str, label: str, block_usage: bool, bypass_env_var: Optional[str] +) -> Callable[[T], T]: + def decorator(obj: T) -> T: + obj_name = getattr(obj, "__name__", type(obj).__name__) + msg = f"[{label.upper()}] {obj_name}: {message}" - elif callable(obj): # decorating a function or method + if isinstance(obj, type): # decorating a class + orig_init = obj.__init__ - @functools.wraps(obj) - def wrapper(*args, **kwargs): - warnings.warn(warn_msg, category=UserWarning, stacklevel=2) - return obj(*args, **kwargs) + @functools.wraps(orig_init) + def new_init(self, *args, **kwargs): + # Load .env file if dotenv is available + load_dotenv() - return cast(T, wrapper) + # Check if usage should be bypassed via environment variable at call time + should_bypass = ( + bypass_env_var is not None + and os.environ.get(bypass_env_var, "").lower() == "true" + ) - else: - raise TypeError( - f"@{label} can only be applied to classes or callable objects" + if should_bypass: + # Bypass completely - no warning, no error + pass + elif block_usage: + raise RuntimeError(msg) + else: + warnings.warn(msg, category=UserWarning, stacklevel=2) + return orig_init(self, *args, **kwargs) + + obj.__init__ = new_init # type: ignore[attr-defined] + return cast(T, obj) + + elif callable(obj): # decorating a function or method + + @functools.wraps(obj) + def wrapper(*args, **kwargs): + # Load .env file if dotenv is available + load_dotenv() + + # Check if usage should be bypassed via environment variable at call time + should_bypass = ( + bypass_env_var is not None + and os.environ.get(bypass_env_var, "").lower() == "true" ) - return decorator + if should_bypass: + # Bypass completely - no warning, no error + pass + elif block_usage: + raise RuntimeError(msg) + else: + warnings.warn(msg, category=UserWarning, stacklevel=2) + return obj(*args, **kwargs) - return decorator_factory + return cast(T, wrapper) + + else: + raise TypeError( + f"@{label} can only be applied to classes or callable objects" + ) + + return decorator working_in_progress = _make_feature_decorator( label="WIP", default_message=( - "This feature is a work in progress and may be incomplete or unstable." + "This feature is a work in progress and is not working completely. ADK" + " users are not supposed to use it." ), + block_usage=True, + bypass_env_var="ADK_ALLOW_WIP_FEATURES", ) """Mark a class or function as a work in progress. +By default, decorated functions/classes will raise RuntimeError when used. +Set ADK_ALLOW_WIP_FEATURES=true environment variable to bypass this restriction. +ADK users are not supposed to set this environment variable. + Sample usage: ``` @@ -92,8 +157,19 @@ def my_wip_function(): Sample usage: ``` -@experimental("This API may have breaking change in the future.") +# Use with default message +@experimental class ExperimentalClass: pass + +# Use with custom message +@experimental("This API may have breaking change in the future.") +class CustomExperimentalClass: + pass + +# Use with empty parentheses (same as default message) +@experimental() +def experimental_function(): + pass ``` """ diff --git a/tests/integration/models/test_litellm_no_function.py b/tests/integration/models/test_litellm_no_function.py new file mode 100644 index 000000000..05072b899 --- /dev/null +++ b/tests/integration/models/test_litellm_no_function.py @@ -0,0 +1,165 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.models import LlmRequest +from google.adk.models import LlmResponse +from google.adk.models.lite_llm import LiteLlm +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" + +_SYSTEM_PROMPT = """You are a helpful assistant.""" + + +def get_weather(city: str) -> str: + """Simulates a web search. Use it get information on weather. + + Args: + city: A string containing the location to get weather information for. + + Returns: + A string with the simulated weather information for the queried city. + """ + if "sf" in city.lower() or "san francisco" in city.lower(): + return "It's 70 degrees and foggy." + return "It's 80 degrees and sunny." + + +@pytest.fixture +def oss_llm(): + return LiteLlm(model=_TEST_MODEL_NAME) + + +@pytest.fixture +def llm_request(): + return LlmRequest( + model=_TEST_MODEL_NAME, + contents=[Content(role="user", parts=[Part.from_text(text="hello")])], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction=_SYSTEM_PROMPT, + ), + ) + + +@pytest.fixture +def llm_request_with_tools(): + return LlmRequest( + model=_TEST_MODEL_NAME, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="What is the weather in San Francisco?") + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction=_SYSTEM_PROMPT, + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="get_weather", + description="Get the weather in a given location", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "city": types.Schema( + type=types.Type.STRING, + description=( + "The city to get the weather for." + ), + ), + }, + required=["city"], + ), + ) + ] + ) + ], + ), + ) + + +@pytest.mark.asyncio +async def test_generate_content_async(oss_llm, llm_request): + async for response in oss_llm.generate_content_async(llm_request): + assert isinstance(response, LlmResponse) + assert response.content.parts[0].text + + +@pytest.mark.asyncio +async def test_generate_content_async(oss_llm, llm_request): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request, stream=False + ) + ] + part = responses[0].content.parts[0] + assert len(part.text) > 0 + + +@pytest.mark.asyncio +async def test_generate_content_async_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=False + ) + ] + function_call = responses[0].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_async_stream(oss_llm, llm_request): + responses = [ + resp + async for resp in oss_llm.generate_content_async(llm_request, stream=True) + ] + text = "" + for i in range(len(responses) - 1): + assert responses[i].partial is True + assert responses[i].content.parts[0].text + text += responses[i].content.parts[0].text + + # Last message should be accumulated text + assert responses[-1].content.parts[0].text == text + assert not responses[-1].partial + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" diff --git a/tests/integration/models/test_litellm_with_function.py b/tests/integration/models/test_litellm_with_function.py new file mode 100644 index 000000000..799c55e5c --- /dev/null +++ b/tests/integration/models/test_litellm_with_function.py @@ -0,0 +1,115 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.models import LlmRequest +from google.adk.models.lite_llm import LiteLlm +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import litellm +import pytest + +litellm.add_function_to_prompt = True + +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" + +_SYSTEM_PROMPT = """ +You are a helpful assistant, and call tools optionally. +If call tools, the tool format should be in json body, and the tool argument values should be parsed from users inputs. +""" + + +_FUNCTIONS = [{ + "name": "get_weather", + "description": "Get the weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to get the weather for.", + }, + }, + "required": ["city"], + }, +}] + + +def get_weather(city: str) -> str: + """Simulates a web search. Use it get information on weather. + + Args: + city: A string containing the location to get weather information for. + + Returns: + A string with the simulated weather information for the queried city. + """ + if "sf" in city.lower() or "san francisco" in city.lower(): + return "It's 70 degrees and foggy." + return "It's 80 degrees and sunny." + + +@pytest.fixture +def oss_llm_with_function(): + return LiteLlm(model=_TEST_MODEL_NAME, functions=_FUNCTIONS) + + +@pytest.fixture +def llm_request(): + return LlmRequest( + model=_TEST_MODEL_NAME, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="What is the weather in San Francisco?") + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction=_SYSTEM_PROMPT, + ), + ) + + +@pytest.mark.asyncio +async def test_generate_content_asyn_with_function( + oss_llm_with_function, llm_request +): + responses = [ + resp + async for resp in oss_llm_with_function.generate_content_async( + llm_request, stream=False + ) + ] + function_call = responses[0].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_asyn_stream_with_function( + oss_llm_with_function, llm_request +): + responses = [ + resp + async for resp in oss_llm_with_function.generate_content_async( + llm_request, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" diff --git a/tests/unittests/a2a/__init__.py b/tests/unittests/a2a/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/converters/__init__.py b/tests/unittests/a2a/converters/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py new file mode 100644 index 000000000..5ad6cd62d --- /dev/null +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -0,0 +1,443 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import Mock +from unittest.mock import patch + +from a2a import types as a2a_types +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part +from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part +from google.genai import types as genai_types +import pytest + + +class TestConvertA2aPartToGenaiPart: + """Test cases for convert_a2a_part_to_genai_part function.""" + + def test_convert_text_part(self): + """Test conversion of A2A TextPart to GenAI Part.""" + # Arrange + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="Hello, world!")) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == "Hello, world!" + + def test_convert_file_part_with_uri(self): + """Test conversion of A2A FilePart with URI to GenAI Part.""" + # Arrange + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri="gs://bucket/file.txt", mimeType="text/plain" + ) + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.file_data is not None + assert result.file_data.file_uri == "gs://bucket/file.txt" + assert result.file_data.mime_type == "text/plain" + + def test_convert_file_part_with_bytes(self): + """Test conversion of A2A FilePart with bytes to GenAI Part.""" + # Arrange + test_bytes = b"test file content" + # Note: A2A FileWithBytes converts bytes to string automatically + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithBytes( + bytes=test_bytes, mimeType="text/plain" + ) + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.inline_data is not None + # Source code now properly converts A2A string back to bytes for GenAI Blob + assert result.inline_data.data == test_bytes + assert result.inline_data.mime_type == "text/plain" + + def test_convert_data_part_function_call(self): + """Test conversion of A2A DataPart with function call metadata.""" + # Arrange + function_call_data = { + "name": "test_function", + "args": {"param1": "value1", "param2": 42}, + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=function_call_data, + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.function_call is not None + assert result.function_call.name == "test_function" + assert result.function_call.args == {"param1": "value1", "param2": 42} + + def test_convert_data_part_function_response(self): + """Test conversion of A2A DataPart with function response metadata.""" + # Arrange + function_response_data = { + "name": "test_function", + "response": {"result": "success", "data": [1, 2, 3]}, + } + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data=function_response_data, + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.function_response is not None + assert result.function_response.name == "test_function" + assert result.function_response.response == { + "result": "success", + "data": [1, 2, 3], + } + + def test_convert_data_part_without_special_metadata(self): + """Test conversion of A2A DataPart without special metadata to text.""" + # Arrange + data = {"key": "value", "number": 123} + a2a_part = a2a_types.Part( + root=a2a_types.DataPart(data=data, metadata={"other": "metadata"}) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == json.dumps(data) + + def test_convert_data_part_no_metadata(self): + """Test conversion of A2A DataPart with no metadata to text.""" + # Arrange + data = {"key": "value", "array": [1, 2, 3]} + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data)) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert isinstance(result, genai_types.Part) + assert result.text == json.dumps(data) + + def test_convert_unsupported_file_type(self): + """Test handling of unsupported file types.""" + + # Arrange - Create a mock unsupported file type + class UnsupportedFileType: + pass + + # Create a part manually since FilePart validation might reject it + mock_file_part = Mock() + mock_file_part.file = UnsupportedFileType() + a2a_part = Mock() + a2a_part.root = mock_file_part + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + def test_convert_unsupported_part_type(self): + """Test handling of unsupported part types.""" + + # Arrange - Create a mock unsupported part type + class UnsupportedPartType: + pass + + mock_part = Mock() + mock_part.root = UnsupportedPartType() + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_a2a_part_to_genai_part(mock_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + +class TestConvertGenaiPartToA2aPart: + """Test cases for convert_genai_part_to_a2a_part function.""" + + def test_convert_text_part(self): + """Test conversion of GenAI text Part to A2A Part.""" + # Arrange + genai_part = genai_types.Part(text="Hello, world!") + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.TextPart) + assert result.text == "Hello, world!" + + def test_convert_file_data_part(self): + """Test conversion of GenAI file_data Part to A2A Part.""" + # Arrange + genai_part = genai_types.Part( + file_data=genai_types.FileData( + file_uri="gs://bucket/file.txt", mime_type="text/plain" + ) + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.FilePart) + assert isinstance(result.file, a2a_types.FileWithUri) + assert result.file.uri == "gs://bucket/file.txt" + assert result.file.mimeType == "text/plain" + + def test_convert_inline_data_part(self): + """Test conversion of GenAI inline_data Part to A2A Part.""" + # Arrange + test_bytes = b"test file content" + genai_part = genai_types.Part( + inline_data=genai_types.Blob(data=test_bytes, mime_type="text/plain") + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.FilePart) + assert isinstance(result.root.file, a2a_types.FileWithBytes) + # A2A FileWithBytes stores bytes as strings + assert result.root.file.bytes == test_bytes.decode("utf-8") + assert result.root.file.mimeType == "text/plain" + + def test_convert_function_call_part(self): + """Test conversion of GenAI function_call Part to A2A Part.""" + # Arrange + function_call = genai_types.FunctionCall( + name="test_function", args={"param1": "value1", "param2": 42} + ) + genai_part = genai_types.Part(function_call=function_call) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = function_call.model_dump(by_alias=True, exclude_none=True) + assert result.root.data == expected_data + assert ( + result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + + def test_convert_function_response_part(self): + """Test conversion of GenAI function_response Part to A2A Part.""" + # Arrange + function_response = genai_types.FunctionResponse( + name="test_function", response={"result": "success", "data": [1, 2, 3]} + ) + genai_part = genai_types.Part(function_response=function_response) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + expected_data = function_response.model_dump( + by_alias=True, exclude_none=True + ) + assert result.root.data == expected_data + assert ( + result.root.metadata[A2A_DATA_PART_METADATA_TYPE_KEY] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ) + + def test_convert_unsupported_part(self): + """Test handling of unsupported GenAI Part types.""" + # Arrange - Create a GenAI Part with no recognized fields + genai_part = genai_types.Part() + + # Act + with patch( + "google.adk.a2a.converters.part_converter.logger" + ) as mock_logger: + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is None + mock_logger.warning.assert_called_once() + + +class TestRoundTripConversions: + """Test cases for round-trip conversions to ensure consistency.""" + + def test_text_part_round_trip(self): + """Test round-trip conversion for text parts.""" + # Arrange + original_text = "Hello, world!" + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text=original_text)) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.TextPart) + assert result_a2a_part.text == original_text + + def test_file_uri_round_trip(self): + """Test round-trip conversion for file parts with URI.""" + # Arrange + original_uri = "gs://bucket/file.txt" + original_mime_type = "text/plain" + a2a_part = a2a_types.Part( + root=a2a_types.FilePart( + file=a2a_types.FileWithUri( + uri=original_uri, mimeType=original_mime_type + ) + ) + ) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.FilePart) + assert isinstance(result_a2a_part.file, a2a_types.FileWithUri) + assert result_a2a_part.file.uri == original_uri + assert result_a2a_part.file.mimeType == original_mime_type + + +class TestEdgeCases: + """Test cases for edge cases and error conditions.""" + + def test_empty_text_part(self): + """Test conversion of empty text part.""" + # Arrange + a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="")) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == "" + + def test_none_input_a2a_to_genai(self): + """Test handling of None input for A2A to GenAI conversion.""" + # This test depends on how the function handles None input + # If it should raise an exception, we test for that + with pytest.raises(AttributeError): + convert_a2a_part_to_genai_part(None) + + def test_none_input_genai_to_a2a(self): + """Test handling of None input for GenAI to A2A conversion.""" + # This test depends on how the function handles None input + # If it should raise an exception, we test for that + with pytest.raises(AttributeError): + convert_genai_part_to_a2a_part(None) + + def test_data_part_with_complex_data(self): + """Test conversion of DataPart with complex nested data.""" + # Arrange + complex_data = { + "nested": { + "array": [1, 2, {"inner": "value"}], + "boolean": True, + "null_value": None, + }, + "unicode": "Hello 世界 🌍", + } + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=complex_data)) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == json.dumps(complex_data) + + def test_data_part_with_empty_metadata(self): + """Test conversion of DataPart with empty metadata dict.""" + # Arrange + data = {"key": "value"} + a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data, metadata={})) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.text == json.dumps(data) diff --git a/tests/unittests/auth/credential_service/test_in_memory_credential_service.py b/tests/unittests/auth/credential_service/test_in_memory_credential_service.py new file mode 100644 index 000000000..9312f72a3 --- /dev/null +++ b/tests/unittests/auth/credential_service/test_in_memory_credential_service.py @@ -0,0 +1,323 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from google.adk.tools.tool_context import ToolContext +import pytest + + +class TestInMemoryCredentialService: + """Tests for the InMemoryCredentialService class.""" + + @pytest.fixture + def credential_service(self): + """Create an InMemoryCredentialService instance for testing.""" + return InMemoryCredentialService() + + @pytest.fixture + def oauth2_auth_scheme(self): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + @pytest.fixture + def oauth2_credentials(self): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + @pytest.fixture + def auth_config(self, oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + exchanged_credential = oauth2_credentials.model_copy(deep=True) + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged_credential, + ) + + @pytest.fixture + def tool_context(self): + """Create a mock ToolContext for testing.""" + mock_context = Mock(spec=ToolContext) + mock_invocation_context = Mock() + mock_invocation_context.app_name = "test_app" + mock_invocation_context.user_id = "test_user" + mock_context._invocation_context = mock_invocation_context + return mock_context + + @pytest.fixture + def another_tool_context(self): + """Create another mock ToolContext with different app/user for testing isolation.""" + mock_context = Mock(spec=ToolContext) + mock_invocation_context = Mock() + mock_invocation_context.app_name = "another_app" + mock_invocation_context.user_id = "another_user" + mock_context._invocation_context = mock_invocation_context + return mock_context + + def test_init(self, credential_service): + """Test that the service initializes with an empty store.""" + assert isinstance(credential_service._credentials, dict) + assert len(credential_service._credentials) == 0 + + @pytest.mark.asyncio + async def test_load_credential_not_found( + self, credential_service, auth_config, tool_context + ): + """Test loading a credential that doesn't exist returns None.""" + result = await credential_service.load_credential(auth_config, tool_context) + assert result is None + + @pytest.mark.asyncio + async def test_save_and_load_credential( + self, credential_service, auth_config, tool_context + ): + """Test saving and then loading a credential.""" + # Save the credential + await credential_service.save_credential(auth_config, tool_context) + + # Load the credential + result = await credential_service.load_credential(auth_config, tool_context) + + # Verify the credential was saved and loaded correctly + assert result is not None + assert result == auth_config.exchanged_auth_credential + assert result.auth_type == AuthCredentialTypes.OAUTH2 + assert result.oauth2.client_id == "mock_client_id" + + @pytest.mark.asyncio + async def test_save_credential_updates_existing( + self, credential_service, auth_config, tool_context, oauth2_credentials + ): + """Test that saving a credential updates an existing one.""" + # Save initial credential + await credential_service.save_credential(auth_config, tool_context) + + # Create a new credential and update the auth_config + new_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="updated_client_id", + client_secret="updated_client_secret", + redirect_uri="https://updated.com/callback", + ), + ) + auth_config.exchanged_auth_credential = new_credential + + # Save the updated credential + await credential_service.save_credential(auth_config, tool_context) + + # Load and verify the credential was updated + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + assert result.oauth2.client_id == "updated_client_id" + assert result.oauth2.client_secret == "updated_client_secret" + + @pytest.mark.asyncio + async def test_credentials_isolated_by_context( + self, credential_service, auth_config, tool_context, another_tool_context + ): + """Test that credentials are isolated between different app/user contexts.""" + # Save credential in first context + await credential_service.save_credential(auth_config, tool_context) + + # Try to load from another context + result = await credential_service.load_credential( + auth_config, another_tool_context + ) + assert result is None + + # Verify original context still has the credential + result = await credential_service.load_credential(auth_config, tool_context) + assert result is not None + + @pytest.mark.asyncio + async def test_multiple_credentials_same_context( + self, credential_service, tool_context, oauth2_auth_scheme + ): + """Test storing multiple credentials in the same context with different keys.""" + # Create two different auth configs with different credential keys + cred1 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client1", + client_secret="secret1", + redirect_uri="https://example1.com/callback", + ), + ) + + cred2 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client2", + client_secret="secret2", + redirect_uri="https://example2.com/callback", + ), + ) + + auth_config1 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred1, + exchanged_auth_credential=cred1, + credential_key="key1", + ) + + auth_config2 = AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=cred2, + exchanged_auth_credential=cred2, + credential_key="key2", + ) + + # Save both credentials + await credential_service.save_credential(auth_config1, tool_context) + await credential_service.save_credential(auth_config2, tool_context) + + # Load and verify both credentials + result1 = await credential_service.load_credential( + auth_config1, tool_context + ) + result2 = await credential_service.load_credential( + auth_config2, tool_context + ) + + assert result1 is not None + assert result2 is not None + assert result1.oauth2.client_id == "client1" + assert result2.oauth2.client_id == "client2" + + def test_get_bucket_for_current_context_creates_nested_structure( + self, credential_service, tool_context + ): + """Test that _get_bucket_for_current_context creates the proper nested structure.""" + storage = credential_service._get_bucket_for_current_context(tool_context) + + # Verify the nested structure was created + assert "test_app" in credential_service._credentials + assert "test_user" in credential_service._credentials["test_app"] + assert isinstance(storage, dict) + assert storage is credential_service._credentials["test_app"]["test_user"] + + def test_get_bucket_for_current_context_reuses_existing( + self, credential_service, tool_context + ): + """Test that _get_bucket_for_current_context reuses existing structure.""" + # Create initial structure + storage1 = credential_service._get_bucket_for_current_context(tool_context) + storage1["test_key"] = "test_value" + + # Get storage again + storage2 = credential_service._get_bucket_for_current_context(tool_context) + + # Verify it's the same storage instance + assert storage1 is storage2 + assert storage2["test_key"] == "test_value" + + def test_get_storage_different_apps( + self, credential_service, tool_context, another_tool_context + ): + """Test that different apps get different storage instances.""" + storage1 = credential_service._get_bucket_for_current_context(tool_context) + storage2 = credential_service._get_bucket_for_current_context( + another_tool_context + ) + + # Verify they are different storage instances + assert storage1 is not storage2 + + # Verify the structure + assert "test_app" in credential_service._credentials + assert "another_app" in credential_service._credentials + assert "test_user" in credential_service._credentials["test_app"] + assert "another_user" in credential_service._credentials["another_app"] + + @pytest.mark.asyncio + async def test_same_user_different_apps( + self, credential_service, auth_config + ): + """Test that the same user in different apps get isolated storage.""" + # Create two contexts with same user but different apps + context1 = Mock(spec=ToolContext) + mock_invocation_context1 = Mock() + mock_invocation_context1.app_name = "app1" + mock_invocation_context1.user_id = "same_user" + context1._invocation_context = mock_invocation_context1 + + context2 = Mock(spec=ToolContext) + mock_invocation_context2 = Mock() + mock_invocation_context2.app_name = "app2" + mock_invocation_context2.user_id = "same_user" + context2._invocation_context = mock_invocation_context2 + + # Save credential in app1 + await credential_service.save_credential(auth_config, context1) + + # Try to load from app2 (should not find it) + result = await credential_service.load_credential(auth_config, context2) + assert result is None + + # Verify app1 still has the credential + result = await credential_service.load_credential(auth_config, context1) + assert result is not None + + @pytest.mark.asyncio + async def test_same_app_different_users( + self, credential_service, auth_config + ): + """Test that different users in the same app get isolated storage.""" + # Create two contexts with same app but different users + context1 = Mock(spec=ToolContext) + mock_invocation_context1 = Mock() + mock_invocation_context1.app_name = "same_app" + mock_invocation_context1.user_id = "user1" + context1._invocation_context = mock_invocation_context1 + + context2 = Mock(spec=ToolContext) + mock_invocation_context2 = Mock() + mock_invocation_context2.app_name = "same_app" + mock_invocation_context2.user_id = "user2" + context2._invocation_context = mock_invocation_context2 + + # Save credential for user1 + await credential_service.save_credential(auth_config, context1) + + # Try to load for user2 (should not find it) + result = await credential_service.load_credential(auth_config, context2) + assert result is None + + # Verify user1 still has the credential + result = await credential_service.load_credential(auth_config, context1) + assert result is not None diff --git a/tests/unittests/auth/exchanger/__init__.py b/tests/unittests/auth/exchanger/__init__.py new file mode 100644 index 000000000..5fb8a262b --- /dev/null +++ b/tests/unittests/auth/exchanger/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for credential exchanger.""" diff --git a/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py b/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py new file mode 100644 index 000000000..66b858232 --- /dev/null +++ b/tests/unittests/auth/exchanger/test_credential_exchanger_registry.py @@ -0,0 +1,242 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the CredentialExchangerRegistry.""" + +from typing import Optional +from unittest.mock import MagicMock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.exchanger.base_credential_exchanger import BaseCredentialExchanger +from google.adk.auth.exchanger.credential_exchanger_registry import CredentialExchangerRegistry +import pytest + + +class MockCredentialExchanger(BaseCredentialExchanger): + """Mock credential exchanger for testing.""" + + def __init__(self, exchange_result: Optional[AuthCredential] = None): + self.exchange_result = exchange_result or AuthCredential( + auth_type=AuthCredentialTypes.HTTP + ) + + def exchange( + self, + auth_credential: AuthCredential, + auth_scheme: Optional[AuthScheme] = None, + ) -> AuthCredential: + """Mock exchange method.""" + return self.exchange_result + + +class TestCredentialExchangerRegistry: + """Test cases for CredentialExchangerRegistry.""" + + def test_initialization(self): + """Test that the registry initializes with an empty exchangers dictionary.""" + registry = CredentialExchangerRegistry() + + # Access the private attribute for testing + assert hasattr(registry, '_exchangers') + assert isinstance(registry._exchangers, dict) + assert len(registry._exchangers) == 0 + + def test_register_single_exchanger(self): + """Test registering a single exchanger.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + # Verify the exchanger was registered + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert retrieved_exchanger is mock_exchanger + + def test_register_multiple_exchangers(self): + """Test registering multiple exchangers for different credential types.""" + registry = CredentialExchangerRegistry() + + api_key_exchanger = MockCredentialExchanger() + oauth2_exchanger = MockCredentialExchanger() + service_account_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.API_KEY, api_key_exchanger) + registry.register(AuthCredentialTypes.OAUTH2, oauth2_exchanger) + registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, service_account_exchanger + ) + + # Verify all exchangers were registered correctly + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is api_key_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.OAUTH2) is oauth2_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.SERVICE_ACCOUNT) + is service_account_exchanger + ) + + def test_register_overwrites_existing_exchanger(self): + """Test that registering an exchanger for an existing type overwrites the previous one.""" + registry = CredentialExchangerRegistry() + + first_exchanger = MockCredentialExchanger() + second_exchanger = MockCredentialExchanger() + + # Register first exchanger + registry.register(AuthCredentialTypes.API_KEY, first_exchanger) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is first_exchanger + ) + + # Register second exchanger for the same type + registry.register(AuthCredentialTypes.API_KEY, second_exchanger) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) is second_exchanger + ) + assert ( + registry.get_exchanger(AuthCredentialTypes.API_KEY) + is not first_exchanger + ) + + def test_get_exchanger_returns_correct_instance(self): + """Test that get_exchanger returns the correct exchanger instance.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.HTTP, mock_exchanger) + + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.HTTP) + assert retrieved_exchanger is mock_exchanger + assert isinstance(retrieved_exchanger, BaseCredentialExchanger) + + def test_get_exchanger_nonexistent_type_returns_none(self): + """Test that get_exchanger returns None for non-existent credential types.""" + registry = CredentialExchangerRegistry() + + # Try to get an exchanger that was never registered + result = registry.get_exchanger(AuthCredentialTypes.OAUTH2) + assert result is None + + def test_get_exchanger_after_registration_and_removal(self): + """Test behavior when an exchanger is registered and then the registry is cleared indirectly.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + # Register exchanger + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + assert registry.get_exchanger(AuthCredentialTypes.API_KEY) is mock_exchanger + + # Clear the internal dictionary (simulating some edge case) + registry._exchangers.clear() + assert registry.get_exchanger(AuthCredentialTypes.API_KEY) is None + + def test_register_with_all_credential_types(self): + """Test registering exchangers for all available credential types.""" + registry = CredentialExchangerRegistry() + + exchangers = {} + credential_types = [ + AuthCredentialTypes.API_KEY, + AuthCredentialTypes.HTTP, + AuthCredentialTypes.OAUTH2, + AuthCredentialTypes.OPEN_ID_CONNECT, + AuthCredentialTypes.SERVICE_ACCOUNT, + ] + + # Register an exchanger for each credential type + for cred_type in credential_types: + exchanger = MockCredentialExchanger() + exchangers[cred_type] = exchanger + registry.register(cred_type, exchanger) + + # Verify all exchangers can be retrieved + for cred_type in credential_types: + retrieved_exchanger = registry.get_exchanger(cred_type) + assert retrieved_exchanger is exchangers[cred_type] + + def test_register_with_mock_exchanger_using_magicmock(self): + """Test registering with a MagicMock exchanger.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MagicMock(spec=BaseCredentialExchanger) + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert retrieved_exchanger is mock_exchanger + + def test_registry_isolation(self): + """Test that different registry instances are isolated from each other.""" + registry1 = CredentialExchangerRegistry() + registry2 = CredentialExchangerRegistry() + + exchanger1 = MockCredentialExchanger() + exchanger2 = MockCredentialExchanger() + + # Register different exchangers in different registry instances + registry1.register(AuthCredentialTypes.API_KEY, exchanger1) + registry2.register(AuthCredentialTypes.API_KEY, exchanger2) + + # Verify isolation + assert registry1.get_exchanger(AuthCredentialTypes.API_KEY) is exchanger1 + assert registry2.get_exchanger(AuthCredentialTypes.API_KEY) is exchanger2 + assert ( + registry1.get_exchanger(AuthCredentialTypes.API_KEY) is not exchanger2 + ) + assert ( + registry2.get_exchanger(AuthCredentialTypes.API_KEY) is not exchanger1 + ) + + def test_exchanger_functionality_through_registry(self): + """Test that exchangers registered in the registry function correctly.""" + registry = CredentialExchangerRegistry() + + # Create a mock exchanger with specific return value + expected_result = AuthCredential(auth_type=AuthCredentialTypes.HTTP) + mock_exchanger = MockCredentialExchanger(exchange_result=expected_result) + + registry.register(AuthCredentialTypes.API_KEY, mock_exchanger) + + # Get the exchanger and test its functionality + retrieved_exchanger = registry.get_exchanger(AuthCredentialTypes.API_KEY) + input_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY) + + result = retrieved_exchanger.exchange(input_credential) + assert result is expected_result + + def test_register_none_exchanger(self): + """Test that registering None as an exchanger works (edge case).""" + registry = CredentialExchangerRegistry() + + # This should work but return None when retrieved + registry.register(AuthCredentialTypes.API_KEY, None) + + result = registry.get_exchanger(AuthCredentialTypes.API_KEY) + assert result is None + + def test_internal_dictionary_structure(self): + """Test the internal structure of the registry.""" + registry = CredentialExchangerRegistry() + mock_exchanger = MockCredentialExchanger() + + registry.register(AuthCredentialTypes.OAUTH2, mock_exchanger) + + # Verify internal dictionary structure + assert AuthCredentialTypes.OAUTH2 in registry._exchangers + assert registry._exchangers[AuthCredentialTypes.OAUTH2] is mock_exchanger + assert len(registry._exchangers) == 1 diff --git a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py new file mode 100644 index 000000000..ef1dbbbee --- /dev/null +++ b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py @@ -0,0 +1,220 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock +from unittest.mock import patch + +from authlib.oauth2.rfc6749 import OAuth2Token +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.exchanger.base_credential_exchanger import CredentialExchangError +from google.adk.auth.exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger +import pytest + + +class TestOAuth2CredentialExchanger: + """Test suite for OAuth2CredentialExchanger.""" + + @pytest.mark.asyncio + async def test_exchange_with_existing_token(self): + """Test exchange method when access token already exists.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return the same credential since access token already exists + assert result == credential + assert result.oauth2.access_token == "existing_token" + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_success(self, mock_oauth2_session): + """Test successful token exchange.""" + # Setup mock + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Verify token exchange was successful + assert result.oauth2.access_token == "new_access_token" + assert result.oauth2.refresh_token == "new_refresh_token" + mock_client.fetch_token.assert_called_once() + + @pytest.mark.asyncio + async def test_exchange_missing_auth_scheme(self): + """Test exchange with missing auth_scheme raises ValueError.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + exchanger = OAuth2CredentialExchanger() + try: + await exchanger.exchange(credential, None) + assert False, "Should have raised ValueError" + except CredentialExchangError as e: + assert "auth_scheme is required" in str(e) + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_no_session(self, mock_oauth2_session): + """Test exchange when OAuth2Session cannot be created.""" + # Mock to return None for create_oauth2_session + mock_oauth2_session.return_value = None + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret to trigger session creation failure + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when session creation fails + assert result == credential + assert result.oauth2.access_token is None + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_exchange_fetch_token_failure(self, mock_oauth2_session): + """Test exchange when fetch_token fails.""" + # Setup mock to raise exception during fetch_token + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_client.fetch_token.side_effect = Exception("Token fetch failed") + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when fetch_token fails + assert result == credential + assert result.oauth2.access_token is None + mock_client.fetch_token.assert_called_once() + + @pytest.mark.asyncio + async def test_exchange_authlib_not_available(self): + """Test exchange when authlib is not available.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + ), + ) + + exchanger = OAuth2CredentialExchanger() + + # Mock AUTHLIB_AVIALABLE to False + with patch( + "google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVIALABLE", + False, + ): + result = await exchanger.exchange(credential, scheme) + + # Should return original credential when authlib is not available + assert result == credential + assert result.oauth2.access_token is None diff --git a/tests/unittests/auth/refresher/__init__.py b/tests/unittests/auth/refresher/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/auth/refresher/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/auth/refresher/test_credential_refresher_registry.py b/tests/unittests/auth/refresher/test_credential_refresher_registry.py new file mode 100644 index 000000000..b00cc4da8 --- /dev/null +++ b/tests/unittests/auth/refresher/test_credential_refresher_registry.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CredentialRefresherRegistry.""" + +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.refresher.base_credential_refresher import BaseCredentialRefresher +from google.adk.auth.refresher.credential_refresher_registry import CredentialRefresherRegistry + + +class TestCredentialRefresherRegistry: + """Tests for the CredentialRefresherRegistry class.""" + + def test_init(self): + """Test that registry initializes with empty refreshers dictionary.""" + registry = CredentialRefresherRegistry() + assert registry._refreshers == {} + + def test_register_refresher(self): + """Test registering a refresher instance for a credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher) + + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher + + def test_register_multiple_refreshers(self): + """Test registering multiple refresher instances for different credential types.""" + registry = CredentialRefresherRegistry() + mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher) + mock_openid_refresher = Mock(spec=BaseCredentialRefresher) + mock_service_account_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher) + registry.register( + AuthCredentialTypes.OPEN_ID_CONNECT, mock_openid_refresher + ) + registry.register( + AuthCredentialTypes.SERVICE_ACCOUNT, mock_service_account_refresher + ) + + assert ( + registry._refreshers[AuthCredentialTypes.OAUTH2] + == mock_oauth2_refresher + ) + assert ( + registry._refreshers[AuthCredentialTypes.OPEN_ID_CONNECT] + == mock_openid_refresher + ) + assert ( + registry._refreshers[AuthCredentialTypes.SERVICE_ACCOUNT] + == mock_service_account_refresher + ) + + def test_register_overwrite_existing_refresher(self): + """Test that registering a refresher overwrites an existing one for the same credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher_1 = Mock(spec=BaseCredentialRefresher) + mock_refresher_2 = Mock(spec=BaseCredentialRefresher) + + # Register first refresher + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_1) + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_1 + + # Register second refresher for same credential type + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_2) + assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_2 + + def test_get_refresher_existing(self): + """Test getting a refresher instance for a registered credential type.""" + registry = CredentialRefresherRegistry() + mock_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_refresher) + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result == mock_refresher + + def test_get_refresher_non_existing(self): + """Test getting a refresher instance for a non-registered credential type returns None.""" + registry = CredentialRefresherRegistry() + + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result is None + + def test_get_refresher_after_registration(self): + """Test getting refresher instances for multiple credential types.""" + registry = CredentialRefresherRegistry() + mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher) + mock_api_key_refresher = Mock(spec=BaseCredentialRefresher) + + registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher) + registry.register(AuthCredentialTypes.API_KEY, mock_api_key_refresher) + + # Get registered refreshers + oauth2_result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + api_key_result = registry.get_refresher(AuthCredentialTypes.API_KEY) + + assert oauth2_result == mock_oauth2_refresher + assert api_key_result == mock_api_key_refresher + + # Get non-registered refresher + http_result = registry.get_refresher(AuthCredentialTypes.HTTP) + assert http_result is None + + def test_register_all_credential_types(self): + """Test registering refreshers for all available credential types.""" + registry = CredentialRefresherRegistry() + + refreshers = {} + for credential_type in AuthCredentialTypes: + mock_refresher = Mock(spec=BaseCredentialRefresher) + refreshers[credential_type] = mock_refresher + registry.register(credential_type, mock_refresher) + + # Verify all refreshers are registered correctly + for credential_type in AuthCredentialTypes: + result = registry.get_refresher(credential_type) + assert result == refreshers[credential_type] + + def test_empty_registry_get_refresher(self): + """Test getting refresher from empty registry returns None for any credential type.""" + registry = CredentialRefresherRegistry() + + for credential_type in AuthCredentialTypes: + result = registry.get_refresher(credential_type) + assert result is None + + def test_registry_independence(self): + """Test that multiple registry instances are independent.""" + registry1 = CredentialRefresherRegistry() + registry2 = CredentialRefresherRegistry() + + mock_refresher1 = Mock(spec=BaseCredentialRefresher) + mock_refresher2 = Mock(spec=BaseCredentialRefresher) + + registry1.register(AuthCredentialTypes.OAUTH2, mock_refresher1) + registry2.register(AuthCredentialTypes.OAUTH2, mock_refresher2) + + # Verify registries are independent + assert ( + registry1.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher1 + ) + assert ( + registry2.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher2 + ) + assert registry1.get_refresher( + AuthCredentialTypes.OAUTH2 + ) != registry2.get_refresher(AuthCredentialTypes.OAUTH2) + + def test_register_with_none_refresher(self): + """Test registering None as a refresher instance.""" + registry = CredentialRefresherRegistry() + + # This should technically work as the registry accepts any value + registry.register(AuthCredentialTypes.OAUTH2, None) + result = registry.get_refresher(AuthCredentialTypes.OAUTH2) + + assert result is None diff --git a/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py new file mode 100644 index 000000000..3342fcb05 --- /dev/null +++ b/tests/unittests/auth/refresher/test_oauth2_credential_refresher.py @@ -0,0 +1,179 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock +from unittest.mock import patch + +from authlib.oauth2.rfc6749 import OAuth2Token +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.refresher.oauth2_credential_refresher import OAuth2CredentialRefresher +import pytest + + +class TestOAuth2CredentialRefresher: + """Test suite for OAuth2CredentialRefresher.""" + + @patch("google.adk.auth.refresher.oauth2_credential_refresher.OAuth2Token") + @pytest.mark.asyncio + async def test_needs_refresh_token_not_expired(self, mock_oauth2_token): + """Test needs_refresh when token is not expired.""" + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = False + mock_oauth2_token.return_value = mock_token_instance + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + expires_at=int(time.time()) + 3600, + ), + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, scheme) + + assert not needs_refresh + + @patch("google.adk.auth.refresher.oauth2_credential_refresher.OAuth2Token") + @pytest.mark.asyncio + async def test_needs_refresh_token_expired(self, mock_oauth2_token): + """Test needs_refresh when token is expired.""" + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="existing_token", + expires_at=int(time.time()) - 3600, # Expired + ), + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, scheme) + + assert needs_refresh + + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @patch("google.adk.auth.oauth2_credential_util.OAuth2Token") + @pytest.mark.asyncio + async def test_refresh_token_expired_success( + self, mock_oauth2_token, mock_oauth2_session + ): + """Test successful token refresh when token is expired.""" + # Setup mock token + mock_token_instance = Mock() + mock_token_instance.is_expired.return_value = True + mock_oauth2_token.return_value = mock_token_instance + + # Setup mock session + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "refreshed_access_token", + "refresh_token": "refreshed_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.refresh_token.return_value = mock_tokens + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + access_token="old_token", + refresh_token="old_refresh_token", + expires_at=int(time.time()) - 3600, # Expired + ), + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, scheme) + + # Verify token refresh was successful + assert result.oauth2.access_token == "refreshed_access_token" + assert result.oauth2.refresh_token == "refreshed_refresh_token" + mock_client.refresh_token.assert_called_once() + + @pytest.mark.asyncio + async def test_refresh_no_oauth2_credential(self): + """Test refresh with no OAuth2 credential returns original.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + # No oauth2 field + ) + + refresher = OAuth2CredentialRefresher() + result = await refresher.refresh(credential, scheme) + + assert result == credential + + @pytest.mark.asyncio + async def test_needs_refresh_no_oauth2_credential(self): + """Test needs_refresh with no OAuth2 credential returns False.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + # No oauth2 field + ) + + refresher = OAuth2CredentialRefresher() + needs_refresh = await refresher.is_refresh_needed(credential, None) + + assert not needs_refresh diff --git a/tests/unittests/auth/test_auth_config.py b/tests/unittests/auth/test_auth_config.py index 7a20cfc63..a398ef321 100644 --- a/tests/unittests/auth/test_auth_config.py +++ b/tests/unittests/auth/test_auth_config.py @@ -68,10 +68,28 @@ def auth_config(oauth2_auth_scheme, oauth2_credentials): ) -def test_get_credential_key(auth_config): +@pytest.fixture +def auth_config_with_key(oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + credential_key="test_key", + ) + + +def test_custom_credential_key(auth_config_with_key): + """Test using custom credential key.""" + + key = auth_config_with_key.credential_key + assert key == "test_key" + + +def test_credential_key(auth_config): """Test generating a unique credential key.""" - key = auth_config.get_credential_key() + key = auth_config.credential_key assert key.startswith("adk_oauth2_") assert "_oauth2_" in key @@ -80,8 +98,8 @@ def test_get_credential_key_with_extras(auth_config): """Test generating a key when model_extra exists.""" # Add model_extra to test cleanup - original_key = auth_config.get_credential_key() - key = auth_config.get_credential_key() + original_key = auth_config.credential_key + key = auth_config.credential_key auth_config.auth_scheme.model_extra["extra_field"] = "value" auth_config.raw_auth_credential.model_extra["extra_field"] = "value" diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index c717682ac..f0d730d02 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -13,8 +13,11 @@ # limitations under the License. import copy +import time +from unittest.mock import Mock from unittest.mock import patch +from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import APIKey from fastapi.openapi.models import APIKeyIn from fastapi.openapi.models import OAuth2 @@ -387,7 +390,7 @@ def test_get_auth_response_exists( state = MockState() # Store a credential in the state - credential_key = auth_config.get_credential_key() + credential_key = auth_config.credential_key state["temp:" + credential_key] = oauth2_credentials_with_auth_uri result = handler.get_auth_response(state) @@ -405,7 +408,8 @@ def test_get_auth_response_not_exists(self, auth_config): class TestParseAndStoreAuthResponse: """Tests for the parse_and_store_auth_response method.""" - def test_non_oauth_scheme(self, auth_config_with_exchanged): + @pytest.mark.asyncio + async def test_non_oauth_scheme(self, auth_config_with_exchanged): """Test with a non-OAuth auth scheme.""" # Modify the auth scheme type to be non-OAuth auth_config = copy.deepcopy(auth_config_with_exchanged) @@ -416,15 +420,18 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged): handler = AuthHandler(auth_config) state = MockState() - handler.parse_and_store_auth_response(state) + await handler.parse_and_store_auth_response(state) - credential_key = auth_config.get_credential_key() + credential_key = auth_config.credential_key assert ( state["temp:" + credential_key] == auth_config.exchanged_auth_credential ) @patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token") - def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): + @pytest.mark.asyncio + async def test_oauth_scheme( + self, mock_exchange_token, auth_config_with_exchanged + ): """Test with an OAuth auth scheme.""" mock_exchange_token.return_value = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, @@ -434,9 +441,9 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): handler = AuthHandler(auth_config_with_exchanged) state = MockState() - handler.parse_and_store_auth_response(state) + await handler.parse_and_store_auth_response(state) - credential_key = auth_config_with_exchanged.get_credential_key() + credential_key = auth_config_with_exchanged.credential_key assert state["temp:" + credential_key] == mock_exchange_token.return_value assert mock_exchange_token.called @@ -444,20 +451,20 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): class TestExchangeAuthToken: """Tests for the exchange_auth_token method.""" - def test_token_exchange_not_supported( + @pytest.mark.asyncio + async def test_token_exchange_not_supported( self, auth_config_with_auth_code, monkeypatch ): """Test when token exchange is not supported.""" - monkeypatch.setattr( - "google.adk.auth.auth_handler.SUPPORT_TOKEN_EXCHANGE", False - ) + monkeypatch.setattr("google.adk.auth.auth_handler.AUTHLIB_AVIALABLE", False) handler = AuthHandler(auth_config_with_auth_code) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == auth_config_with_auth_code.exchanged_auth_credential - def test_openid_missing_token_endpoint( + @pytest.mark.asyncio + async def test_openid_missing_token_endpoint( self, openid_auth_scheme, oauth2_credentials_with_auth_code ): """Test OpenID Connect without a token endpoint.""" @@ -472,11 +479,12 @@ def test_openid_missing_token_endpoint( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_auth_code - def test_oauth2_missing_token_url( + @pytest.mark.asyncio + async def test_oauth2_missing_token_url( self, oauth2_auth_scheme, oauth2_credentials_with_auth_code ): """Test OAuth2 without a token URL.""" @@ -491,11 +499,12 @@ def test_oauth2_missing_token_url( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_auth_code - def test_non_oauth_scheme(self, auth_config_with_auth_code): + @pytest.mark.asyncio + async def test_non_oauth_scheme(self, auth_config_with_auth_code): """Test with a non-OAuth auth scheme.""" # Modify the auth scheme type to be non-OAuth auth_config = copy.deepcopy(auth_config_with_auth_code) @@ -504,11 +513,12 @@ def test_non_oauth_scheme(self, auth_config_with_auth_code): ) handler = AuthHandler(auth_config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == auth_config.exchanged_auth_credential - def test_missing_credentials(self, oauth2_auth_scheme): + @pytest.mark.asyncio + async def test_missing_credentials(self, oauth2_auth_scheme): """Test with missing credentials.""" empty_credential = AuthCredential(auth_type=AuthCredentialTypes.OAUTH2) @@ -518,11 +528,12 @@ def test_missing_credentials(self, oauth2_auth_scheme): ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == empty_credential - def test_credentials_with_token( + @pytest.mark.asyncio + async def test_credentials_with_token( self, auth_config, oauth2_credentials_with_token ): """Test when credentials already have a token.""" @@ -533,15 +544,29 @@ def test_credentials_with_token( ) handler = AuthHandler(config) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result == oauth2_credentials_with_token - @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session) - def test_successful_token_exchange(self, auth_config_with_auth_code): + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + @pytest.mark.asyncio + async def test_successful_token_exchange( + self, mock_oauth2_session, auth_config_with_auth_code + ): """Test a successful token exchange.""" + # Setup mock OAuth2Session + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_tokens = OAuth2Token({ + "access_token": "mock_access_token", + "refresh_token": "mock_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + mock_client.fetch_token.return_value = mock_tokens + handler = AuthHandler(auth_config_with_auth_code) - result = handler.exchange_auth_token() + result = await handler.exchange_auth_token() assert result.oauth2.access_token == "mock_access_token" assert result.oauth2.refresh_token == "mock_refresh_token" diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py new file mode 100644 index 000000000..8e3638dd6 --- /dev/null +++ b/tests/unittests/auth/test_credential_manager.py @@ -0,0 +1,545 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from fastapi.openapi.models import HTTPBearer +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_credential import ServiceAccountCredential +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +import pytest + + +class TestCredentialManager: + """Test suite for CredentialManager.""" + + def test_init(self): + """Test CredentialManager initialization.""" + auth_config = Mock(spec=AuthConfig) + manager = CredentialManager(auth_config) + assert manager._auth_config == auth_config + + @pytest.mark.asyncio + async def test_request_credential(self): + """Test request_credential method.""" + auth_config = Mock(spec=AuthConfig) + tool_context = Mock() + tool_context.request_credential = Mock() + + manager = CredentialManager(auth_config) + await manager.request_credential(tool_context) + + tool_context.request_credential.assert_called_once_with(auth_config) + + @pytest.mark.asyncio + async def test_load_auth_credentials_success(self): + """Test load_auth_credential with successful flow.""" + # Create mocks + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.exchanged_auth_credential = None + + # Mock the credential that will be returned + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.API_KEY + + tool_context = Mock() + + manager = CredentialManager(auth_config) + + # Mock the private methods + manager._validate_credential = AsyncMock() + manager._is_credential_ready = Mock(return_value=False) + manager._load_existing_credential = AsyncMock(return_value=None) + manager._load_from_auth_response = AsyncMock(return_value=mock_credential) + manager._exchange_credential = AsyncMock( + return_value=(mock_credential, False) + ) + manager._refresh_credential = AsyncMock( + return_value=(mock_credential, False) + ) + manager._save_credential = AsyncMock() + + result = await manager.get_auth_credential(tool_context) + + # Verify all methods were called + manager._validate_credential.assert_called_once() + manager._is_credential_ready.assert_called_once() + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) + manager._exchange_credential.assert_called_once_with(mock_credential) + manager._refresh_credential.assert_called_once_with(mock_credential) + manager._save_credential.assert_called_once_with( + tool_context, mock_credential + ) + + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_auth_credentials_no_credential(self): + """Test load_auth_credential when no credential is available.""" + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.exchanged_auth_credential = None + + tool_context = Mock() + + manager = CredentialManager(auth_config) + + # Mock the private methods + manager._validate_credential = AsyncMock() + manager._is_credential_ready = Mock(return_value=False) + manager._load_existing_credential = AsyncMock(return_value=None) + manager._load_from_auth_response = AsyncMock(return_value=None) + manager._exchange_credential = AsyncMock() + manager._refresh_credential = AsyncMock() + manager._save_credential = AsyncMock() + + result = await manager.get_auth_credential(tool_context) + + # Verify methods were called but no credential returned + manager._validate_credential.assert_called_once() + manager._is_credential_ready.assert_called_once() + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) + manager._exchange_credential.assert_not_called() + manager._refresh_credential.assert_not_called() + manager._save_credential.assert_not_called() + + assert result is None + + @pytest.mark.asyncio + async def test_load_existing_credential_already_exchanged(self): + """Test _load_existing_credential when credential is already exchanged.""" + auth_config = Mock(spec=AuthConfig) + mock_credential = Mock(spec=AuthCredential) + auth_config.exchanged_auth_credential = mock_credential + + tool_context = Mock() + + manager = CredentialManager(auth_config) + manager._load_from_credential_service = AsyncMock(return_value=None) + + result = await manager._load_existing_credential(tool_context) + + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_existing_credential_with_credential_service(self): + """Test _load_existing_credential with credential service.""" + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = None + + mock_credential = Mock(spec=AuthCredential) + + tool_context = Mock() + + manager = CredentialManager(auth_config) + manager._load_from_credential_service = AsyncMock( + return_value=mock_credential + ) + + result = await manager._load_existing_credential(tool_context) + + manager._load_from_credential_service.assert_called_once_with(tool_context) + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_from_credential_service_with_service(self): + """Test _load_from_credential_service from tool context when credential service is available.""" + auth_config = Mock(spec=AuthConfig) + + mock_credential = Mock(spec=AuthCredential) + + # Mock credential service + credential_service = Mock() + credential_service.load_credential = AsyncMock(return_value=mock_credential) + + # Mock invocation context + invocation_context = Mock() + invocation_context.credential_service = credential_service + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + result = await manager._load_from_credential_service(tool_context) + + credential_service.load_credential.assert_called_once_with( + auth_config, tool_context + ) + assert result == mock_credential + + @pytest.mark.asyncio + async def test_load_from_credential_service_no_service(self): + """Test _load_from_credential_service when no credential service is available.""" + auth_config = Mock(spec=AuthConfig) + + # Mock invocation context with no credential service + invocation_context = Mock() + invocation_context.credential_service = None + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + result = await manager._load_from_credential_service(tool_context) + + assert result is None + + @pytest.mark.asyncio + async def test_save_credential_with_service(self): + """Test _save_credential with credential service.""" + auth_config = Mock(spec=AuthConfig) + mock_credential = Mock(spec=AuthCredential) + + # Mock credential service + credential_service = AsyncMock() + + # Mock invocation context + invocation_context = Mock() + invocation_context.credential_service = credential_service + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + await manager._save_credential(tool_context, mock_credential) + + credential_service.save_credential.assert_called_once_with( + auth_config, tool_context + ) + assert auth_config.exchanged_auth_credential == mock_credential + + @pytest.mark.asyncio + async def test_save_credential_no_service(self): + """Test _save_credential when no credential service is available.""" + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = None + mock_credential = Mock(spec=AuthCredential) + + # Mock invocation context with no credential service + invocation_context = Mock() + invocation_context.credential_service = None + + tool_context = Mock() + tool_context._invocation_context = invocation_context + + manager = CredentialManager(auth_config) + await manager._save_credential(tool_context, mock_credential) + + # Should not raise an error, and credential should not be set in auth_config + # when there's no credential service (according to implementation) + assert auth_config.exchanged_auth_credential is None + + @pytest.mark.asyncio + async def test_refresh_credential_oauth2(self): + """Test _refresh_credential with OAuth2 credential.""" + mock_oauth2_auth = Mock(spec=OAuth2Auth) + + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.OAUTH2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = Mock() + + # Mock refresher + mock_refresher = Mock() + mock_refresher.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher.refresh = AsyncMock(return_value=mock_credential) + + auth_config.raw_auth_credential = mock_credential + + manager = CredentialManager(auth_config) + + # Mock the refresher registry to return our mock refresher + with patch.object( + manager._refresher_registry, + "get_refresher", + return_value=mock_refresher, + ): + result, was_refreshed = await manager._refresh_credential(mock_credential) + + mock_refresher.is_refresh_needed.assert_called_once_with( + mock_credential, auth_config.auth_scheme + ) + mock_refresher.refresh.assert_called_once_with( + mock_credential, auth_config.auth_scheme + ) + assert result == mock_credential + assert was_refreshed is True + + @pytest.mark.asyncio + async def test_refresh_credential_no_refresher(self): + """Test _refresh_credential with credential that has no refresher.""" + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + + manager = CredentialManager(auth_config) + + # Mock the refresher registry to return None (no refresher available) + with patch.object( + manager._refresher_registry, + "get_refresher", + return_value=None, + ): + result, was_refreshed = await manager._refresh_credential(mock_credential) + + assert result == mock_credential + assert was_refreshed is False + + @pytest.mark.asyncio + async def test_is_credential_ready_api_key(self): + """Test _is_credential_ready with API key credential.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + + manager = CredentialManager(auth_config) + result = manager._is_credential_ready() + + assert result is True + + @pytest.mark.asyncio + async def test_is_credential_ready_oauth2(self): + """Test _is_credential_ready with OAuth2 credential (needs processing).""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + + manager = CredentialManager(auth_config) + result = manager._is_credential_ready() + + assert result is False + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_oauth2(self): + """Test _validate_credential with no raw credential for OAuth2.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises(ValueError, match="raw_auth_credential is required"): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_openid(self): + """Test _validate_credential with no raw credential for OpenID Connect.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.openIdConnect + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises(ValueError, match="raw_auth_credential is required"): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_validate_credential_no_raw_credential_other_scheme(self): + """Test _validate_credential with no raw credential for other schemes.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.apiKey + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = None + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + await manager._validate_credential() + + # Should return without error for non-OAuth2/OpenID schemes + + @pytest.mark.asyncio + async def test_validate_credential_oauth2_missing_oauth2_field(self): + """Test _validate_credential with OAuth2 credential missing oauth2 field.""" + auth_scheme = Mock() + auth_scheme.type_ = AuthSchemeType.oauth2 + + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 + mock_raw_credential.oauth2 = None + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + auth_config.auth_scheme = auth_scheme + + manager = CredentialManager(auth_config) + + with pytest.raises( + ValueError, match="auth_config.raw_credential.oauth2 required" + ): + await manager._validate_credential() + + @pytest.mark.asyncio + async def test_exchange_credentials_service_account(self): + """Test _exchange_credential with service account credential (no exchanger available).""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.SERVICE_ACCOUNT + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = Mock() + + manager = CredentialManager(auth_config) + + # Mock the exchanger registry to return None (no exchanger available) + with patch.object( + manager._exchanger_registry, "get_exchanger", return_value=None + ): + result, was_exchanged = await manager._exchange_credential( + mock_raw_credential + ) + + assert result == mock_raw_credential + assert was_exchanged is False + + @pytest.mark.asyncio + async def test_exchange_credential_no_exchanger(self): + """Test _exchange_credential with credential that has no exchanger.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY + + auth_config = Mock(spec=AuthConfig) + + manager = CredentialManager(auth_config) + + # Mock the exchanger registry to return None (no exchanger available) + with patch.object( + manager._exchanger_registry, "get_exchanger", return_value=None + ): + result, was_exchanged = await manager._exchange_credential( + mock_raw_credential + ) + + assert result == mock_raw_credential + assert was_exchanged is False + + +# Test fixtures +@pytest.fixture +def oauth2_auth_scheme(): + """Create an OAuth2 auth scheme for testing.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + return OAuth2(flows=flows) + + +@pytest.fixture +def openid_auth_scheme(): + """Create an OpenID Connect auth scheme for testing.""" + return OpenIdConnectWithConfig( + type_="openIdConnect", + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + + +@pytest.fixture +def bearer_auth_scheme(): + """Create a Bearer auth scheme for testing.""" + return HTTPBearer(bearerFormat="JWT") + + +@pytest.fixture +def oauth2_credential(): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + +@pytest.fixture +def service_account_credential(): + """Create service account credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=ServiceAccountCredential( + type="service_account", + project_id="test-project", + private_key_id="key-id", + private_key=( + "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE" + " KEY-----\n" + ), + client_email="test@test-project.iam.gserviceaccount.com", + client_id="123456789", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test%40test-project.iam.gserviceaccount.com", + ), + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ), + ) + + +@pytest.fixture +def api_key_credential(): + """Create API key credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + api_key="test-api-key", + ) + + +@pytest.fixture +def http_bearer_credential(): + """Create HTTP Bearer credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="bearer-token"), + ), + ) diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py new file mode 100644 index 000000000..aba6a9923 --- /dev/null +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock + +from authlib.oauth2.rfc6749 import OAuth2Token +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens + + +class TestOAuth2CredentialUtil: + """Test suite for OAuth2 credential utility functions.""" + + def test_create_oauth2_session_openid_connect(self): + """Test create_oauth2_session with OpenID Connect scheme.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + + def test_create_oauth2_session_oauth2_scheme(self): + """Test create_oauth2_session with OAuth2 scheme.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + scheme = OAuth2(type_="oauth2", flows=flows) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + + def test_create_oauth2_session_invalid_scheme(self): + """Test create_oauth2_session with invalid scheme.""" + scheme = Mock() # Invalid scheme type + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_create_oauth2_session_missing_credentials(self): + """Test create_oauth2_session with missing credentials.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_update_credential_with_tokens(self): + """Test update_credential_with_tokens function.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + + update_credential_with_tokens(credential, tokens) + + assert credential.oauth2.access_token == "new_access_token" + assert credential.oauth2.refresh_token == "new_refresh_token" + assert credential.oauth2.expires_at == int(time.time()) + 3600 + assert credential.oauth2.expires_in == 3600 diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 1721885f3..2139a8c20 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -129,6 +129,7 @@ def _echo(msg: str) -> None: artifact_service = cli.InMemoryArtifactService() session_service = cli.InMemorySessionService() + credential_service = cli.InMemoryCredentialService() dummy_root = types.SimpleNamespace(name="root") session = await cli.run_input_file( @@ -137,6 +138,7 @@ def _echo(msg: str) -> None: root_agent=dummy_root, artifact_service=artifact_service, session_service=session_service, + credential_service=credential_service, input_path=str(input_path), ) @@ -199,9 +201,10 @@ async def test_run_interactively_whitespace_and_exit( ) -> None: """run_interactively should skip blank input, echo once, then exit.""" # make a session that belongs to dummy agent - svc = cli.InMemorySessionService() - sess = await svc.create_session(app_name="dummy", user_id="u") + session_service = cli.InMemorySessionService() + sess = await session_service.create_session(app_name="dummy", user_id="u") artifact_service = cli.InMemoryArtifactService() + credential_service = cli.InMemoryCredentialService() root_agent = types.SimpleNamespace(name="root") # fake user input: blank -> 'hello' -> 'exit' @@ -212,7 +215,9 @@ async def test_run_interactively_whitespace_and_exit( echoed: list[str] = [] monkeypatch.setattr(click, "echo", lambda msg: echoed.append(msg)) - await cli.run_interactively(root_agent, artifact_service, sess, svc) + await cli.run_interactively( + root_agent, artifact_service, sess, session_service, credential_service + ) # verify: assistant echoed once with 'echo:hello' assert any("echo:hello" in m for m in echoed) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index ad204005e..2b93226db 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -23,6 +23,7 @@ 'GOOGLE_API_KEY': 'fake_google_api_key', 'GOOGLE_CLOUD_PROJECT': 'fake_google_cloud_project', 'GOOGLE_CLOUD_LOCATION': 'fake_google_cloud_location', + 'ADK_ALLOW_WIP_FEATURES': 'true', } ENV_SETUPS = { diff --git a/tests/unittests/evaluation/mock_gcs_utils.py b/tests/unittests/evaluation/mock_gcs_utils.py new file mode 100644 index 000000000..d9ea008c3 --- /dev/null +++ b/tests/unittests/evaluation/mock_gcs_utils.py @@ -0,0 +1,117 @@ +from typing import Optional +from typing import Union + + +class MockBlob: + """Mocks a GCS Blob object. + + This class provides mock implementations for a few common GCS Blob methods, + allowing the user to test code that interacts with GCS without actually + connecting to a real bucket. + """ + + def __init__(self, name: str) -> None: + """Initializes a MockBlob. + + Args: + name: The name of the blob. + """ + self.name = name + self.content: Optional[bytes] = None + self.content_type: Optional[str] = None + self._exists: bool = False + + def upload_from_string( + self, data: Union[str, bytes], content_type: Optional[str] = None + ) -> None: + """Mocks uploading data to the blob (from a string or bytes). + + Args: + data: The data to upload (string or bytes). + content_type: The content type of the data (optional). + """ + if isinstance(data, str): + self.content = data.encode("utf-8") + elif isinstance(data, bytes): + self.content = data + else: + raise TypeError("data must be str or bytes") + + if content_type: + self.content_type = content_type + self._exists = True + + def download_as_text(self) -> str: + """Mocks downloading the blob's content as text. + + Returns: + str: The content of the blob as text. + + Raises: + Exception: If the blob doesn't exist (hasn't been uploaded to). + """ + if self.content is None: + return b"" + return self.content + + def delete(self) -> None: + """Mocks deleting a blob.""" + self.content = None + self.content_type = None + self._exists = False + + def exists(self) -> bool: + """Mocks checking if the blob exists.""" + return self._exists + + +class MockBucket: + """Mocks a GCS Bucket object.""" + + def __init__(self, name: str) -> None: + """Initializes a MockBucket. + + Args: + name: The name of the bucket. + """ + self.name = name + self.blobs: dict[str, MockBlob] = {} + + def blob(self, blob_name: str) -> MockBlob: + """Mocks getting a Blob object (doesn't create it in storage). + + Args: + blob_name: The name of the blob. + + Returns: + A MockBlob instance. + """ + if blob_name not in self.blobs: + self.blobs[blob_name] = MockBlob(blob_name) + return self.blobs[blob_name] + + def list_blobs(self, prefix: Optional[str] = None) -> list[MockBlob]: + """Mocks listing blobs in a bucket, optionally with a prefix.""" + if prefix: + return [ + blob for name, blob in self.blobs.items() if name.startswith(prefix) + ] + return list(self.blobs.values()) + + def exists(self) -> bool: + """Mocks checking if the bucket exists.""" + return True + + +class MockClient: + """Mocks the GCS Client.""" + + def __init__(self) -> None: + """Initializes MockClient.""" + self.buckets: dict[str, MockBucket] = {} + + def bucket(self, bucket_name: str) -> MockBucket: + """Mocks getting a Bucket object.""" + if bucket_name not in self.buckets: + self.buckets[bucket_name] = MockBucket(bucket_name) + return self.buckets[bucket_name] diff --git a/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py b/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py new file mode 100644 index 000000000..7fd0bb97e --- /dev/null +++ b/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py @@ -0,0 +1,191 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.errors.not_found_error import NotFoundError +from google.adk.evaluation._eval_set_results_manager_utils import _sanitize_eval_set_result_name +from google.adk.evaluation._eval_set_results_manager_utils import create_eval_set_result +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_metrics import EvalMetricResult +from google.adk.evaluation.eval_metrics import EvalMetricResultPerInvocation +from google.adk.evaluation.eval_result import EvalCaseResult +from google.adk.evaluation.evaluator import EvalStatus +from google.adk.evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager +from google.genai import types as genai_types +import pytest + +from .mock_gcs_utils import MockBucket +from .mock_gcs_utils import MockClient + + +def _get_test_eval_case_results(): + # Create mock Invocation objects + actual_invocation_1 = Invocation( + invocation_id="actual_1", + user_content=genai_types.Content( + parts=[genai_types.Part(text="input_1")] + ), + ) + expected_invocation_1 = Invocation( + invocation_id="expected_1", + user_content=genai_types.Content( + parts=[genai_types.Part(text="expected_input_1")] + ), + ) + actual_invocation_2 = Invocation( + invocation_id="actual_2", + user_content=genai_types.Content( + parts=[genai_types.Part(text="input_2")] + ), + ) + expected_invocation_2 = Invocation( + invocation_id="expected_2", + user_content=genai_types.Content( + parts=[genai_types.Part(text="expected_input_2")] + ), + ) + + eval_metric_result_1 = EvalMetricResult( + metric_name="metric", + threshold=0.8, + score=1.0, + eval_status=EvalStatus.PASSED, + ) + eval_metric_result_2 = EvalMetricResult( + metric_name="metric", + threshold=0.8, + score=0.5, + eval_status=EvalStatus.FAILED, + ) + eval_metric_result_per_invocation_1 = EvalMetricResultPerInvocation( + actual_invocation=actual_invocation_1, + expected_invocation=expected_invocation_1, + eval_metric_results=[eval_metric_result_1], + ) + eval_metric_result_per_invocation_2 = EvalMetricResultPerInvocation( + actual_invocation=actual_invocation_2, + expected_invocation=expected_invocation_2, + eval_metric_results=[eval_metric_result_2], + ) + return [ + EvalCaseResult( + eval_set_id="eval_set", + eval_id="eval_case_1", + final_eval_status=EvalStatus.PASSED, + overall_eval_metric_results=[eval_metric_result_1], + eval_metric_result_per_invocation=[ + eval_metric_result_per_invocation_1 + ], + session_id="session_1", + ), + EvalCaseResult( + eval_set_id="eval_set", + eval_id="eval_case_2", + final_eval_status=EvalStatus.FAILED, + overall_eval_metric_results=[eval_metric_result_2], + eval_metric_result_per_invocation=[ + eval_metric_result_per_invocation_2 + ], + session_id="session_2", + ), + ] + + +class TestGcsEvalSetResultsManager: + + @pytest.fixture + def gcs_eval_set_results_manager(self, mocker): + mock_storage_client = MockClient() + bucket_name = "test_bucket" + mock_bucket = MockBucket(bucket_name) + mocker.patch.object(mock_storage_client, "bucket", return_value=mock_bucket) + mocker.patch( + "google.cloud.storage.Client", return_value=mock_storage_client + ) + return GcsEvalSetResultsManager(bucket_name=bucket_name) + + def test_save_eval_set_result(self, gcs_eval_set_results_manager, mocker): + mocker.patch("time.time", return_value=12345678) + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_results = _get_test_eval_case_results() + eval_set_result = create_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + blob_name = gcs_eval_set_results_manager._get_eval_set_result_blob_name( + app_name, eval_set_result.eval_set_result_id + ) + mock_write_eval_set_result = mocker.patch.object( + gcs_eval_set_results_manager, + "_write_eval_set_result", + ) + gcs_eval_set_results_manager.save_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + mock_write_eval_set_result.assert_called_once_with( + blob_name, + eval_set_result, + ) + + def test_get_eval_set_result_not_found( + self, gcs_eval_set_results_manager, mocker + ): + mocker.patch("time.time", return_value=12345678) + app_name = "test_app" + with pytest.raises(NotFoundError) as e: + gcs_eval_set_results_manager.get_eval_set_result( + app_name, "non_existent_id" + ) + + def test_get_eval_set_result(self, gcs_eval_set_results_manager, mocker): + mocker.patch("time.time", return_value=12345678) + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_results = _get_test_eval_case_results() + gcs_eval_set_results_manager.save_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + eval_set_result = create_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + retrieved_eval_set_result = ( + gcs_eval_set_results_manager.get_eval_set_result( + app_name, eval_set_result.eval_set_result_id + ) + ) + assert retrieved_eval_set_result == eval_set_result + + def test_list_eval_set_results(self, gcs_eval_set_results_manager, mocker): + mocker.patch("time.time", return_value=123) + app_name = "test_app" + eval_set_ids = ["test_eval_set_1", "test_eval_set_2", "test_eval_set_3"] + for eval_set_id in eval_set_ids: + eval_case_results = _get_test_eval_case_results() + gcs_eval_set_results_manager.save_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + retrieved_eval_set_result_ids = ( + gcs_eval_set_results_manager.list_eval_set_results(app_name) + ) + assert retrieved_eval_set_result_ids == [ + "test_app_test_eval_set_1_123", + "test_app_test_eval_set_2_123", + "test_app_test_eval_set_3_123", + ] + + def test_list_eval_set_results_empty(self, gcs_eval_set_results_manager): + app_name = "test_app" + retrieved_eval_set_result_ids = ( + gcs_eval_set_results_manager.list_eval_set_results(app_name) + ) + assert retrieved_eval_set_result_ids == [] diff --git a/tests/unittests/evaluation/test_gcs_eval_sets_manager.py b/tests/unittests/evaluation/test_gcs_eval_sets_manager.py new file mode 100644 index 000000000..bb8e3bd3b --- /dev/null +++ b/tests/unittests/evaluation/test_gcs_eval_sets_manager.py @@ -0,0 +1,404 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.errors.not_found_error import NotFoundError +from google.adk.evaluation.eval_case import EvalCase +from google.adk.evaluation.eval_set import EvalSet +from google.adk.evaluation.gcs_eval_sets_manager import _EVAL_SET_FILE_EXTENSION +from google.adk.evaluation.gcs_eval_sets_manager import GcsEvalSetsManager +import pytest + +from .mock_gcs_utils import MockBlob +from .mock_gcs_utils import MockBucket +from .mock_gcs_utils import MockClient + + +class TestGcsEvalSetsManager: + """Tests for GcsEvalSetsManager.""" + + @pytest.fixture + def gcs_eval_sets_manager(self, mocker): + mock_storage_client = MockClient() + bucket_name = "test_bucket" + mock_bucket = MockBucket(bucket_name) + mocker.patch.object(mock_storage_client, "bucket", return_value=mock_bucket) + mocker.patch( + "google.cloud.storage.Client", return_value=mock_storage_client + ) + return GcsEvalSetsManager(bucket_name=bucket_name) + + def test_gcs_eval_sets_manager_get_eval_set_success( + self, gcs_eval_sets_manager + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mock_bucket = gcs_eval_sets_manager.bucket + mock_blob = mock_bucket.blob( + f"{app_name}/evals/eval_sets/{eval_set_id}{_EVAL_SET_FILE_EXTENSION}" + ) + mock_blob.upload_from_string(mock_eval_set.model_dump_json()) + + eval_set = gcs_eval_sets_manager.get_eval_set(app_name, eval_set_id) + + assert eval_set == mock_eval_set + + def test_gcs_eval_sets_manager_get_eval_set_not_found( + self, gcs_eval_sets_manager + ): + app_name = "test_app" + eval_set_id = "test_eval_set_not_exist" + eval_set = gcs_eval_sets_manager.get_eval_set(app_name, eval_set_id) + + assert eval_set is None + + def test_gcs_eval_sets_manager_create_eval_set_success( + self, gcs_eval_sets_manager, mocker + ): + mocked_time = 12345678 + mocker.patch("time.time", return_value=mocked_time) + app_name = "test_app" + eval_set_id = "test_eval_set" + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, + "_write_eval_set_to_blob", + ) + eval_set_blob_name = gcs_eval_sets_manager._get_eval_set_blob_name( + app_name, eval_set_id + ) + + gcs_eval_sets_manager.create_eval_set(app_name, eval_set_id) + + mock_write_eval_set_to_blob.assert_called_once_with( + eval_set_blob_name, + EvalSet( + eval_set_id=eval_set_id, + name=eval_set_id, + eval_cases=[], + creation_timestamp=mocked_time, + ), + ) + + def test_gcs_eval_sets_manager_create_eval_set_invalid_id( + self, gcs_eval_sets_manager + ): + app_name = "test_app" + eval_set_id = "invalid-id" + + with pytest.raises(ValueError, match="Invalid Eval Set Id"): + gcs_eval_sets_manager.create_eval_set(app_name, eval_set_id) + + def test_gcs_eval_sets_manager_list_eval_sets_success( + self, gcs_eval_sets_manager + ): + app_name = "test_app" + mock_blob_1 = MockBlob( + f"test_app/evals/eval_sets/eval_set_1{_EVAL_SET_FILE_EXTENSION}" + ) + mock_blob_2 = MockBlob( + f"test_app/evals/eval_sets/eval_set_2{_EVAL_SET_FILE_EXTENSION}" + ) + mock_blob_3 = MockBlob("test_app/evals/eval_sets/not_an_eval_set.txt") + mock_bucket = gcs_eval_sets_manager.bucket + mock_bucket.blobs = { + mock_blob_1.name: mock_blob_1, + mock_blob_2.name: mock_blob_2, + mock_blob_3.name: mock_blob_3, + } + + eval_sets = gcs_eval_sets_manager.list_eval_sets(app_name) + + assert eval_sets == ["eval_set_1", "eval_set_2"] + + def test_gcs_eval_sets_manager_add_eval_case_success( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + eval_set_blob_name = gcs_eval_sets_manager._get_eval_set_blob_name( + app_name, eval_set_id + ) + + gcs_eval_sets_manager.add_eval_case(app_name, eval_set_id, mock_eval_case) + + assert len(mock_eval_set.eval_cases) == 1 + assert mock_eval_set.eval_cases[0] == mock_eval_case + mock_write_eval_set_to_blob.assert_called_once_with( + eval_set_blob_name, mock_eval_set + ) + + def test_gcs_eval_sets_manager_add_eval_case_eval_set_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=None + ) + + with pytest.raises( + NotFoundError, match="Eval set `test_eval_set` not found." + ): + gcs_eval_sets_manager.add_eval_case(app_name, eval_set_id, mock_eval_case) + + def test_gcs_eval_sets_manager_add_eval_case_eval_case_id_exists( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet( + eval_set_id=eval_set_id, eval_cases=[mock_eval_case] + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + + with pytest.raises( + ValueError, + match=( + f"Eval id `{eval_case_id}` already exists in `{eval_set_id}` eval" + " set." + ), + ): + gcs_eval_sets_manager.add_eval_case(app_name, eval_set_id, mock_eval_case) + + def test_gcs_eval_sets_manager_get_eval_case_success( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet( + eval_set_id=eval_set_id, eval_cases=[mock_eval_case] + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + + eval_case = gcs_eval_sets_manager.get_eval_case( + app_name, eval_set_id, eval_case_id + ) + + assert eval_case == mock_eval_case + + def test_gcs_eval_sets_manager_get_eval_case_eval_set_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=None + ) + + eval_case = gcs_eval_sets_manager.get_eval_case( + app_name, eval_set_id, eval_case_id + ) + + assert eval_case is None + + def test_gcs_eval_sets_manager_get_eval_case_eval_case_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + + eval_case = gcs_eval_sets_manager.get_eval_case( + app_name, eval_set_id, eval_case_id + ) + + assert eval_case is None + + def test_gcs_eval_sets_manager_update_eval_case_success( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase( + eval_id=eval_case_id, conversation=[], creation_timestamp=456 + ) + updated_eval_case = EvalCase( + eval_id=eval_case_id, conversation=[], creation_timestamp=123 + ) + mock_eval_set = EvalSet( + eval_set_id=eval_set_id, eval_cases=[mock_eval_case] + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_case", return_value=mock_eval_case + ) + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + eval_set_blob_name = gcs_eval_sets_manager._get_eval_set_blob_name( + app_name, eval_set_id + ) + + gcs_eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + + assert len(mock_eval_set.eval_cases) == 1 + assert mock_eval_set.eval_cases[0] == updated_eval_case + mock_write_eval_set_to_blob.assert_called_once_with( + eval_set_blob_name, + EvalSet(eval_set_id=eval_set_id, eval_cases=[updated_eval_case]), + ) + + def test_gcs_eval_sets_manager_update_eval_case_eval_set_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + updated_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_case", return_value=None + ) + + with pytest.raises( + NotFoundError, + match=f"Eval set `{eval_set_id}` not found.", + ): + gcs_eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + + def test_gcs_eval_sets_manager_update_eval_case_eval_case_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + updated_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + + with pytest.raises( + NotFoundError, + match=( + f"Eval case `{eval_case_id}` not found in eval set `{eval_set_id}`." + ), + ): + gcs_eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + + def test_gcs_eval_sets_manager_delete_eval_case_success( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet( + eval_set_id=eval_set_id, eval_cases=[mock_eval_case] + ) + mock_bucket = gcs_eval_sets_manager.bucket + mock_blob = mock_bucket.blob( + f"{app_name}/evals/eval_sets/{eval_set_id}{_EVAL_SET_FILE_EXTENSION}" + ) + mock_blob.upload_from_string(mock_eval_set.model_dump_json()) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_case", return_value=mock_eval_case + ) + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + eval_set_blob_name = gcs_eval_sets_manager._get_eval_set_blob_name( + app_name, eval_set_id + ) + + gcs_eval_sets_manager.delete_eval_case(app_name, eval_set_id, eval_case_id) + + assert len(mock_eval_set.eval_cases) == 0 + mock_write_eval_set_to_blob.assert_called_once_with( + eval_set_blob_name, + EvalSet(eval_set_id=eval_set_id, eval_cases=[]), + ) + + def test_gcs_eval_sets_manager_delete_eval_case_eval_set_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + + with pytest.raises( + NotFoundError, + match=f"Eval set `{eval_set_id}` not found.", + ): + gcs_eval_sets_manager.delete_eval_case( + app_name, eval_set_id, eval_case_id + ) + mock_write_eval_set_to_blob.assert_not_called() + + def test_gcs_eval_sets_manager_delete_eval_case_eval_case_not_found( + self, gcs_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_set", return_value=mock_eval_set + ) + mocker.patch.object( + gcs_eval_sets_manager, "get_eval_case", return_value=None + ) + mock_write_eval_set_to_blob = mocker.patch.object( + gcs_eval_sets_manager, "_write_eval_set_to_blob" + ) + + with pytest.raises( + NotFoundError, + match=( + f"Eval case `{eval_case_id}` not found in eval set `{eval_set_id}`." + ), + ): + gcs_eval_sets_manager.delete_eval_case( + app_name, eval_set_id, eval_case_id + ) + mock_write_eval_set_to_blob.assert_not_called() diff --git a/tests/unittests/evaluation/test_local_eval_set_results_manager.py b/tests/unittests/evaluation/test_local_eval_set_results_manager.py index 038f17abb..3411d9b7a 100644 --- a/tests/unittests/evaluation/test_local_eval_set_results_manager.py +++ b/tests/unittests/evaluation/test_local_eval_set_results_manager.py @@ -21,24 +21,17 @@ import time from unittest.mock import patch +from google.adk.errors.not_found_error import NotFoundError +from google.adk.evaluation._eval_set_results_manager_utils import _sanitize_eval_set_result_name from google.adk.evaluation.eval_result import EvalCaseResult from google.adk.evaluation.eval_result import EvalSetResult from google.adk.evaluation.evaluator import EvalStatus from google.adk.evaluation.local_eval_set_results_manager import _ADK_EVAL_HISTORY_DIR from google.adk.evaluation.local_eval_set_results_manager import _EVAL_SET_RESULT_FILE_EXTENSION -from google.adk.evaluation.local_eval_set_results_manager import _sanitize_eval_set_result_name from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager import pytest -def test_sanitize_eval_set_result_name(): - assert _sanitize_eval_set_result_name("app/name") == "app_name" - assert _sanitize_eval_set_result_name("app_name") == "app_name" - assert _sanitize_eval_set_result_name("app/name/with/slashes") == ( - "app_name_with_slashes" - ) - - class TestLocalEvalSetResultsManager: @pytest.fixture(autouse=True) @@ -115,11 +108,9 @@ def test_get_eval_set_result(self, mock_time): def test_get_eval_set_result_not_found(self, mock_time): mock_time.return_value = self.timestamp - with pytest.raises(ValueError) as e: + with pytest.raises(NotFoundError) as e: self.manager.get_eval_set_result(self.app_name, "non_existent_id") - assert "does not exist" in str(e.value) - @patch("time.time") def test_list_eval_set_results(self, mock_time): mock_time.return_value = self.timestamp diff --git a/tests/unittests/evaluation/test_local_eval_sets_manager.py b/tests/unittests/evaluation/test_local_eval_sets_manager.py index 2b919fa83..de00659e2 100644 --- a/tests/unittests/evaluation/test_local_eval_sets_manager.py +++ b/tests/unittests/evaluation/test_local_eval_sets_manager.py @@ -361,8 +361,8 @@ def test_local_eval_sets_manager_create_eval_set_success( app_name = "test_app" eval_set_id = "test_eval_set" mocker.patch("os.path.exists", return_value=False) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) eval_set_file_path = os.path.join( local_eval_sets_manager._agents_dir, @@ -371,7 +371,7 @@ def test_local_eval_sets_manager_create_eval_set_success( ) local_eval_sets_manager.create_eval_set(app_name, eval_set_id) - mock_write_eval_set.assert_called_once_with( + mock_write_eval_set_to_path.assert_called_once_with( eval_set_file_path, EvalSet( eval_set_id=eval_set_id, @@ -420,8 +420,8 @@ def test_local_eval_sets_manager_add_eval_case_success( "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_set", return_value=mock_eval_set, ) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) local_eval_sets_manager.add_eval_case(app_name, eval_set_id, mock_eval_case) @@ -434,7 +434,7 @@ def test_local_eval_sets_manager_add_eval_case_success( eval_set_id + _EVAL_SET_FILE_EXTENSION, ) mock_eval_set.eval_cases.append(mock_eval_case) - mock_write_eval_set.assert_called_once_with( + mock_write_eval_set_to_path.assert_called_once_with( expected_eval_set_file_path, mock_eval_set ) @@ -568,8 +568,8 @@ def test_local_eval_sets_manager_update_eval_case_success( "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", return_value=mock_eval_case, ) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) local_eval_sets_manager.update_eval_case( @@ -583,12 +583,12 @@ def test_local_eval_sets_manager_update_eval_case_success( app_name, eval_set_id + _EVAL_SET_FILE_EXTENSION, ) - mock_write_eval_set.assert_called_once_with( + mock_write_eval_set_to_path.assert_called_once_with( expected_eval_set_file_path, EvalSet(eval_set_id=eval_set_id, eval_cases=[updated_eval_case]), ) - def test_local_eval_sets_manager_update_eval_case_eval_case_not_found( + def test_local_eval_sets_manager_update_eval_case_eval_set_not_found( self, local_eval_sets_manager, mocker ): app_name = "test_app" @@ -601,10 +601,34 @@ def test_local_eval_sets_manager_update_eval_case_eval_case_not_found( return_value=None, ) + with pytest.raises( + NotFoundError, + match=f"Eval set `{eval_set_id}` not found.", + ): + local_eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + + def test_local_eval_sets_manager_update_eval_case_eval_case_not_found( + self, local_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + updated_eval_case = EvalCase(eval_id=eval_case_id, conversation=[]) + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_set", + return_value=mock_eval_set, + ) + mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", + return_value=None, + ) with pytest.raises( NotFoundError, match=( - f"Eval Set `{eval_set_id}` or Eval id `{eval_case_id}` not found." + f"Eval case `{eval_case_id}` not found in eval set `{eval_set_id}`." ), ): local_eval_sets_manager.update_eval_case( @@ -630,8 +654,8 @@ def test_local_eval_sets_manager_delete_eval_case_success( "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", return_value=mock_eval_case, ) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) local_eval_sets_manager.delete_eval_case( @@ -644,12 +668,12 @@ def test_local_eval_sets_manager_delete_eval_case_success( app_name, eval_set_id + _EVAL_SET_FILE_EXTENSION, ) - mock_write_eval_set.assert_called_once_with( + mock_write_eval_set_to_path.assert_called_once_with( expected_eval_set_file_path, EvalSet(eval_set_id=eval_set_id, eval_cases=[]), ) - def test_local_eval_sets_manager_delete_eval_case_eval_case_not_found( + def test_local_eval_sets_manager_delete_eval_case_eval_set_not_found( self, local_eval_sets_manager, mocker ): app_name = "test_app" @@ -660,18 +684,41 @@ def test_local_eval_sets_manager_delete_eval_case_eval_case_not_found( "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", return_value=None, ) - mock_write_eval_set = mocker.patch( - "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set" + mock_write_eval_set_to_path = mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager._write_eval_set_to_path" ) + with pytest.raises( + NotFoundError, + match=f"Eval set `{eval_set_id}` not found.", + ): + local_eval_sets_manager.delete_eval_case( + app_name, eval_set_id, eval_case_id + ) + + mock_write_eval_set_to_path.assert_not_called() + + def test_local_eval_sets_manager_delete_eval_case_eval_case_not_found( + self, local_eval_sets_manager, mocker + ): + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_id = "test_eval_case" + mock_eval_set = EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_set", + return_value=mock_eval_set, + ) + mocker.patch( + "google.adk.evaluation.local_eval_sets_manager.LocalEvalSetsManager.get_eval_case", + return_value=None, + ) with pytest.raises( NotFoundError, match=( - f"Eval Set `{eval_set_id}` or Eval id `{eval_case_id}` not found." + f"Eval case `{eval_case_id}` not found in eval set `{eval_set_id}`." ), ): local_eval_sets_manager.delete_eval_case( app_name, eval_set_id, eval_case_id ) - - mock_write_eval_set.assert_not_called() diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py b/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py new file mode 100644 index 000000000..f3eefb186 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py @@ -0,0 +1,201 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.agents import Agent +from google.adk.agents.live_request_queue import LiveRequest +from google.adk.agents.live_request_queue import LiveRequestQueue +from google.adk.agents.run_config import RunConfig +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.models.llm_request import LlmRequest +from google.genai import types +import pytest + +from ... import testing_utils + + +class TestBaseLlmFlow(BaseLlmFlow): + """Test implementation of BaseLlmFlow for testing purposes.""" + + pass + + +@pytest.fixture +def test_blob(): + """Test blob for audio data.""" + return types.Blob(data=b'\x00\xFF\x00\xFF', mime_type='audio/pcm') + + +@pytest.fixture +def mock_llm_connection(): + """Mock LLM connection for testing.""" + connection = mock.AsyncMock() + connection.send_realtime = mock.AsyncMock() + return connection + + +@pytest.mark.asyncio +async def test_send_to_model_with_disabled_vad(test_blob, mock_llm_connection): + """Test _send_to_model with automatic_activity_detection.disabled=True.""" + # Create LlmRequest with disabled VAD + realtime_input_config = types.RealtimeInputConfig( + automatic_activity_detection=types.AutomaticActivityDetection( + disabled=True + ) + ) + + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, + user_content='', + run_config=RunConfig(realtime_input_config=realtime_input_config), + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send a blob to the queue + live_request = LiveRequest(blob=test_blob) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + mock_llm_connection.send_realtime.assert_called_once_with(test_blob) + + +@pytest.mark.asyncio +async def test_send_to_model_with_enabled_vad(test_blob, mock_llm_connection): + """Test _send_to_model with automatic_activity_detection.disabled=False. + + Custom VAD activity signal is not supported so we should still disable it. + """ + # Create LlmRequest with enabled VAD + realtime_input_config = types.RealtimeInputConfig( + automatic_activity_detection=types.AutomaticActivityDetection( + disabled=False + ) + ) + + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send a blob to the queue + live_request = LiveRequest(blob=test_blob) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + mock_llm_connection.send_realtime.assert_called_once_with(test_blob) + + +@pytest.mark.asyncio +async def test_send_to_model_without_realtime_config( + test_blob, mock_llm_connection +): + """Test _send_to_model without realtime_input_config (default behavior).""" + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send a blob to the queue + live_request = LiveRequest(blob=test_blob) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + mock_llm_connection.send_realtime.assert_called_once_with(test_blob) + + +@pytest.mark.asyncio +async def test_send_to_model_with_none_automatic_activity_detection( + test_blob, mock_llm_connection +): + """Test _send_to_model with automatic_activity_detection=None.""" + # Create LlmRequest with None automatic_activity_detection + realtime_input_config = types.RealtimeInputConfig( + automatic_activity_detection=None + ) + + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, + user_content='', + run_config=RunConfig(realtime_input_config=realtime_input_config), + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send a blob to the queue + live_request = LiveRequest(blob=test_blob) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + mock_llm_connection.send_realtime.assert_called_once_with(test_blob) + + +@pytest.mark.asyncio +async def test_send_to_model_with_text_content(mock_llm_connection): + """Test _send_to_model with text content (not blob).""" + # Create invocation context with live request queue + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + invocation_context.live_request_queue = LiveRequestQueue() + + # Create flow and start _send_to_model task + flow = TestBaseLlmFlow() + + # Send text content to the queue + content = types.Content( + role='user', parts=[types.Part.from_text(text='Hello')] + ) + live_request = LiveRequest(content=content) + invocation_context.live_request_queue.send(live_request) + invocation_context.live_request_queue.close() + + # Run _send_to_model + await flow._send_to_model(mock_llm_connection, invocation_context) + + # Verify send_content was called instead of send_realtime + mock_llm_connection.send_content.assert_called_once_with(content) + mock_llm_connection.send_realtime.assert_not_called() diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py new file mode 100644 index 000000000..232711503 --- /dev/null +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -0,0 +1,111 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.models.gemini_llm_connection import GeminiLlmConnection +from google.genai import types +import pytest + + +@pytest.fixture +def mock_gemini_session(): + """Mock Gemini session for testing.""" + return mock.AsyncMock() + + +@pytest.fixture +def gemini_connection(mock_gemini_session): + """GeminiLlmConnection instance with mocked session.""" + return GeminiLlmConnection(mock_gemini_session) + + +@pytest.fixture +def test_blob(): + """Test blob for audio data.""" + return types.Blob(data=b'\x00\xFF\x00\xFF', mime_type='audio/pcm') + + +@pytest.mark.asyncio +async def test_send_realtime_default_behavior( + gemini_connection, mock_gemini_session, test_blob +): + """Test send_realtime with default automatic_activity_detection value (True).""" + await gemini_connection.send_realtime(test_blob) + + # Should call send once + mock_gemini_session.send.assert_called_once_with(input=test_blob.model_dump()) + + +@pytest.mark.asyncio +async def test_send_history(gemini_connection, mock_gemini_session): + """Test send_history method.""" + history = [ + types.Content(role='user', parts=[types.Part.from_text(text='Hello')]), + types.Content( + role='model', parts=[types.Part.from_text(text='Hi there!')] + ), + ] + + await gemini_connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + assert 'input' in call_args + assert call_args['input'].turns == history + assert call_args['input'].turn_complete is False # Last message is from model + + +@pytest.mark.asyncio +async def test_send_content_text(gemini_connection, mock_gemini_session): + """Test send_content with text content.""" + content = types.Content( + role='user', parts=[types.Part.from_text(text='Hello')] + ) + + await gemini_connection.send_content(content) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + assert 'input' in call_args + assert call_args['input'].turns == [content] + assert call_args['input'].turn_complete is True + + +@pytest.mark.asyncio +async def test_send_content_function_response( + gemini_connection, mock_gemini_session +): + """Test send_content with function response.""" + function_response = types.FunctionResponse( + name='test_function', response={'result': 'success'} + ) + content = types.Content( + role='user', parts=[types.Part(function_response=function_response)] + ) + + await gemini_connection.send_content(content) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + assert 'input' in call_args + assert call_args['input'].function_responses == [function_response] + + +@pytest.mark.asyncio +async def test_close(gemini_connection, mock_gemini_session): + """Test close method.""" + await gemini_connection.close() + + mock_gemini_session.close.assert_called_once() diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 9278bee54..fb8540bb2 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -341,6 +341,255 @@ async def __aexit__(self, *args): assert connection is mock_connection +@pytest.mark.asyncio +async def test_generate_content_async_with_custom_headers( + gemini_llm, llm_request, generate_content_response +): + """Test that tracking headers are updated when custom headers are provided.""" + # Add custom headers to the request config + custom_headers = {"custom-header": "custom-value"} + for key in gemini_llm._tracking_headers: + custom_headers[key] = "custom " + gemini_llm._tracking_headers[key] + llm_request.config.http_options = types.HttpOptions(headers=custom_headers) + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + # Create a mock coroutine that returns the generate_content_response + async def mock_coro(): + return generate_content_response + + mock_client.aio.models.generate_content.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=False + ) + ] + + # Verify that the config passed to generate_content contains merged headers + mock_client.aio.models.generate_content.assert_called_once() + call_args = mock_client.aio.models.generate_content.call_args + config_arg = call_args.kwargs["config"] + + for key, value in config_arg.http_options.headers.items(): + if key in gemini_llm._tracking_headers: + assert value == gemini_llm._tracking_headers[key] + else: + assert value == custom_headers[key] + + assert len(responses) == 1 + assert isinstance(responses[0], LlmResponse) + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_custom_headers( + gemini_llm, llm_request +): + """Test that tracking headers are updated when custom headers are provided in streaming mode.""" + # Add custom headers to the request config + custom_headers = {"custom-header": "custom-value"} + llm_request.config.http_options = types.HttpOptions(headers=custom_headers) + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + # Create mock stream responses + class MockAsyncIterator: + + def __init__(self, seq): + self.iter = iter(seq) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration: + raise StopAsyncIteration + + mock_responses = [ + types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + role="model", parts=[Part.from_text(text="Hello")] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + ] + + async def mock_coro(): + return MockAsyncIterator(mock_responses) + + mock_client.aio.models.generate_content_stream.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=True + ) + ] + + # Verify that the config passed to generate_content_stream contains merged headers + mock_client.aio.models.generate_content_stream.assert_called_once() + call_args = mock_client.aio.models.generate_content_stream.call_args + config_arg = call_args.kwargs["config"] + + expected_headers = custom_headers.copy() + expected_headers.update(gemini_llm._tracking_headers) + assert config_arg.http_options.headers == expected_headers + + assert len(responses) == 2 + + +@pytest.mark.asyncio +async def test_generate_content_async_without_custom_headers( + gemini_llm, llm_request, generate_content_response +): + """Test that tracking headers are not modified when no custom headers exist.""" + # Ensure no http_options exist initially + llm_request.config.http_options = None + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + + async def mock_coro(): + return generate_content_response + + mock_client.aio.models.generate_content.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=False + ) + ] + + # Verify that the config passed to generate_content has no http_options + mock_client.aio.models.generate_content.assert_called_once() + call_args = mock_client.aio.models.generate_content.call_args + config_arg = call_args.kwargs["config"] + assert config_arg.http_options is None + + assert len(responses) == 1 + + +def test_live_api_version_vertex_ai(gemini_llm): + """Test that _live_api_version returns 'v1beta1' for Vertex AI backend.""" + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + ): + assert gemini_llm._live_api_version == "v1beta1" + + +def test_live_api_version_gemini_api(gemini_llm): + """Test that _live_api_version returns 'v1alpha' for Gemini API backend.""" + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.GEMINI_API + ): + assert gemini_llm._live_api_version == "v1alpha" + + +def test_live_api_client_properties(gemini_llm): + """Test that _live_api_client is properly configured with tracking headers and API version.""" + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + ): + client = gemini_llm._live_api_client + + # Verify that the client has the correct headers and API version + http_options = client._api_client._http_options + assert http_options.api_version == "v1beta1" + + # Check that tracking headers are included + tracking_headers = gemini_llm._tracking_headers + for key, value in tracking_headers.items(): + assert key in http_options.headers + assert value in http_options.headers[key] + + +@pytest.mark.asyncio +async def test_connect_with_custom_headers(gemini_llm, llm_request): + """Test that connect method updates tracking headers and API version when custom headers are provided.""" + # Setup request with live connect config and custom headers + custom_headers = {"custom-live-header": "live-value"} + llm_request.live_connect_config = types.LiveConnectConfig( + http_options=types.HttpOptions(headers=custom_headers) + ) + + mock_live_session = mock.AsyncMock() + + # Mock the _live_api_client to return a mock client + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + # Create a mock context manager + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + async with gemini_llm.connect(llm_request) as connection: + # Verify that the connect method was called with the right config + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify that tracking headers were merged with custom headers + expected_headers = custom_headers.copy() + expected_headers.update(gemini_llm._tracking_headers) + assert config_arg.http_options.headers == expected_headers + + # Verify that API version was set + assert config_arg.http_options.api_version == gemini_llm._live_api_version + + # Verify that system instruction and tools were set + assert config_arg.system_instruction is not None + assert config_arg.tools == llm_request.config.tools + + # Verify connection is properly wrapped + assert isinstance(connection, GeminiLlmConnection) + + +@pytest.mark.asyncio +async def test_connect_without_custom_headers(gemini_llm, llm_request): + """Test that connect method works properly when no custom headers are provided.""" + # Setup request with live connect config but no custom headers + llm_request.live_connect_config = types.LiveConnectConfig() + + mock_live_session = mock.AsyncMock() + + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + async with gemini_llm.connect(llm_request) as connection: + # Verify that the connect method was called with the right config + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify that http_options remains None since no custom headers were provided + assert config_arg.http_options is None + + # Verify that system instruction and tools were still set + assert config_arg.system_instruction is not None + assert config_arg.tools == llm_request.config.tools + + assert isinstance(connection, GeminiLlmConnection) + + @pytest.mark.parametrize( ( "api_backend, " diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index f316e83ae..8b43cc48b 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1194,11 +1194,11 @@ async def test_generate_content_async_stream( assert responses[2].content.role == "model" assert responses[2].content.parts[0].text == "two:" assert responses[3].content.role == "model" - assert responses[3].content.parts[0].function_call.name == "test_function" - assert responses[3].content.parts[0].function_call.args == { + assert responses[3].content.parts[-1].function_call.name == "test_function" + assert responses[3].content.parts[-1].function_call.args == { "test_arg": "test_value" } - assert responses[3].content.parts[0].function_call.id == "test_tool_call_id" + assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" mock_completion.assert_called_once() _, kwargs = mock_completion.call_args @@ -1257,11 +1257,11 @@ async def test_generate_content_async_stream_with_usage_metadata( assert responses[2].content.role == "model" assert responses[2].content.parts[0].text == "two:" assert responses[3].content.role == "model" - assert responses[3].content.parts[0].function_call.name == "test_function" - assert responses[3].content.parts[0].function_call.args == { + assert responses[3].content.parts[-1].function_call.name == "test_function" + assert responses[3].content.parts[-1].function_call.args == { "test_arg": "test_value" } - assert responses[3].content.parts[0].function_call.id == "test_tool_call_id" + assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" assert responses[3].usage_metadata.prompt_token_count == 10 assert responses[3].usage_metadata.candidates_token_count == 5 diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 92f6a29dd..6a9e0b46a 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -245,8 +245,14 @@ async def async_request( raise ValueError(f'Unsupported http method: {http_method}') -def mock_vertex_ai_session_service(): +def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None): """Creates a mock Vertex AI Session service for testing.""" + if agent_engine_id: + return VertexAiSessionService( + project='test-project', + location='test-location', + agent_engine_id=agent_engine_id, + ) return VertexAiSessionService( project='test-project', location='test-location' ) @@ -265,7 +271,7 @@ def mock_get_api_client(): '2': (MOCK_EVENT_JSON_2, 'my_token'), } with mock.patch( - 'google.adk.sessions.vertex_ai_session_service._get_api_client', + 'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client', return_value=api_client, ): yield @@ -273,8 +279,12 @@ def mock_get_api_client(): @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') -async def test_get_empty_session(): - session_service = mock_vertex_ai_session_service() +@pytest.mark.parametrize('agent_engine_id', [None, '123']) +async def test_get_empty_session(agent_engine_id): + if agent_engine_id: + session_service = mock_vertex_ai_session_service(agent_engine_id) + else: + session_service = mock_vertex_ai_session_service() with pytest.raises(ValueError) as excinfo: await session_service.get_session( app_name='123', user_id='user', session_id='0' diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index 1a8ed5233..b1d5ff822 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -56,7 +56,9 @@ def __init__(self, parts: list[types.Part]): super().__init__(role='model', parts=parts) -async def create_invocation_context(agent: Agent, user_content: str = ''): +async def create_invocation_context( + agent: Agent, user_content: str = '', run_config: RunConfig = None +): invocation_id = 'test_id' artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() @@ -73,7 +75,7 @@ async def create_invocation_context(agent: Agent, user_content: str = ''): user_content=types.Content( role='user', parts=[types.Part.from_text(text=user_content)] ), - run_config=RunConfig(), + run_config=run_config or RunConfig(), ) if user_content: append_user_content( @@ -205,13 +207,16 @@ async def run_async(self, new_message: types.ContentUnion) -> list[Event]: events.append(event) return events - def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: + def run_live( + self, live_request_queue: LiveRequestQueue, run_config: RunConfig = None + ) -> list[Event]: collected_responses = [] async def consume_responses(session: Session): run_res = self.runner.run_live( session=session, live_request_queue=live_request_queue, + run_config=run_config or RunConfig(), ) async for response in run_res: diff --git a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py index cd37a105e..c9b542e51 100644 --- a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py +++ b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py @@ -20,6 +20,7 @@ from google.adk.auth.auth_credential import HttpCredentials from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool +from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import AuthPreparationResult from google.genai.types import FunctionDeclaration from google.genai.types import Schema from google.genai.types import Type @@ -50,7 +51,9 @@ def mock_rest_api_tool(): "required": ["user_id", "page_size", "filter", "connection_name"], } mock_tool._operation_parser = mock_parser - mock_tool.call.return_value = {"status": "success", "data": "mock_data"} + mock_tool.call = mock.AsyncMock( + return_value={"status": "success", "data": "mock_data"} + ) return mock_tool @@ -179,9 +182,6 @@ async def test_run_with_auth_async_none_token( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = mock.MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) # Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None mock_auth_credential_without_token = AuthCredential( auth_type=AuthCredentialTypes.HTTP, @@ -190,8 +190,12 @@ async def test_run_with_auth_async_none_token( credentials=HttpCredentials(token=None), # Token is None ), ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = ( - mock_auth_credential_without_token + mock_tool_auth_handler_instance.prepare_auth_credentials = mock.AsyncMock( + return_value=( + AuthPreparationResult( + state="done", auth_credential=mock_auth_credential_without_token + ) + ) ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance @@ -229,18 +233,18 @@ async def test_run_with_auth_async( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = mock.MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "done" - ) - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", - credentials=HttpCredentials(token="mocked_token"), - ), + + mock_tool_auth_handler_instance.prepare_auth_credentials = mock.AsyncMock( + return_value=AuthPreparationResult( + state="done", + auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="mocked_token"), + ), + ), + ) ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance result = await integration_tool_with_auth.run_async( diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py new file mode 100644 index 000000000..e8b373416 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -0,0 +1,129 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import re +from unittest import mock + +from google.adk.tools.bigquery.client import get_bigquery_client +from google.auth.exceptions import DefaultCredentialsError +from google.oauth2.credentials import Credentials +import pytest + + +def test_bigquery_client_project(): + """Test BigQuery client project.""" + # Trigger the BigQuery client creation + client = get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the client has the desired project set + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_explicit(): + """Test BigQuery client creation does not invoke default auth.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_with_default_auth(): + """Test BigQuery client creation invokes default auth to set the project.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate credentials + mock_creds = mock.create_autospec(Credentials, instance=True) + + # Simulate output of the default auth + mock_default_auth.return_value = (mock_creds, "test-gcp-project") + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project=None, + credentials=mock_creds, + ) + + # Verify that default auth was called once to set the client project + mock_default_auth.assert_called_once() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_project_set_with_env(): + """Test BigQuery client creation sets the project from environment variable.""" + # Let's simulate the project set in environment variables + with mock.patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True + ): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the BigQuery client creation + client = get_bigquery_client( + project=None, + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_bigquery_client_user_agent(): + """Test BigQuery client user agent.""" + with mock.patch( + "google.cloud.bigquery.client.Connection", autospec=True + ) as mock_connection: + # Trigger the BigQuery client creation + get_bigquery_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the tracking user agent was set + client_info_arg = mock_connection.call_args[1].get("client_info") + assert client_info_arg is not None + assert re.search( + r"adk-bigquery-tool google-adk/([0-9A-Za-z._\-+/]+)", + client_info_arg.user_agent, + ) diff --git a/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py new file mode 100644 index 000000000..14ecea558 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py @@ -0,0 +1,122 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock + +from google.adk.tools.bigquery import metadata_tool +from google.auth.exceptions import DefaultCredentialsError +from google.cloud import bigquery +from google.oauth2.credentials import Credentials +import pytest + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.list_datasets", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_list_dataset_ids(mock_default_auth, mock_list_datasets): + """Test list_dataset_ids tool invocation.""" + project = "my_project_id" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_list_datasets.return_value = [ + bigquery.DatasetReference(project, "dataset1"), + bigquery.DatasetReference(project, "dataset2"), + ] + result = metadata_tool.list_dataset_ids(project, mock_credentials) + assert result == ["dataset1", "dataset2"] + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.get_dataset", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_get_dataset_info(mock_default_auth, mock_get_dataset): + """Test get_dataset_info tool invocation.""" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_get_dataset.return_value = mock.create_autospec( + Credentials, instance=True + ) + result = metadata_tool.get_dataset_info( + "my_project_id", "my_dataset_id", mock_credentials + ) + assert result != { + "status": "ERROR", + "error_details": "Your default credentials were not found", + } + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.list_tables", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_list_table_ids(mock_default_auth, mock_list_tables): + """Test list_table_ids tool invocation.""" + project = "my_project_id" + dataset = "my_dataset_id" + dataset_ref = bigquery.DatasetReference(project, dataset) + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_list_tables.return_value = [ + bigquery.TableReference(dataset_ref, "table1"), + bigquery.TableReference(dataset_ref, "table2"), + ] + result = metadata_tool.list_table_ids(project, dataset, mock_credentials) + assert result == ["table1", "table2"] + mock_default_auth.assert_not_called() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.get_table", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_get_table_info(mock_default_auth, mock_get_table): + """Test get_table_info tool invocation.""" + mock_credentials = mock.create_autospec(Credentials, instance=True) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + mock_get_table.return_value = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_table_info( + "my_project_id", "my_dataset_id", "my_table_id", mock_credentials + ) + assert result != { + "status": "ERROR", + "error_details": "Your default credentials were not found", + } + mock_default_auth.assert_not_called() diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index fc592c6c8..3cb8c3c4a 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -14,14 +14,20 @@ from __future__ import annotations +import os import textwrap from typing import Optional +from unittest import mock from google.adk.tools import BaseTool from google.adk.tools.bigquery import BigQueryCredentialsConfig from google.adk.tools.bigquery import BigQueryToolset from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode +from google.adk.tools.bigquery.query_tool import execute_sql +from google.auth.exceptions import DefaultCredentialsError +from google.cloud import bigquery +from google.oauth2.credentials import Credentials import pytest @@ -218,3 +224,159 @@ async def test_execute_sql_declaration_write(tool_config): - Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE". - First run "DROP TABLE", followed by "CREATE TABLE". - To insert data into a table, use "INSERT INTO" statement.""") + + +@pytest.mark.parametrize( + ("write_mode",), + [ + pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.ALLOWED, id="allowed"), + ], +) +def test_execute_sql_select_stmt(write_mode): + """Test execute_sql tool for SELECT query when writes are blocked.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + query_result = [{"num": 123}] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=write_mode) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + # Simulate the result of query_and_wait API + bq_client.query_and_wait.return_value = query_result + + # Test the tool + result = execute_sql(project, query, credentials, tool_config) + assert result == {"status": "SUCCESS", "rows": query_result} + + +@pytest.mark.parametrize( + ("query", "statement_type"), + [ + pytest.param( + "CREATE TABLE my_dataset.my_table AS SELECT 123 AS num", + "CREATE_AS_SELECT", + id="create-as-select", + ), + pytest.param( + "DROP TABLE my_dataset.my_table", + "DROP_TABLE", + id="drop-table", + ), + ], +) +def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): + """Test execute_sql tool for non-SELECT query when writes are blocked.""" + project = "my_project" + query_result = [] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + # Simulate the result of query_and_wait API + bq_client.query_and_wait.return_value = query_result + + # Test the tool + result = execute_sql(project, query, credentials, tool_config) + assert result == {"status": "SUCCESS", "rows": query_result} + + +@pytest.mark.parametrize( + ("query", "statement_type"), + [ + pytest.param( + "CREATE TABLE my_dataset.my_table AS SELECT 123 AS num", + "CREATE_AS_SELECT", + id="create-as-select", + ), + pytest.param( + "DROP TABLE my_dataset.my_table", + "DROP_TABLE", + id="drop-table", + ), + ], +) +def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): + """Test execute_sql tool for non-SELECT query when writes are blocked.""" + project = "my_project" + query_result = [] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=WriteMode.BLOCKED) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + # Simulate the result of query_and_wait API + bq_client.query_and_wait.return_value = query_result + + # Test the tool + result = execute_sql(project, query, credentials, tool_config) + assert result == { + "status": "ERROR", + "error_details": "Read-only mode only supports SELECT statements.", + } + + +@pytest.mark.parametrize( + ("write_mode",), + [ + pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.ALLOWED, id="allowed"), + ], +) +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True) +@mock.patch("google.cloud.bigquery.Client.query", autospec=True) +@mock.patch("google.auth.default", autospec=True) +def test_execute_sql_no_default_auth( + mock_default_auth, mock_query, mock_query_and_wait, write_mode +): + """Test execute_sql tool invocation does not involve calling default auth.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + query_result = [{"num": 123}] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=write_mode) + + # Simulate the behavior of default auth - on purpose throw exception when + # the default auth is called + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + mock_query.return_value = query_job + + # Simulate the result of query_and_wait API + mock_query_and_wait.return_value = query_result + + # Test the tool worked without invoking default auth + result = execute_sql(project, query, credentials, tool_config) + assert result == {"status": "SUCCESS", "rows": query_result} + mock_default_auth.assert_not_called() diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index ea9990b9f..4129dc512 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -96,9 +96,7 @@ async def test_bigquery_toolset_tools_selective(selected_tools): ], ) @pytest.mark.asyncio -async def test_bigquery_toolset_unknown_tool_raises( - selected_tools, returned_tools -): +async def test_bigquery_toolset_unknown_tool(selected_tools, returned_tools): """Test BigQuery toolset with filter. This test verifies the behavior of the BigQuery toolset when filter is diff --git a/tests/unittests/tools/mcp_tool/__init__.py b/tests/unittests/tools/mcp_tool/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/tools/mcp_tool/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py new file mode 100644 index 000000000..448d41260 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -0,0 +1,342 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +from io import StringIO +import json +import sys +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +import pytest + +# Import real MCP classes +try: + from mcp import StdioServerParameters +except ImportError: + # Create a mock if MCP is not available + class StdioServerParameters: + + def __init__(self, command="test_command", args=None): + self.command = command + self.args = args or [] + + +class MockClientSession: + """Mock ClientSession for testing.""" + + def __init__(self): + self._read_stream = Mock() + self._write_stream = Mock() + self._read_stream._closed = False + self._write_stream._closed = False + self.initialize = AsyncMock() + + +class MockAsyncExitStack: + """Mock AsyncExitStack for testing.""" + + def __init__(self): + self.aclose = AsyncMock() + self.enter_async_context = AsyncMock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class TestMCPSessionManager: + """Test suite for MCPSessionManager class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_stdio_connection_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=5.0 + ) + + def test_init_with_stdio_server_parameters(self): + """Test initialization with StdioServerParameters (deprecated).""" + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.logger" + ) as mock_logger: + manager = MCPSessionManager(self.mock_stdio_params) + + # Should log deprecation warning + mock_logger.warning.assert_called_once() + assert "StdioServerParameters is not recommended" in str( + mock_logger.warning.call_args + ) + + # Should convert to StdioConnectionParams + assert isinstance(manager._connection_params, StdioConnectionParams) + assert manager._connection_params.server_params == self.mock_stdio_params + assert manager._connection_params.timeout == 5 + + def test_init_with_stdio_connection_params(self): + """Test initialization with StdioConnectionParams.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + assert manager._connection_params == self.mock_stdio_connection_params + assert manager._errlog == sys.stderr + assert manager._sessions == {} + + def test_init_with_sse_connection_params(self): + """Test initialization with SseConnectionParams.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", + headers={"Authorization": "Bearer token"}, + timeout=10.0, + ) + manager = MCPSessionManager(sse_params) + + assert manager._connection_params == sse_params + + def test_init_with_streamable_http_params(self): + """Test initialization with StreamableHTTPConnectionParams.""" + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", timeout=15.0 + ) + manager = MCPSessionManager(http_params) + + assert manager._connection_params == http_params + + def test_generate_session_key_stdio(self): + """Test session key generation for stdio connections.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # For stdio, headers should be ignored and return constant key + key1 = manager._generate_session_key({"Authorization": "Bearer token"}) + key2 = manager._generate_session_key(None) + + assert key1 == "stdio_session" + assert key2 == "stdio_session" + assert key1 == key2 + + def test_generate_session_key_sse(self): + """Test session key generation for SSE connections.""" + sse_params = SseConnectionParams(url="https://example.com/mcp") + manager = MCPSessionManager(sse_params) + + headers1 = {"Authorization": "Bearer token1"} + headers2 = {"Authorization": "Bearer token2"} + + key1 = manager._generate_session_key(headers1) + key2 = manager._generate_session_key(headers2) + key3 = manager._generate_session_key(headers1) + + # Different headers should generate different keys + assert key1 != key2 + # Same headers should generate same key + assert key1 == key3 + + # Should be deterministic hash + headers_json = json.dumps(headers1, sort_keys=True) + expected_hash = hashlib.md5(headers_json.encode()).hexdigest() + assert key1 == f"session_{expected_hash}" + + def test_merge_headers_stdio(self): + """Test header merging for stdio connections.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Stdio connections don't support headers + headers = manager._merge_headers({"Authorization": "Bearer token"}) + assert headers is None + + def test_merge_headers_sse(self): + """Test header merging for SSE connections.""" + base_headers = {"Content-Type": "application/json"} + sse_params = SseConnectionParams( + url="https://example.com/mcp", headers=base_headers + ) + manager = MCPSessionManager(sse_params) + + # With additional headers + additional = {"Authorization": "Bearer token"} + merged = manager._merge_headers(additional) + + expected = { + "Content-Type": "application/json", + "Authorization": "Bearer token", + } + assert merged == expected + + def test_is_session_disconnected(self): + """Test session disconnection detection.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock session + session = MockClientSession() + + # Not disconnected + assert not manager._is_session_disconnected(session) + + # Disconnected - read stream closed + session._read_stream._closed = True + assert manager._is_session_disconnected(session) + + @pytest.mark.asyncio + async def test_create_session_stdio_new(self): + """Test creating a new stdio session.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + mock_session = MockClientSession() + mock_exit_stack = MockAsyncExitStack() + + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.stdio_client" + ) as mock_stdio: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack" + ) as mock_exit_stack_class: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.ClientSession" + ) as mock_session_class: + + # Setup mocks + mock_exit_stack_class.return_value = mock_exit_stack + mock_stdio.return_value = AsyncMock() + mock_exit_stack.enter_async_context.side_effect = [ + ("read", "write"), # First call returns transports + mock_session, # Second call returns session + ] + mock_session_class.return_value = mock_session + + # Create session + session = await manager.create_session() + + # Verify session creation + assert session == mock_session + assert len(manager._sessions) == 1 + assert "stdio_session" in manager._sessions + + # Verify session was initialized + mock_session.initialize.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_reuse_existing(self): + """Test reusing an existing connected session.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock existing session + existing_session = MockClientSession() + existing_exit_stack = MockAsyncExitStack() + manager._sessions["stdio_session"] = (existing_session, existing_exit_stack) + + # Session is connected + existing_session._read_stream._closed = False + existing_session._write_stream._closed = False + + session = await manager.create_session() + + # Should reuse existing session + assert session == existing_session + assert len(manager._sessions) == 1 + + # Should not create new session + existing_session.initialize.assert_not_called() + + @pytest.mark.asyncio + async def test_close_success(self): + """Test successful cleanup of all sessions.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add mock sessions + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + + manager._sessions["session1"] = (session1, exit_stack1) + manager._sessions["session2"] = (session2, exit_stack2) + + await manager.close() + + # All sessions should be closed + exit_stack1.aclose.assert_called_once() + exit_stack2.aclose.assert_called_once() + assert len(manager._sessions) == 0 + + @pytest.mark.asyncio + async def test_close_with_errors(self): + """Test cleanup when some sessions fail to close.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Add mock sessions + session1 = MockClientSession() + exit_stack1 = MockAsyncExitStack() + exit_stack1.aclose.side_effect = Exception("Close error 1") + + session2 = MockClientSession() + exit_stack2 = MockAsyncExitStack() + + manager._sessions["session1"] = (session1, exit_stack1) + manager._sessions["session2"] = (session2, exit_stack2) + + custom_errlog = StringIO() + manager._errlog = custom_errlog + + # Should not raise exception + await manager.close() + + # Good session should still be closed + exit_stack2.aclose.assert_called_once() + assert len(manager._sessions) == 0 + + # Error should be logged + error_output = custom_errlog.getvalue() + assert "Warning: Error during MCP session cleanup" in error_output + assert "Close error 1" in error_output + + +def test_retry_on_closed_resource_decorator(): + """Test the retry_on_closed_resource decorator.""" + + call_count = 0 + + @retry_on_closed_resource + async def mock_function(self): + nonlocal call_count + call_count += 1 + if call_count == 1: + import anyio + + raise anyio.ClosedResourceError("Resource closed") + return "success" + + @pytest.mark.asyncio + async def test_retry(): + nonlocal call_count + call_count = 0 + + mock_self = Mock() + result = await mock_function(mock_self) + + assert result == "success" + assert call_count == 2 # First call fails, second succeeds + + # Run the test + import asyncio + + asyncio.run(test_retry()) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py new file mode 100644 index 000000000..d25a84eac --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -0,0 +1,339 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.tool_context import ToolContext +from google.genai.types import FunctionDeclaration +import pytest + + +# Mock MCP Tool from mcp.types +class MockMCPTool: + """Mock MCP Tool for testing.""" + + def __init__(self, name="test_tool", description="Test tool description"): + self.name = name + self.description = description + self.inputSchema = { + "type": "object", + "properties": { + "param1": {"type": "string", "description": "First parameter"}, + "param2": {"type": "integer", "description": "Second parameter"}, + }, + "required": ["param1"], + } + + +class TestMCPTool: + """Test suite for MCPTool class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_mcp_tool = MockMCPTool() + self.mock_session_manager = Mock(spec=MCPSessionManager) + self.mock_session = AsyncMock() + self.mock_session_manager.create_session = AsyncMock( + return_value=self.mock_session + ) + + def test_init_basic(self): + """Test basic initialization without auth.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test tool description" + assert tool._mcp_tool == self.mock_mcp_tool + assert tool._mcp_session_manager == self.mock_session_manager + + def test_init_with_auth(self): + """Test initialization with authentication.""" + # Create real auth scheme instances instead of mocks + from fastapi.openapi.models import OAuth2 + + auth_scheme = OAuth2(flows={}) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + # The auth config is stored in the parent class _credentials_manager + assert tool._credentials_manager is not None + assert tool._credentials_manager._auth_config.auth_scheme == auth_scheme + assert ( + tool._credentials_manager._auth_config.raw_auth_credential + == auth_credential + ) + + def test_init_with_empty_description(self): + """Test initialization with empty description.""" + mock_tool = MockMCPTool(description=None) + tool = MCPTool( + mcp_tool=mock_tool, + mcp_session_manager=self.mock_session_manager, + ) + + assert tool.description == "" + + def test_get_declaration(self): + """Test function declaration generation.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + declaration = tool._get_declaration() + + assert isinstance(declaration, FunctionDeclaration) + assert declaration.name == "test_tool" + assert declaration.description == "Test tool description" + assert declaration.parameters is not None + + @pytest.mark.asyncio + async def test_run_async_impl_no_auth(self): + """Test running tool without authentication.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Mock the session response + expected_response = {"result": "success"} + self.mock_session.call_tool = AsyncMock(return_value=expected_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=None + ) + + assert result == expected_response + self.mock_session_manager.create_session.assert_called_once_with( + headers=None + ) + # Fix: call_tool uses 'arguments' parameter, not positional args + self.mock_session.call_tool.assert_called_once_with( + "test_tool", arguments=args + ) + + @pytest.mark.asyncio + async def test_run_async_impl_with_oauth2(self): + """Test running tool with OAuth2 authentication.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create OAuth2 credential + oauth2_auth = OAuth2Auth(access_token="test_access_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + # Mock the session response + expected_response = {"result": "success"} + self.mock_session.call_tool = AsyncMock(return_value=expected_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + assert result == expected_response + # Check that headers were passed correctly + self.mock_session_manager.create_session.assert_called_once() + call_args = self.mock_session_manager.create_session.call_args + headers = call_args[1]["headers"] + assert headers == {"Authorization": "Bearer test_access_token"} + + @pytest.mark.asyncio + async def test_get_headers_oauth2(self): + """Test header generation for OAuth2 credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + oauth2_auth = OAuth2Auth(access_token="test_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer test_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_bearer(self): + """Test header generation for HTTP Bearer credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="bearer", credentials=HttpCredentials(token="bearer_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer bearer_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_basic(self): + """Test header generation for HTTP Basic credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="basic", + credentials=HttpCredentials(username="user", password="pass"), + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should create Basic auth header with base64 encoded credentials + import base64 + + expected_encoded = base64.b64encode(b"user:pass").decode() + assert headers == {"Authorization": f"Basic {expected_encoded}"} + + @pytest.mark.asyncio + async def test_get_headers_api_key(self): + """Test header generation for API Key credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"X-API-Key": "my_api_key"} + + @pytest.mark.asyncio + async def test_get_headers_no_credential(self): + """Test header generation with no credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, None) + + assert headers is None + + @pytest.mark.asyncio + async def test_get_headers_service_account(self): + """Test header generation for service account credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create service account credential + service_account = ServiceAccount(scopes=["test"]) + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=service_account, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should return None as service account credentials are not supported for direct header generation + assert headers is None + + @pytest.mark.asyncio + async def test_run_async_impl_retry_decorator(self): + """Test that the retry decorator is applied correctly.""" + # This is more of an integration test to ensure the decorator is present + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Check that the method has the retry decorator + assert hasattr(tool._run_async_impl, "__wrapped__") + + @pytest.mark.asyncio + async def test_get_headers_http_custom_scheme(self): + """Test header generation for custom HTTP scheme.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="custom", credentials=HttpCredentials(token="custom_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "custom custom_token"} + + def test_init_validation(self): + """Test that initialization validates required parameters.""" + # This test ensures that the MCPTool properly handles its dependencies + with pytest.raises(TypeError): + MCPTool() # Missing required parameters + + with pytest.raises(TypeError): + MCPTool(mcp_tool=self.mock_mcp_tool) # Missing session manager diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py new file mode 100644 index 000000000..0ba29b1da --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -0,0 +1,269 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import StringIO +import sys +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +import pytest + +# Import the real MCP classes for proper instantiation +try: + from mcp import StdioServerParameters +except ImportError: + # Create a mock if MCP is not available + class StdioServerParameters: + + def __init__(self, command="test_command", args=None): + self.command = command + self.args = args or [] + + +class MockMCPTool: + """Mock MCP Tool for testing.""" + + def __init__(self, name, description="Test tool description"): + self.name = name + self.description = description + self.inputSchema = { + "type": "object", + "properties": {"param": {"type": "string"}}, + } + + +class MockListToolsResult: + """Mock ListToolsResult for testing.""" + + def __init__(self, tools): + self.tools = tools + + +class TestMCPToolset: + """Test suite for MCPToolset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_session_manager = Mock(spec=MCPSessionManager) + self.mock_session = AsyncMock() + self.mock_session_manager.create_session = AsyncMock( + return_value=self.mock_session + ) + + def test_init_basic(self): + """Test basic initialization with StdioServerParameters.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + + # Note: StdioServerParameters gets converted to StdioConnectionParams internally + assert toolset._errlog == sys.stderr + assert toolset._auth_scheme is None + assert toolset._auth_credential is None + + def test_init_with_stdio_connection_params(self): + """Test initialization with StdioConnectionParams.""" + stdio_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=10.0 + ) + toolset = MCPToolset(connection_params=stdio_params) + + assert toolset._connection_params == stdio_params + + def test_init_with_sse_connection_params(self): + """Test initialization with SseConnectionParams.""" + sse_params = SseConnectionParams( + url="https://example.com/mcp", headers={"Authorization": "Bearer token"} + ) + toolset = MCPToolset(connection_params=sse_params) + + assert toolset._connection_params == sse_params + + def test_init_with_streamable_http_params(self): + """Test initialization with StreamableHTTPConnectionParams.""" + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", + headers={"Content-Type": "application/json"}, + ) + toolset = MCPToolset(connection_params=http_params) + + assert toolset._connection_params == http_params + + def test_init_with_tool_filter_list(self): + """Test initialization with tool filter as list.""" + tool_filter = ["tool1", "tool2"] + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=tool_filter + ) + + # The tool filter is stored in the parent BaseToolset class + # We can verify it by checking the filtering behavior in get_tools + assert toolset._is_tool_selected is not None + + def test_init_with_auth(self): + """Test initialization with authentication.""" + # Create real auth scheme instances + from fastapi.openapi.models import OAuth2 + + auth_scheme = OAuth2(flows={}) + from google.adk.auth.auth_credential import OAuth2Auth + + auth_credential = AuthCredential( + auth_type="oauth2", + oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + ) + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + assert toolset._auth_scheme == auth_scheme + assert toolset._auth_credential == auth_credential + + def test_init_missing_connection_params(self): + """Test initialization with missing connection params raises error.""" + with pytest.raises(ValueError, match="Missing connection params"): + MCPToolset(connection_params=None) + + @pytest.mark.asyncio + async def test_get_tools_basic(self): + """Test getting tools without filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("tool1"), + MockMCPTool("tool2"), + MockMCPTool("tool3"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 3 + for tool in tools: + assert isinstance(tool, MCPTool) + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + assert tools[2].name == "tool3" + + @pytest.mark.asyncio + async def test_get_tools_with_list_filter(self): + """Test getting tools with list-based filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("tool1"), + MockMCPTool("tool2"), + MockMCPTool("tool3"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + tool_filter = ["tool1", "tool3"] + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=tool_filter + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool3" + + @pytest.mark.asyncio + async def test_get_tools_with_function_filter(self): + """Test getting tools with function-based filtering.""" + # Mock tools from MCP server + mock_tools = [ + MockMCPTool("read_file"), + MockMCPTool("write_file"), + MockMCPTool("list_directory"), + ] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + + def file_tools_filter(tool, context): + """Filter for file-related tools only.""" + return "file" in tool.name + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, tool_filter=file_tools_filter + ) + toolset._mcp_session_manager = self.mock_session_manager + + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert tools[0].name == "read_file" + assert tools[1].name == "write_file" + + @pytest.mark.asyncio + async def test_close_success(self): + """Test successful cleanup.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + await toolset.close() + + self.mock_session_manager.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_exception(self): + """Test cleanup when session manager raises exception.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + # Mock close to raise an exception + self.mock_session_manager.close = AsyncMock( + side_effect=Exception("Cleanup error") + ) + + custom_errlog = StringIO() + toolset._errlog = custom_errlog + + # Should not raise exception + await toolset.close() + + # Should log the error + error_output = custom_errlog.getvalue() + assert "Warning: Error during MCPToolset cleanup" in error_output + assert "Cleanup error" in error_output + + @pytest.mark.asyncio + async def test_get_tools_retry_decorator(self): + """Test that get_tools has retry decorator applied.""" + toolset = MCPToolset(connection_params=self.mock_stdio_params) + + # Check that the method has the retry decorator + assert hasattr(toolset.get_tools, "__wrapped__") diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index 303dda69d..c4cbea7b9 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -14,6 +14,7 @@ import json +from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -194,7 +195,8 @@ def test_get_declaration( @patch( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request" ) - def test_call_success( + @pytest.mark.asyncio + async def test_call_success( self, mock_request, mock_tool_context, @@ -217,7 +219,7 @@ def test_call_success( ) # Call the method - result = tool.call(args={}, tool_context=mock_tool_context) + result = await tool.call(args={}, tool_context=mock_tool_context) # Check the result assert result == {"result": "success"} @@ -225,7 +227,8 @@ def test_call_success( @patch( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request" ) - def test_call_auth_pending( + @pytest.mark.asyncio + async def test_call_auth_pending( self, mock_request, sample_endpoint, @@ -246,12 +249,14 @@ def test_call_auth_pending( "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" ) as mock_from_tool_context: mock_tool_auth_handler_instance = MagicMock() - mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( - "pending" + mock_prepare_result = MagicMock() + mock_prepare_result.state = "pending" + mock_tool_auth_handler_instance.prepare_auth_credentials = AsyncMock( + return_value=mock_prepare_result ) mock_from_tool_context.return_value = mock_tool_auth_handler_instance - response = tool.call(args={}, tool_context=None) + response = await tool.call(args={}, tool_context=None) assert response == { "pending": True, "message": "Needs your authorization to access your data.", diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py index 0a3b8ccce..e405ce5b8 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py @@ -14,6 +14,7 @@ from typing import Optional from unittest.mock import MagicMock +from unittest.mock import patch from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent @@ -115,7 +116,8 @@ def openid_connect_credential(): return credential -def test_openid_connect_no_auth_response( +@pytest.mark.asyncio +async def test_openid_connect_no_auth_response( openid_connect_scheme, openid_connect_credential ): # Setup Mock exchanger @@ -131,12 +133,13 @@ def test_openid_connect_no_auth_response( credential_exchanger=mock_exchanger, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'pending' assert result.auth_credential == openid_connect_credential -def test_openid_connect_with_auth_response( +@pytest.mark.asyncio +async def test_openid_connect_with_auth_response( openid_connect_scheme, openid_connect_credential, monkeypatch ): mock_exchanger = MockOpenIdConnectCredentialExchanger( @@ -147,10 +150,11 @@ def test_openid_connect_with_auth_response( tool_context = create_mock_tool_context() mock_auth_handler = MagicMock() - mock_auth_handler.get_auth_response.return_value = AuthCredential( + returned_credentail = AuthCredential( auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'), ) + mock_auth_handler.get_auth_response.return_value = returned_credentail mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler' monkeypatch.setattr( mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler @@ -164,7 +168,7 @@ def test_openid_connect_with_auth_response( credential_exchanger=mock_exchanger, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'done' assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP assert 'test_access_token' in result.auth_credential.http.credentials.token @@ -172,11 +176,12 @@ def test_openid_connect_with_auth_response( stored_credential = credential_store.get_credential( openid_connect_scheme, openid_connect_credential ) - assert stored_credential == result.auth_credential + assert stored_credential == returned_credentail mock_auth_handler.get_auth_response.assert_called_once() -def test_openid_connect_existing_token( +@pytest.mark.asyncio +async def test_openid_connect_existing_token( openid_connect_scheme, openid_connect_credential ): _, existing_credential = token_to_scheme_credential( @@ -196,6 +201,77 @@ def test_openid_connect_existing_token( openid_connect_credential, credential_store=credential_store, ) - result = handler.prepare_auth_credentials() + result = await handler.prepare_auth_credentials() assert result.state == 'done' assert result.auth_credential == existing_credential + + +@patch( + 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialRefresher' +) +@pytest.mark.asyncio +async def test_openid_connect_existing_oauth2_token_refresh( + mock_oauth2_refresher, openid_connect_scheme, openid_connect_credential +): + """Test that OAuth2 tokens are refreshed when existing credentials are found.""" + # Create existing OAuth2 credential + existing_credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id='test_client_id', + client_secret='test_client_secret', + access_token='existing_token', + refresh_token='refresh_token', + ), + ) + + # Mock the refreshed credential + refreshed_credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id='test_client_id', + client_secret='test_client_secret', + access_token='refreshed_token', + refresh_token='new_refresh_token', + ), + ) + + # Setup mock OAuth2CredentialRefresher + from unittest.mock import AsyncMock + + mock_refresher_instance = MagicMock() + mock_refresher_instance.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher_instance.refresh = AsyncMock(return_value=refreshed_credential) + mock_oauth2_refresher.return_value = mock_refresher_instance + + tool_context = create_mock_tool_context() + credential_store = ToolContextCredentialStore(tool_context=tool_context) + + # Store the existing credential + key = credential_store.get_credential_key( + openid_connect_scheme, openid_connect_credential + ) + credential_store.store_credential(key, existing_credential) + + handler = ToolAuthHandler( + tool_context, + openid_connect_scheme, + openid_connect_credential, + credential_store=credential_store, + ) + + result = await handler.prepare_auth_credentials() + + # Verify OAuth2CredentialRefresher was called for refresh + mock_oauth2_refresher.assert_called_once() + + mock_refresher_instance.is_refresh_needed.assert_called_once_with( + existing_credential + ) + mock_refresher_instance.refresh.assert_called_once_with( + existing_credential, openid_connect_scheme + ) + + assert result.state == 'done' + # The result should contain the refreshed credential after exchange + assert result.auth_credential is not None diff --git a/tests/unittests/tools/test_authenticated_function_tool.py b/tests/unittests/tools/test_authenticated_function_tool.py new file mode 100644 index 000000000..88454032a --- /dev/null +++ b/tests/unittests/tools/test_authenticated_function_tool.py @@ -0,0 +1,541 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.authenticated_function_tool import AuthenticatedFunctionTool +from google.adk.tools.tool_context import ToolContext +import pytest + +# Test functions for different scenarios + + +def sync_function_no_credential(arg1: str, arg2: int) -> str: + """Test sync function without credential parameter.""" + return f"sync_result_{arg1}_{arg2}" + + +async def async_function_no_credential(arg1: str, arg2: int) -> str: + """Test async function without credential parameter.""" + return f"async_result_{arg1}_{arg2}" + + +def sync_function_with_credential(arg1: str, credential: AuthCredential) -> str: + """Test sync function with credential parameter.""" + return f"sync_cred_result_{arg1}_{credential.auth_type.value}" + + +async def async_function_with_credential( + arg1: str, credential: AuthCredential +) -> str: + """Test async function with credential parameter.""" + return f"async_cred_result_{arg1}_{credential.auth_type.value}" + + +def sync_function_with_tool_context( + arg1: str, tool_context: ToolContext +) -> str: + """Test sync function with tool_context parameter.""" + return f"sync_context_result_{arg1}" + + +async def async_function_with_both( + arg1: str, tool_context: ToolContext, credential: AuthCredential +) -> str: + """Test async function with both tool_context and credential parameters.""" + return f"async_both_result_{arg1}_{credential.auth_type.value}" + + +def function_with_optional_args( + arg1: str, arg2: str = "default", credential: AuthCredential = None +) -> str: + """Test function with optional arguments.""" + cred_type = credential.auth_type.value if credential else "none" + return f"optional_result_{arg1}_{arg2}_{cred_type}" + + +class MockCallable: + """Test callable class for testing.""" + + def __init__(self): + self.__name__ = "MockCallable" + self.__doc__ = "Test callable documentation" + + def __call__(self, arg1: str, credential: AuthCredential) -> str: + return f"callable_result_{arg1}_{credential.auth_type.value}" + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + # Create a mock auth_type that returns the expected value + mock_auth_type = Mock() + mock_auth_type.value = "oauth2" + credential.auth_type = mock_auth_type + return credential + + +class TestAuthenticatedFunctionTool: + """Test suite for AuthenticatedFunctionTool.""" + + def test_init_with_sync_function(self): + """Test initialization with synchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, + auth_config=auth_config, + response_for_auth_required="Please authenticate", + ) + + assert tool.name == "sync_function_no_credential" + assert ( + tool.description == "Test sync function without credential parameter." + ) + assert tool.func == sync_function_no_credential + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == "Please authenticate" + assert "credential" in tool._ignore_params + + def test_init_with_async_function(self): + """Test initialization with asynchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=async_function_no_credential, auth_config=auth_config + ) + + assert tool.name == "async_function_no_credential" + assert ( + tool.description == "Test async function without credential parameter." + ) + assert tool.func == async_function_no_credential + assert tool._response_for_auth_required is None + + def test_init_with_callable(self): + """Test initialization with callable object.""" + auth_config = _create_mock_auth_config() + test_callable = MockCallable() + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + + assert tool.name == "MockCallable" + assert tool.description == "Test callable documentation" + assert tool.func == test_callable + + def test_init_no_auth_config(self): + """Test initialization without auth_config.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + + assert tool._credentials_manager is None + + @pytest.mark.asyncio + async def test_run_async_sync_function_no_credential_manager(self): + """Test run_async with sync function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_async_function_no_credential_manager(self): + """Test run_async with async function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=async_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "async_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"sync_cred_result_test_{credential.auth_type.value}" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_async_function_with_credential(self): + """Test run_async with async function that expects credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_cred_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, + auth_config=auth_config, + response_for_auth_required="Custom auth required", + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Custom auth required" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_default_message(self): + """Test run_async when no credential is available with default message.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + + @pytest.mark.asyncio + async def test_run_async_function_without_credential_param(self): + """Test run_async with function that doesn't have credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Credential should not be passed to function since it doesn't have the parameter + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_function_with_tool_context(self): + """Test run_async with function that has tool_context parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_tool_context, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_context_result_test" + + @pytest.mark.asyncio + async def test_run_async_function_with_both_params(self): + """Test run_async with function that has both tool_context and credential parameters.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_both, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_both_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_function_with_optional_credential(self): + """Test run_async with function that has optional credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=function_with_optional_args, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert ( + result == f"optional_result_test_default_{credential.auth_type.value}" + ) + + @pytest.mark.asyncio + async def test_run_async_callable_object(self): + """Test run_async with callable object.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + test_callable = MockCallable() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"callable_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_propagates_function_exception(self): + """Test that run_async propagates exceptions from the wrapped function.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + def failing_function(arg1: str, credential: AuthCredential) -> str: + raise ValueError("Function failed") + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=failing_function, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(ValueError, match="Function failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_missing_required_args(self): + """Test run_async with missing required arguments.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} # Missing arg2 + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Should return error dict indicating missing parameters + assert isinstance(result, dict) + assert "error" in result + assert "arg2" in result["error"] + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_credential_in_ignore_params(self): + """Test that 'credential' is added to ignore_params during initialization.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + assert "credential" in tool._ignore_params + + @pytest.mark.asyncio + async def test_run_async_with_none_credential(self): + """Test run_async when credential is None but function expects it.""" + tool = AuthenticatedFunctionTool(func=function_with_optional_args) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "optional_result_test_default_none" + + def test_signature_inspection(self): + """Test that the tool correctly inspects function signatures.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + signature = inspect.signature(tool.func) + assert "credential" in signature.parameters + assert "arg1" in signature.parameters + + @pytest.mark.asyncio + async def test_args_to_call_modification(self): + """Test that args_to_call is properly modified with credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + # Create a spy function to check what arguments are passed + original_args = {} + + def spy_function(arg1: str, credential: AuthCredential) -> str: + nonlocal original_args + original_args = {"arg1": arg1, "credential": credential} + return "spy_result" + + tool = AuthenticatedFunctionTool(func=spy_function, auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "spy_result" + assert original_args is not None + assert original_args["arg1"] == "test" + assert original_args["credential"] == credential diff --git a/tests/unittests/tools/test_base_authenticated_tool.py b/tests/unittests/tools/test_base_authenticated_tool.py new file mode 100644 index 000000000..55454224d --- /dev/null +++ b/tests/unittests/tools/test_base_authenticated_tool.py @@ -0,0 +1,343 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.base_authenticated_tool import BaseAuthenticatedTool +from google.adk.tools.tool_context import ToolContext +import pytest + + +class _TestAuthenticatedTool(BaseAuthenticatedTool): + """Test implementation of BaseAuthenticatedTool for testing purposes.""" + + def __init__( + self, + name="test_auth_tool", + description="Test authenticated tool", + auth_config=None, + unauthenticated_response=None, + ): + super().__init__( + name=name, + description=description, + auth_config=auth_config, + response_for_auth_required=unauthenticated_response, + ) + self.run_impl_called = False + self.run_impl_result = "test_result" + + async def _run_async_impl(self, *, args, tool_context, credential): + """Test implementation of the abstract method.""" + self.run_impl_called = True + self.last_args = args + self.last_tool_context = tool_context + self.last_credential = credential + return self.run_impl_result + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + credential.auth_type = AuthCredentialTypes.OAUTH2 + return credential + + +class TestBaseAuthenticatedTool: + """Test suite for BaseAuthenticatedTool.""" + + def test_init_with_auth_config(self): + """Test initialization with auth_config.""" + auth_config = _create_mock_auth_config() + unauthenticated_response = {"error": "Not authenticated"} + + tool = _TestAuthenticatedTool( + name="test_tool", + description="Test description", + auth_config=auth_config, + unauthenticated_response=unauthenticated_response, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test description" + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == unauthenticated_response + + def test_init_with_no_auth_config(self): + """Test initialization without auth_config.""" + tool = _TestAuthenticatedTool() + + assert tool.name == "test_auth_tool" + assert tool.description == "Test authenticated tool" + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._credentials_manager is None + + def test_init_with_default_unauthenticated_response(self): + """Test initialization with default unauthenticated response.""" + auth_config = _create_mock_auth_config() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._response_for_auth_required is None + + @pytest.mark.asyncio + async def test_run_async_no_credentials_manager(self): + """Test run_async when no credentials manager is configured.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential is None + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential == credential + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_custom_response(self): + """Test run_async when no credential is available with custom response.""" + auth_config = _create_mock_auth_config() + custom_response = { + "status": "authentication_required", + "message": "Please login", + } + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_string_response(self): + """Test run_async when no credential is available with string response.""" + auth_config = _create_mock_auth_config() + custom_response = "Custom authentication required message" + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + + @pytest.mark.asyncio + async def test_run_async_propagates_impl_exception(self): + """Test that run_async propagates exceptions from _run_async_impl.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + # Make the implementation raise an exception + async def failing_impl(*, args, tool_context, credential): + raise ValueError("Implementation failed") + + tool._run_async_impl = failing_impl + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(ValueError, match="Implementation failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_with_different_args_types(self): + """Test run_async with different argument types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + + # Test with empty args + result = await tool.run_async(args={}, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == {} + + # Test with complex args + complex_args = { + "string_param": "test", + "number_param": 42, + "list_param": [1, 2, 3], + "dict_param": {"nested": "value"}, + } + result = await tool.run_async(args=complex_args, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == complex_args + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_abstract_nature(self): + """Test that BaseAuthenticatedTool cannot be instantiated directly.""" + with pytest.raises(TypeError): + # This should fail because _run_async_impl is abstract + BaseAuthenticatedTool(name="test", description="test") + + @pytest.mark.asyncio + async def test_run_async_return_values(self): + """Test run_async with different return value types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {} + + # Test with None return + tool.run_impl_result = None + result = await tool.run_async(args=args, tool_context=tool_context) + assert result is None + + # Test with dict return + tool.run_impl_result = {"key": "value"} + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == {"key": "value"} + + # Test with list return + tool.run_impl_result = [1, 2, 3] + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == [1, 2, 3] diff --git a/tests/unittests/utils/test_feature_decorator.py b/tests/unittests/utils/test_feature_decorator.py index aa03fc746..e2f16446a 100644 --- a/tests/unittests/utils/test_feature_decorator.py +++ b/tests/unittests/utils/test_feature_decorator.py @@ -1,3 +1,5 @@ +import os +import tempfile import warnings from google.adk.utils.feature_decorator import experimental @@ -11,25 +13,201 @@ def run(self): return "running" +@working_in_progress("function not ready") +def wip_function(): + return "executing" + + @experimental("api may have breaking change in the future.") def experimental_fn(): return "executing" -def test_working_in_progress_class_warns(): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") +@experimental("class may change") +class ExperimentalClass: + + def run(self): + return "running experimental" + + +# Test classes/functions for new usage patterns +@experimental +class ExperimentalClassNoParens: + + def run(self): + return "running experimental without parens" + +@experimental() +class ExperimentalClassEmptyParens: + + def run(self): + return "running experimental with empty parens" + + +@experimental +def experimental_fn_no_parens(): + return "executing without parens" + + +@experimental() +def experimental_fn_empty_parens(): + return "executing with empty parens" + + +def test_working_in_progress_class_raises_error(): + """Test that WIP class raises RuntimeError by default.""" + # Ensure environment variable is not set + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + try: feature = IncompleteFeature() + assert False, "Expected RuntimeError to be raised" + except RuntimeError as e: + assert "[WIP] IncompleteFeature:" in str(e) + assert "don't use yet" in str(e) - assert feature.run() == "running" - assert len(w) == 1 - assert issubclass(w[0].category, UserWarning) - assert "[WIP] IncompleteFeature:" in str(w[0].message) - assert "don't use yet" in str(w[0].message) +def test_working_in_progress_function_raises_error(): + """Test that WIP function raises RuntimeError by default.""" + # Ensure environment variable is not set + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + try: + result = wip_function() + assert False, "Expected RuntimeError to be raised" + except RuntimeError as e: + assert "[WIP] wip_function:" in str(e) + assert "function not ready" in str(e) + + +def test_working_in_progress_class_bypassed_with_env_var(): + """Test that WIP class works without warnings when env var is set.""" + # Set the bypass environment variable + os.environ["ADK_ALLOW_WIP_FEATURES"] = "true" + + try: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + feature = IncompleteFeature() + result = feature.run() + + assert result == "running" + # Should have no warnings when bypassed + assert len(w) == 0 + finally: + # Clean up environment variable + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_working_in_progress_function_bypassed_with_env_var(): + """Test that WIP function works without warnings when env var is set.""" + # Set the bypass environment variable + os.environ["ADK_ALLOW_WIP_FEATURES"] = "true" + + try: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + result = wip_function() + + assert result == "executing" + # Should have no warnings when bypassed + assert len(w) == 0 + finally: + # Clean up environment variable + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_working_in_progress_env_var_case_insensitive(): + """Test that WIP bypass works with different case values.""" + test_cases = ["true", "True", "TRUE", "tRuE"] + + for case in test_cases: + os.environ["ADK_ALLOW_WIP_FEATURES"] = case + + try: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + result = wip_function() + + assert result == "executing" + assert len(w) == 0 + finally: + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_working_in_progress_env_var_false_values(): + """Test that WIP still raises errors with false-like env var values.""" + false_values = ["false", "False", "FALSE", "0", "", "anything_else"] + + for false_val in false_values: + os.environ["ADK_ALLOW_WIP_FEATURES"] = false_val + + try: + result = wip_function() + assert False, f"Expected RuntimeError with env var '{false_val}'" + except RuntimeError as e: + assert "[WIP] wip_function:" in str(e) + finally: + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_working_in_progress_loads_from_dotenv_file(): + """Test that WIP decorator can load environment variables from .env file.""" + # Skip test if dotenv is not available + try: + from dotenv import load_dotenv + except ImportError: + import pytest + + pytest.skip("python-dotenv not available") + + # Ensure environment variable is not set in os.environ + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] -def test_experimental_method_warns(): + # Create a temporary .env file in current directory + dotenv_path = ".env.test" + + try: + # Write the env file + with open(dotenv_path, "w") as f: + f.write("ADK_ALLOW_WIP_FEATURES=true\n") + + # Load the environment variables from the file + load_dotenv(dotenv_path) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # This should work because the .env file contains ADK_ALLOW_WIP_FEATURES=true + result = wip_function() + + assert result == "executing" + # Should have no warnings when bypassed via .env file + assert len(w) == 0 + + finally: + # Clean up + try: + os.unlink(dotenv_path) + except FileNotFoundError: + pass + if "ADK_ALLOW_WIP_FEATURES" in os.environ: + del os.environ["ADK_ALLOW_WIP_FEATURES"] + + +def test_experimental_function_warns(): + """Test that experimental function shows warnings (unchanged behavior).""" with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -40,3 +218,84 @@ def test_experimental_method_warns(): assert issubclass(w[0].category, UserWarning) assert "[EXPERIMENTAL] experimental_fn:" in str(w[0].message) assert "breaking change in the future" in str(w[0].message) + + +def test_experimental_class_warns(): + """Test that experimental class shows warnings (unchanged behavior).""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + exp_class = ExperimentalClass() + result = exp_class.run() + + assert result == "running experimental" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] ExperimentalClass:" in str(w[0].message) + assert "class may change" in str(w[0].message) + + +def test_experimental_class_no_parens_warns(): + """Test that experimental class without parentheses shows default warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + exp_class = ExperimentalClassNoParens() + result = exp_class.run() + + assert result == "running experimental without parens" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] ExperimentalClassNoParens:" in str(w[0].message) + assert "This feature is experimental and may change or be removed" in str( + w[0].message + ) + + +def test_experimental_class_empty_parens_warns(): + """Test that experimental class with empty parentheses shows default warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + exp_class = ExperimentalClassEmptyParens() + result = exp_class.run() + + assert result == "running experimental with empty parens" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] ExperimentalClassEmptyParens:" in str(w[0].message) + assert "This feature is experimental and may change or be removed" in str( + w[0].message + ) + + +def test_experimental_function_no_parens_warns(): + """Test that experimental function without parentheses shows default warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + result = experimental_fn_no_parens() + + assert result == "executing without parens" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] experimental_fn_no_parens:" in str(w[0].message) + assert "This feature is experimental and may change or be removed" in str( + w[0].message + ) + + +def test_experimental_function_empty_parens_warns(): + """Test that experimental function with empty parentheses shows default warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + result = experimental_fn_empty_parens() + + assert result == "executing with empty parens" + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[EXPERIMENTAL] experimental_fn_empty_parens:" in str(w[0].message) + assert "This feature is experimental and may change or be removed" in str( + w[0].message + )