8000 chore: fix unit tests · stevenchendan/adk-python@d40df2e · GitHub
[go: up one dir, main page]

Skip to content

Commit d40df2e

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: fix unit tests
PiperOrigin-RevId: 764107186
1 parent a66f122 commit d40df2e

File tree

3 files changed

+62
-53
lines changed

3 files changed

+62
-53
lines changed

tests/unittests/fast_api/test_fast_api.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,14 @@
1414

1515
import asyncio
1616
import logging
17-
import os
18-
import sys
1917
import time
20-
import types as ptypes
2118
from unittest.mock import MagicMock
2219
from unittest.mock import patch
2320

2421
from fastapi.testclient import TestClient
2522
from google.adk.agents.base_agent import BaseAgent
2623
from google.adk.agents.run_config import RunConfig
2724
from google.adk.cli.fast_api import get_fast_api_app
28-
from google.adk.cli.utils import envs
2925
from google.adk.events import Event
3026
from google.adk.runners import Runner
3127
from google.adk.sessions.base_session_service import ListSessionsResponse
@@ -48,22 +44,7 @@ def __init__(self, name):
4844
self.sub_agents = []
4945

5046

51-
# Set up dummy module and add to sys.modules
52-
dummy_module = ptypes.ModuleType("test_agent")
53-
dummy_module.agent = ptypes.SimpleNamespace(
54-
root_agent=DummyAgent(name="dummy_agent")
55-
)
56-
sys.modules["test_app"] = dummy_module
57-
58-
# Try to load environment variables, with a fallback for testing
59-
try:
60-
envs.load_dotenv_for_agent("test_app", ".")
61-
except Exception as e:
62-
logger.warning(f"Could not load environment variables: {e}")
63-
# Create a basic .env file if needed
64-
if not os.path.exists(".env"):
65-
with open(".env", "w") as f:
66-
f.write("# Test environment variables\n")
47+
root_agent = DummyAgent(name="dummy_agent")
6748

6849

6950
# Create sample events that our mocked runner will return
@@ -150,6 +131,20 @@ def test_session_info():
150131
}
151132

152133

134+
@pytest.fixture
135+
def mock_agent_loader():
136+
137+
class MockAgentLoader:
138+
139+
def __init__(self, agents_dir: str):
140+
pass
141+
142+
def load_agent(self, app_name):
143+
return root_agent
144+
145+
return MockAgentLoader(".")
146+
147+
153148
@pytest.fixture
154149
def mock_session_service():
155150
"""Create a mock session service that uses an in-memory dictionary."""
@@ -287,24 +282,33 @@ def mock_memory_service():
287282

288283

289284
@pytest.fixture
290-
def test_app(mock_session_service, mock_artifact_service, mock_memory_service):
285+
def test_app(
286+
mock_session_service,
287+
mock_artifact_service,
288+
mock_memory_service,
289+
mock_agent_loader,
290+
):
291291
"""Create a TestClient for the FastAPI app without starting a server."""
292292

293293
# Patch multiple services and signal handlers
294294
with (
295295
patch("signal.signal", return_value=None),
296296
patch(
297-
"google.adk.cli.fast_api.InMemorySessionService", # Changed this line
297+
"google.adk.cli.fast_api.InMemorySessionService",
298298
return_value=mock_session_service,
299299
),
300300
patch(
301-
"google.adk.cli.fast_api.InMemoryArtifactService", # Make consistent
301+
"google.adk.cli.fast_api.InMemoryArtifactService",
302302
return_value=mock_artifact_service,
303303
),
304304
patch(
305-
"google.adk.cli.fast_api.InMemoryMemoryService", # Make consistent
305+
"google.adk.cli.fast_api.InMemoryMemoryService",
306306
return_value=mock_memory_service,
307307
),
308+
patch(
309+
"google.adk.cli.fast_api.AgentLoader",
310+
return_value=mock_agent_loader,
311+
),
308312
):
309313
# Get the FastAPI app, but don't actually run it
310314
app = get_fast_api_app(

tests/unittests/sessions/test_session_service.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,9 @@ async def test_append_event_complete(service_type):
315315

316316

317317
@pytest.mark.asyncio
318-
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE])
318+
@pytest.mark.parametrize(
319+
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
320+
)
319321
async def test_get_session_with_config(service_type):
320322
session_service = get_session_service(service_type)
321323
app_name = 'my_app'

tests/unittests/test_telemetry.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import Any
22
from typing import Optional
33

4-
from google.adk.sessions import InMemorySessionService
54
from google.adk.agents.invocation_context import InvocationContext
65
from google.adk.agents.llm_agent import LlmAgent
76
from google.adk.models.llm_request import LlmRequest
87
from google.adk.models.llm_response import LlmResponse
8+
from google.adk.sessions import InMemorySessionService
99
from google.adk.telemetry import trace_call_llm
1010
from google.genai import types
1111
import pytest
@@ -16,10 +16,10 @@ async def _create_invocation_context(
1616
) -> InvocationContext:
1717
session_service = InMemorySessionService()
1818
session = await session_service.create_session(
19-
app_name='test_app', user_id='test_user', state=state
19+
app_name="test_app", user_id="test_user", state=state
2020
)
2121
invocation_context = InvocationContext(
22-
invocation_id='test_id',
22+
invocation_id="test_id",
2323
agent=agent,
2424
session=session,
2525
session_service=session_service,
@@ -29,34 +29,37 @@ async def _create_invocation_context(
2929

3030
@pytest.mark.asyncio
3131
async def test_trace_call_llm_function_response_includes_part_from_bytes():
32-
agent = LlmAgent(name='test_agent')
32+
agent = LlmAgent(name="test_agent")
3333
invocation_context = await _create_invocation_context(agent)
3434
llm_request = LlmRequest(
35-
contents=[
36-
types.Content(
37-
role="user",
38-
parts=[
39-
types.Part.from_function_response(
40-
name="test_function_1",
41-
response={
42-
"result": b"test_data",
43-
},
35+
contents=[
36+
types.Content(
37+
role="user",
38+
parts=[
39+
types.Part.from_function_response(
40+
name="test_function_1",
41+
response={
42+
"result": b"test_data",
43+
},
44+
),
45+
],
4446
),
45-
],
46-
),
47-
types.Content(
48-
role="user",
49-
parts=[
50-
types.Part.from_function_response(
51-
name="test_function_2",
52-
response={
53-
"result": types.Part.from_bytes(data=b"test_data", mime_type="application/octet-stream"),
54-
},
47+
types.Content(
48+
role="user",
49+
parts=[
50+
types.Part.from_function_response(
51+
name="test_function_2",
52+
response={
53+
"result": types.Part.from_bytes(
54+
data=b"test_data",
55+
mime_type="application/octet-stream",
56+
),
57+
},
58+
),
59+
],
5560
),
56-
],
57-
),
58-
],
59-
config=types.GenerateContentConfig(system_instruction=""),
61+
],
62+
config=types.GenerateContentConfig(system_instruction=""),
6063
)
6164
llm_response = LlmResponse(turn_complete=True)
62-
trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response)
65+
trace_call_llm(invocation_context, "test_event_id", llm_request, llm_response)

0 commit comments

Comments
 (0)
0