|
15 | 15 |
|
16 | 16 | import asyncio
|
17 | 17 | import logging
|
| 18 | +import os |
18 | 19 | import re
|
19 | 20 | from typing import Any
|
20 | 21 | from typing import Dict
|
|
25 | 26 | from typing_extensions import override
|
26 | 27 |
|
27 | 28 | from google import genai
|
| 29 | +from google.genai.errors import ClientError |
28 | 30 |
|
29 | 31 | from . import _session_util
|
30 | 32 | from ..events.event import Event
|
|
34 | 36 | from .base_session_service import ListSessionsResponse
|
35 | 37 | from .session import Session
|
36 | 38 |
|
| 39 | + |
37 | 40 | isoparse = parser.isoparse
|
38 | 41 | logger = logging.getLogger('google_adk.' + __name__)
|
39 | 42 |
|
@@ -93,24 +96,47 @@ async def create_session(
|
93 | 96 | operation_id = api_response['name'].split('/')[-1]
|
94 | 97 |
|
95 | 98 | max_retry_attempt = 5
|
96 |
| - lro_response = None |
97 |
| - while max_retry_attempt >= 0: |
98 |
| - lro_response = await api_client.async_request( |
99 |
| - http_method='GET', |
100 |
| - path=f'operations/{operation_id}', |
101 |
| - request_dict={}, |
102 |
| - ) |
103 |
| - |
104 |
| - if lro_response.get('done', None): |
105 |
| - break |
106 | 99 |
|
107 |
| - await asyncio.sleep(1) |
108 |
| - max_retry_attempt -= 1 |
109 |
| - |
110 |
| - if lro_response is None or not lro_response.get('done', None): |
111 |
| - raise TimeoutError( |
112 |
| - f'Timeout waiting for operation {operation_id} to complete.' |
113 |
| - ) |
| 100 | + if _is_vertex_express_mode(): |
| 101 | + # Express mode doesn't support LRO, so we need to poll |
| 102 | + # the session resource. |
| 103 | + # TODO: remove this once LRO polling is supported in Express mode. |
| 104 | + while max_retry_attempt >= 0: |
| 105 | + try: |
| 106 | + await api_client.async_request( |
| 107 | + http_method='GET', |
| 108 | + path=( |
| 109 | + f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' |
| 110 | + ), |
| 111 | + request_dict={}, |
| 112 | + ) |
| 113 | + break |
| 114 | + except ClientError as e: |
| 115 | + logger.info('Polling for session %s: %s', session_id, e) |
| 116 | + await asyncio.sleep(1) |
| 117 | + max_retry_attempt -= 1 |
| 118 | + continue |
| 119 | + if max_retry_attempt < 0: |
| 120 | + raise TimeoutError('Session creation failed.') |
| 121 | + else: |
| 122 | + lro_response = None |
| 123 | + while max_retry_attempt >= 0: |
| 124 | + lro_response = await api_client.async_request( |
| 125 | + http_method='GET', |
| 126 | + path=f'operations/{operation_id}', |
| 127 | + request_dict={}, |
| 128 | + ) |
| 129 | + |
| 130 | + if lro_response.get('done', None): |
| 131 | + break |
| 132 | + |
| 133 | + await asyncio.sleep(1) |
| 134 | + max_retry_attempt -= 1 |
| 135 | + |
| 136 | + if lro_response is None or not lro_response.get('done', None): |
| 137 | + raise TimeoutError( |
| 138 | + f'Timeout waiting for operation {operation_id} to complete.' |
| 139 | + ) |
114 | 140 |
|
115 | 141 | # Get session resource
|
116 | 142 | get_session_api_response = await api_client.async_request(
|
@@ -303,6 +329,15 @@ def _get_api_client(self):
|
303 | 329 | return client._api_client
|
304 | 330 |
|
305 | 331 |
|
| 332 | +def _is_vertex_express_mode() -> bool: |
| 333 | + """Check if Vertex AI and API key are both enabled, meaning the user is using the Vertex Express Mode. |
| 334 | + """ |
| 335 | + return ( |
| 336 | + os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in ['true', '1'] |
| 337 | + and os.environ.get('GOOGLE_API_KEY', None) is not None |
| 338 | + ) |
| 339 | + |
| 340 | + |
306 | 341 | def _convert_event_to_json(event: Event) -> Dict[str, Any]:
|
307 | 342 | metadata_json = {
|
308 | 343 | 'partial': event.partial,
|
|
0 commit comments