8000 fix: Fix filtering by user_id for vertex ai session service listing · stevenchendan/adk-python@9d4ca4e · GitHub
[go: up one dir, main page]

Skip to content

Commit 9d4ca4e

Browse files
dpark27copybara-github
authored andcommitted
fix: Fix filtering by user_id for vertex ai session service listing
When the user id contains special characters (i.e. an email), we have added in extra url parsing to address those characters. We have also added an if statement to use the correct url when there is no user_id supplied. Copybara import of the project: -- ef84990 by Danny Park <dpark@calicolabs.com>: -- 773cd2b by Danny Park <dpark@calicolabs.com>: COPYBARA_INTEGRATE_REVIEW=google#996 from dpark27:fix/list_vertex_ai_sessions d351d7f PiperOrigin-RevId: 764522026
1 parent fc3e374 commit 9d4ca4e

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
import asyncio
1717
import logging
1818
import re
19+
import time
1920
from typing import Any
2021
from typing import Optional
22+
import urllib.parse
2123

2224
from dateutil import parser
2325
from google.genai import types
@@ -186,10 +188,15 @@ async def list_sessions(
186188
) -> ListSessionsResponse:
187189
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
188190

191+
path = f"reasoningEngines/{reasoning_engine_id}/sessions"
192+
if user_id:
193+
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe="")
194+
path = path + f"?filter=user_id={parsed_user_id}"
195+
189196
api_client = _get_api_client(self.project, self.location)
190197
api_response = await api_client.async_request(
191198
http_method='GET',
192-
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
199+
path=path,
193200
request_dict={},
194201
)
195202

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@
111111

112112

113113
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
114-
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$'
114+
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$' # %22 represents double-quotes in a URL-encoded string
115115
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
116116
LRO_REGEX = r'^operations/([^/]+)$'
117117

@@ -127,7 +127,7 @@ def __init__(self) -> None:
127127
async def async_request(
128128
self, http_method: str, path: str, request_dict: dict[str, Any]
129129
):
130-
"""Mocks the API Client request method."""
130+
"""Mocks the API Client request method"""
131131
if http_method == 'GET':
132132
if re.match(SESSION_REGEX, path):
133133
match = re.match(SESSION_REGEX, path)
@@ -149,20 +149,14 @@ async def async_request(
149149
elif re.match(EVENTS_REGEX, path):
150150
match = re.match(EVENTS_REGEX, path)
151151
if match:
152-
return {
153-
'sessionEvents': (
154-
self.event_dict[match.group(2)]
155-
if match.group(2) in self.event_dict
156-
else []
157-
)
158-
}
152+
session_id = match.group(2)
153+
return {'sessionEvents': self.event_dict.get(session_id, [])}
159154
elif re.match(LRO_REGEX, path):
155+
# Mock long-running operation as completed
160156
return {
161-
'name': (
162-
'projects/test-project/locations/test-location/'
163-
'reasoningEngines/123/sessions/4'
164-
),
157+
'name': path,
165158
'done': True,
159+
'response': self.session_dict['4'] # Return the created session
166160
}
167161
else:
168162
raise ValueError(f'Unsupported path: {path}')
@@ -225,10 +219,10 @@ def mock_get_api_client():
225219
async def test_get_empty_session():
226220
session_service = mock_vertex_ai_session_service()
227221
with pytest.raises(ValueError) as excinfo:
228-
assert await session_service.get_session(
222+
await session_service.get_session(
229223
app_name='123', user_id='user', session_id='0'
230224
)
231-
assert str(excinfo.value) == 'Session not found: 0'
225+
assert str(excinfo.value) == 'Session not found: 0'
232226

233227

234228
@pytest.mark.asyncio
@@ -247,10 +241,10 @@ async def test_get_and_delete_session():
247241
app_name='123', user_id='user', session_id='1'
248242
)
249243
with pytest.raises(ValueError) as excinfo:
250-
assert await session_service.get_session(
244+
await session_service.get_session(
251245
app_name='123', user_id='user', session_id='1'
252246
)
253-
assert str(excinfo.value) == 'Session not found: 1'
247+
assert str(excinfo.value) == 'Session not found: 1'
254248

255249

256250
@pytest.mark.asyncio
@@ -292,6 +286,6 @@ async def test_create_session_with_custom_session_id():
292286
await session_service.create_session(
293287
app_name='123', user_id='user', session_id='1'
294288
)
295-
assert str(excinfo.value) == (
296-
'User-provided Session id is not supported for VertexAISessionService.'
297-
)
289+
assert str(excinfo.value) == (
290+
'User-provided Session id is not supported for VertexAISessionService.'
291+
)

0 commit comments

Comments
 (0)
0