10000 Currently if a model calls a FunctionTool without all the mandatory p… · jsondai/adk-python@f872577 · GitHub
[go: up one dir, main page]

Skip to content

Commit f872577

Browse files
ankursharmascopybara-github
authored andcommitted
Currently if a model calls a FunctionTool without all the mandatory parameters, the code will just break. This change basically adds the capability for the FunctionTool to identify if the model is missing required arguments, and in that case, instead of breaking the execution, it provides a error message to the model so it could fix the request and retry.
PiperOrigin-RevId: 751023475
1 parent e6109b1 commit f872577

File tree

2 files changed

+280
-0
lines changed

2 files changed

+280
-0
lines changed

src/google/adk/tools/function_tool.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ async def run_async(
5959
if 'tool_context' in signature.parameters:
6060
args_to_call['tool_context'] = tool_context
6161

62+
# Before invoking the function, we check for if the list of args passed in
63+
# has all the mandatory arguments or not.
64+
# If the check fails, then we don't invoke the tool and let the Agent know
65+
# that there was a missing a input parameter. This will basically help
66+
# the underlying model fix the issue and retry.
67+
mandatory_args = self._get_mandatory_args()
68+
missing_mandatory_args = [
69+
arg for arg in mandatory_args if arg not in args_to_call
70+
]
71+
72+
if missing_mandatory_args:
73+
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
74+
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
75+
{missing_mandatory_args_str}
76+
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
77+
return {'error': error_str}
78+
6279
if inspect.iscoroutinefunction(self.func):
6380
return await self.func(**args_to_call) or {}
6481
else:
@@ -85,3 +102,28 @@ async def _call_live(
85102
args_to_call['tool_context'] = tool_context
86103
async for item in self.func(**args_to_call):
87104
yield item
105+
106+
def _get_mandatory_args(
107+
self,
108+
) -> list[str]:
109+
"""Identifies mandatory parameters (those without default values) for a function.
110+
111+
Returns:
112+
A list of strings, where each string is the name of a mandatory parameter.
113+
"""
114+
signature = inspect.signature(self.func)
115+
mandatory_params = []
116+
117+
for name, param in signature.parameters.items():
118+
# A parameter is mandatory if:
119+
# 1. It has no default value (param.default is inspect.Parameter.empty)
120+
# 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter
121+
#
122+
# For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
123+
if param.default == inspect.Parameter.empty and param.kind not in (
124+
inspect.Parameter.VAR_POSITIONAL,
125+
inspect.Parameter.VAR_KEYWORD,
126+
):
127+
mandatory_params.append(name)
128+
129+
return mandatory_params
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import MagicMock
16+
17+
from google.adk.tools.function_tool import FunctionTool
18+
import pytest
19+
20+
21+
def function_for_testing_with_no_args():
22+
"""Function for testing with no args."""
23+
pass
24+
25+
26+
async def async_function_for_testing_with_1_arg_and_tool_context(
27+
arg1, tool_context
28+
):
29+
"""Async function for testing with 1 arge and tool context."""
30+
assert arg1
31+
assert tool_context
32+
return arg1
33+
34+
35+
async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
36+
"""Async function for testing with 2 arge and no tool context."""
37+
assert arg1
38+
assert arg2
39+
return arg1
40+
41+
42+
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
43+
"""Function for testing with 1 arge and tool context."""
44+
assert arg1
45+
assert tool_context
46+
return arg1
47+
48+
49+
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
50+
"""Function for testing with 2 arge and no tool context."""
51+
assert arg1
52+
assert arg2
53+
return arg1
54+
55+
56+
async def async_function_for_testing_with_4_arg_and_no_tool_context(
57+
arg1, arg2, arg3, arg4
58+
):
59+
"""Async function for testing with 4 args."""
60+
pass
61+
62+
63+
def function_for_testing_with_4_arg_and_no_tool_context(arg1, arg2, arg3, arg4):
64+
"""Function for testing with 4 args."""
65+
pass
66+
67+
68+
def test_init():
69+
"""Test that the FunctionTool is initialized correctly."""
70+
tool = FunctionTool(function_for_testing_with_no_args)
71+
assert tool.name == "function_for_testing_with_no_args"
72+
assert tool.description == "Function for testing with no args."
73+
assert tool.func == function_for_testing_with_no_args
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_run_async_with_tool_context_async_func():
78+
"""Test that run_async calls the function with tool_context when tool_context is in signature (async function)."""
79+
80+
tool = FunctionTool(async_function_for_testing_with_1_arg_and_tool_context)
81+
args = {"arg1": "test_value_1"}
82+
result = await tool.run_async(args=args, tool_context=MagicMock())
83+
assert result == "test_value_1"
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_run_async_without_tool_context_async_func():
88+
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
89+
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
90+
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
91+
result = await tool.run_async(args=args, tool_context=MagicMock())
92+
assert result == "test_value_1"
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_run_async_with_tool_context_sync_func():
97+
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""
98+
tool = FunctionTool(function_for_testing_with_1_arg_and_tool_context)
99+
args = {"arg1": "test_value_1"}
100+
result = await tool.run_async(args=args, tool_context=MagicMock())
101+
assert result == "test_value_1"
102+
103+
104+
@pytest.mark.asyncio
105+
async def test_run_async_without_tool_context_sync_func():
106+
"""Test that run_async calls the function without tool_context when tool_context is not in signature (synchronous function)."""
107+
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
108+
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
109+
result = await tool.run_async(args=args, tool_context=MagicMock())
110+
assert result == "test_value_1"
111+
112+
113+
@pytest.mark.asyncio
114+
async def test_run_async_1_missing_arg_sync_func():
115+
"""Test that run_async calls the function with 1 missing arg in signature (synchronous function)."""
116+
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
117+
args = {"arg1": "test_value_1"}
118+
result = await tool.run_async(args=args, tool_context=MagicMock())
119+
assert result == {
120+
"error": (
121+
"""Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
122+
arg2
123+
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
124+
)
125+
}
126+
127+
128+
@pytest.mark.asyncio
129+
async def test_run_async_1_missing_arg_async_func():
130+
"""Test that run_async calls the function with 1 missing arg in signature (async function)."""
131+
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
132+
args = {"arg2": "test_value_1"}
133+
result = await tool.run_async(args=args, tool_context=MagicMock())
134+
assert result == {
135+
"error": (
136+
"""Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
137+
arg1
138+
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
139+
)
140+
}
141+
142+
143+
@pytest.mark.asyncio
144+
async def test_run_async_3_missing_arg_sync_func():
145+
"""Test that run_async calls the function with 3 missing args in signature (synchronous function)."""
146+
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
147+
args = {"arg2": "test_value_1"}
148+
result = await tool.run_async(args=args, tool_context=MagicMock())
149+
assert result == {
150+
"error": (
151+
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
152+
arg1
153+
arg3
154+
arg4
155+
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
156+
)
157+
}
158+
159+
160+
@pytest.mark.asyncio
161+
async def test_run_async_3_missing_arg_async_func():
162+
"""Test that run_async calls the function with 3 missing args in signature (async function)."""
163+
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
164+
args = {"arg3": "test_value_1"}
165+
result = await tool.run_async(args=args, tool_context=MagicMock())
166+
assert result == {
167+
"error": (
168+
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
169+
arg1
170+
arg2
171+
arg4
172+
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
173+
)
174+
}
175+
176+
177+
@pytest.mark.asyncio
178+
async def test_run_async_missing_all_arg_sync_func():
179+
"""Test that run_async calls the function with all missing args in signature (synchronous function)."""
180+
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
181+
args = {}
182+
result = await tool.run_async(args=args, tool_context=MagicMock())
183+
assert result == {
184+
"error": (
185+
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
186+
arg1
187+
arg2
188+
arg3
189+
arg4
190+
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
191+
)
192+
}
193+
194+
195+
@pytest.mark.asyncio
196+
async def test_run_async_missing_all_arg_async_func():
197+
"""Test that run_async calls the function with all missing args in signature (async function)."""
198+
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
199+
args = {}
200+
result = await tool.run_async(args=args, tool_context=MagicMock())
201+
assert result == {
202+
"error": (
203+
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
204+
arg1
205+
arg2
206+
arg3
207+
arg4
208+
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
209+
)
210+
}
211+
212+
213+
@pytest.mark.asyncio
214+
async def test_run_async_with_optional_args_not_set_sync_func():
215+
"""Test that run_async calls the function for sync funciton with optional args not set."""
216+
217+
def func_with_optional_args(arg1, arg2=None, *, arg3, arg4=None, **kwargs):
218+
return f"{arg1},{arg3}"
219+
220+
tool = FunctionTool(func_with_optional_args)
221+
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
222+
result = await tool.run_async(args=args, tool_context=MagicMock())
223+
assert result == "test_value_1,test_value_3"
224+
225+
226+
@pytest.mark.asyncio
227+
async def test_run_async_with_optional_args_not_set_async_func():
228+
"""Test that run_async calls the function for async funciton with optional args not set."""
229+
230+
async def async_func_with_optional_args(
231+
arg1, arg2=None, *, arg3, arg4=None, **kwargs
232+
):
233+
return f"{arg1},{arg3}"
234+
235+
tool = FunctionTool(async_func_with_optional_args)
236+
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
237+
result = await tool.run_async(args=args, tool_context=MagicMock())
238+
assert result == "test_value_1,test_value_3"

0 commit comments

Comments
 (0)
0