8000 chore: add base toolbox tests · liunix61/adk-python@070dc93 · GitHub
[go: up one dir, main page]

Skip to content

Commit 070dc93

Browse files
chore: add base toolbox tests
1 parent a4a0859 commit 070dc93

File tree

2 files changed

+272
-1
lines changed

2 files changed

+272
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ test = [
8686
"langgraph>=0.2.60", # For LangGraphAgent
8787
"litellm>=1.63.11", # For LiteLLM tests
8888
"llama-index-readers-file>=0.4.0", # For retrieval tests
89-
89+
"pytest-aioresponses>=0.3.0", # For MCP Toolbox tests
9090
"pytest-asyncio>=0.25.0",
9191
"pytest-mock>=3.14.0",
9292
"pytest-xdist>=3.6.1",

tests/unittests/tools/test_toolbox.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Any, Callable, Mapping, Optional
17+
18+
from aioresponses import aioresponses
19+
from google.adk.tools import toolbox
20+
import pytest
21+
22+
TEST_BASE_URL = "http://toolbox.example.com"
23+
24+
25+
@pytest.fixture
26+
def sync_client_environment():
27+
"""
28+
Ensures a clean environment for ToolboxSyncClient class-level resources.
29+
It resets the class-level event loop and thread before the test
30+
and stops them after the test. This is crucial for test isolation
31+
due to ToolboxSyncClient's use of class-level loop/thread.
32+
"""
33+
# Save current state if any (more of a defensive measure)
34+
original_loop = getattr(toolbox.ToolboxSyncClient, "_ToolboxSyncClient__loop", None)
35+
original_thread = getattr(
36+
toolbox.ToolboxSyncClient, "_ToolboxSyncClient__thread", None
37+
)
38+
39+
# Force reset class state before the test.
40+
# This ensures any client created will start a new loop/thread.
41+
42+
# Ensure no loop/thread is running from a previous misbehaving test or setup
43+
assert original_loop is None or not original_loop.is_running()
44+
assert original_thread is None or not original_thread.is_alive()
45+
46+
toolbox.ToolboxSyncClient._ToolboxSyncClient__loop = None
47+
toolbox.ToolboxSyncClient._ToolboxSyncClient__thread = None
48+
49+
yield
50+
51+
# Teardown: stop the loop and join the thread created *during* the test.
52+
test_loop = getattr(toolbox.ToolboxSyncClient, "_ToolboxSyncClient__loop", None)
53+
test_thread = getattr(toolbox.ToolboxSyncClient, "_ToolboxSyncClient__thread", None)
54+
55+
if test_loop and test_loop.is_running():
56+
test_loop.call_soon_threadsafe(test_loop.stop)
57+
if test_thread and test_thread.is_alive():
58+
test_thread.join(timeout=5)
59+
60+
# Explicitly set to None to ensure a clean state for the next fixture use/test.
61+
toolbox.ToolboxSyncClient._ToolboxSyncClient__loop = None
62+
toolbox.ToolboxSyncClient._ToolboxSyncClient__thread = None
63+
64+
65+
@pytest.fixture
66+
def sync_client(sync_client_environment):
67+
"""
68+
Provides a ToolboxSyncClient instance within an isolated environment.
69+
The client's underlying async session is automatically closed after the test.
70+
The class-level loop/thread are managed by sync_client_environment.
71+
"""
72+
client = toolbox.ToolboxSyncClient(TEST_BASE_URL)
73+
74+
yield client
75+
76+
client.close() # Closes the async_client's session.
77+
# Loop/thread shutdown is handled by sync_client_environment's teardown.
78+
79+
80+
@pytest.fixture()
81+
def test_tool_str_schema():
82+
return toolbox.protocol.ToolSchema(
83+
description="Test Tool with String input",
84+
parameters=[
85+
toolbox.protocol.ParameterSchema(
86+
name="param1", type="string", description="Description of Param1"
87+
)
88+
],
89+
)
90+
91+
92+
@pytest.fixture()
93+
def test_tool_int_bool_schema():
94+
return toolbox.protocol.ToolSchema(
95+
description="Test Tool with Int, Bool",
96+
parameters=[
97+
toolbox.protocol.ParameterSchema(
98+
name="argA", type="integer", description="Argument A"
99+
),
100+
toolbox.protocol.ParameterSchema(
101+
name="argB", type="boolean", description="Argument B"
102+
),
103+
],
104+
)
105+
106+
107+
@pytest.fixture()
108+
def test_tool_auth_schema():
109+
return toolbox.protocol.ToolSchema(
110+
description="Test Tool with Int,Bool+Auth",
111+
parameters=[
112+
toolbox.protocol.ParameterSchema(
113+
name="argA", type="integer", description="Argument A"
114+
),
115+
toolbox.protocol.ParameterSchema(
116+
name="argB",
117+
type="boolean",
118+
description="Argument B",
119+
authSources=["my-auth-service"],
120+
),
121+
],
122+
)
123+
124+
125+
@pytest.fixture
126+
def tool_schema_minimal():
127+
return toolbox.protocol.ToolSchema(description="Minimal Test Tool", parameters=[])
128+
129+
130+
# --- Helper Functions for Mocking ---
131+
def mock_tool_load(
132+
aio_resp: aioresponses,
133+
tool_name: str,
134+
tool_schema: toolbox.protocol.ToolSchema,
135+
base_url: str = TEST_BASE_URL,
136+
server_version: str = "0.0.0",
137+
status: int = 200,
138+
callback: Optional[Callable] = None,
139+
payload_override: Optional[Any] = None,
140+
):
141+
url = f"{base_url}/api/tool/{tool_name}"
142+
payload_data = {}
143+
if payload_override is not None:
144+
payload_data = payload_override
145+
else:
146+
manifest = toolbox.protocol.ManifestSchema(
147+
serverVersion=server_version, tools={tool_name: tool_schema}
148+
)
149+
payload_data = manifest.model_dump()
150+
aio_resp.get(url, payload=payload_data, status=status, callback=callback)
151+
152+
153+
def mock_toolset_load(
154+
aio_resp: aioresponses,
155+
toolset_name: str,
156+
tools_dict: Mapping[str, toolbox.protocol.ToolSchema],
157+
base_url: str = TEST_BASE_URL,
158+
server_version: str = "0.0.0",
159+
status: int = 200,
160+
callback: Optional[Callable] = None,
161+
):
162+
url_path = f"toolset/{toolset_name}" if toolset_name else "toolset/"
163+
url = f"{base_url}/api/{url_path}"
164+
manifest = toolbox.protocol.ManifestSchema(
165+
serverVersion=server_version, tools=tools_dict
166+
)
167+
aio_resp.get(url, payload=manifest.model_dump(), status=status, callback=callback)
168+
169+
170+
def mock_tool_invoke(
171+
aio_resp: aioresponses,
172+
tool_name: str,
173+
base_url: str = TEST_BASE_URL,
174+
response_payload: Any = {"result": "ok"},
175+
status: int = 200,
176+
callback: Optional[Callable] = None,
177+
):
178+
url = f"{base_url}/api/tool/{tool_name}/invoke"
179+
aio_resp.post(url, payload=response_payload, status=status, callback=callback)
180+
181+
182+
# --- Tests for General ToolboxSyncClient Functionality ---
183+
184+
185+
def test_sync_load_tool_success(aioresponses, test_tool_str_schema, sync_client):
186+
TOOL_NAME = "test_tool_sync_1"
187+
mock_tool_load(aioresponses, TOOL_NAME, test_tool_str_schema)
188+
mock_tool_invoke(
189+
aioresponses, TOOL_NAME, response_payload={"result": "sync_tool_ok"}
190+
)
191+
192+
loaded_tool = sync_client.load_tool(TOOL_NAME)
193+
194+
assert callable(loaded_tool)
195+
assert isinstance(loaded_tool, toolbox.sync_tool.ToolboxSyncTool)
196+
assert loaded_tool.__name__ == TOOL_NAME
197+
assert test_tool_str_schema.description in loaded_tool.__doc__
198+
sig = inspect.signature(loaded_tool)
199+
assert list(sig.parameters.keys()) == [
200+
p.name for p in test_tool_str_schema.parameters
201+
]
202+
result = loaded_tool(param1="some value")
203+
assert result == "sync_tool_ok"
204+
205+
206+
def test_sync_load_toolset_success(
207+
aioresponses, test_tool_str_schema, test_tool_int_bool_schema, sync_client
208+
):
209+
TOOLSET_NAME = "my_sync_toolset"
210+
TOOL1_NAME = "sync_tool1"
211+
TOOL2_NAME = "sync_tool2"
212+
tools_definition = {
213+
TOOL1_NAME: test_tool_str_schema,
214+
TOOL2_NAME: test_tool_int_bool_schema,
215+
}
216+
mock_toolset_load(aioresponses, TOOLSET_NAME, tools_definition)
217+
mock_tool_invoke(
218+
aioresponses, TOOL1_NAME, response_payload={"result": f"{TOOL1_NAME}_ok"}
219+
)
220+
mock_tool_invoke(
221+
aioresponses, TOOL2_NAME, response_payload={"result": f"{TOOL2_NAME}_ok"}
222+
)
223+
224+
tools = sync_client.load_toolset(TOOLSET_NAME)
225+
226+
assert isinstance(tools, list)
227+
assert len(tools) == len(tools_definition)
228+
assert all(isinstance(t, toolbox.sync_tool.ToolboxSyncTool) for t in tools)
229+
assert {t.__name__ for t in tools} == tools_definition.keys()
230+
tool1 = next(t for t in tools if t.__name__ == TOOL1_NAME)
231+
result1 = tool1(param1="hello")
232+
assert result1 == f"{TOOL1_NAME}_ok"
233+
234+
235+
def test_sync_invoke_tool_server_error(aioresponses, test_tool_str_schema, sync_client):
236+
TOOL_NAME = "sync_server_error_tool"
237+
ERROR_MESSAGE = "Simulated Server Error for Sync Client"
238+
mock_tool_load(aioresponses, TOOL_NAME, test_tool_str_schema)
239+
mock_tool_invoke(
240+
aioresponses, TOOL_NAME, response_payload={"error": ERROR_MESSAGE}, status=500
241+
)
242+
243+
loaded_tool = sync_client.load_tool(TOOL_NAME)
244+
with pytest.raises(Exception, match=ERROR_MESSAGE):
245+
loaded_tool(param1="some input")
246+
247+
248+
def test_sync_load_tool_not_found_in_manifest(
249+
aioresponses, test_tool_str_schema, sync_client
250+
):
251+
ACTUAL_TOOL_IN_MANIFEST = "actual_tool_sync_abc"
252+
REQUESTED_TOOL_NAME = "non_existent_tool_sync_xyz"
253+
mismatched_manifest_payload = toolbox.protocol.ManifestSchema(
254+
serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str_schema}
255+
).model_dump()
256+
mock_tool_load(
257+
aio_resp=aioresponses,
258+
tool_name=REQUESTED_TOOL_NAME,
259+
tool_schema=test_tool_str_schema,
260+
payload_override=mismatched_manifest_payload,
261+
)
262+
263+
with pytest.raises(
264+
Exception,
265+
match=f"Tool '{REQUESTED_TOOL_NAME}' not found!",
266+
):
267+
sync_client.load_tool(REQUESTED_TOOL_NAME)
268+
aioresponses.assert_called_once_with(
269+
f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}",
270+
method="GET",
271+
)

0 commit comments

Comments
 (0)
0