1
+ import asyncio
1
2
import logging
2
3
from typing import (
3
4
Callable ,
8
9
from .session import (
9
10
QuerySessionAsync ,
10
11
)
12
+ from ... import issues
11
13
from ...retries import (
12
14
RetrySettings ,
13
15
retry_operation_async ,
21
23
class QuerySessionPoolAsync :
22
24
"""QuerySessionPoolAsync is an object to simplify operations with sessions of Query Service."""
23
25
24
- def __init__ (self , driver : common_utils .SupportedDriverType ):
26
+ def __init__ (self , driver : common_utils .SupportedDriverType , size : int = 10 ):
25
27
"""
26
28
:param driver: A driver instance
29
+ :param size: Size of session pool
27
30
"""
28
31
29
32
logger .warning ("QuerySessionPoolAsync is an experimental API, which could be changed." )
30
33
self ._driver = driver
31
-
32
- def checkout (self ) -> "SimpleQuerySessionCheckoutAsync" :
34
+ self ._size = size
35
+ self ._should_stop = asyncio .Event ()
36
+ self ._queue = asyncio .PriorityQueue ()
37
+ self ._current_size = 0
38
+ self ._waiters = 0
39
+
40
+ async def _create_new_session (self ):
41
+ session = QuerySessionAsync (self ._driver )
42
+ await session .create ()
43
+ logger .debug (f"New session was created for pool. Session id: { session ._state .session_id } " )
44
+ return session
45
+
46
+ async def acquire (self , timeout : float ) -> QuerySessionAsync :
47
+ if self ._should_stop .is_set ():
48
+ logger .error ("An attempt to take session from closed session pool." )
49
+ raise RuntimeError ("An attempt to take session from closed session pool." )
50
+
51
+ try :
52
+ _ , session = self ._queue .get_nowait ()
53
+ logger .debug (f"Acquired active session from queue: { session ._state .session_id } " )
54
+ return session
55
+ except asyncio .QueueEmpty :
56
+ pass
57
+
58
+ if self ._current_size < self ._size :
59
+ logger .debug (f"Session pool is not large enough: { self ._current_size } < { self ._size } , will create new one." )
60
+ session = await self ._create_new_session ()
61
+ self ._current_size += 1
62
+ return session
63
+
64
+ try :
65
+ self ._waiters += 1
66
+ session = await self ._get_session_with_timeout (timeout )
67
+ return session
68
+ except asyncio .TimeoutError :
69
+ raise issues .SessionPoolEmpty ("Timeout on acquire session" )
70
+ finally :
71
+ self ._waiters -= 1
72
+
73
+ async def _get_session_with_timeout (self , timeout : float ):
74
+ task_wait = asyncio .ensure_future (asyncio .wait_for (self ._queue .get (), timeout = timeout ))
75
+ task_stop = asyncio .ensure_future (asyncio .ensure_future (self ._should_stop .wait ()))
76
+ done , _ = await asyncio .wait ((task_wait , task_stop ), return_when = asyncio .FIRST_COMPLETED )
77
+ if task_stop in done :
78
+ task_wait .cancel ()
79
+ return await self ._create_new_session () # TODO: not sure why
80
+ _ , session = task_wait .result ()
81
+ return session
82
+
83
+ async def release (self , session : QuerySessionAsync ) -> None :
84
+ self ._queue .put_nowait ((1 , session ))
85
+ logger .debug ("Session returned to queue: %s" , session ._state .session_id )
86
+
87
+ def checkout (self , timeout : float = 10 ) -> "SimpleQuerySessionCheckoutAsync" :
33
88
"""WARNING: This API is experimental and could be changed.
34
89
Return a Session context manager, that opens session on enter and closes session on exit.
35
90
"""
36
91
37
- return SimpleQuerySessionCheckoutAsync (self )
92
+ return SimpleQuerySessionCheckoutAsync (self , timeout )
38
93
39
94
async def retry_operation_async (
40
95
self , callee : Callable , retry_settings : Optional [RetrySettings ] = None , * args , ** kwargs
@@ -86,7 +141,19 @@ async def wrapped_callee():
86
141
return await retry_operation_async (wrapped_callee , retry_settings )
87
142
88
143
async def stop (self , timeout = None ):
89
- pass # TODO: implement
144
+ self ._should_stop .set ()
145
+
146
+ tasks = []
147
+ while True :
148
+ try :
149
+ _ , session = self ._queue .get_nowait ()
150
+ tasks .append (session .delete ())
151
+ except asyncio .QueueEmpty :
152
+ break
153
+
154
+ await asyncio .gather (* tasks )
155
+
156
+ logger .debug ("All session were deleted." )
90
157
91
158
async def __aenter__ (self ):
92
159
return self
@@ -96,13 +163,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
96
163
97
164
98
165
class SimpleQuerySessionCheckoutAsync :
99
- def __init__ (self , pool : QuerySessionPoolAsync ):
166
+ def __init__ (self , pool : QuerySessionPoolAsync , timeout : float ):
100
167
self ._pool = pool
101
- self ._session = QuerySessionAsync (pool ._driver )
168
+ self ._timeout = timeout
169
+ self ._session = None
102
170
103
171
async def __aenter__ (self ) -> QuerySessionAsync :
104
- await self ._session . create ( )
172
+ self . _session = await self ._pool . acquire ( self . _timeout )
105
173
return self ._session
106
174
107
175
async def __aexit__ (self , exc_type , exc_val , exc_tb ):
108
- await self ._session . delete ( )
176
+ await self ._pool . release ( self . _session )
0 commit comments