8000 chore: Fixes test_fast_api.py (part I for circular deps). · mindpower/adk-python@9324801 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9324801

Browse files
Jacksunweicopybara-github
authored andcommitted
chore: Fixes test_fast_api.py (part I for circular deps).
It still fails due to signal used not in main thread. It will be fixed later. PiperOrigin-RevId: 760050504
1 parent f592de4 commit 9324801

File tree

2 files changed

+64
-39
lines changed

2 files changed

+64
-39
lines changed

src/google/adk/memory/vertex_ai_rag_memory_service.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
16+
from __future__ import annotations
17+
1518
from collections import OrderedDict
1619
import json
1720
import os
1821
import tempfile
1922
from typing import Optional
23+
from typing import TYPE_CHECKING
2024

2125
from google.genai import types
2226
from typing_extensions import override
2327
from vertexai.preview import rag
2428

25-
from ..events.event import Event
26-
from ..sessions.session import Session
2729
from . import _utils
2830
from .base_memory_service import BaseMemoryService
2931
from .base_memory_service import SearchMemoryResponse
3032
from .memory_entry import MemoryEntry
3133

34+
if TYPE_CHECKING:
35+
from ..events.event import Event
36+
from ..sessions.session import Session
37+
3238

3339
class VertexAiRagMemoryService(BaseMemoryService):
3440
"""A memory service that uses Vertex AI RAG for storage and retrieval."""
@@ -103,6 +109,8 @@ async def search_memory(
103109
self, *, app_name: str, user_id: str, query: str
104110
) -> SearchMemoryResponse:
105111
"""Searches for sessions that match the query using rag.retrieval_query."""
112+
from ..events.event import Event
113+
106114
response = rag.retrieval_query(
107115
text=query,
108116
rag_resources=self._vertex_rag_store.rag_resources,

tests/unittests/fast_api/test_fast_api.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,33 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import asyncio
1618
import json
1719
import sys
1820
import threading
1921
import time
2022
import types as ptypes
2123
from typing import AsyncGenerator
24+
from typing import TYPE_CHECKING
2225

23-
from google.adk.agents import BaseAgent
24-
from google.adk.agents import LiveRequest
26+
from google.adk.agents.base_agent import BaseAgent
27+
from google.adk.agents.live_request_queue import LiveRequest
2528
from google.adk.agents.run_config import RunConfig
2629
from google.adk.cli.fast_api import AgentRunRequest
2730
from google.adk.cli.fast_api import get_fast_api_app
2831
from google.adk.cli.utils import envs
29-
from google.adk.events import Event
3032
from google.adk.runners import Runner
3133
from google.genai import types
3234
import httpx
3335
import pytest
3436
from uvicorn.main import run as uvicorn_run
3537
import websockets
3638

39+
if TYPE_CHECKING:
40+
from google.adk.events import Event
41+
3742

3843
# Here we “fake” the agent module that get_fast_api_app expects.
3944
# The server code does: `agent_module = importlib.import_module(agent_name)`
@@ -49,33 +54,45 @@ class DummyAgent(BaseAgent):
4954
sys.modules["test_app"] = dummy_module
5055
envs.load_dotenv_for_agent("test_app", ".")
5156

52-
event1 = Event(
53-
author="dummy agent",
54-
invocation_id="invocation_id",
55-
content=types.Content(
56-
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
57-
),
58-
)
5957

60-
event2 = Event(
61-
author="dummy agent",
62-
invocation_id="invocation_id",
63-
content=types.Content(
64-
role="model",
65-
parts=[
66-
types.Part(
67-
text=None,
68-
inline_data=types.Blob(
69-
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
70-
),
71-
)
72-
],
73-
),
74-
)
58+
def _event_1():
59+
from google.adk.events import Event
7560

76-
event3 = Event(
77-
author="dummy agent", invocation_id="invocation_id", interrupted=True
78-
)
61+
return Event(
62+
author="dummy agent",
63+
invocation_id="invocation_id",
64+
content=types.Content(
65+
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
66+
),
67+
)
68+
69+
70+
def _event_2():
71+
from google.adk.events import Event
72+
73+
return Event(
74+
author="dummy agent",
75+
invocation_id="invocation_id",
76+
content=types.Content(
77+
role="model",
78+
parts=[
79+
types.Part(
80+
text=None,
81+
inline_data=types.Blob(
82+
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
83+
),
84+
)
85+
],
86+
),
87+
)
88+
89+
90+
def _event_3():
91+
from google.adk.events import Event
92+
93+
return Event(
94+
author="dummy agent", invocation_id="invocation_id", interrupted=True
95+
)
7996

8097

8198
# For simplicity, we patch Runner.run_live to yield dummy events.
@@ -84,13 +101,13 @@ async def dummy_run_live(
84101
self, session, live_request_queue
85102
) -> AsyncGenerator[Event, None]:
86103
# Immediately yield a dummy event with a text reply.
87-
yield event1
104+
yield _event_1()
88105
await asyncio.sleep(0)
89106

90-
yield event2
107+
yield _event_2()
91108
await asyncio.sleep(0)
92109

93-
yield event3
110+
yield _event_3()
94111

95112
raise Exception()
96113

@@ -103,13 +120,13 @@ async def dummy_run_async(
103120
run_config: RunConfig = RunConfig(),
104121
) -> AsyncGenerator[Event, None]:
105122
# Immediately yield a dummy event with a text reply.
106-
yield event1
123+
yield _event_1()
107124
await asyncio.sleep(0)
108125

109-
yield event2
126+
yield _event_2()
110127
await asyncio.sleep(0)
111128

112-
yield event3
129+
yield _event_3()
113130

114131
return
115132

@@ -199,15 +216,15 @@ async def test_sse_endpoint():
199216
if event_data:
200217
event_count += 1
201218
if event_count == 1:
202-
assert event_data == event1.model_dump_json(
219+
assert event_data == _event_1().model_dump_json(
203220
exclude_none=True, by_alias=True
204221
)
205222
elif event_count == 2:
206-
assert event_data == event2.model_dump_json(
223+
assert event_data == _event_2().model_dump_json(
207224
exclude_none=True, by_alias=True
208225
)
209226
elif event_count == 3:
210-
assert event_data == event3.model_dump_json(
227+
assert event_data == _event_3().model_dump_json(
211228
exclude_none=True, by_alias=True
212229
)
213230
else:

0 commit comments

Comments
 (0)
0