8000 chore: Set agent_engine_id in the service constructor, also use the a… · devevignesh/adk-python@fc65873 · GitHub
[go: up one dir, main page]

Skip to content

Commit fc65873

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: Set agent_engine_id in the service constructor, also use the agent_engine_id field instead of overriding app_name in FastAPI endpoint
PiperOrigin-RevId: 770427903
1 parent a7ea374 commit fc65873

File tree

3 files changed

+74
-73
lines changed

3 files changed

+74
-73
lines changed

src/google/adk/cli/fast_api.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ async def internal_lifespan(app: FastAPI):
276276
memory_service = InMemoryMemoryService()
277277

278278
# Build the Session service
279-
agent_engine_id = ""
280279
if session_service_uri:
281280
if session_service_uri.startswith("agentengine://"):
282281
# Create vertex session service
@@ -285,8 +284,9 @@ async def internal_lifespan(app: FastAPI):
285284
raise click.ClickException("Agent engine id can not be empty.")
286285
envs.load_dotenv_for_agent("", agents_dir)
287286
session_service = VertexAiSessionService(
288-
os.environ["GOOGLE_CLOUD_PROJECT"],
289-
os.environ["GOOGLE_CLOUD_LOCATION"],
287+
project=os.environ["GOOGLE_CLOUD_PROJECT"],
288+
location=os.environ["GOOGLE_CLOUD_LOCATION"],
289+
agent_engine_id=agent_engine_id,
290290
)
291291
else:
292292
session_service = DatabaseSessionService(db_url=session_service_uri)
@@ -357,8 +357,6 @@ def get_session_trace(session_id: str) -> Any:
357357
async def get_session(
358358
app_name: str, user_id: str, session_id: str
359359
) -> Session:
360-
# Connect to managed session if agent_engine_id is set.
361-
app_name = agent_engine_id if agent_engine_id else app_name
362360
session = await session_service.get_session(
363361
app_name=app_name, user_id=user_id, session_id=session_id
364362
)
@@ -371,8 +369,6 @@ async def get_session(
371369
response_model_exclude_none=True,
372370
)
373371
async def list_sessions(app_name: str, user_id: str) -> list[Session]:
374-
# Connect to managed session if agent_engine_id is set.
375-
app_name = agent_engine_id if agent_engine_id else app_name
376372
list_sessions_response = await session_service.list_sessions(
377373
app_name=app_name, user_id=user_id
378374
)
@@ -393,8 +389,6 @@ async def create_session_with_id(
393389
session_id: str,
394390
state: Optional[dict[str, Any]] = None,
395391
) -> Session:
396-
# Connect to managed session if agent_engine_id is set.
397-
app_name = agent_engine_id if agent_engine_id else app_name
398392
if (
399393
await session_service.get_session(
400394
app_name=app_name, user_id=user_id, session_id=session_id
@@ -419,8 +413,6 @@ async def create_session(
419413
user_id: str,
420414
state: Optional[dict[str, Any]] = None,
421415
) -> Session:
422-
# Connect to managed session if agent_engine_id is set.
423-
app_name = agent_engine_id if agent_engine_id else app_name
424416
logger.info("New session created")
425417
return await session_service.create_session(
426418
app_name=app_name, user_id=user_id, state=state
@@ -660,8 +652,6 @@ def list_eval_results(app_name: str) -> list[str]:
660652

661653
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
662654
async def delete_session(app_name: str, user_id: str, session_id: str):
663-
# Connect to managed session if agent_engine_id is set.
664-
app_name = agent_engine_id if agent_engine_id else app_name
665655
await session_service.delete_session(
666656
app_name=app_name, user_id=user_id, session_id=session_id
667657
)
@@ -677,7 +667,6 @@ async def load_artifact(
677667
artifact_name: str,
678668
version: Optional[int] = Query(None),
679669
) -> Optional[types.Part]:
680-
app_name = agent_engine_id if agent_engine_id else app_name
681670
artifact = await artifact_service.load_artifact(
682671
app_name=app_name,
683672
user_id=user_id,
@@ -700,7 +689,6 @@ async def load_artifact_version(
700689
artifact_name: str,
701690
version_id: int,
702691
) -> Optional[types.Part]:
703-
app_name = agent_engine_id if agent_engine_id else app_name
704692
artifact = await artifact_service.load_artifact(
705693
app_name=app_name,
706694
user_id=user_id,
@@ -719,7 +707,6 @@ async def load_artifact_version(
719707
async def list_artifact_names(
720708
app_name: str, user_id: str, session_id: str
721709
) -> list[str]:
722-
app_name = agent_engine_id if agent_engine_id else app_name
723710
return await artifact_service.list_artifact_keys(
724711
app_name=app_name, user_id=user_id, session_id=session_id
725712
)
@@ -731,7 +718,6 @@ async def list_artifact_names(
731718
async def list_artifact_versions(
732719
app_name: str, user_id: str, session_id: str, artifact_name: str
733720
) -> list[int]:
734-
app_name = agent_engine_id if agent_engine_id else app_name
735721
return await artifact_service.list_versions(
736722
app_name=app_name,
737723
user_id=user_id,
@@ -745,7 +731,6 @@ async def list_artifact_versions(
745731
async def delete_artifact(
746732
app_name: str, user_id: str, session_id: str, artifact_name: str
747733
):
748-
app_name = agent_engine_id if agent_engine_id else app_name
749734
await artifact_service.delete_artifact(
750735
app_name=app_name,
751736
user_id=user_id,
@@ -755,10 +740,8 @@ async def delete_artifact(
755740

756741
@app.post("/run", response_model_exclude_none=True)
757742
async def agent_run(req: AgentRunRequest) -> list[Event]:
758-
# Connect to managed session if agent_engine_id is set.
759-
app_name = agent_engine_id if agent_engine_id else req.app_name
760743
session = await session_service.get_session(
761-
app_name=app_name, user_id=req.user_id, session_id=req.session_id
744+
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
762745
)
763746
if not session:
764747
raise HTTPException(status_code=404, detail="Session not found")
@@ -776,11 +759,9 @@ async def agent_run(req: AgentRunRequest) -> list[Event]:
776759

777760
@app.post("/run_sse")
778761
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
779-
# Connect to managed session if agent_engine_id is set.
780-
app_name = agent_engine_id if agent_engine_id else req.app_name
781762
# SSE endpoint
782763
session = await session_service.get_session(
783-
app_name=app_name, user_id=req.user_id, session_id=req.session_id
764+
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
784765
)
785766
if not session:
786767
raise HTTPException(status_code=404, detail="Session not found")
@@ -818,8 +799,6 @@ async def event_generator():
818799
async def get_event_graph(
819800
app_name: str, user_id: str, session_id: str, event_id: str
820801
):
821-
# Connect to managed session if agent_engine_id is set.
822-
app_name = agent_engine_id if agent_engine_id else app_name
823802
session = await session_service.get_session(
824803
app_name=app_name, user_id=user_id, session_id=session_id
825804
)
@@ -875,8 +854,6 @@ async def agent_live_run(
875854
) -> None:
876855
await websocket.accept()
877856

878-
# Connect to managed session if agent_engine_id is set.
879-
app_name = agent_engine_id if agent_engine_id else app_name
880857
session = await session_service.get_session(
881858
app_name=app_name, user_id=user_id, session_id=session_id
882859
)
@@ -940,7 +917,7 @@ async def _get_runner_async(app_name: str) -> Runner:
940917
return runner_dict[app_name]
941918
root_agent = agent_loader.load_agent(app_name)
942919
runner = Runner(
943-
app_name=agent_engine_id if agent_engine_id else app_name,
920+
app_name=app_name,
944921
agent=root_agent,
945922
artifact_service=artifact_service,
946923
session_service=session_service,

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import urllib.parse
2323

2424
from dateutil import parser
25-
from google.genai import types
2625
from typing_extensions import override
2726

2827
from google import genai
@@ -40,15 +39,27 @@
4039

4140

4241
class VertexAiSessionService(BaseSessionService):
43-
"""Connects to the managed Vertex AI Session Service."""
42+
"""Connects to the Vertex AI Agent Engine Session Service using GenAI API client.
43+
44+
https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/sessions/overview
45+
"""
4446

4547
def __init__(
4648
self,
47-
project: str = None,
48-
location: str = None,
49+
project: Optional[str] = None,
50+
location: Optional[str] = None,
51+
agent_engine_id: Optional[str] = None,
4952
):
50-
self.project = project
51-
self.location = location
53+
"""Initializes the VertexAiSessionService.
54+
55+
Args:
56+
project: The project id of the project to use.
57+
location: The location of the project to use.
58+
agent_engine_id: The resource ID of the agent engine to use.
59+
"""
60+
self._project = project
61+
self._location = location
62+
self._agent_engine_id = agent_engine_id
5263

5364
@override
5465
async def create_session(
@@ -64,14 +75,13 @@ async def create_session(
6475
'User-provided Session id is not supported for'
6576
' VertexAISessionService.'
6677
)
67-
68-
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
78+
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
79+
api_client = self._get_api_client()
6980

7081
session_json_dict = {'user_id': user_id}
7182
if state:
7283
session_json_dict['session_state'] = state
7384

74-
api_client = _get_api_client(self.project, self.location)
7585
api_response = await api_client.async_request(
7686
http_method='POST',
7787
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
@@ -130,10 +140,10 @@ async def get_session(
130140
session_id: str,
131141
config: Optional[GetSessionConfig] = None,
132142
) -> Optional[Session]:
133-
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
143+
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
144+
api_client = self._get_api_client()
134145

135146
# Get session resource
136-
api_client = _get_api_client(self.project, self.location)
137147
get_session_api_response = await api_client.async_request(
138148
http_method='GET',
139149
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
@@ -203,14 +213,14 @@ async def get_session(
203213
async def list_sessions(
204214
self, *, app_name: str, user_id: str
205215
) -> ListSessionsResponse:
206-
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
216+
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
217+
api_client = self._get_api_client()
207218

208219
path = f'reasoningEngines/{reasoning_engine_id}/sessions'
209220
if user_id:
210221
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
211222
path = path + f'?filter=user_id={parsed_user_id}'
212223

213-
api_client = _get_api_client(self.project, self.location)
214224
api_response = await api_client.async_request(
215225
http_method='GET',
216226
path=path,
@@ -236,8 +246,9 @@ async def list_sessions(
236246
async def delete_session(
237247
self, *, app_name: str, user_id: str, session_id: str
238248
) -> None:
239-
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
240-
api_client = _get_api_client(self.project, self.location)
249+
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
250+
api_client = self._get_api_client()
251+
241252
try:
242253
await api_client.async_request(
243254
< 10000 span class=pl-s1>http_method='DELETE',
@@ -253,24 +264,43 @@ async def append_event(self, session: Session, event: Event) -> Event:
253264
# Update the in-memory session.
254265
await super().append_event(session=session, event=event)
255266

256-
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
257-
api_client = _get_api_client(self.project, self.location)
267+
reasoning_engine_id = self._get_reasoning_engine_id(session.app_name)
268+
api_client = self._get_api_client()
258269
await api_client.async_request(
259270
http_method='POST',
260271
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
261272
request_dict=_convert_event_to_json(event),
262273
)
263274
return event
264275

276+
def _get_reasoning_engine_id(self, app_name: str):
277+
if self._agent_engine_id:
278+
return self._agent_engine_id
265279

266-
def _get_api_client(project: str, location: str):
267-
"""Instantiates an API client for the given project and location.
280+
if app_name.isdigit():
281+
return app_name
268282

269-
It needs to be instantiated inside each request so that the event loop
270-
management.
271-
"""
272-
client = genai.Client(vertexai=True, project=project, location=location)
273-
return client._api_client
283+
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
284+
match = re.fullmatch(pattern, app_name)
285+
286+
if not bool(match):
287+
raise ValueError(
288+
f'App name {app_name} is not valid. It should either be the full'
289+
' ReasoningEngine resource name, or the reasoning engine id.'
290+
)
291+
292+
return match.groups()[-1]
293+
294+
def _get_api_client(self):
295+
"""Instantiates an API client for the given project and location.
296+
297+
It needs to be instantiated inside each request so that the event loop
298+
management can be properly propagated.
299+
"""
300+
client = genai.Client(
301+
vertexai=True, project=self._project, location=self._location
302+
)
303+
return client._api_client
274304

275305

276306
def _convert_event_to_json(event: Event) -> Dict[str, Any]:
@@ -366,19 +396,3 @@ def _from_api_event(api_event: Dict[str, Any]) -> Event:
366396
)
367397

368398
return event
369-
370-
371-
def _parse_reasoning_engine_id(app_name: str):
372-
if app_name.isdigit():
373-
return app_name
374-
375-
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
376-
match = re.fullmatch(pattern, app_name)
377-
378-
if not bool(match):
379-
raise ValueError(
380-
f'App name {app_name} is not valid. It should either be the full'
381-
' ReasoningEngine resource name, or the reasoning engine id.'
382-
)
383-
384-
return match.groups()[-1]

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,14 @@ async def async_request(
245245
raise ValueError(f'Unsupported http method: {http_method}')
246246

247247

248-
def mock_vertex_ai_session_service():
248+
def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None):
249249
"""Creates a mock Vertex AI Session service for testing."""
250+
if agent_engine_id:
251+
return VertexAiSessionService(
252+
project='test-project',
253+
location='test-location',
254+
agent_engine_id=agent_engine_id,
255+
)
250256
return VertexAiSessionService(
251257
project='test-project', location='test-location'
252258
)
@@ -265,16 +271,20 @@ def mock_get_api_client():
265271
'2': (MOCK_EVENT_JSON_2, 'my_token'),
266272
}
267273
with mock.patch(
268-
'google.adk.sessions.vertex_ai_session_service._get_api_client',
274+
'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client',
269275
return_value=api_client,
270276
):
271277
yield
272278

273279

274280
@pytest.mark.asyncio
275281
@pytest.mark.usefixtures('mock_get_api_client')
276-
async def test_get_empty_session():
277-
session_service = mock_vertex_ai_session_service()
282+
@pytest.mark.parametrize('agent_engine_id', [None, '123'])
283+
async def test_get_empty_session(agent_engine_id):
284+
if agent_engine_id:
285+
session_service = mock_vertex_ai_session_service(agent_engine_id)
286+
else:
287+
session_service = mock_vertex_ai_session_service()
278288
with pytest.raises(ValueError) as excinfo:
279289
await session_service.get_session(
280290
app_name='123', user_id='user', session_id='0'

0 commit comments

Comments
 (0)
0