|
15 | 15 |
|
16 | 16 | import asyncio
|
17 | 17 | import logging
|
| 18 | +import os |
18 | 19 | import re
|
19 | 20 | from typing import Any
|
20 | 21 | from typing import Dict
|
21 | 22 | from typing import Optional
|
22 | 23 | import urllib.parse
|
23 | 24 |
|
24 | 25 | from dateutil import parser
|
| 26 | +from google.genai.errors import ClientError |
25 | 27 | from typing_extensions import override
|
26 | 28 |
|
27 | 29 | from google import genai
|
@@ -93,24 +95,47 @@ async def create_session(
|
93 | 95 | operation_id = api_response['name'].split('/')[-1]
|
94 | 96 |
|
95 | 97 | max_retry_attempt = 5
|
96 |
| - lro_response = None |
97 |
| - while max_retry_attempt >= 0: |
98 |
| - lro_response = await api_client.async_request( |
99 |
| - http_method='GET', |
100 |
| - path=f'operations/{operation_id}', |
101 |
| - request_dict={}, |
102 |
| - ) |
103 |
| - |
104 |
| - if lro_response.get('done', None): |
105 |
| - break |
106 |
| - |
107 |
| - await asyncio.sleep(1) |
108 |
| - max_retry_attempt -= 1 |
109 | 98 |
|
110 |
| - if lro_response is None or not lro_response.get('done', None): |
111 |
| - raise TimeoutError( |
112 |
| - f'Timeout waiting for operation {operation_id} to complete.' |
113 |
| - ) |
| 99 | + if _is_vertex_express_mode(self._project, self._location): |
| 100 | + # Express mode doesn't support LRO, so we need to poll |
| 101 | + # the session resource. |
| 102 | + # TODO: remove this once LRO polling is supported in Express mode. |
| 103 | + while max_retry_attempt >= 0: |
| 104 | + try: |
| 105 | + await api_client.async_request( |
| 106 | + http_method='GET', |
| 107 | + path=( |
| 108 | + f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' |
| 109 | + ), |
| 110 | + request_dict={}, |
| 111 | + ) |
| 112 | + break |
| 113 | + except ClientError as e: |
| 114 | + logger.info('Polling for session %s: %s', session_id, e) |
| 115 | + await asyncio.sleep(1) |
| 116 | + max_retry_attempt -= 1 |
| 117 | + continue |
| 118 | + if max_retry_attempt < 0: |
| 119 | + raise TimeoutError('Session creation failed.') |
| 120 | + else: |
| 121 | + lro_response = None |
| 122 | + while max_retry_attempt >= 0: |
| 123 | + lro_response = await api_client.async_request( |
| 124 | + http_method='GET', |
| 125 | + path=f'operations/{operation_id}', |
| 126 | + request_dict={}, |
| 127 | + ) |
| 128 | + |
| 129 | + if lro_response.get('done', None): |
| 130 | + break |
| 131 | + |
| 132 | + await asyncio.sleep(1) |
| 133 | + max_retry_attempt -= 1 |
| 134 | + |
| 135 | + if lro_response is None or not lro_response.get('done', None): |
| 136 | + raise TimeoutError( |
| 137 | + f'Timeout waiting for operation {operation_id} to complete.' |
| 138 | + ) |
114 | 139 |
|
115 | 140 | # Get session resource
|
116 | 141 | get_session_api_response = await api_client.async_request(
|
@@ -300,9 +325,24 @@ def _get_api_client(self):
|
300 | 325 | client = genai.Client(
|
301 | 326 | vertexai=True, project=self._project, location=self._location
|
302 | 327 | )
|
| 328 | + client._api_client._http_options.base_url = ( |
| 329 | + 'https://staging-aiplatform.sandbox.googleapis.com' |
| 330 | + ) |
303 | 331 | return client._api_client
|
304 | 332 |
|
305 | 333 |
|
| 334 | +def _is_vertex_express_mode( |
| 335 | + project
8326
: Optional[str], location: Optional[str] |
| 336 | +) -> bool: |
| 337 | + """Check if Vertex AI and API key are both enabled, meaning the user is using the Vertex Express Mode.""" |
| 338 | + return ( |
| 339 | + os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in ['true', '1'] |
| 340 | + and os.environ.get('GOOGLE_API_KEY', None) is not None |
| 341 | + and project is None |
| 342 | + and location is None |
| 343 | + ) |
| 344 | + |
| 345 | + |
306 | 346 | def _convert_event_to_json(event: Event) -> Dict[str, Any]:
|
307 | 347 | metadata_json = {
|
308 | 348 | 'partial': event.partial,
|
|
0 commit comments