8000 Move callback samples to dedicated directory · calvingiles/adk-python@812485f · GitHub
[go: up one dir, main page]

Skip to content

Commit 812485f

Browse files
selcukguncopybara-github
authored andcommitted
Move callback samples to dedicated directory
Sample chained callback logs can we seen running asyncio_run.py PiperOrigin-RevId: 757793406
1 parent cdb4cac commit 812485f

File tree

4 files changed

+356
-87
lines changed

4 files changed

+356
-87
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
17+
from google.adk import Agent
18+
from google.adk.planners import BuiltInPlanner
19+
from google.adk.planners import PlanReActPlanner
20+
from google.adk.tools.tool_context import ToolContext
21+
from google.genai import types
22+
23+
24+
def roll_die(sides: int, tool_context: ToolContext) -> int:
25+
"""Roll a die and return the rolled result.
26+
27+
Args:
28+
sides: The integer number of sides the die has.
29+
30+
Returns:
31+
An integer of the result of rolling the die.
32+
"""
33+
result = random.randint(1, sides)
34+
if not 'rolls' in tool_context.state:
35+
tool_context.state['rolls'] = []
36+
37+
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
38+
return result
39+
40+
41+
async def check_prime(nums: list[int]) -> str:
42+
"""Check if a given list of numbers are prime.
43+
44+
Args:
45+
nums: The list of numbers to check.
46+
47+
Returns:
48+
A str indicating which number is prime.
49+
"""
50+
primes = set()
51+
for number in nums:
52+
number = int(number)
53+
if number <= 1:
54+
continue
55+
is_prime = True
56+
for i in range(2, int(number**0.5) + 1):
57+
if number % i == 0:
58+
is_prime = False
59+
break
60+
if is_prime:
61+
primes.add(number)
62+
return (
63+
'No prime numbers found.'
64+
if not primes
65+
else f"{', '.join(str(num) for num in primes)} are prime numbers."
66+
)
67+
68+
69+
async def before_agent_callback(callback_context):
70+
print('@before_agent_callback')
71+
return None
72+
73+
74+
async def after_agent_callback(callback_context):
75+
print('@after_agent_callback')
76+
return None
77+
78+
79+
async def before_model_callback(callback_context, llm_request):
80+
print('@before_model_callback')
81+
return None
82+
83+
84+
async def after_model_callback(callback_context, llm_response):
85+
print('@after_model_callback')
86+
return None
87+
88+
89+
def after_agent_cb1(callback_context):
90+
print('@after_agent_cb1')
91+
92+
93+
def after_agent_cb2(callback_context):
94+
print('@after_agent_cb2')
95+
return types.Content(
96+
parts=[
97+
types.Part(
98+
text='(stopped) after_agent_cb2',
99+
),
100+
],
101+
)
102+
103+
104+
def after_agent_cb3(callback_context):
105+
print('@after_agent_cb3')
106+
107+
108+
def before_agent_cb1(callback_context):
109+
print('@before_agent_cb1')
110+
111+
112+
def before_agent_cb2(callback_context):
113+
print('@before_agent_cb2')
114+
115+
116+
def before_agent_cb3(callback_context):
117+
print('@before_agent_cb3')
118+
119+
120+
def before_tool_cb1(tool, args, tool_context):
121+
print('@before_tool_cb1')
122+
123+
124+
def before_tool_cb2(tool, args, tool_context):
125+
print('@before_tool_cb2')
126+
127+
128+
def before_tool_cb3(tool, args, tool_context):
129+
print('@before_tool_cb3')
130+
131+
132+
def after_tool_cb1(tool, args, tool_context, tool_response):
133+
print('@after_tool_cb1')
134+
135+
136+
def after_tool_cb2(tool, args, tool_context, tool_response):
137+
print('@after_tool_cb2')
138+
return {'test': 'after_tool_cb2', 'response': tool_response}
139+
140+
141+
def after_tool_cb3(tool, args, tool_context, tool_response):
142+
print('@after_tool_cb3')
143+
144+
145+
root_agent = Agent(
146+
model='gemini-2.0-flash-exp',
147+
name='data_processing_agent',
148+
description=(
149+
'hello world agent that can roll a dice of 8 sides and check prime'
150+
' numbers.'
151+
),
152+
instruction="""
153+
You roll dice and answer questions about the outcome of the dice rolls.
154+
You can roll dice of different sizes.
155+
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
156+
It is ok to discuss previous dice roles, and comment on the dice rolls.
157+
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
158+
You should never roll a die on your own.
159+
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
160+
You should not check prime numbers before calling the tool.
161+
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
162+
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
163+
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
164+
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
165+
3. When you respond, you must include the roll_die result from step 1.
166+
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
167+
You should not rely on the previous history on prime results.
168+
""",
169+
tools=[
170+
roll_die,
171+
check_prime,
172+
],
173+
# planner=BuiltInPlanner(
174+
# thinking_config=types.ThinkingConfig(
175+
# include_thoughts=True,
176+
# ),
177+
# ),
178+
generate_content_config=types.GenerateContentConfig(
179+
safety_settings=[
180+
types.SafetySetting( # avoid false alarm about rolling dice.
181+
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
182+
threshold=types.HarmBlockThreshold.OFF,
183+
),
184+
]
185+
),
186+
before_agent_callback=[
187+
before_agent_cb1,
188+
before_agent_cb2,
189+
before_agent_cb3,
190+
],
191+
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
192+
before_model_callback=before_model_callback,
193+
after_model_callback=after_model_callback,
194+
before_tool_callback=[before_tool_cb1, before_tool_cb2, before_tool_cb3],
195+
after_tool_callback=[after_tool_cb1, after_tool_cb2, after_tool_cb3],
196+
)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import time
17+
import warnings
18+
19+
import agent
20+
from dotenv import load_dotenv
21+
from google.adk import Runner
22+
from google.adk.agents.run_config import RunConfig
23+
from google.adk.artifacts import InMemoryArtifactService
24+
from google.adk.cli.utils import logs
25+
from google.adk.sessions import InMemorySessionService
26+
from google.adk.sessions import Session
27+
from google.genai import types
28+
29+
load_dotenv(override=True)
30+
warnings.filterwarnings('ignore', category=UserWarning)
31+
logs.log_to_tmp_folder()
32+
33+
34+
async def main():
35+
app_name = 'my_app'
36+
user_id_1 = 'user1'
37+
session_service = InMemorySessionService()
38+
artifact_service = InMemoryArtifactService()
39+
runner = Runner(
40+
app_name=app_name,
41+
agent=agent.root_agent,
42+
artifact_service=artifact_service,
43+
session_service=session_service,
44+
)
45+
session_11 = session_service.create_session(
46+
app_name=app_name, user_id=user_id_1
47+
)
48+
49+
async def run_prompt(session: Session, new_message: str):
50+
content = types.Content(
51+
role='user', parts=[types.Part.from_text(text=new_message)]
52+
)
53+
print('** User says:', content.model_dump(exclude_none=True))
54+
async for event in runner.run_async(
55+
user_id=user_id_1,
56+
session_id=session.id,
57+
new_message=content,
58+
):
59+
if event.content.parts and event.content.parts[0].text:
60+
print(f'** {event.author}: {event.content.parts[0].text}')
61+
62+
async def run_prompt_bytes(session: Session, new_message: str):
63+
content = types.Content(
64+
role='user',
65+
parts=[
66+
types.Part.from_bytes(
67+
data=str.encode(new_message), mime_type='text/plain'
68+
)
69+
],
70+
)
71+
print('** User says:', content.model_dump(exclude_none=True))
72+
async for event in runner.run_async(
73+
user_id=user_id_1,
74+
session_id=session.id,
75+
new_message=content,
76+
run_config=RunConfig(save_input_blobs_as_artifacts=True),
77+
):
78+
if event.content.parts and event.content.parts[0].text:
79+
print(f'** {event.author}: {event.content.parts[0].text}')
80+
81+
start_time = time.time()
82+
print('Start time:', start_time)
83+
print('------------------------------------')
84+
await run_prompt(session_11, 'Hi')
85+
await run_prompt(session_11, 'Roll a die with 100 sides')
86+
await run_prompt(session_11, 'Roll a die again with 100 sides.')
87+
await run_prompt(session_11, 'What numbers did I got?')
88+
await run_prompt_bytes(session_11, 'Hi bytes')
89+
print(
90+
await artifact_service.list_artifact_keys(
91+
app_name=app_name, user_id=user_id_1, session_id=session_11.id
92+
)
93+
)
94+
end_time = time.time()
95+
print('------------------------------------')
96+
print('End time:', end_time)
97+
print('Total time:', end_time - start_time)
98+
99+
100+
def main_sync():
101+
app_name = 'my_app'
102+
user_id_1 = 'user1'
103+
session_service = InMemorySessionService()
104+
artifact_service = InMemoryArtifactService()
105+
runner = Runner(
106+
app_name=app_name,
107+
agent=agent.root_agent,
108+
artifact_service=artifact_service,
109+
session_service=session_service,
110+
)
111+
session_11 = session_service.create_session(
112+
app_name=app_name, user_id=user_id_1
113+
)
114+
115+
def run_prompt(session: Session, new_message: str):
116+
content = types.Content(
117+
role='user', parts=[types.Part.from_text(text=new_message)]
118+
)
119+
print('** User says:', content.model_dump(exclude_none=True))
120+
for event in runner.run(
121+
user_id=user_id_1,
122+
session_id=session.id,
123+
new_message=content,
124+
):
125+
if event.content.parts and event.content.parts[0].text:
126+
print(f'** {event.author}: {event.content.parts[0].text}')
127+
128+
start_time = time.time()
129+
print('Start time:', start_time)
130+
print('------------------------------------')
131+
run_prompt(session_11, 'Hi')
132+
run_prompt(session_11, 'Roll a die with 100 sides.')
133+
run_prompt(session_11, 'Roll a die again with 100 sides.')
134+
run_prompt(session_11, 'What numbers did I got?')
135+
end_time = time.time()
136+
print('------------------------------------')
137+
print('End time:', end_time)
138+
print('Total time:', end_time - start_time)
139+
140+
141+
if __name__ == '__main__':
142+
print('--------------ASYNC--------------------')
143+
asyncio.run(main())
144+
print('--------------SYNC--------------------')
145+
main_sync()

0 commit comments

Comments
 (0)
0