8000 Copybara import of the project: · GeekyprogrammerEJ/adk-python@27ce65f · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit 27ce65f

Browse files
AlankritVerma01copybara-github
authored andcommitted
Copybara import of the project:
-- 2173606 by Alankrit Verma <alankrit386@gmail.com>: feat(llm_flows): support async before/after tool callbacks Previously, callbacks were treated as purely synchronous, so passing an async coroutine caused “was never awaited” errors and Pydantic serialization failures. Now we detect awaitable return values from before_tool_callback and after_tool_callback, and `await` them if necessary. Fixes: google#380 -- 08ac9a1 by Alankrit Verma <alankrit386@gmail.com>: Refactor function callback handling and update type signatures - Simplify variable names in `functions.py`: always use `function_response` and `altered_function_response` - Update LlmAgent callback type aliases to support async: - Import `Awaitable` - Change `BeforeToolCallback` and `AfterToolCallback` signatures to return `Awaitable[Optional[dict]]` - Ensure `after_tool_callback` uses `await` when necessary -- fcbf574 by Alankrit Verma <alankrit386@gmail.com>: refactor: update callback type signatures to support sync and async responses COPYBARA_INTEGRATE_REVIEW=google#434 from AlankritVerma01:support-async-tool-callbacks 926b0ef PiperOrigin-RevId: 753005846
1 parent dbbeb19 commit 27ce65f

File tree

4 files changed

+150
-23
lines changed

4 files changed

+150
-23
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,4 @@ line_length = 200
147147
[tool.pytest.ini_options]
148148
testpaths = ["tests"]
149149
asyncio_default_fixture_loop_scope = "function"
150+
asyncio_mode = "auto"

src/google/adk/agents/llm_agent.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18-
from typing import Any
19-
from typing import AsyncGenerator
20-
from typing import Callable
21-
from typing import Literal
22-
from typing import Optional
23-
from typing import Union
18+
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
2419

2520
from google.genai import types
2621
from pydantic import BaseModel
@@ -62,11 +57,11 @@
6257
]
6358
BeforeToolCallback: TypeAlias = Callable[
6459
[BaseTool, dict[str, Any], ToolContext],
65-
Optional[dict],
60+
Union[Awaitable[Optional[dict]], Optional[dict]],
6661
]
6762
AfterToolCallback: TypeAlias = Callable[
6863
[BaseTool, dict[str, Any], ToolContext, dict],
69-
Optional[dict],
64+
Union[Awaitable[Optional[dict]], Optional[dict]],
7065
]
7166

7267
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]

src/google/adk/flows/llm_flows/functions.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -151,28 +151,33 @@ async def handle_function_calls_async(
151151
# do not use "args" as the variable name, because it is a reserved keyword
152152
# in python debugger.
153153
function_args = function_call.args or {}
154-
function_response = None
155-
# Calls the tool if before_tool_callback does not exist or returns None.
154+
function_response: Optional[dict] = None
155+
156+
# before_tool_callback (sync or async)
156157
if agent.before_tool_callback:
157158
function_response = agent.before_tool_callback(
158159
tool=tool, args=function_args, tool_context=tool_context
159160
)
161+
if inspect.isawaitable(function_response):
162+
function_response = await function_response
160163

161164
if not function_response:
162165
function_response = await __call_tool_async(
163166
tool, args=function_args, tool_context=tool_context
164167
)
165168

166-
# Calls after_tool_callback if it exists.
169+
# after_tool_callback (sync or async)
167170
if agent.after_tool_callback:
168-
new_response = agent.after_tool_callback(
171+
altered_function_response = agent.after_tool_callback(
169172
tool=tool,
170173
args=function_args,
171174
tool_context=tool_context,
172175
tool_response=function_response,
173176
)
174-
if 67E6 new_response:
175-
function_response = new_response
177+
if inspect.isawaitable(altered_function_response):
178+
altered_function_response = await altered_function_response
179+
if altered_function_response is not None:
180+
function_response = altered_function_response
176181

177182
if tool.is_long_running:
178183
# Allow long running function to return None to not provide function response.
@@ -223,27 +228,44 @@ async def handle_function_calls_live(
223228
# in python debugger.
224229
function_args = function_call.args or {}
225230
function_response = None
226-
# Calls the tool if before_tool_callback does not exist or returns None.
231+
# # Calls the tool if before_tool_callback does not exist or returns None.
232+
# if agent.before_tool_callback:
233+
# function_response = agent.before_tool_callback(
234+
# tool, function_args, tool_context
235+
# )
227236
if agent.before_tool_callback:
228237
function_response = agent.before_tool_callback(
229-
tool, function_args, tool_context
238+
tool=tool, args=function_args, tool_context=tool_context
230239
)
240+
if inspect.isawaitable(function_response):
241+
function_response = await function_response
231242

232243
if not function_response:
233244
function_response = await _process_function_live_helper(
234245
tool, tool_context, function_call, function_args, invocation_context
235246
)
236247

237248
# Calls after_tool_callback if it exists.
249+
# if agent.after_tool_callback:
250+
# new_response = agent.after_tool_callback(
251+
# tool,
252+
# function_args,
253+
# tool_context,
254+
# function_response,
255+
# )
256+
# if new_response:
257+
# function_response = new_response
238258
if agent.after_tool_callback:
239-
new_response = agent.after_tool_callback(
240-
tool,
241-
function_args,
242-
tool_context,
243-
function_response,
259+
altered_function_response = agent.after_tool_callback(
260+
tool=tool,
261+
args=function_args,
262+
tool_context=tool_context,
263+
tool_response=function_response,
244264
)
245-
if new_response:
246-
function_response = new_response
265+
if inspect.isawaitable(altered_function_response):
266+
altered_function_response = await altered_function_response
267+
if altered_function_response is not None:
268+
function_response = altered_function_response
247269

248270
if tool.is_long_running:
249271
# Allow async function to return None to not provide function response.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, Optional
16+
17+
import pytest
18+
19+
from google.adk.agents import Agent
20+
from google.adk.tools.function_tool import FunctionTool
21+
from google.adk.tools.tool_context import ToolContext
22+
from google.adk.flows.llm_flows.functions import handle_function_calls_async
23+
from google.adk.events.event import Event
24+
from google.genai import types
25+
26+
from ... import utils
27+
28+
29+
class AsyncBeforeToolCallback:
30+
31+
def __init__(self, mock_response: Dict[str, Any]):
32+
self.mock_response = mock_response
33+
34+
async def __call__(
35+
self,
36+
tool: FunctionTool,
37+
args: Dict[str, Any],
38+
tool_context: ToolContext,
39+
) -> Optional[Dict[str, Any]]:
40+
return self.mock_response
41+
42+
43+
class AsyncAfterToolCallback:
44+
45+
def __init__(self, mock_response: Dict[str, Any]):
46+
self.mock_response = mock_response
47+
48+
async def __call__(
49+
self,
50+
tool: FunctionTool,
51+
args: Dict[str, Any],
52+
tool_context: ToolContext,
53+
tool_response: Dict[str, Any],
54+
) -> Optional[Dict[str, Any]]:
55+
return self.mock_response
56+
57+
58+
async def invoke_tool_with_callbacks(
59+
before_cb=None, after_cb=None
60+
) -> Optional[Event]:
61+
def simple_fn(**kwargs) -> Dict[str, Any]:
62+
return {"initial": "response"}
63+
64+
tool = FunctionTool(simple_fn)
65+
model = utils.MockModel.create(responses=[])
66+
agent = Agent(
67+
name="agent",
68+
model=model,
69+
tools=[tool],
70+
before_tool_callback=before_cb,
71+
after_tool_callback=after_cb,
72+
)
73+
invocation_context = utils.create_invocation_context(
74+
agent=agent, user_content=""
75+
)
76+
# Build function call event
77+
function_call = types.FunctionCall(name=tool.name, args={})
78+
content = types.Content(parts=[types.Part(function_call=function_call)])
79+
event = Event(
80+
invocation_id=invocation_context.invocation_id,
81+
author=agent.name,
82+
content=content,
83+
)
84+
tools_dict = {tool.name: tool}
85+
return await handle_function_calls_async(
86+
invocation_context,
87+
event,
88+
tools_dict,
89+
)
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_async_before_tool_callback():
94+
mock_resp = {"test": "before_tool_callback"}
95+
before_cb = AsyncBeforeToolCallback(mock_resp)
96+
result_event = await invoke_tool_with_callbacks(before_cb=before_cb)
97+
assert result_event is not None
98+
part = result_event.content.parts[0]
99+
assert part.function_response.response == mock_resp
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_async_after_tool_callback():
104+
mock_resp = {"test": "after_tool_callback"}
105+
after_cb = AsyncAfterToolCallback(mock_resp)
106+
result_event = await invoke_tool_with_callbacks(after_cb=after_cb)
107+
assert result_event is not None
108+
part = result_event.content.parts[0]
109+
assert part.function_response.response == mock_resp

0 commit comments

Comments
 (0)
0