8000 feat(live): Support live mode of sequential agent · AIGoBig/adk-python@4188673 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4188673

Browse files
hangfeicopybara-github
authored andcommitted
feat(live): Support live mode of sequential agent
Add a `task_completed` function to the agent so when a model finished the task, it can send a signal and the program knows it can go to next agent. This cl include: * Implements the `_run_live_impl` in `sequential_agent` so it can handle live case. * Add an example for sequential agent. * Improve error message for unimplemented _run_live_impl in other agents. Note: 1. Compared to non-live case, live agents process a continuous streams of audio or video, so it doesn't have a native way to tell if it's finished and should pass to next agent or not. So we introduce a task_compelted() function so the model can call this function to signal that it's finished the task and we can move on to next agent. 2. live agents doesn't seems to be very useful or natural in parallel or loop agents so we don't implement it for now. If there is user demand, we can implement it easily using similar approach. PiperOrigin-RevId: 758315430
1 parent 39f78dc commit 4188673

File tree

7 files changed

+181
-20
lines changed

7 files changed

+181
-20
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
17+
from google.adk.agents.llm_agent import LlmAgent
18+
from google.adk.agents.sequential_agent import SequentialAgent
19+
from google.genai import types
20+
21+
22+
# --- Roll Die Sub-Agent ---
23+
def roll_die(sides: int) -> int:
24+
"""Roll a die and return the rolled result."""
25+
return random.randint(1, sides)
26+
27+
28+
roll_agent = LlmAgent(
29+
name="roll_agent",
30+
description="Handles rolling dice of different sizes.",
31+
model="gemini-2.0-flash-exp",
32+
instruction="""
33+
You are responsible for rolling dice based on the user's request.
34+
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
35+
""",
36+
tools=[roll_die],
37+
generate_content_config=types.GenerateContentConfig(
38+
safety_settings=[
39+
types.SafetySetting( # avoid false alarm about rolling dice.
40+
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
41+
threshold=types.HarmBlockThreshold.OFF,
42+
),
43+
]
44+
),
45+
)
46+
47+
48+
def check_prime(nums: list[int]) -> str:
49+
"""Check if a given list of numbers are prime."""
50+
primes = set()
51+
for number in nums:
52+
number = int(number)
53+
if number <= 1:
54+
continue
55+
is_prime = True
56+
for i in range(2, int(number**0.5) + 1):
57+
if number % i == 0:
58+
is_prime = False
59+
break
60+
if is_prime:
61+
primes.add(number)
62+
return (
63+
"No prime numbers found."
64+
if not primes
65+
else f"{', '.join(str(num) for num in primes)} are prime numbers."
66+
)
67+
68+
69+
prime_agent = LlmAgent(
70+
name="prime_agent",
71+
description="Handles checking if numbers are prime.",
72+
model="gemini-2.0-flash-exp",
73+
instruction="""
74+
You are responsible for checking whether numbers are prime.
75+
When asked to check primes, you must call the check_prime tool with a list of integers.
76+
Never attempt to determine prime numbers manually.
77+
Return the prime number results to the root agent.
78+
""",
79+
tools=[check_prime],
80+
generate_content_config=types.GenerateContentConfig(
81+
safety_settings=[
82+
types.SafetySetting( # avoid false alarm about rolling dice.
83+
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
84+
threshold=types.HarmBlockThreshold.OFF,
85+
),
86+
]
87+
),
88+
)
89+
90+
root_agent = SequentialAgent(
91+
name="code_pipeline_agent",
92+
sub_agents=[roll_agent, prime_agent],
93+
# The agents will run in the order provided: roll_agent -> prime_agent
94+
)

src/google/adk/agents/loop_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,5 @@ async def _run_async_impl(
5858
async def _run_live_impl(
5959
self, ctx: InvocationContext
6060
) -> AsyncGenerator[Event, None]:
61-
raise NotImplementedError('The behavior for run_live is not defined yet.')
61+
raise NotImplementedError('This is not supported yet for LoopAgent.')
6262
yield # AsyncGenerator requires having at least one yield statement

src/google/adk/agents/parallel_agent.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,10 @@ async def _run_async_impl(
9494
agent_runs = [agent.run_async(ctx) for agent in self.sub_agents]
9595
async for event in _merge_agent_run(agent_runs):
9696
yield event
97+
98+
@override
99+
async def _run_live_impl(
100+
self, ctx: InvocationContext
101+
) -> AsyncGenerator[Event, None]:
102+
raise NotImplementedError("This is not supported yet for ParallelAgent.")
103+
yield # AsyncGenerator requires having at least one yield statement

src/google/adk/agents/sequential_agent.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ..agents.invocation_context import InvocationContext
2424
from ..events.event import Event
2525
from .base_agent import BaseAgent
26+
from .llm_agent import LlmAgent
2627

2728

2829
class SequentialAgent(BaseAgent):
@@ -40,6 +41,36 @@ async def _run_async_impl(
4041
async def _run_live_impl(
4142
self, ctx: InvocationContext
4243
) -> AsyncGenerator[Event, None]:
44+
"""Implementation for live SequentialAgent.
45+
46+
Compared to non-live case, live agents process a continous streams of audio
47+
or video, so it doesn't have a way to tell if it's finished and should pass
48+
to next agent or not. So we introduce a task_compelted() function so the
49+
model can call this function to signal that it's finished the task and we
50+
can move on to next agent.
51+
52+
Args:
53+
ctx: The invocation context of the agent.
54+
"""
55+
# There is no way to know if it's using live during init phase so we have to init it here
56+
for sub_agent in self.sub_agents:
57+
# add tool
58+
def task_completed():
59+
"""
60+
Signals that the model has successfully completed the user's question
61+
or task.
62+
"""
63+
return "Task completion signaled."
64+
65+
if isinstance(sub_agent, LlmAgent):
66+
# Use function name to dedupe.
67+
if task_completed.__name__ not in sub_agent.tools:
68+
sub_agent.tools.append(task_completed)
69+
sub_agent.instruction += f"""If you finished the user' request
70+
according to its description, call {task_completed.__name__} function
71+
to exit so the next agents can take over. When calling this function,
72+
do not generate any text other than the function call.'"""
73+
4374
for sub_agent in self.sub_agents:
4475
async for event in sub_agent.run_live(ctx):
4576
yield event

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,18 @@ async def run_live(
135135
# cancel the tasks that belongs to the closed connection.
136136
send_task.cancel()
137137
await llm_connection.close()
138+
if (
139+
event.content
140+
and event.content.parts
141+
and event.content.parts[0].function_response
142+
and event.content.parts[0].function_response.name
143+
== 'task_completed'
144+
):
145+
# this is used for sequential agent to signal the end of the agent.
146+
await asyncio.sleep(1)
147+
# cancel the tasks that belongs to the closed connection.
148+
send_task.cancel()
149+
return
138150
finally:
139151
# Clean up
140152
if not send_task.done():
@@ -237,7 +249,7 @@ def get_author_for_event(llm_response):
237249
if (
238250
event.content
239251
and event.content.parts
240-
and event.content.parts[0].text
252+
and event.content.parts[0].inline_data is None
241253
and not event.partial
242254
):
243255
# This can be either user data or transcription data.

src/google/adk/runners.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,13 @@ async def run_live(
254254
"""Runs the agent in live mode (experimental feature).
255255
256256
Args:
257-
session: The session to use. This parameter is deprecated, please use
258-
`user_id` and `session_id` instead.
259257
user_id: The user ID for the session. Required if `session` is None.
260258
session_id: The session ID for the session. Required if `session` is
261259
None.
262260
live_request_queue: The queue for live requests.
263261
run_config: The run config for the agent.
262+
session: The session to use. This parameter is deprecated, please use
263+
`user_id` and `session_id` instead.
264264
265265
Yields:
266266
AsyncGenerator[Event, None]: An asynchronous generator that yields
@@ -302,22 +302,24 @@ async def run_live(
302302

303303
invocation_context.active_streaming_tools = {}
304304
# TODO(hangfei): switch to use canonical_tools.
305-
for tool in invocation_context.agent.tools:
306-
# replicate a LiveRequestQueue for streaming tools that relis on
307-
# LiveRequestQueue
308-
from typing import get_type_hints
309-
310-
type_hints = get_type_hints(tool)
311-
for arg_type in type_hints.values():
312-
if arg_type is LiveRequestQueue:
313-
if not invocation_context.active_streaming_tools:
314-
invocation_context.active_streaming_tools = {}
315-
active_streaming_tools = ActiveStreamingTool(
316-
stream=LiveRequestQueue()
317-
)
318-
invocation_context.active_streaming_tools[tool.__name__] = (
319-
active_streaming_tools
320-
)
305+
# for shell agents, there is no tools associated with it so we should skip.
306+
if hasattr(invocation_context.agent, 'tools'):
307+
for tool in invocation_context.agent.tools:
308+
# replicate a LiveRequestQueue for streaming tools that relis on
309+
# LiveRequestQueue
310+
from typing import get_type_hints
311+
312+
type_hints = get_type_hints(tool)
313+
for arg_type 7D43 in type_hints.values():
314+
if arg_type is LiveRequestQueue:
315+
if not invocation_context.active_streaming_tools:
316+
invocation_context.active_streaming_tools = {}
317+
active_streaming_tools = ActiveStreamingTool(
318+
stream=LiveRequestQueue()
319+
)
320+
invocation_context.active_streaming_tools[tool.__name__] = (
321+
active_streaming_tools
322+
)
321323

322324
async for event in invocation_context.agent.run_live(invocation_context):
323325
self.session_service.append_event(session=session, event=event)

0 commit comments

Comments
 (0)
0