diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 258dcd933..bd1345162 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -14,6 +14,7 @@ from __future__ import annotations import asyncio +import json import logging import re from typing import Any @@ -87,6 +88,7 @@ async def create_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions', request_dict=session_json_dict, ) + api_response = _convert_api_response(api_response) logger.info(f'Create Session response {api_response}') session_id = api_response['name'].split('/')[-3] @@ -100,6 +102,7 @@ async def create_session( path=f'operations/{operation_id}', request_dict={}, ) + lro_response = _convert_api_response(lro_response) if lro_response.get('done', None): break @@ -118,6 +121,7 @@ async def create_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) + get_session_api_response = _convert_api_response(get_session_api_response) update_timestamp = isoparse( get_session_api_response['updateTime'] @@ -149,6 +153,7 @@ async def get_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, ) + get_session_api_response = _convert_api_response(get_session_api_response) session_id = get_session_api_response['name'].split('/')[-1] update_timestamp = isoparse( @@ -167,9 +172,12 @@ async def get_session( path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', request_dict={}, ) + list_events_api_response = _convert_api_response(list_events_api_response) # Handles empty response case - if list_events_api_response.get('httpHeaders', None): + if not list_events_api_response or list_events_api_response.get( + 'httpHeaders', None + ): return session session.events += [ @@ -226,9 +234,10 @@ async def list_sessions( path=path, request_dict={}, ) + api_response = _convert_api_response(api_response) # Handles empty response case - if api_response.get('httpHeaders', None): + if not api_response or api_response.get('httpHeaders', None): return ListSessionsResponse() sessions = [] @@ -303,6 +312,13 @@ def _get_api_client(self): return client._api_client +def _convert_api_response(api_response): + """Converts the API response to a JSON object based on the type.""" + if hasattr(api_response, 'body'): + return json.loads(api_response.body) + return api_response + + def _convert_event_to_json(event: Event) -> Dict[str, Any]: metadata_json = { 'partial': event.partial,