8000 Support chaining for model callbacks · aphraz/adk-python@e4317c9 · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit e4317c9

Browse files
selcukguncopybara-github
authored andcommitted
Support chaining for model callbacks
(before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 755565583
1 parent 794a70e commit e4317c9

File tree

5 files changed

+731
-23
lines changed

5 files changed

+731
-23
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,26 @@
4747

4848
logger = logging.getLogger(__name__)
4949

50-
51-
BeforeModelCallback: TypeAlias = Callable[
50+
_SingleBeforeModelCallback: TypeAlias = Callable[
5251
[CallbackContext, LlmRequest],
5352
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
5453
]
55-
AfterModelCallback: TypeAlias = Callable[
54+
55+
BeforeModelCallback: TypeAlias = Union[
56+
_SingleBeforeModelCallback,
57+
list[_SingleBeforeModelCallback],
58+
]
59+
60+
_SingleAfterModelCallback: TypeAlias = Callable[
5661
[CallbackContext, LlmResponse],
5762
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
5863
]
64+
65+
AfterModelCallback: TypeAlias = Union[
66+
_SingleAfterModelCallback,
67+
list[_SingleAfterModelCallback],
68+
]
69+
5970
BeforeToolCallback: TypeAlias = Callable[
6071
[BaseTool, dict[str, Any], ToolContext],
6172
Union[Awaitable[Optional[dict]], Optional[dict]],
@@ -174,7 +185,11 @@ class LlmAgent(BaseAgent):
174185

175186
# Callbacks - Start
176187
before_model_callback: Optional[BeforeModelCallback] = None
177-
"""Called before calling the LLM.
188+
"""Callback or list of callbacks to be called before calling the LLM.
189+
190+
When a list of callbacks is provided, the callbacks will be called in the
191+
order they are listed until a callback does not return None.
192+
178193
Args:
179194
callback_context: CallbackContext,
180195
llm_request: LlmRequest, The raw model request. Callback can mutate the
@@ -185,7 +200,10 @@ class LlmAgent(BaseAgent):
185200
skipped and the provided content will be returned to user.
186201
"""
187202
after_model_callback: Optional[AfterModelCallback] = None
188-
"""Called after calling LLM.
203+
"""Callback or list of callbacks to be called after calling the LLM.
204+
205+
When a list of callbacks is provided, the callbacks will be called in the
206+
order they are listed until a callback does not return None.
189207
190208
Args:
191209
callback_context: CallbackContext,
@@ -285,6 +303,32 @@ def canonical_tools(self) -> list[BaseTool]:
285303
"""
286304
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
287305

306+
@property
307+
def canonical_before_model_callbacks(
308+
self,
309+
) -> list[_SingleBeforeModelCallback]:
310+
"""The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback.
311+
312+
This method is only for use by Agent Development Kit.
313+
"""
314+
if not self.before_model_callback:
315+
return []
316+
if isinstance(self.before_model_callback, list):
317+
return self.before_model_callback
318+
return [self.before_model_callback]
319+
320+
@property
321+
def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]:
322+
"""The resolved self.after_model_callback field as a list of _SingleAfterModelCallback.
323+
324+
This method is only for use by Agent Development Kit.
325+
"""
326+
if not self.after_model_callback:
327+
return []
328+
if isinstance(self.after_model_callback, list):
329+
return self.after_model_callback
330+
return [self.after_model_callback]
331+
288332
@property
289333
def _llm_flow(self) -> BaseLlmFlow:
290334
if (

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ async def _receive_from_model(
193193
"""Receive data from model and process events using BaseLlmConnection."""
194194
def get_author(llm_response):
195195
"""Get the author of the event.
196-
197-
When the model returns transcription, the author is "user". Otherwise, the author is the agent.
196+
197+
When the model returns transcription, the author is "user". Otherwise, the
198+
author is the agent.
198199
"""
199200
if llm_response and llm_response.content and llm_response.content.role == "user":
200201
return "user"
@@ -509,20 +510,21 @@ async def _handle_before_model_callback(
509510
if not isinstance(agent, LlmAgent):
510511
return
511512

512-
if not agent.before_model_callback:
513+
if not agent.canonical_before_model_callbacks:
513514
return
514515

515516
callback_context = CallbackContext(
516517
invocation_context, event_actions=model_response_event.actions
517518
)
518-
before_model_callback_content = agent.before_model_callback(
519-
callback_context=callback_context, llm_request=llm_request
520-
)
521519

522-
if inspect.isawaitable(before_model_callback_content):
523-
before_model_callback_content = await before_model_callback_content
524-
525-
return before_model_callback_content
520+
for callback in agent.canonical_before_model_callbacks:
521+
before_model_callback_content = callback(
522+
callback_context=callback_context, llm_request=llm_request
523+
)
524+
if inspect.isawaitable(before_model_callback_content):
525+
before_model_callback_content = await before_model_callback_content
526+
if before_model_callback_content:
527+
return before_model_callback_content
526528

527529
async def _handle_after_model_callback(
528530
self,
@@ -536,20 +538,21 @@ async def _handle_after_model_callback(
536538
if not isinstance(agent, LlmAgent):
537539
return
538540

539-
if not agent.after_model_callback:
541+
if not agent.canonical_after_model_callbacks:
540542
return
541543

542544
callback_context = CallbackContext(
543545
invocation_context, event_actions=model_response_event.actions
544546
)
545-
after_model_callback_content = agent.after_model_callback(
546-
callback_context=callback_context, llm_response=llm_response
547-
)
548-
549-
if inspect.isawaitable(after_model_callback_content):
550-
after_model_callback_content = await after_model_callback_content
551547

552-
return after_model_callback_content
548+
for callback in agent.canonical_after_model_callbacks:
549+
after_model_callback_content = callback(
550+
callback_context=callback_context, llm_response=llm_response
551+
)
552+
if inspect.isawaitable(after_model_callback_content):
553+
after_model_callback_content = await after_model_callback_content
554+
if after_model_callback_content:
555+
return after_model_callback_content
553556

554557
def _finalize_model_response_event(
555558
self,
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any
16+
from typing import Optional
17+
18+
from google.adk.agents.callback_context import CallbackContext
19+
from google.adk.agents.llm_agent import Agent
20+
from google.adk.models import LlmRequest
21+
from google.adk.models import LlmResponse
22+
from google.genai import types
23+
from google.genai import types
24+
from pydantic import BaseModel
25+
import pytest
26+
27+
from .. import utils
28+
29+
30+
class MockAgentCallback(BaseModel):
31+
mock_response: str
32+
33+
def __call__(
34+
self,
35+
callback_context: CallbackContext,
36+
) -> types.Content:
37+
return types.Content(parts=[types.Part(text=self.mock_response)])
38+
39+
40+
class MockAsyncAgentCallback(BaseModel):
41+
mock_response: str
42+
43+
async def __call__(
44+
self,
45+
callback_context: CallbackContext,
46+
) -> types.Content:
47+
return types.Content(parts=[types.Part(text=self.mock_response)])
48+
49+
50+
def noop_callback(**kwargs) -> Optional[LlmResponse]:
51+
pass
52+
53+
54+
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
55+
pass
56+
57+
58+
@pytest.mark.asyncio
59+
async def test_before_agent_callback():
60+
responses = ['agent_response']
61+
mock_model = utils.MockModel.create(responses=responses)
62+
agent = Agent(
63+
name='root_agent',
64+
model=mock_model,
65+
before_agent_callback=MockAgentCallback(
66+
mock_response='before_agent_callback'
67+
),
68+
)
69+
70+
runner = utils.TestInMemoryRunner(agent)
71+
assert utils.simplify_events(
72+
await runner.run_async_with_new_session('test')
73+
) == [
74+
('root_agent', 'before_agent_callback'),
75+
]
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_after_agent_callback():
80+
responses = ['agent_response']
81+
mock_model = utils.MockModel.create(responses=responses)
82+
agent = Agent(
83+
name='root_agent',
84+
model=mock_model,
85+
after_agent_callback=MockAgentCallback(
86+
mock_response='after_agent_callback'
87+
),
88+
)
89+
90+
runner = utils.TestInMemoryRunner(agent)
91+
assert utils.simplify_events(
92+
await runner.run_async_with_new_session('test')
93+
) == [
94+
('root_agent', 'agent_response'),
95+
('root_agent', 'after_agent_callback'),
96+
]
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_before_agent_callback_noop():
101+
responses = ['agent_response']
102+
mock_model = utils.MockModel.create(responses=responses)
103+
agent = Agent(
104+
name='root_agent',
105+
model=mock_model,
106+
before_agent_callback=noop_callback,
107+
)
108+
109+
runner = utils.TestInMemoryRunner(agent)
110+
assert utils.simplify_events(
111+
await runner.run_async_with_new_session('test')
112+
) == [
113+
('root_agent', 'agent_response'),
114+
]
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_after_agent_callback_noop():
119+
responses = ['agent_response']
120+
mock_model = utils.MockModel.create(responses=responses)
121+
agent = Agent(
122+
name='root_agent',
123+
model=mock_model,
124+
before_agent_callback=noop_callback,
125+
)
126+
127+
runner = utils.TestInMemoryRunner(agent)
128+
assert utils.simplify_events(
129+
await runner.run_async_with_new_session('test')
130+
) == [
131+
('root_agent', 'agent_response'),
132+
]
133+
134+
135+
@pytest.mark.asyncio
136+
async def test_async_before_agent_callback():
137+
responses = ['agent_response']
138+
mock_model = utils.MockModel.create(responses=responses)
139+
agent = Agent(
140+
name='root_agent',
141+
model=mock_model,
142+
before_agent_callback=MockAsyncAgentCallback(
143+
mock_response='async_before_agent_callback'
144+
),
145+
)
146+
147+
runner = utils.TestInMemoryRunner(agent)
148+
assert utils.simplify_events(
149+
await runner.run_async_with_new_session('test')
150+
) == [
151+
('root_agent', 'async_before_agent_callback'),
152+
]
153+
154+
155+
@pytest.mark.asyncio
156+
async def test_async_after_agent_callback():
157+
responses = ['agent_response']
158+
mock_model = utils.MockModel.create(responses=responses)
159+
agent = Agent(
160+
name='root_agent',
161+
model=mock_model,
162+
after_agent_callback=MockAsyncAgentCallback(
163+
mock_response='async_after_agent_callback'
164+
),
165+
)
166+
167+
runner = utils.TestInMemoryRunner(agent)
168+
assert utils.simplify_events(
169+
await runner.run_async_with_new_session('test')
170+
) == [
171+
('root_agent', 'agent_response'),
172+
('root_agent', 'async_after_agent_callback'),
173+
]
174+
175+
176+
@pytest.mark.asyncio
177+
async def test_async_before_agent_callback_noop():
178+
responses = ['agent_response']
179+
mock_model = utils.MockModel.create(responses=responses)
180+
agent = Agent(
181+
name='root_agent',
182+
model=mock_model,
183+
before_agent_callback=async_noop_callback,
184+
)
185+
186+
runner = utils.TestInMemoryRunner(agent)
187+
assert utils.simplify_events(
188+
await runner.run_async_with_new_session('test')
189+
) == [
190+
('root_agent', 'agent_response'),
191+
]
192+
193+
194+
@pytest.mark.asyncio
862F 195+
async def test_async_after_agent_callback_noop():
196+
responses = ['agent_response']
197+
mock_model = utils.MockModel.create(responses=responses)
198+
agent = Agent(
199+
name='root_agent',
200+
model=mock_model,
201+
before_agent_callback=async_noop_callback,
202+
)
203+
204+
runner = utils.TestInMemoryRunner(agent)
205+
assert utils.simplify_events(
206+
await runner.run_async_with_new_session('test')
207+
) == [
208+
('root_agent', 'agent_response'),
209+
]

0 commit comments

Comments
 (0)
0