8000 Provide inject_session_state as public util method · Jay-flow/adk-python@c5a0437 · GitHub
[go: up one dir, main page]

Skip to content

Commit c5a0437

Browse files
selcukguncopybara-github
authored andcommitted
Provide inject_session_state as public util method
This is useful for injecting artifacts and session state variable into instruction template typically in instruction providers. PiperOrigin-RevId: 761595473
1 parent e060344 commit c5a0437

File tree

3 files changed

+234
-77
lines changed

3 files changed

+234
-77
lines changed

src/google/adk/flows/llm_flows/instructions.py

Lines changed: 7 additions & 77 deletions
-
"""Populates values in the instruction template, e.g. state, artifact, etc."""
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ...agents.readonly_context import ReadonlyContext
2727
from ...events.event import Event
2828
from ...sessions.state import State
29+
from ...utils import instructions_utils
2930
from ._base_llm_processor import BaseLlmRequestProcessor
3031

3132
if TYPE_CHECKING:
@@ -60,7 +61,9 @@ async def run_async(
6061
)
6162
si = raw_si
6263
if not bypass_state_injection:
63-
si = await _populate_values(raw_si, invocation_context)
64+
si = await instructions_utils.inject_session_state(
65+
raw_si, ReadonlyContext(invocation_context)
66+
)
6467
llm_request.append_instructions([si])
6568

6669
# Appends agent instructions if set.
@@ -70,7 +73,9 @@ async def run_async(
7073
)
7174
si = raw_si
7275
if not bypass_state_injection:
73-
si = await _populate_values(raw_si, invocation_context)
76+
si = await instructions_utils.inject_session_state(
77+
raw_si, ReadonlyContext(invocation_context)
78+
)
7479
llm_request.append_instructions([si])
7580

7681
# Maintain async generator behavior
@@ -79,78 +84,3 @@ async def run_async(
7984

8085

8186
request_processor = _InstructionsLlmRequestProcessor()
82-
83-
84-
async def _populate_values(
85-
instruction_template: str,
86-
context: InvocationContext,
87-
) -> str:
88
89-
90-
async def _async_sub(pattern, repl_async_fn, string) -> str:
91-
result = []
92-
last_end = 0
93-
for match in re.finditer(pattern, string):
94-
result.append(string[last_end : match.start()])
95-
replacement = await repl_async_fn(match)
96-
result.append(replacement)
97-
last_end = match.end()
98-
result.append(string[last_end:])
99-
return ''.join(result)
100-
101-
async def _replace_match(match) -> str:
102-
var_name = match.group().lstrip('{').rstrip('}').strip()
103-
optional = False
104-
if var_name.endswith('?'):
105-
optional = True
106-
var_name = var_name.removesuffix('?')
107-
if var_name.startswith('artifact.'):
108-
var_name = var_name.removeprefix('artifact.')
109-
if context.artifact_service is None:
110-
raise ValueError('Artifact service is not initialized.')
111-
artifact = await context.artifact_service.load_artifact(
112-
app_name=context.session.app_name,
113-
user_id=context.session.user_id,
114-
session_id=context.session.id,
115-
filename=var_name,
116-
)
117-
if not var_name:
118-
raise KeyError(f'Artifact {var_name} not found.')
119-
return str(artifact)
120-
else:
121-
if not _is_valid_state_name(var_name):
122-
return match.group()
123-
if var_name in context.session.state:
124-
return str(context.session.state[var_name])
125-
else:
126-
if optional:
127-
return ''
128-
else:
129-
raise KeyError(f'Context variable not found: `{var_name}`.')
130-
131-
return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
132-
133-
134-
def _is_valid_state_name(var_name):
135-
"""Checks if the variable name is a valid state name.
136-
137-
Valid state is either:
138-
- Valid identifier
139-
- <Valid prefix>:<Valid identifier>
140-
All the others will just return as it is.
141-
142-
Args:
143-
var_name: The variable name to check.
144-
145-
Returns:
146-
True if the variable name is a valid state name, False otherwise.
147-
"""
148-
parts = var_name.split(':')
149-
if len(parts) == 1:
150-
return var_name.isidentifier()
151-
152-
if len(parts) == 2:
153-
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
154-
if (parts[0] + ':') in prefixes:
155-
return parts[1].isidentifier()
156-
return False
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from google.adk.agents import Agent
2+
from google.adk.agents.invocation_context import InvocationContext
3+
from google.adk.agents.readonly_context import ReadonlyContext
4+
from google.adk.sessions import Session
5+
from google.adk.utils import instructions_utils
6+
import pytest
7+
8+
from .. import utils
9+
10+
11+
class MockArtifactService:
12+
13+
def __init__(self, artifacts: dict):
14+
self.artifacts = artifacts
15+
16+
async def load_artifact(self, app_name, user_id, session_id, filename):
17+
if filename in self.artifacts:
18+
return self.artifacts[filename]
19+
else:
20+
raise KeyError(f"Artifact '{filename}' not found.")
21+
22+
23+
async def _create_test_readonly_context(
24+
state: dict = None,
25+
artifact_service: MockArtifactService = None,
26+
app_name: str = "test_app",
27+
user_id: str = "test_user",
28+
session_id: str = "test_session_id",
29+
) -> ReadonlyContext:
30+
agent = Agent(
31+
model="gemini-2.0-flash",
32+
name="agent",
33+
instruction="test",
34+
)
35+
invocation_context = await utils.create_invocation_context(agent=agent)
36+
invocation_context.session = Session(
37+
state=state if state else {},
38+
app_name=app_name,
39+
user_id=user_id,
40+
id=session_id,
41+
)
42+
43+
invocation_context.artifact_service = artifact_service
44+
return ReadonlyContext(invocation_context)
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_inject_session_state():
49+
instruction_template = "Hello {user_name}, you are in {app_state} state."
50+
invocation_context = await _create_test_readonly_context(
51+
state={"user_name": "Foo", "app_state": "active"}
52+
)
53+
54+
populated_instruction = await instructions_utils.inject_session_state(
55+
instruction_template, invocation_context
56+
)
57+
assert populated_instruction == "Hello Foo, you are in active state."
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_inject_session_state_with_artifact():
62+
instruction_template = "The artifact content is: {artifact.my_file}"
63+
mock_artifact_service = MockArtifactService(
64+
{"my_file": "This is my artifact content."}
65+
)
66+
invocation_context = await _create_test_readonly_context(
67+
artifact_service=mock_artifact_service
68+
)
69+
70+
populated_instruction = await instructions_utils.inject_session_state(
71+
instruction_template, invocation_context
72+
)
73+
assert (
74+
populated_instruction
75+
== "The artifact content is: This is my artifact content."
76+
)
77+
78+
79+
@pytest.mark.asyncio
80+
async def test_inject_session_state_with_optional_state():
81+
instruction_template = "Optional value: {optional_value?}"
82+
invocation_context = await _create_test_readonly_context()
83+
84+
populated_instruction = await instructions_utils.inject_session_state(
85+
instruction_template, invocation_context
86+
)
87+
assert populated_instruction == "Optional value: "
88+
89+
90+
@pytest.mark.asyncio
91+
async def test_inject_session_state_with_missing_state_raises_key_error():
92+
instruction_template = "Hello {missing_key}!"
93+
invocation_context = await _create_test_readonly_context(
94+
state={"user_name": "Foo"}
95+
)
96+
97+
with pytest.raises(
98+
KeyError, match="Context variable not found: `missing_key`."
99+
):
100+
await instructions_utils.inject_session_state(
101+
instruction_template, invocation_context
102+
)
103+
104+
105+
@pytest.mark.asyncio
106+
async def test_inject_session_state_with_missing_artifact_raises_key_error():
107+
instruction_template = "The artifact content is: {artifact.missing_file}"
108+
mock_artifact_service = MockArtifactService(
109+
{"my_file": "This is my artifact content."}
110+
)
111+
invocation_context = await _create_test_readonly_context(
112+
artifact_service=mock_artifact_service
113+
)
114+
115+
with pytest.raises(KeyError, match="Artifact 'missing_file' not found."):
116+
await instructions_utils.inject_session_state(
117+
instruction_template, invocation_context
118+
)
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_inject_session_state_with_invalid_state_name_returns_original():
123+
instruction_template = "Hello {invalid-key}!"
124+
invocation_context = await _create_test_readonly_context(
125+
state={"user_name": "Foo"}
126+
)
127+
128+
populated_instruction = await instructions_utils.inject_session_state(
129+
instruction_template, invocation_context
130+
)
131+
assert populated_instruction == "Hello {invalid-key}!"
132+
133+
134+
@pytest.mark.asyncio
135+
async def test_inject_session_state_with_invalid_prefix_state_name_returns_original():
136+
instruction_template = "Hello {invalid:key}!"
137+
invocation_context = await _create_test_readonly_context(
138+
state={"user_name": "Foo"}
139+
)
140+
141+
populated_instruction = await instructions_utils.inject_session_state(
142+
instruction_template, invocation_context
143+
)
144+
assert populated_instruction == "Hello {invalid:key}!"
145+
146+
147+
@pytest.mark.asyncio
148+
async def test_inject_session_state_with_valid_prefix_state():
149+
instruction_template = "Hello {app:user_name}!"
150+
invocation_context = await _create_test_readonly_context(
151+
state={"app:user_name": "Foo"}
152+
)
153+
154+
populated_instruction = await instructions_utils.inject_session_state(
155+
instruction_template, invocation_context
156+
)
157+
assert populated_instruction == "Hello Foo!"
158+
159+
160+
@pytest.mark.asyncio
161+
async def test_inject_session_state_with_multiple_variables_and_artifacts():
162+
instruction_template = """
163+
Hello {user_name},
164+
You are {user_age} years old.
165+
Your favorite color is {favorite_color?}.
166+
The artifact says: {artifact.my_file}
167+
And another optional artifact: {artifact.other_file}
168+
"""
169+
mock_artifact_service = MockArtifactService({
170+
"my_file": "This is my artifact content.",
171+
"other_file": "This is another artifact content.",
172+
})
173+
invocation_context = await _create_test_readonly_context(
174+
state={"user_name": "Foo", "user_age": 30, "favorite_color": "blue"},
175+
artifact_service=mock_artifact_service,
176+
)
177+
178+
populated_instruction = await instructions_utils.inject_session_state(
179+
instruction_template, invocation_context
180+
)
181+
expected_instruction = """
182+
Hello Foo,
183+
You are 30 years old.
184+
Your favorite color is blue.
185+
The artifact says: This is my artifact content.
186+
And another optional artifact: This is another artifact content.
187+
"""
188+
assert populated_instruction == expected_instruction
189+
190+
191+
@pytest.mark.asyncio
192+
async def test_inject_session_state_with_empty_artifact_name_raises_key_error():
193+
instruction_template = "The artifact content is: {artifact.}"
194+
mock_artifact_service = MockArtifactService(
195+
{"my_file": "This is my artifact content."}
196+
)
197+
invocation_context = await _create_test_readonly_context(
198+
artifact_service=mock_artifact_service
199+
)
200+
201+
with pytest.raises(KeyError, match="Artifact '' not found."):
202+
await instructions_utils.inject_session_state(
203+
instruction_template, invocation_context
204+
)
205+
206+
207+
@pytest.mark.asyncio
208+
async def test_inject_session_state_artifact_service_not_initialized_raises_value_error():
209+
instruction_template = "The artifact content is: {artifact.my_file}"
210+
invocation_context = await _create_test_readonly_context()
211+
with pytest.raises(ValueError, match="Artifact service is not initialized."):
212+
await instructions_utils.inject_session_state(
213+
instruction_template, invocation_context
214+
)

0 commit comments

Comments
 (0)
0