10000 fix: support Callable that has __call__ as coroutine function in Func… · Jay-flow/adk-python@f67ccf3 · GitHub
[go: up one dir, main page]

Skip to content

Commit f67ccf3

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: support Callable that has __call__ as coroutine function in FunctionTool
PiperOrigin-RevId: 760913537
1 parent 5115474 commit f67ccf3

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

src/google/adk/tools/function_tool.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,17 @@ class FunctionTool(BaseTool):
3333
"""
3434

3535
def __init__(self, func: Callable[..., Any]):
36-
super().__init__(name=func.__name__, description=func.__doc__)
36+
"""Extract metadata from a callable object."""
37+
if inspect.isfunction(func) or inspect.ismethod(func):
38+
# Handle regular functions and methods
39+
name = func.__name__
40+
doc = func.__doc__ or ''
41+
else:
42+
# Handle objects with __call__ method
43+
call_method = func.__call__
44+
name = func.__class__.__name__
45+
doc = call_method.__doc__ or func.__doc__ or ''
46+
super().__init__(name=name, description=doc)
3747
self.func = func
3848

3949
@override
@@ -76,7 +86,14 @@ async def run_async(
7686
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
7787
return {'error': error_str}
7888

79-
if inspect.iscoroutinefunction(self.func):
89+
# Functions are callable objects, but not all callable objects are functions
90+
# checking coroutine function is not enough. We also need to check whether
91+
# Callable's __call__ function is a coroutine funciton
92+
if (
93+
inspect.iscoroutinefunction(self.func)
94+
or hasattr(self.func, '__call__')
95+
and inspect.iscoroutinefunction(self.func.__call__)
96+
):
8097
return await self.func(**args_to_call) or {}
8198
else:
8299
return self.func(**args_to_call) or {}

tests/unittests/tools/test_function_tool.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,29 @@ async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
3939
return arg1
4040

4141

42+
class AsyncCallableWith2ArgsAndNoToolContext:
43+
44+
async def __call__(self, arg1, arg2):
45+
assert arg1
46+
assert arg2
47+
return arg1
48+
49+
4250
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
4351
"""Function for testing with 1 arge and tool context."""
4452
assert arg1
4553
assert tool_context
4654
return arg1
4755

4856

57+
class AsyncCallableWith1ArgAndToolContext:
58+
59+
async def __call__(self, arg1, tool_context):
60+
assert arg1
61+
assert tool_context
62+
return arg1
63+
64+
4965
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
5066
"""Function for testing with 2 arge and no tool context."""
5167
assert arg1
@@ -83,6 +99,16 @@ async def test_run_async_with_tool_context_async_func():
8399
assert result == "test_value_1"
84100

85101

102+
@pytest.mark.asyncio
103+
async def test_run_async_with_tool_context_async_callable():
104+
"""Test that run_async calls the callable with tool_context when tool_context is in signature (async callable)."""
105+
106+
tool = FunctionTool(AsyncCallableWith1ArgAndToolContext())
107+
args = {"arg1": "test_value_1"}
108+
result = await tool.run_async(args=args, tool_context=MagicMock())
109+
assert result == "test_value_1"
110+
111+
86112
@pytest.mark.asyncio
87113
async def test_run_async_without_tool_context_async_func():
88114
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
@@ -92,6 +118,15 @@ async def test_run_async_without_tool_context_async_func():
92118
assert result == "test_value_1"
93119

94120

121+
@pytest.mark.asyncio
122+
async def test_run_async_without_tool_context_async_callable():
123+
"""Test that run_async calls the callable without tool_context when tool_context is not in signature (async callable)."""
124+
tool = FunctionTool(AsyncCallableWith2ArgsAndNoToolContext())
125+
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
126+
result = await tool.run_async(args=args, tool_context=MagicMock())
127+
assert result == "test_value_1"
128+
129+
95130
@pytest.mark.asyncio
96131
async def test_run_async_with_tool_context_sync_func():
97132
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""

0 commit comments

Comments
 (0)
0