|
16 | 16 | import asyncio
|
17 | 17 | import json
|
18 | 18 | import logging
|
| 19 | +import os |
19 | 20 | import re
|
20 | 21 | from typing import Any
|
21 | 22 | from typing import Dict
|
22 | 23 | from typing import Optional
|
23 | 24 | import urllib.parse
|
24 | 25 |
|
25 | 26 | from dateutil import parser
|
| 27 | +from google.genai.errors import ClientError |
26 | 28 | from typing_extensions import override
|
27 | 29 |
|
28 | 30 | from google import genai
|
@@ -95,25 +97,46 @@ async def create_session(
|
95 | 97 | operation_id = api_response['name'].split('/')[-1]
|
96 | 98 |
|
97 | 99 | max_retry_attempt = 5
|
98 |
| - lro_response = None |
99 |
| - while max_retry_attempt >= 0: |
100 |
| - lro_response = await api_client.async_request( |
101 |
| - http_method='GET', |
102 |
| - path=f'operations/{operation_id}', |
103 |
| - request_dict={}, |
104 |
| - ) |
105 |
| - lro_response = _convert_api_response(lro_response) |
106 | 100 |
|
107 |
| - if lro_response.get('done', None): |
108 |
| - break |
109 |
| - |
110 |
| - await asyncio.sleep(1) |
111 |
| - max_retry_attempt -= 1 |
112 |
| - |
113 |
| - if lro_response is None or not lro_response.get('done', None): |
114 |
| - raise TimeoutError( |
115 |
| - f'Timeout waiting for operation {operation_id} to complete.' |
116 |
| - ) |
| 101 | + if _is_vertex_express_mode(self._project, self._location): |
| 102 | + # Express mode doesn't support LRO, so we need to poll |
| 103 | + # the session resource. |
| 104 | + # TODO: remove this once LRO polling is supported in Express mode. |
| 105 | + for i in range(max_retry_attempt): |
| 106 | + try: |
| 107 | + await api_client.async_request( |
| 108 | + http_method='GET', |
| 109 | + path=( |
| 110 | + f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' |
| 111 | + ), |
| 112 | + request_dict={}, |
| 113 | + ) |
| 114 | + break |
| 115 | + except ClientError as e: |
| 116 | + logger.info('Polling for session %s: %s', session_id, e) |
| 117 | + # Add slight exponential backoff to avoid excessive polling. |
| 118 | + await asyncio.sleep(1 + 0.5 * i) |
| 119 | + else: |
| 120 | + raise TimeoutError('Session creation failed.') |
| 121 | + else: |
| 122 | + lro_response = None |
| 123 | + for _ in range(max_retry_attempt): |
| 124 | + lro_response = await api_client.async_request( |
| 125 | + http_method='GET', |
| 126 | + path=f'operations/{operation_id}', |
| 127 | + request_dict={}, |
| 128 | + ) |
| 129 | + lro_response = _convert_api_response(lro_response) |
| 130 | + |
| 131 | + if lro_response.get('done', None): |
| 132 | + break |
| 133 | + |
| 134 | + await asyncio.sleep(1) |
| 135 | + |
| 136 | + if lro_response is None or not lro_response.get('done', None): |
| 137 | + raise TimeoutError( |
| 138 | + f'Timeout waiting for operation {operation_id} to complete.' |
| 139 | + ) |
117 | 140 |
|
118 | 141 | # Get session resource
|
119 | 142 | get_session_api_response = await api_client.async_request(
|
@@ -312,6 +335,18 @@ def _get_api_client(self):
|
312 | 335 | return client._api_client
|
313 | 336 |
|
314 | 337 |
|
| 338 | +def _is_vertex_express_mode( |
| 339 | + project: Optional[str], location: Optional[str] |
| 340 | +) -> bool: |
| 341 | + """Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode.""" |
| 342 | + return ( |
| 343 | + os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in ['true', '1'] |
| 344 | + and os.environ.get('GOOGLE_API_KEY', None) is not None |
| 345 | + and project is None |
| 346 | + and location is None |
| 347 | + ) |
| 348 | + |
| 349 | + |
315 | 350 | def _convert_api_response(api_response):
|
316 | 351 | """Converts the API response to a JSON object based on the type."""
|
317 | 352 | if hasattr(api_response, 'body'):
|
|
0 commit comments