Skip to content

Commit 3d2f13c

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: Update the retry logic of create session polling
This should slightly increase the timeout also reduce the polling frequency. PiperOrigin-RevId: 778323416
1 parent 9af2394 commit 3d2f13c

File tree

1 file changed

+60
-42
lines changed

1 file changed

+60
-42
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525

2626
from dateutil import parser
2727
from google.genai.errors import ClientError
28+
from tenacity import retry
29+
from tenacity import retry_if_result
30+
from tenacity import RetryError
31+
from tenacity import stop_after_attempt
32+
from tenacity import wait_exponential
2833
from typing_extensions import override
2934

3035
from google import genai
@@ -64,6 +69,20 @@ def __init__(
6469
self._location = location
6570
self._agent_engine_id = agent_engine_id
6671

72+
async def _get_session_api_response(
73+
self,
74+
reasoning_engine_id: str,
75+
session_id: str,
76+
api_client: genai.ApiClient,
77+
):
78+
get_session_api_response = await api_client.async_request(
79+
http_method='GET',
80+
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
81+
request_dict={},
82+
)
83+
get_session_api_response = _convert_api_response(get_session_api_response)
84+
return get_session_api_response
85+
6786
@override
6887
async def create_session(
6988
self,
@@ -95,66 +114,68 @@ async def create_session(
95114

96115
session_id = api_response['name'].split('/')[-3]
97116
operation_id = api_response['name'].split('/')[-1]
98-
99-
max_retry_attempt = 5
100-
101117
if _is_vertex_express_mode(self._project, self._location):
102118
# Express mode doesn't support LRO, so we need to poll
103119
# the session resource.
104120
# TODO: remove this once LRO polling is supported in Express mode.
105-
for i in range(max_retry_attempt):
121+
@retry(
122+
stop=stop_after_attempt(5),
123+
wait=wait_exponential(multiplier=1, min=1, max=3),
124+
retry=retry_if_result(lambda response: not response),
125+
reraise=True,
126+
)
127+
async def _poll_session_resource():
106128
try:
107-
await api_client.async_request(
108-
http_method='GET',
109-
path=(
110-
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
111-
),
112-
request_dict={},
129+
return await self._get_session_api_response(
130+
reasoning_engine_id, session_id, api_client
113131
)
114-
break
115-
except ClientError as e:
116-
logger.info('Polling for session %s: %s', session_id, e)
117-
# Add slight exponential backoff to avoid excessive polling.
118-
await asyncio.sleep(1 + 0.5 * i)
119-
else:
120-
raise TimeoutError('Session creation failed.')
132+
except ClientError:
133+
logger.info(f'Polling session resource')
134+
return None
135+
136+
try:
137+
await _poll_session_resource()
138+
except Exception as exc:
139+
raise ValueError('Failed to create session.') from exc
121140
else:
122-
lro_response = None
123-
for _ in range(max_retry_attempt):
141+
142+
@retry(
143+
stop=stop_after_attempt(5),
144+
wait=wait_exponential(multiplier=1, min=1, max=3),
145+
retry=retry_if_result(
146+
lambda response: not response.get('done', False),
147+
),
148+
reraise=True,
149+
)
150+
async def _poll_lro():
124151
lro_response = await api_client.async_request(
125152
http_method='GET',
126153
path=f'operations/{operation_id}',
127154
request_dict={},
128155
)
129156
lro_response = _convert_api_response(lro_response)
157+
return lro_response
130158

131-
if lro_response.get('done', None):
132-
break
133-
134-
await asyncio.sleep(1)
135-
136-
if lro_response is None or not lro_response.get('done', None):
159+
try:
160+
await _poll_lro()
161+
except RetryError as exc:
137162
raise TimeoutError(
138163
f'Timeout waiting for operation {operation_id} to complete.'
139-
)
164+
) from exc
165+
except Exception as exc:
166+
raise ValueError('Failed to create session.') from exc
140167

141-
# Get session resource
142-
get_session_api_response = await api_client.async_request(
143-
http_method='GET',
144-
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
145-
request_dict={},
168+
get_session_api_response = await self._get_session_api_response(
169+
reasoning_engine_id, session_id, api_client
146170
)
147-
get_session_api_response = _convert_api_response(get_session_api_response)
148-
149-
update_timestamp = isoparse(
150-
get_session_api_response['updateTime']
151-
).timestamp()
152171
session = Session(
153172
app_name=str(app_name),
154173
user_id=str(user_id),
155174
id=str(session_id),
156175
state=get_session_api_response.get('sessionState', {}),
157-
last_update_time=update_timestamp,
176+
last_update_time=isoparse(
177+
get_session_api_response['updateTime']
178+
).timestamp(),
158179
)
159180
return session
160181

@@ -171,12 +192,9 @@ async def get_session(
171192
api_client = self._get_api_client()
172193

173194
# Get session resource
174-
get_session_api_response = await api_client.async_request(
175-
http_method='GET',
176-
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
177-
request_dict={},
195+
get_session_api_response = await self._get_session_api_response(
196+
reasoning_engine_id, session_id, api_client
178197
)
179-
get_session_api_response = _convert_api_response(get_session_api_response)
180198

181199
if get_session_api_response['userId'] != user_id:
182200
raise ValueError(f'Session not found: {session_id}')

0 commit comments

Comments
 (0)