@@ -55,7 +55,13 @@ async def acquire(self) -> QuerySessionAsync:
55
55
pass
56
56
57
57
if session is None and self ._current_size == self ._size :
58
- _ , session = await self ._queue .get ()
58
+ queue_get = asyncio .ensure_future (self ._queue .get ())
59
+ task_stop = asyncio .ensure_future (asyncio .ensure_future (self ._should_stop .wait ()))
60
+ done , _ = await asyncio .wait ((queue_get , task_stop ), return_when = asyncio .FIRST_COMPLETED )
61
+ if task_stop in done :
62
+ queue_get .cancel ()
63
+ return await self ._create_new_session () # TODO: not sure why
64
+ _ , session = queue_get .result ()
59
65
60
66
if session is not None :
61
67
if session ._state .attached :
@@ -70,23 +76,6 @@ async def acquire(self) -> QuerySessionAsync:
70
76
self ._current_size += 1
71
77
return session
72
78
73
- async def acquire_wih_timeout (self , timeout : float ):
74
- if self ._should_stop .is_set ():
75
- logger .error ("An attempt to take session from closed session pool." )
76
- raise RuntimeError ("An attempt to take session from closed session pool." )
77
-
78
- try :
79
- task_wait = asyncio .ensure_future (asyncio .wait_for (self .acquire (), timeout = timeout ))
80
- task_stop = asyncio .ensure_future (asyncio .ensure_future (self ._should_stop .wait ()))
81
- done , _ = await asyncio .wait ((task_wait , task_stop ), return_when = asyncio .FIRST_COMPLETED )
82
- if task_stop in done :
83
- task_wait .cancel ()
84
- return await self ._create_new_session () # TODO: not sure why
85
- session = task_wait .result ()
86
- return session
87
- except asyncio .TimeoutError :
88
- raise issues .SessionPoolEmpty ("Timeout on acquire session" )
89
-
90
79
async def release (self , session : QuerySessionAsync ) -> None :
91
80
self ._queue .put_nowait ((1 , session ))
92
81
logger .debug ("Session returned to queue: %s" , session ._state .session_id )
@@ -170,14 +159,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
170
159
171
160
172
161
class SimpleQuerySessionCheckoutAsync :
173
- def __init__ (self , pool : QuerySessionPoolAsync , timeout : Optional [float ]):
162
+ def __init__ (self , pool : QuerySessionPoolAsync , timeout : Optional [float ] = None ):
174
163
self ._pool = pool
175
164
self ._timeout = timeout
176
165
self ._session = None
177
166
178
167
async def __aenter__ (self ) -> QuerySessionAsync :
179
168
if self ._timeout and self ._timeout > 0 :
180
- self ._session = await self ._pool .acquire_wih_timeout (self ._timeout )
169
+ self ._session = await self ._pool .acquire_with_timeout (self ._timeout )
181
170
else :
182
171
self ._session = await self ._pool .acquire ()
183
172
return self ._session
0 commit comments