|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import inspect |
| 16 | +from typing import Any, Callable, Mapping, Optional |
| 17 | + |
| 18 | +from aioresponses import aioresponses |
| 19 | +from google.adk.tools import toolbox |
| 20 | +import pytest |
| 21 | + |
| 22 | +TEST_BASE_URL = "http://toolbox.example.com" |
| 23 | + |
| 24 | + |
| 25 | +@pytest.fixture |
| 26 | +def sync_client_environment(): |
| 27 | + """ |
| 28 | + Ensures a clean environment for ToolboxSyncClient class-level resources. |
| 29 | + It resets the class-level event loop and thread before the test |
| 30 | + and stops them after the test. This is crucial for test isolation |
| 31 | + due to ToolboxSyncClient's use of class-level loop/thread. |
| 32 | + """ |
| 33 | + # Save current state if any (more of a defensive measure) |
| 34 | + original_loop = getattr(toolbox.ToolboxSyncClient, "_ToolboxSyncClient__loop", None) |
| 35 | + original_thread = getattr( |
| 36 | + toolbox.ToolboxSyncClient, "_ToolboxSyncClient__thread", None |
| 37 | + ) |
| 38 | + |
| 39 | + # Force reset class state before the test. |
| 40 | + # This ensures any client created will start a new loop/thread. |
| 41 | + |
| 42 | + # Ensure no loop/thread is running from a previous misbehaving test or setup |
| 43 | + assert original_loop is None or not original_loop.is_running() |
| 44 | + assert original_thread is None or not original_thread.is_alive() |
| 45 | + |
| 46 | + toolbox.ToolboxSyncClient._ToolboxSyncClient__loop = None |
| 47 | + toolbox.ToolboxSyncClient._ToolboxSyncClient__thread = None |
| 48 | + |
| 49 | + yield |
| 50 | + |
| 51 | + # Teardown: stop the loop and join the thread created *during* the test. |
| 52 | + test_loop = getattr(toolbox.ToolboxSyncClient, "_ToolboxSyncClient__loop", None) |
| 53 | + test_thread = getattr(toolbox.ToolboxSyncClient, "_ToolboxSyncClient__thread", None) |
| 54 | + |
| 55 | + if test_loop and test_loop.is_running(): |
| 56 | + test_loop.call_soon_threadsafe(test_loop.stop) |
| 57 | + if test_thread and test_thread.is_alive(): |
| 58 | + test_thread.join(timeout=5) |
| 59 | + |
| 60 | + # Explicitly set to None to ensure a clean state for the next fixture use/test. |
| 61 | + toolbox.ToolboxSyncClient._ToolboxSyncClient__loop = None |
| 62 | + toolbox.ToolboxSyncClient._ToolboxSyncClient__thread = None |
| 63 | + |
| 64 | + |
| 65 | +@pytest.fixture |
| 66 | +def sync_client(sync_client_environment): |
| 67 | + """ |
| 68 | + Provides a ToolboxSyncClient instance within an isolated environment. |
| 69 | + The client's underlying async session is automatically closed after the test. |
| 70 | + The class-level loop/thread are managed by sync_client_environment. |
| 71 | + """ |
| 72 | + client = toolbox.ToolboxSyncClient(TEST_BASE_URL) |
| 73 | + |
| 74 | + yield client |
| 75 | + |
| 76 | + client.close() # Closes the async_client's session. |
| 77 | + # Loop/thread shutdown is handled by sync_client_environment's teardown. |
| 78 | + |
| 79 | + |
| 80 | +@pytest.fixture() |
| 81 | +def test_tool_str_schema(): |
| 82 | + return toolbox.protocol.ToolSchema( |
| 83 | + description="Test Tool with String input", |
| 84 | + parameters=[ |
| 85 | + toolbox.protocol.ParameterSchema( |
| 86 | + name="param1", type="string", description="Description of Param1" |
| 87 | + ) |
| 88 | + ], |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +@pytest.fixture() |
| 93 | +def test_tool_int_bool_schema(): |
| 94 | + return toolbox.protocol.ToolSchema( |
| 95 | + description="Test Tool with Int, Bool", |
| 96 | + parameters=[ |
| 97 | + toolbox.protocol.ParameterSchema( |
| 98 | + name="argA", type="integer", description="Argument A" |
| 99 | + ), |
| 100 | + toolbox.protocol.ParameterSchema( |
| 101 | + name="argB", type="boolean", description="Argument B" |
| 102 | + ), |
| 103 | + ], |
| 104 | + ) |
| 105 | + |
| 106 | + |
| 107 | +@pytest.fixture() |
| 108 | +def test_tool_auth_schema(): |
| 109 | + return toolbox.protocol.ToolSchema( |
| 110 | + description="Test Tool with Int,Bool+Auth", |
| 111 | + parameters=[ |
| 112 | + toolbox.protocol.ParameterSchema( |
| 113 | + name="argA", type="integer", description="Argument A" |
| 114 | + ), |
| 115 | + toolbox.protocol.ParameterSchema( |
| 116 | + name="argB", |
| 117 | + type="boolean", |
| 118 | + description="Argument B", |
| 119 | + authSources=["my-auth-service"], |
| 120 | + ), |
| 121 | + ], |
| 122 | + ) |
| 123 | + |
| 124 | + |
| 125 | +@pytest.fixture |
| 126 | +def tool_schema_minimal(): |
| 127 | + return toolbox.protocol.ToolSchema(description="Minimal Test Tool", parameters=[]) |
| 128 | + |
| 129 | + |
| 130 | +# --- Helper Functions for Mocking --- |
| 131 | +def mock_tool_load( |
| 132 | + aio_resp: aioresponses, |
| 133 | + tool_name: str, |
| 134 | + tool_schema: toolbox.protocol.ToolSchema, |
| 135 | + base_url: str = TEST_BASE_URL, |
| 136 | + server_version: str = "0.0.0", |
| 137 | + status: int = 200, |
| 138 | + callback: Optional[Callable] = None, |
| 139 | + payload_override: Optional[Any] = None, |
| 140 | +): |
| 141 | + url = f"{base_url}/api/tool/{tool_name}" |
| 142 | + payload_data = {} |
| 143 | + if payload_override is not None: |
| 144 | + payload_data = payload_override |
| 145 | + else: |
| 146 | + manifest = toolbox.protocol.ManifestSchema( |
| 147 | + serverVersion=server_version, tools={tool_name: tool_schema} |
| 148 | + ) |
| 149 | + payload_data = manifest.model_dump() |
| 150 | + aio_resp.get(url, payload=payload_data, status=status, callback=callback) |
| 151 | + |
| 152 | + |
| 153 | +def mock_toolset_load( |
| 154 | + aio_resp: aioresponses, |
| 155 | + toolset_name: str, |
| 156 | + tools_dict: Mapping[str, toolbox.protocol.ToolSchema], |
| 157 | + base_url: str = TEST_BASE_URL, |
| 158 | + server_version: str = "0.0.0", |
| 159 | + status: int = 200, |
| 160 | + callback: Optional[Callable] = None, |
| 161 | +): |
| 162 | + url_path = f"toolset/{toolset_name}" if toolset_name else "toolset/" |
| 163 | + url = f"{base_url}/api/{url_path}" |
| 164 | + manifest = toolbox.protocol.ManifestSchema( |
| 165 | + serverVersion=server_version, tools=tools_dict |
| 166 | + ) |
| 167 | + aio_resp.get(url, payload=manifest.model_dump(), status=status, callback=callback) |
| 168 | + |
| 169 | + |
| 170 | +def mock_tool_invoke( |
| 171 | + aio_resp: aioresponses, |
| 172 | + tool_name: str, |
| 173 | + base_url: str = TEST_BASE_URL, |
| 174 | + response_payload: Any = {"result": "ok"}, |
| 175 | + status: int = 200, |
| 176 | + callback: Optional[Callable] = None, |
| 177 | +): |
| 178 | + url = f"{base_url}/api/tool/{tool_name}/invoke" |
| 179 | + aio_resp.post(url, payload=response_payload, status=status, callback=callback) |
| 180 | + |
| 181 | + |
| 182 | +# --- Tests for General ToolboxSyncClient Functionality --- |
| 183 | + |
| 184 | + |
| 185 | +def test_sync_load_tool_success(aioresponses, test_tool_str_schema, sync_client): |
| 186 | + TOOL_NAME = "test_tool_sync_1" |
| 187 | + mock_tool_load(aioresponses, TOOL_NAME, test_tool_str_schema) |
| 188 | + mock_tool_invoke( |
| 189 | + aioresponses, TOOL_NAME, response_payload={"result": "sync_tool_ok"} |
| 190 | + ) |
| 191 | + |
| 192 | + loaded_tool = sync_client.load_tool(TOOL_NAME) |
| 193 | + |
| 194 | + assert callable(loaded_tool) |
| 195 | + assert isinstance(loaded_tool, toolbox.sync_tool.ToolboxSyncTool) |
| 196 | + assert loaded_tool.__name__ == TOOL_NAME |
| 197 | + assert test_tool_str_schema.description in loaded_tool.__doc__ |
| 198 | + sig = inspect.signature(loaded_tool) |
| 199 | + assert list(sig.parameters.keys()) == [ |
| 200 | + p.name for p in test_tool_str_schema.parameters |
| 201 | + ] |
| 202 | + result = loaded_tool(param1="some value") |
| 203 | + assert result == "sync_tool_ok" |
| 204 | + |
| 205 | + |
| 206 | +def test_sync_load_toolset_success( |
| 207 | + aioresponses, test_tool_str_schema, test_tool_int_bool_schema, sync_client |
| 208 | +): |
| 209 | + TOOLSET_NAME = "my_sync_toolset" |
| 210 | + TOOL1_NAME = "sync_tool1" |
| 211 | + TOOL2_NAME = "sync_tool2" |
| 212 | + tools_definition = { |
| 213 | + TOOL1_NAME: test_tool_str_schema, |
| 214 | + TOOL2_NAME: test_tool_int_bool_schema, |
| 215 | + } |
| 216 | + mock_toolset_load(aioresponses, TOOLSET_NAME, tools_definition) |
| 217 | + mock_tool_invoke( |
| 218 | + aioresponses, TOOL1_NAME, response_payload={"result": f"{TOOL1_NAME}_ok"} |
| 219 | + ) |
| 220 | + mock_tool_invoke( |
| 221 | + aioresponses, TOOL2_NAME, response_payload={"result": f"{TOOL2_NAME}_ok"} |
| 222 | + ) |
| 223 | + |
| 224 | + tools = sync_client.load_toolset(TOOLSET_NAME) |
| 225 | + |
| 226 | + assert isinstance(tools, list) |
| 227 | + assert len(tools) == len(tools_definition) |
| 228 | + assert all(isinstance(t, toolbox.sync_tool.ToolboxSyncTool) for t in tools) |
| 229 | + assert {t.__name__ for t in tools} == tools_definition.keys() |
| 230 | + tool1 = next(t for t in tools if t.__name__ == TOOL1_NAME) |
| 231 | + result1 = tool1(param1="hello") |
| 232 | + assert result1 == f"{TOOL1_NAME}_ok" |
| 233 | + |
| 234 | + |
| 235 | +def test_sync_invoke_tool_server_error(aioresponses, test_tool_str_schema, sync_client): |
| 236 | + TOOL_NAME = "sync_server_error_tool" |
| 237 | + ERROR_MESSAGE = "Simulated Server Error for Sync Client" |
| 238 | + mock_tool_load(aioresponses, TOOL_NAME, test_tool_str_schema) |
| 239 | + mock_tool_invoke( |
| 240 | + aioresponses, TOOL_NAME, response_payload={"error": ERROR_MESSAGE}, status=500 |
| 241 | + ) |
| 242 | + |
| 243 | + loaded_tool = sync_client.load_tool(TOOL_NAME) |
| 244 | + with pytest.raises(Exception, match=ERROR_MESSAGE): |
| 245 | + loaded_tool(param1="some input") |
| 246 | + |
| 247 | + |
| 248 | +def test_sync_load_tool_not_found_in_manifest( |
| 249 | + aioresponses, test_tool_str_schema, sync_client |
| 250 | +): |
| 251 | + ACTUAL_TOOL_IN_MANIFEST = "actual_tool_sync_abc" |
| 252 | + REQUESTED_TOOL_NAME = "non_existent_tool_sync_xyz" |
| 253 | + mismatched_manifest_payload = toolbox.protocol.ManifestSchema( |
| 254 | + serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str_schema} |
| 255 | + ).model_dump() |
| 256 | + mock_tool_load( |
| 257 | + aio_resp=aioresponses, |
| 258 | + tool_name=REQUESTED_TOOL_NAME, |
| 259 | + tool_schema=test_tool_str_schema, |
| 260 | + payload_override=mismatched_manifest_payload, |
| 261 | + ) |
| 262 | + |
| 263 | + with pytest.raises( |
| 264 | + Exception, |
| 265 | + match=f"Tool '{REQUESTED_TOOL_NAME}' not found!", |
| 266 | + ): |
| 267 | + sync_client.load_tool(REQUESTED_TOOL_NAME) |
| 268 | + aioresponses.assert_called_once_with( |
| 269 | + f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", |
| 270 | + method="GET", |
| 271 | + ) |
0 commit comments