16
16
import asyncio
17
17
import logging
18
18
import re
19
- import time
20
19
from typing import Any
20
+ from typing import Dict
21
21
from typing import Optional
22
22
import urllib .parse
23
23
@@ -50,9 +50,6 @@ def __init__(
50
50
self .project = project
51
51
self .location = location
52
52
53
- client = genai .Client (vertexai = True , project = project , location = location )
54
- self .api_client = client ._api_client
55
-
56
53
@override
57
54
async def create_session (
58
55
self ,
@@ -86,6 +83,7 @@ async def create_session(
86
83
operation_id = api_response ['name' ].split ('/' )[- 1 ]
87
84
88
85
max_retry_attempt = 5
86
+ lro_response = None
89
87
while max_retry_attempt >= 0 :
90
88
lro_response = await api_client .async_request (
91
89
http_method = 'GET' ,
@@ -99,6 +97,11 @@ async def create_session(
99
97
await asyncio .sleep (1 )
100
98
max_retry_attempt -= 1
101
99
100
+ if lro_response is None or not lro_response .get ('done' , None ):
101
+ raise TimeoutError (
102
+ f'Timeout waiting for operation { operation_id } to complete.'
103
+ )
104
+
102
105
# Get session resource
103
106
get_session_api_response = await api_client .async_request (
104
107
http_method = 'GET' ,
@@ -235,11 +238,15 @@ async def delete_session(
235
238
) -> None :
236
239
reasoning_engine_id = _parse_reasoning_engine_id (app_name )
237
240
api_client = _get_api_client (self .project , self .location )
238
- await api_client .async_request (
239
- http_method = 'DELETE' ,
240
- path = f'reasoningEngines/{ reasoning_engine_id } /sessions/{ session_id } ' ,
241
- request_dict = {},
242
- )
241
+ try :
242
+ await api_client .async_request (
243
+ http_method = 'DELETE' ,
244
+ path = f'reasoningEngines/{ reasoning_engine_id } /sessions/{ session_id } ' ,
245
+ request_dict = {},
246
+ )
247
+ except Exception as e :
248
+ logger .error (f'Error deleting session { session_id } : { e } ' )
249
+ raise e
243
250
244
251
@override
245
252
async def append_event (self , session : Session , event : Event ) -> Event :
@@ -266,7 +273,7 @@ def _get_api_client(project: str, location: str):
266
273
return client ._api_client
267
274
268
275
269
- def _convert_event_to_json (event : Event ):
276
+ def _convert_event_to_json (event : Event ) -> Dict [ str , Any ] :
270
277
metadata_json = {
271
278
'partial' : event .partial ,
272
279
'turn_complete' : event .turn_complete ,
@@ -318,7 +325,7 @@ def _convert_event_to_json(event: Event):
318
325
return event_json
319
326
320
327
321
- def _from_api_event (api_event : dict ) -> Event :
328
+ def _from_api_event (api_event : Dict [ str , Any ] ) -> Event :
322
329
event_actions = EventActions ()
323
330
if api_event .get ('actions' , None ):
324
331
event_actions = EventActions (
0 commit comments