@@ -43,7 +43,7 @@ async def _create_new_session(self):
43
43
logger .debug (f"New session was created for pool. Session id: { session ._state .session_id } " )
44
44
return session
45
45
46
- async def acquire (self , timeout : float ) -> QuerySessionAsync :
46
+ async def acquire (self ) -> QuerySessionAsync :
47
47
if self ._should_stop .is_set ():
48
48
logger .error ("An attempt to take session from closed session pool." )
49
49
raise RuntimeError ("An attempt to take session from closed session pool." )
@@ -55,13 +55,7 @@ async def acquire(self, timeout: float) -> QuerySessionAsync:
55
55
pass
56
56
57
57
if session is None and self ._current_size == self ._size :
58
- try :
59
- self ._waiters += 1
60
- session = await self ._get_session_with_timeout (timeout )
61
- except asyncio .TimeoutError :
62
- raise issues .SessionPoolEmpty ("Timeout on acquire session" )
63
- finally :
64
- self ._waiters -= 1
58
+ _ , session = await self ._queue .get ()
65
59
66
60
if session is not None :
67
61
if session ._state .attached :
@@ -76,21 +70,28 @@ async def acquire(self, timeout: float) -> QuerySessionAsync:
76
70
self ._current_size += 1
77
71
return session
78
72
79
- async def _get_session_with_timeout (self , timeout : float ):
80
- task_wait = asyncio .ensure_future (asyncio .wait_for (self ._queue .get (), timeout = timeout ))
81
- task_stop = asyncio .ensure_future (asyncio .ensure_future (self ._should_stop .wait ()))
82
- done , _ = await asyncio .wait ((task_wait , task_stop ), return_when = asyncio .FIRST_COMPLETED )
83
- if task_stop in done :
84
- task_wait .cancel ()
85
- return await self ._create_new_session () # TODO: not sure why
86
- _ , session = task_wait .result ()
87
- return session
73
+ async def acquire_wih_timeout (self , timeout : float ):
74
+ if self ._should_stop .is_set ():
75
+ logger .error ("An attempt to take session from closed session pool." )
76
+ raise RuntimeError ("An attempt to take session from closed session pool." )
77
+
78
+ try :
79
+ task_wait = asyncio .ensure_future (asyncio .wait_for (self .acquire (), timeout = timeout ))
80
+ task_stop = asyncio .ensure_future (asyncio .ensure_future (self ._should_stop .wait ()))
81
+ done , _ = await asyncio .wait ((task_wait , task_stop ), return_when = asyncio .FIRST_COMPLETED )
82
+ if task_stop in done :
83
+ task_wait .cancel ()
84
+ return await self ._create_new_session () # TODO: not sure why
85
+ session = task_wait .result ()
86
+ return session
87
+ except asyncio .TimeoutError :
88
+ raise issues .SessionPoolEmpty ("Timeout on acquire session" )
88
89
89
90
async def release (self , session : QuerySessionAsync ) -> None :
90
91
self ._queue .put_nowait ((1 , session ))
91
92
logger .debug ("Session returned to queue: %s" , session ._state .session_id )
92
93
93
- def checkout (self , timeout : float = 10 ) -> "SimpleQuerySessionCheckoutAsync" :
94
+ def checkout (self , timeout : Optional [ float ] = None ) -> "SimpleQuerySessionCheckoutAsync" :
94
95
"""WARNING: This API is experimental and could be changed.
95
96
Return a Session context manager, that opens session on enter and closes session on exit.
96
97
"""
@@ -169,13 +170,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
169
170
170
171
171
172
class SimpleQuerySessionCheckoutAsync :
172
- def __init__ (self , pool : QuerySessionPoolAsync , timeout : float ):
173
+ def __init__ (self , pool : QuerySessionPoolAsync , timeout : Optional [ float ] ):
173
174
self ._pool = pool
174
175
self ._timeout = timeout
175
176
self ._session = None
176
177
177
178
async def __aenter__ (self ) -> QuerySessionAsync :
178
- self ._session = await self ._pool .acquire (self ._timeout )
179
+ if self ._timeout and self ._timeout > 0 :
180
+ self ._session = await self ._pool .acquire_wih_timeout (self ._timeout )
181
+ else :
182
+ self ._session = await self ._pool .acquire ()
179
183
return self ._session
180
184
181
185
async def __aexit__ (self , exc_type , exc_val , exc_tb ):
0 commit comments