8000 Support async instruction and global instruction provider · calvingiles/adk-python@4c4cfb7 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit 4c4cfb7

Browse files
selcukguncopybara-github
authored andcommitted
Support async instruction and global instruction provider
PiperOrigin-RevId: 757808335
1 parent 812485f commit 4c4cfb7

File tree

4 files changed

+165
-7
lines changed

4 files changed

+165
-7
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import inspect
1718
import logging
1819
from typing import (
1920
Any,
@@ -96,7 +97,9 @@
9697
list[_SingleAfterToolCallback],
9798
]
9899

99-
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
100+
InstructionProvider: TypeAlias = Callable[
101+
[ReadonlyContext], Union[str, Awaitable[str]]
102+
]
100103

101104
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
102105
ExamplesUnion = Union[list[Example], BaseExampleProvider]
@@ -302,25 +305,31 @@ def canonical_model(self) -> BaseLlm:
302305
ancestor_agent = ancestor_agent.parent_agent
303306
raise ValueError(f'No model found for {self.name}.')
304307

305-
def canonical_instruction(self, ctx: ReadonlyContext) -> str:
308+
async def canonical_instruction(self, ctx: ReadonlyContext) -> str:
306309
"""The resolved self.instruction field to construct instruction for this agent.
307310
308311
This method is only for use by Agent Development Kit.
309312
"""
310313
if isinstance(self.instruction, str):
311314
return self.instruction
312315
else:
313-
return self.instruction(ctx)
316+
instruction = self.instruction(ctx)
317+
if inspect.isawaitable(instruction):
318+
instruction = await instruction
319+
return instruction
314320

315-
def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
321+
async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
316322
"""The resolved self.instruction field to construct global instruction.
317323
318324
This method is only for use by Agent Development Kit.
319325
"""
320326
if isinstance(self.global_instruction, str):
321327
return self.global_instruction
322328
else:
323-
return self.global_instruction(ctx)
329+
global_instruction = self.global_instruction(ctx)
330+
if inspect.isawaitable(global_instruction):
331+
global_instruction = await global_instruction
332+
return global_instruction
324333

325334
async def canonical_tools(
326335
self, ctx: ReadonlyContext = None

src/google/adk/flows/llm_flows/instructions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,17 @@ async def run_async(
5353
if (
5454
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
5555
): # not empty str
56-
raw_si = root_agent.canonical_global_instruction(
56+
raw_si = await root_agent.canonical_global_instruction(
5757
ReadonlyContext(invocation_context)
5858
)
5959
si = await _populate_values(raw_si, invocation_context)
6060
llm_request.append_instructions([si])
6161

6262
# Appends agent instructions if set.
6363
if agent.instruction: # not empty str
64-
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
64+
raw_si = await agent.canonical_instruction(
65+
ReadonlyContext(invocation_context)
66+
)
6567
si = await _populate_values(raw_si, invocation_context)
6668
llm_request.append_instructions([si])
6769

tests/unittests/agents/test_llm_agent_fields.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ def _instruction_provider(ctx: ReadonlyContext) -> str:
9292
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
9393

9494

95+
def test_async_canonical_instruction():
96+
async def _instruction_provider(ctx: ReadonlyContext) -> str:
97+
return f'instruction: {ctx.state["state_var"]}'
98+
99+
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
100+
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
101+
102+
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
103+
104+
95105
def test_canonical_global_instruction_str():
96106
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
97107
ctx = _create_readonly_context(agent)
@@ -114,6 +124,21 @@ def _global_instruction_provider(ctx: ReadonlyContext) -> str:
114124
)
115125

116126

127+
def test_async_canonical_global_instruction():
128+
async def _global_instruction_provider(ctx: ReadonlyContext) -> str:
129+
return f'global instruction: {ctx.state["state_var"]}'
130+
131+
agent = LlmAgent(
132+
name='test_agent', global_instruction=_global_instruction_provider
133+
)
134+
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
135+
136+
assert (
137+
agent.canonical_global_instruction(ctx)
138+
== 'global instruction: state_value'
139+
)
140+
141+
117142
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
118143
with caplog.at_level('WARNING'):
119144

tests/unittests/flows/llm_flows/test_instructions.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,44 @@ def build_function_instruction(readonly_context: ReadonlyContext) -> str:
9292
)
9393

9494

95+
@pytest.mark.asyncio
96+
async def test_async_function_system_instruction():
97+
async def build_function_instruction(
98+
readonly_context: ReadonlyContext,
99+
) -> str:
100+
return (
101+
"This is the function agent instruction for invocation:"
102+
f" {readonly_ 10000 context.invocation_id}."
103+
)
104+
105+
request = LlmRequest(
106+
model="gemini-1.5-flash",
107+
config=types.GenerateContentConfig(system_instruction=""),
108+
)
109+
agent = Agent(
110+
model="gemini-1.5-flash",
111+
name="agent",
112+
instruction=build_function_instruction,
113+
)
114+
invocation_context = utils.create_invocation_context(agent=agent)
115+
invocation_context.session = Session(
116+
app_name="test_app",
117+
user_id="test_user",
118+
id="test_id",
119+
state={"customerId": "1234567890", "customer_int": 30},
120+
)
121+
122+
async for _ in instructions.request_processor.run_async(
123+
invocation_context,
124+
request,
125+
):
126+
pass
127+
128+
assert request.config.system_instruction == (
129+
"This is the function agent instruction for invocation: test_id."
130+
)
131+
132+
95133
@pytest.mark.asyncio
96134
async def test_global_system_instruction():
97135
sub_agent = Agent(
@@ -128,6 +166,90 @@ async def test_global_system_instruction():
128166
)
129167

130168

169+
@pytest.mark.asyncio
170+
async def test_function_global_system_instruction():
171+
def sub_agent_si(readonly_context: ReadonlyContext) -> str:
172+
return "This is the sub agent instruction."
173+
174+
def root_agent_gi(readonly_context: ReadonlyContext) -> str:
175+
return "This is the global instruction."
176+
177+
sub_agent = Agent(
178+
model="gemini-1.5-flash",
179+
name="sub_agent",
180+
instruction=sub_agent_si,
181+
)
182+
root_agent = Agent(
183+
model="gemini-1.5-flash",
184+
name="root_agent",
185+
global_instruction=root_agent_gi,
186+
sub_agents=[sub_agent],
187+
)
188+
request = LlmRequest(
189+
model="gemini-1.5-flash",
190+
config=types.GenerateContentConfig(system_instruction=""),
191+
)
192+
invocation_context = utils.create_invocation_context(agent=sub_agent)
193+
invocation_context.session = Session(
194+
app_name="test_app",
195+
user_id="test_user",
196+
id="test_id",
197+
state={"customerId": "1234567890", "customer_int": 30},
198+
)
199+
200+
async for _ in instructions.request_processor.run_async(
201+
invocation_context,
202+
request,
203+
):
204+
pass
205+
206+
assert request.config.system_instruction == (
207+
"This is the global instruction.\n\nThis is the sub agent instruction."
208+
)
209+
210+
211+
@pytest.mark.asyncio
212+
async def test_async_function_global_system_instruction():
213+
async def sub_agent_si(readonly_context: ReadonlyContext) -> str:
214+
return "This is the sub agent instruction."
215+
216+
async def root_agent_gi(readonly_context: ReadonlyContext) -> str:
217+
return "This is the global instruction."
218+
219+
sub_agent = Agent(
220+
model="gemini-1.5-flash",
221+
name="sub_agent",
222+
instruction=sub_agent_si,
223+
)
224+
root_agent = Agent(
225+
model="gemini-1.5-flash",
226+
name="root_agent",
227+
global_instruction=root_agent_gi,
228+
sub_agents=[sub_agent],
229+
)
230+
request = LlmRequest(
231+
model="gemini-1.5-flash",
232+
config=types.GenerateContentConfig(system_instruction=""),
233+
)
234+
invocation_context = utils.create_invocation_context(agent=sub_agent)
235+
invocation_context.session = Session(
236+
app_name="test_app",
237+
user_id="test_user",
238+
id="test_id",
239+
state={"customerId": "1234567890", "customer_int": 30},
240+
)
241+
242+
async for _ in instructions.request_processor.run_async(
243+
invocation_context,
244+
request,
245+
):
246+
pass
247+
248+
assert request.config.system_instruction == (
249+
"This is the global instruction.\n\nThis is the sub agent instruction."
250+
)
251+
252+
131253
@pytest.mark.asyncio
132254
async def test_build_system_instruction_with_namespace():
133255
request = LlmRequest(

0 commit comments

Comments
 (0)
0