Skip to content

Commit 00cc8cd

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add Vertex Express mode compatibility for VertexAiSessionService
PiperOrigin-RevId: 775317848
1 parent 9597a44 commit 00cc8cd

File tree

1 file changed

+53
-18
lines changed

1 file changed

+53
-18
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
import asyncio
1717
import json
1818
import logging
19+
import os
1920
import re
2021
from typing import Any
2122
from typing import Dict
2223
from typing import Optional
2324
import urllib.parse
2425

2526
from dateutil import parser
27+
from google.genai.errors import ClientError
2628
from typing_extensions import override
2729

2830
from google import genai
@@ -95,25 +97,46 @@ async def create_session(
9597
operation_id = api_response['name'].split('/')[-1]
9698

9799
max_retry_attempt = 5
98-
lro_response = None
99-
while max_retry_attempt >= 0:
100-
lro_response = await api_client.async_request(
101-
http_method='GET',
102-
path=f'operations/{operation_id}',
103-
request_dict={},
104-
)
105-
lro_response = _convert_api_response(lro_response)
106100

107-
if lro_response.get('done', None):
108-
break
109-
110-
await asyncio.sleep(1)
111-
max_retry_attempt -= 1
112-
113-
if lro_response is None or not lro_response.get('done', None):
114-
raise TimeoutError(
115-
f'Timeout waiting for operation {operation_id} to complete.'
116-
)
101+
if _is_vertex_express_mode(self._project, self._location):
102+
# Express mode doesn't support LRO, so we need to poll
103+
# the session resource.
104+
# TODO: remove this once LRO polling is supported in Express mode.
105+
for i in range(max_retry_attempt):
106+
try:
107+
await api_client.async_request(
108+
http_method='GET',
109+
path=(
110+
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
111+
),
112+
request_dict={},
113+
)
114+
break
115+
except ClientError as e:
116+
logger.info('Polling for session %s: %s', session_id, e)
117+
# Add slight exponential backoff to avoid excessive polling.
118+
await asyncio.sleep(1 + 0.5 * i)
119+
else:
120+
raise TimeoutError('Session creation failed.')
121+
else:
122+
lro_response = None
123+
for _ in range(max_retry_attempt):
124+
lro_response = await api_client.async_request(
125+
http_method='GET',
126+
path=f'operations/{operation_id}',
127+
request_dict={},
128+
)
129+
lro_response = _convert_api_response(lro_response)
130+
131+
if lro_response.get('done', None):
132+
break
133+
134+
await asyncio.sleep(1)
135+
136+
if lro_response is None or not lro_response.get('done', None):
137+
raise TimeoutError(
138+
f'Timeout waiting for operation {operation_id} to complete.'
139+
)
117140

118141
# Get session resource
119142
get_session_api_response = await api_client.async_request(
@@ -312,6 +335,18 @@ def _get_api_client(self):
312335
return client._api_client
313336

314337

338+
def _is_vertex_express_mode(
339+
project: Optional[str], location: Optional[str]
340+
) -> bool:
341+
"""Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode."""
342+
return (
343+
os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in ['true', '1']
344+
and os.environ.get('GOOGLE_API_KEY', None) is not None
345+
and project is None
346+
and location is None
347+
)
348+
349+
315350
def _convert_api_response(api_response):
316351
"""Converts the API response to a JSON object based on the type."""
317352
if hasattr(api_response, 'body'):

0 commit comments

Comments
 (0)