Skip to content

Commit b7ef95c

Browse files
committed
AsyncIO query session pool
1 parent 5b2ad25 commit b7ef95c

File tree

1 file changed

+77
-9
lines changed

1 file changed

+77
-9
lines changed

ydb/aio/query/pool.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from typing import (
34
Callable,
@@ -8,6 +9,7 @@
89
from .session import (
910
QuerySessionAsync,
1011
)
12+
from ... import issues
1113
from ...retries import (
1214
RetrySettings,
1315
retry_operation_async,
@@ -21,20 +23,73 @@
2123
class QuerySessionPoolAsync:
2224
"""QuerySessionPoolAsync is an object to simplify operations with sessions of Query Service."""
2325

24-
def __init__(self, driver: common_utils.SupportedDriverType):
26+
def __init__(self, driver: common_utils.SupportedDriverType, size: int = 10):
2527
"""
2628
:param driver: A driver instance
29+
:param size: Size of session pool
2730
"""
2831

2932
logger.warning("QuerySessionPoolAsync is an experimental API, which could be changed.")
3033
self._driver = driver
31-
32-
def checkout(self) -> "SimpleQuerySessionCheckoutAsync":
34+
self._size = size
35+
self._should_stop = asyncio.Event()
36+
self._queue = asyncio.PriorityQueue()
37+
self._current_size = 0
38+
self._waiters = 0
39+
40+
async def _create_new_session(self):
41+
session = QuerySessionAsync(self._driver)
42+
await session.create()
43+
logger.debug(f"New session was created for pool. Session id: {session._state.session_id}")
44+
return session
45+
46+
async def acquire(self, timeout: float) -> QuerySessionAsync:
47+
if self._should_stop.is_set():
48+
logger.error("An attempt to take session from closed session pool.")
49+
raise RuntimeError("An attempt to take session from closed session pool.")
50+
51+
try:
52+
_, session = self._queue.get_nowait()
53+
logger.debug(f"Acquired active session from queue: {session._state.session_id}")
54+
return session
55+
except asyncio.QueueEmpty:
56+
pass
57+
58+
if self._current_size < self._size:
59+
logger.debug(f"Session pool is not large enough: {self._current_size} < {self._size}, will create new one.")
60+
session = await self._create_new_session()
61+
self._current_size += 1
62+
return session
63+
64+
try:
65+
self._waiters += 1
66+
session = await self._get_session_with_timeout(timeout)
67+
return session
68+
except asyncio.TimeoutError:
69+
raise issues.SessionPoolEmpty("Timeout on acquire session")
70+
finally:
71+
self._waiters -= 1
72+
73+
async def _get_session_with_timeout(self, timeout: float):
74+
task_wait = asyncio.ensure_future(asyncio.wait_for(self._queue.get(), timeout=timeout))
75+
task_stop = asyncio.ensure_future(asyncio.ensure_future(self._should_stop.wait()))
76+
done, _ = await asyncio.wait((task_wait, task_stop), return_when=asyncio.FIRST_COMPLETED)
77+
if task_stop in done:
78+
task_wait.cancel()
79+
return await self._create_new_session() # TODO: not sure why
80+
_, session = task_wait.result()
81+
return session
82+
83+
async def release(self, session: QuerySessionAsync) -> None:
84+
self._queue.put_nowait((1, session))
85+
logger.debug("Session returned to queue: %s", session._state.session_id)
86+
87+
def checkout(self, timeout: float = 10) -> "SimpleQuerySessionCheckoutAsync":
3388
"""WARNING: This API is experimental and could be changed.
3489
Return a Session context manager, that opens session on enter and closes session on exit.
3590
"""
3691

37-
return SimpleQuerySessionCheckoutAsync(self)
92+
return SimpleQuerySessionCheckoutAsync(self, timeout)
3893

3994
async def retry_operation_async(
4095
self, callee: Callable, retry_settings: Optional[RetrySettings] = None, *args, **kwargs
@@ -86,7 +141,19 @@ async def wrapped_callee():
86141
return await retry_operation_async(wrapped_callee, retry_settings)
87142

88143
async def stop(self, timeout=None):
89-
pass # TODO: implement
144+
self._should_stop.set()
145+
146+
tasks = []
147+
while True:
148+
try:
149+
_, session = self._queue.get_nowait()
150+
tasks.append(session.delete())
151+
except asyncio.QueueEmpty:
152+
break
153+
154+
await asyncio.gather(*tasks)
155+
156+
logger.debug("All session were deleted.")
90157

91158
async def __aenter__(self):
92159
return self
@@ -96,13 +163,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
96163

97164

98165
class SimpleQuerySessionCheckoutAsync:
99-
def __init__(self, pool: QuerySessionPoolAsync):
166+
def __init__(self, pool: QuerySessionPoolAsync, timeout: float):
100167
self._pool = pool
101-
self._session = QuerySessionAsync(pool._driver)
168+
self._timeout = timeout
169+
self._session = None
102170

103171
async def __aenter__(self) -> QuerySessionAsync:
104-
await self._session.create()
172+
self._session = await self._pool.acquire(self._timeout)
105173
return self._session
106174

107175
async def __aexit__(self, exc_type, exc_val, exc_tb):
108-
await self._session.delete()
176+
await self._pool.release(self._session)

0 commit comments

Comments
 (0)