8000 test: Add unittest suites for testing HITL confirmation flow on runne… · codenamenam/adk-python@4dbec15 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4dbec15

Browse files
XinranTangcopybara-github
authored andcommitted
test: Add unittest suites for testing HITL confirmation flow on runner level
PiperOrigin-RevId: 807327997
1 parent 402f362 commit 4dbec15

File tree

2 files changed

+378
-0
lines changed

2 files changed

+378
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 365 additions & 0 deletions
6D38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for HITL flows with different agent structures."""
16+
17+
import copy
18+
from unittest import mock
19+
20+
from google.adk.agents.base_agent import BaseAgent
21+
from google.adk.agents.llm_agent import LlmAgent
22+
from google.adk.agents.loop_agent import LoopAgent
23+
from google.adk.agents.parallel_agent import ParallelAgent
24+
from google.adk.agents.sequential_agent import SequentialAgent
25+
from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
26+
from google.adk.tools.function_tool import FunctionTool
27+
from google.adk.tools.tool_context import ToolContext
28+
from google.genai.types import FunctionCall
29+
from google.genai.types import FunctionResponse
30+
from google.genai.types import GenerateContentResponse
31+
from google.genai.types import Part
32+
import pytest
33+
34+
from .. import testing_utils
35+
36+
37+
def _create_llm_response_from_tools(
38+
tools: list[FunctionTool],
39+
) -> GenerateContentResponse:
40+
"""Creates a mock LLM response containing a function call."""
41+
parts = [
42+
Part(function_call=FunctionCall(name=tool.name, args={}))
43+
for tool in tools
44+
]
45+
return testing_utils.LlmResponse(
46+
content=testing_utils.ModelContent(parts=parts)
47+
)
48+
49+
50+
def _create_llm_response_from_text(text: str) -> GenerateContentResponse:
51+
"""Creates a mock LLM response containing text."""
52+
return testing_utils.LlmResponse(
53+
content=testing_utils.ModelContent(parts=[Part(text=text)])
54+
)
55+
56+
57+
def _test_request_confirmation_function(
58+
tool_context: ToolContext,
59+
) -> dict[str, str]:
60+
"""A test tool function that requests confirmation."""
61+
if not tool_context.tool_confirmation:
62+
tool_context.request_confirmation(hint="test hint for request_confirmation")
63+
return {"error": "test error for request_confirmation"}
64+
return {"result": f"confirmed={tool_context.tool_confirmation.confirmed}"}
65+
66+
67+
def _test_request_confirmation_function_with_custom_schema(
68+
tool_context: ToolContext,
69+
) -> dict[str, str]:
70+
"""A test tool function that requests confirmation, but with a custom payload schema."""
71+
if not tool_context.tool_confirmation:
72+
tool_context.request_confirmation(
73+
hint="test hint for request_confirmation with custom payload schema",
74+
payload={
75+
"test_custom_payload": {
76+
"int_field": 0,
77+
"str_field": "",
78+
"bool_field": False,
79+
}
80+
},
81+
)
82+
return {"error": "test error for request_confirmation"}
83+
return {
84+
"result": f"confirmed={tool_context.tool_confirmation.confirmed}",
85+
"custom_payload": tool_context.tool_confirmation.payload,
86+
}
87+
88+
89+
class BaseHITLTest:
90+
"""Base class for HITL tests with common fixtures."""
91+
92+
@pytest.fixture
93+
def runner(self, agent: BaseAgent) -> testing_utils.InMemoryRunner:
94+
"""Provides an in-memory runner for the agent."""
95+
return testing_utils.InMemoryRunner(root_agent=agent)
96+
97+
98+
class TestHITLConfirmationFlowWithSingleAgent(BaseHITLTest):
99+
"""Tests the HITL confirmation flow with a single LlmAgent."""
100+
101+
@pytest.fixture
102+
def tools(self) -> list[FunctionTool]:
103+
"""Provides the tools for the agent."""
104+
return [FunctionTool(func=_test_request_confirmation_function)]
105+
106+
@pytest.fixture
107+
def llm_responses(
108+
self, tools: list[FunctionTool]
109+
) -> list[GenerateContentResponse]:
110+
"""Provides mock LLM responses for the tests."""
111+
return [
112+
_create_llm_response_from_tools(tools),
113+
_create_llm_response_from_text("test llm response after tool call"),
114+
_create_llm_response_from_text(
115+
"test llm response after final tool call"
116+
),
117+
]
118+
119+
@pytest.fixture
120+
def mock_model(
121+
self, llm_responses: list[GenerateContentResponse]
122+
) -> testing_utils.MockModel:
123+
"""Provides a mock model with predefined responses."""
124+
return testing_utils.MockModel(responses=llm_responses)
125+
126+
@pytest.fixture
127+
def agent(
128+
self, mock_model: testing_utils.MockModel, tools: list[FunctionTool]
129+
) -> LlmAgent:
130+
"""Provides a single LlmAgent for the test."""
131+
return LlmAgent(name="root_agent", model=mock_model, tools=tools)
132+
133+
@pytest.mark.asyncio
134+
@pytest.mark.parametrize("tool_call_confirmed", [True, False])
135+
async def test_confirmation_flow(
136+
self,
137+
runner: testing_utils.InMemoryRunner,
138+
agent: LlmAgent,
139+
tool_call_confirmed: bool,
140+
):
141+
"""Tests HITL flow where all tool calls are confirmed."""
142+
user_query = testing_utils.UserContent("test user query")
143+
events = await runner.run_async(user_query)
144+
tools = agent.tools
145+
146+
expected_parts = [
147+
(
148+
agent.name,
149+
Part(function_call=FunctionCall(name=tools[0].name, args={})),
150+
),
151+
(
152+
agent.name,
153+
Part(
154+
function_call=FunctionCall(
155+
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
156+
args={
157+
"originalFunctionCall": {
158+
"name": tools[0].name,
159+
"id": mock.ANY,
160+
"args": {},
161+
},
162+
"toolConfirmation": {
163+
"hint": "test hint for request_confirmation",
164+
"confirmed": False,
165+
},
166+
},
167+
)
168+
),
169+
),
170+
(
171+
agent.name,
172+
Part(
173+
function_response=FunctionResponse(
174+
name=tools[0].name,
175+
response={"error": "test error for request_confirmation"},
176+
)
177+
),
178+
),
179+
(agent.name, "test llm response after tool call"),
180+
]
181+
182+
simplified = testing_utils.simplify_events(copy.deepcopy(events))
183+
for i, (agent_name, part) in enumerate(expected_parts):
184+
assert simplified[i][0] == agent_name
185+
assert simplified[i][1] == part
186+
187+
ask_for_confirmation_function_call_id = (
188+
events[1].content.parts[0].function_call.id
189+
)
190+
user_confirmation = testing_utils.UserContent(
191+
Part(
192+
function_response=FunctionResponse(
193+
id=ask_for_confirmation_function_call_id,
194+
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
195+
response={"confirmed": tool_call_confirmed},
196+
)
197+
)
198+
)
199+
events = await runner.run_async(user_confirmation)
200+
201+
expected_parts_final = [
202+
(
203+
agent.name,
204+
Part(
205+
function_response=FunctionResponse(
206+
name=tools[0].name,
207+
response={"result": f"confirmed={tool_call_confirmed}"},
208+
)
209+
),
210+
),
211+
(agent.name, "test llm response after final tool call"),
212+
]
213+
assert (
214+
testing_utils.simplify_events(copy.deepcopy(events))
215+
== expected_parts_final
216+
)
217+
218+
219+
class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest):
220+
"""Tests the HITL confirmation flow with a single agent, for custom confirmation payload schema."""
221+
222+
@pytest.fixture
223+
def tools(self) -> list[FunctionTool]:
224+
"""Provides the tools for the agent."""
225+
return [
226+
FunctionTool(
227+
func=_test_request_confirmation_function_with_custom_schema
228+
)
229+
]
230+
231+
@pytest.fixture
232+
def llm_responses(
233+
self, tools: list[FunctionTool]
234+
) -> list[GenerateContentResponse]:
235+
"""Provides mock LLM responses for the tests."""
236+
return [
237+
_create_llm_response_from_tools(tools),
238+
_create_llm_response_from_text("test llm response after tool call"),
239+
_create_llm_response_from_text(
240+
"test llm response after final tool call"
241+
),
242+
]
243+
244+
@pytest.fixture
245+
def mock_model(
246+
self, llm_responses: list[GenerateContentResponse]
247+
) -> testing_utils.MockModel:
248+
"""Provides a mock model with predefined responses."""
249+
return testing_utils.MockModel(responses=llm_responses)
250+
251+
@pytest.fixture
252+
def agent(
253+
self, mock_model: testing_utils.MockModel, tools: list[FunctionTool]
254+
) -> LlmAgent:
255+
"""Provides a single LlmAgent for the test."""
256+
return LlmAgent(name="root_agent", model=mock_model, tools=tools)
257+
258+
@pytest.mark.asyncio
259+
@pytest.mark.parametrize("tool_call_confirmed", [True, False])
260+
async def test_confirmation_flow(
261+
self,
262+
runner: testing_utils.InMemoryRunner,
263+
agent: LlmAgent,
264+
tool_call_confirmed: bool,
265+
):
266+
"""Tests HITL flow with custom payload schema."""
267+
tools = agent.tools
268+
user_query = testing_utils.UserContent("test user query")
269+
events = await runner.run_async(user_query)
270+
271+
expected_parts = [
272+
(
273+
agent.name,
274+
Part(function_call=FunctionCall(name=tools[0].name, args={})),
275+
),
276+
(
277+
agent.name,
278+
Part(
279+
function_call=FunctionCall(
280+
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
281+
args={
282+
"originalFunctionCall": {
283+
"name": tools[0].name,
284+
"id": mock.ANY,
285+
"args": {},
286+
},
287+
"toolConfirmation": {
288+
"hint": (
289+
"test hint for request_confirmation with"
290+
" custom payload schema"
291+
),
292+
"confirmed": False,
293+
"payload": {
294+
"test_custom_payload": {
295+
"int_field": 0,
296+
"str_field": "",
297+
"bool_field": False,
298+
}
299+
},
300+
},
301+
},
302+
)
303+
),
304+
),
305+
(
306+
agent.name,
307+
Part(
308+
function_response=FunctionResponse(
309+
name=tools[0].name,
310+
response={"error": "test error for request_confirmation"},
311+
)
312+
),
313+
),
314+
(agent.name, "test llm response after tool call"),
315+
]
316+
317+
simplified = testing_utils.simplify_events(copy.deepcopy(events))
318+
for i, (agent_name, part) in enumerate(expected_parts):
319+
assert simplified[i][0] == agent_name
320+
assert simplified[i][1] == part
321+
322+
ask_for_confirmation_function_call_id = (
323+
events[1].content.parts[0].function_call.id
324+
)
325+
custom_payload = {
326+
"test_custom_payload": {
327+
"int_field": 123,
328+
"str_field": "test_str",
329+
"bool_field": True,
330+
}
331+
}
332+
user_confirmation = testing_utils.UserContent(
333+
Part(
334+
function_response=FunctionResponse(
335+
id=ask_for_confirmation_function_call_id,
336+
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
337+
response={
338+
"confirmed": tool_call_confirmed,
339+
"payload": custom_payload,
340+
},
341+
)
342+
)
343+
)
344+
events = await runner.run_async(user_confirmation)
345+
346+
expected_response = {
347+
"result": f"confirmed={tool_call_confirmed}",
348+
"custom_payload": custom_payload,
349+
}
350+
expected_parts_final = [
351+
(
352+
agent.name,
353+
Part(
354+
function_response=FunctionResponse(
355+
name=tools[0].name,
356+
response=expected_response,
357+
)
358+
),
359+
),
360+
(agent.name, "test llm response after final tool call"),
361+
]
362+
assert (
363+
testing_utils.simplify_events(copy.deepcopy(events))
364+
== expected_parts_final
365+
)

0 commit comments

Comments
 (0)
0