8000 Support async agent and model callbacks · aphraz/adk-python@794a70e · GitHub
[go: up one dir, main page]

Skip to content

Commit 794a70e

Browse files
selcukguncopybara-github
authored andcommitted
Support async agent and model callbacks
PiperOrigin-RevId: 755542756
1 parent f96cdc6 commit 794a70e

File tree

25 files changed

+359
-105
lines changed

25 files changed

+359
-105
lines changed

src/google/adk/agents/base_agent.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any
17+
import inspect
18+
from typing import Any, Awaitable, Union
1819
from typing import AsyncGenerator
1920
from typing import Callable
2021
from typing import final
@@ -37,10 +38,15 @@
3738

3839
tracer = trace.get_tracer('gcp.vertex.agent')
3940

40-
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
41+
BeforeAgentCallback = Callable[
42+
[CallbackContext],
43+
Union[Awaitable[Optional[t 8000 ypes.Content]], Optional[types.Content]],
44+
]
4145

42-
43-
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
46+
AfterAgentCallback = Callable[
47+
[CallbackContext],
48+
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
49+
]
4450

4551

4652
class BaseAgent(BaseModel):
@@ -119,7 +125,7 @@ async def run_async(
119125
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
120126
ctx = self._create_invocation_context(parent_context)
121127

122-
if event := self.__handle_before_agent_callback(ctx):
128+
if event := await self.__handle_before_agent_callback(ctx):
123129
yield event
124130
if ctx.end_invocation:
125131
return
@@ -130,7 +136,7 @@ async def run_async(
130136
if ctx.end_invocation:
131137
return
132138

133-
if event := self.__handle_after_agent_callback(ctx):
139+
if event := await self.__handle_after_agent_callback(ctx):
134140
yield event
135141

136142
@final
@@ -230,7 +236,7 @@ def _create_invocation_context(
230236
invocation_context.branch = f'{parent_context.branch}.{self.name}'
231237
return invocation_context
232238

233-
def __handle_before_agent_callback(
239+
async def __handle_before_agent_callback(
234240
self, ctx: InvocationContext
235241
) -> Optional[Event]:
236242
"""Runs the before_agent_callback if it exists.
@@ -248,6 +254,9 @@ def __handle_before_agent_callback(
248254
callback_context=callback_context
249255
)
250256

257+
if inspect.isawaitable(before_agent_callback_content):
258+
before_agent_callback_content = await before_agent_callback_content
259+
251260
if before_agent_callback_content:
252261
ret_event = Event(
253262
invocation_id=ctx.invocation_id,
@@ -269,7 +278,7 @@ def __handle_before_agent_callback(
269278

270279
return ret_event
271280

272-
def __handle_after_agent_callback(
281+
async def __handle_after_agent_callback(
273282
self, invocation_context: InvocationContext
274283
) -> Optional[Event]:
275284
"""Runs the after_agent_callback if it exists.
@@ -287,6 +296,9 @@ def __handle_after_agent_callback(
287296
callback_context=callback_context
288297
)
289298

299+
if inspect.isawaitable(after_agent_callback_content):
300+
after_agent_callback_content = await after_agent_callback_content
301+
290302
if after_agent_callback_content or callback_context.state.has_delta():
291303
ret_event = Event(
292304
invocation_id=invocation_context.invocation_id,

src/google/adk/agents/llm_agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,12 @@
4949

5050

5151
BeforeModelCallback: TypeAlias = Callable[
52-
[CallbackContext, LlmRequest], Optional[LlmResponse]
52+
[CallbackContext, LlmRequest],
53+
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
5354
]
5455
AfterModelCallback: TypeAlias = Callable[
5556
[CallbackContext, LlmResponse],
56-
Optional[LlmResponse],
57+
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
5758
]
5859
BeforeToolCallback: TypeAlias = Callable[
5960
[BaseTool, dict[str, Any], ToolContext],

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

tests/integration/fixture/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

tests/integration/fixture/callback_agent/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from . import agent
15+
from . import agent

tests/integration/models/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

tests/integration/test_evalute_agent_in_fixture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from google.adk.evaluation import AgentEvaluator
2020
import pytest
2121

22+
2223
def agent_eval_artifacts_in_fixture():
2324
"""Get all agents from fixture folder."""
2425
agent_eval_artifacts = []

tests/integration/tools/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

tests/unittests/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

0 commit comments

Comments
 (0)
0