25
25
26
26
from dateutil import parser
27
27
from google .genai .errors import ClientError
28
+ from tenacity import retry
29
+ from tenacity import retry_if_result
30
+ from tenacity import RetryError
31
+ from tenacity import stop_after_attempt
32
+ from tenacity import wait_exponential
28
33
from typing_extensions import override
29
34
30
35
from google import genai
@@ -64,6 +69,20 @@ def __init__(
64
69
self ._location = location
65
70
self ._agent_engine_id = agent_engine_id
66
71
72
+ async def _get_session_api_response (
73
+ self ,
74
+ reasoning_engine_id : str ,
75
+ session_id : str ,
76
+ api_client : genai .ApiClient ,
77
+ ):
78
+ get_session_api_response = await api_client .async_request (
79
+ http_method = 'GET' ,
80
+ path = f'reasoningEngines/{ reasoning_engine_id } /sessions/{ session_id } ' ,
81
+ request_dict = {},
82
+ )
83
+ get_session_api_response = _convert_api_response (get_session_api_response )
84
+ return get_session_api_response
85
+
67
86
@override
68
87
async def create_session (
69
88
self ,
@@ -95,66 +114,68 @@ async def create_session(
95
114
96
115
session_id = api_response ['name' ].split ('/' )[- 3 ]
97
116
operation_id = api_response ['name' ].split ('/' )[- 1 ]
98
-
99
- max_retry_attempt = 5
100
-
101
117
if _is_vertex_express_mode (self ._project , self ._location ):
102
118
# Express mode doesn't support LRO, so we need to poll
103
119
# the session resource.
104
120
# TODO: remove this once LRO polling is supported in Express mode.
105
- for i in range (max_retry_attempt ):
121
+ @retry (
122
+ stop = stop_after_attempt (5 ),
123
+ wait = wait_exponential (multiplier = 1 , min = 1 , max = 3 ),
124
+ retry = retry_if_result (lambda response : not response ),
125
+ reraise = True ,
126
+ )
127
+ async def _poll_session_resource ():
106
128
try :
107
- await api_client .async_request (
108
- http_method = 'GET' ,
109
- path = (
110
- f'reasoningEngines/{ reasoning_engine_id } /sessions/{ session_id } '
111
- ),
112
- request_dict = {},
129
+ return await self ._get_session_api_response (
130
+ reasoning_engine_id , session_id , api_client
113
131
)
114
- break
115
- except ClientError as e :
116
- logger .info ('Polling for session %s: %s' , session_id , e )
117
- # Add slight exponential backoff to avoid excessive polling.
118
- await asyncio .sleep (1 + 0.5 * i )
119
- else :
120
- raise TimeoutError ('Session creation failed.' )
132
+ except ClientError :
133
+ logger .info (f'Polling session resource' )
134
+ return None
135
+
136
+ try :
137
+ await _poll_session_resource ()
138
+ except Exception as exc :
139
+ raise ValueError ('Failed to create session.' ) from exc
121
140
else :
122
- lro_response = None
123
- for _ in range (max_retry_attempt ):
141
+
142
+ @retry (
143
+ stop = stop_after_attempt (5 ),
144
+ wait = wait_exponential (multiplier = 1 , min = 1 , max = 3 ),
145
+ retry = retry_if_result (
146
+ lambda response : not response .get ('done' , False ),
147
+ ),
148
+ reraise = True ,
149
+ )
150
+ async def _poll_lro ():
124
151
lro_response = await api_client .async_request (
125
152
http_method = 'GET' ,
126
153
path = f'operations/{ operation_id } ' ,
127
154
request_dict = {},
128
155
)
129
156
lro_response = _convert_api_response (lro_response )
157
+ return lro_response
130
158
131
- if lro_response .get ('done' , None ):
132
- break
133
-
134
- await asyncio .sleep (1 )
135
-
136
- if lro_response is None or not lro_response .get ('done' , None ):
159
+ try :
160
+ await _poll_lro ()
161
+ except RetryError as exc :
137
162
raise TimeoutError (
138
163
f'Timeout waiting for operation { operation_id } to complete.'
139
- )
164
+ ) from exc
165
+ except Exception as exc :
166
+ raise ValueError ('Failed to create session.' ) from exc
140
167
141
- # Get session resource
142
- get_session_api_response = await api_client .async_request (
143
- http_method = 'GET' ,
144
- path = f'reasoningEngines/{ reasoning_engine_id } /sessions/{ session_id } ' ,
145
- request_dict = {},
168
+ get_session_api_response = await self ._get_session_api_response (
169
+ reasoning_engine_id , session_id , api_client
146
170
)
147
- get_session_api_response = _convert_api_response (get_session_api_response )
148
-
149
- update_timestamp = isoparse (
150
- get_session_api_response ['updateTime' ]
151
- ).timestamp ()
152
171
session = Session (
153
172
app_name = str (app_name ),
154
173
user_id = str (user_id ),
155
174
id = str (session_id ),
156
175
state = get_session_api_response .get ('sessionState' , {}),
157
- last_update_time = update_timestamp ,
176
+ last_update_time = isoparse (
177
+ get_session_api_response ['updateTime' ]
178
+ ).timestamp (),
158
179
)
159
180
return session
160
181
@@ -171,12 +192,9 @@ async def get_session(
171
192
api_client = self ._get_api_client ()
172
193
173
194
# Get session resource
174
- get_session_api_response = await api_client .async_request (
175
- http_method = 'GET' ,
176
- path = f'reasoningEngines/{ reasoning_engine_id } /sessions/{ session_id } ' ,
177
- request_dict = {},
195
+ get_session_api_response = await self ._get_session_api_response (
196
+ reasoning_engine_id , session_id , api_client
178
197
)
179
- get_session_api_response = _convert_api_response (get_session_api_response )
180
198
181
199
if get_session_api_response ['userId' ] != user_id :
182
200
raise ValueError (f'Session not found: { session_id } ' )
0 commit comments