Skip to content

Commit 2f6f9b3

Browse files
committed
refactor asyncio pool timeout
1 parent a812257 commit 2f6f9b3

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

tests/aio/query/test_query_session_pool.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,17 @@ async def test_pool_size_limit_logic(self, pool: QuerySessionPoolAsync):
6161
ids = set()
6262

6363
for i in range(1, target_size + 1):
64-
session = await pool.acquire(timeout=0.5)
64+
session = await pool.acquire_wih_timeout(timeout=0.5)
6565
assert pool._current_size == i
6666
assert session._state.session_id not in ids
6767
ids.add(session._state.session_id)
6868

6969
with pytest.raises(ydb.SessionPoolEmpty):
70-
await pool.acquire(timeout=0.5)
70+
await pool.acquire_wih_timeout(timeout=0.5)
7171

7272
await pool.release(session)
7373

74-
session = await pool.acquire(timeout=0.5)
74+
session = await pool.acquire_wih_timeout(timeout=0.5)
7575
assert pool._current_size == target_size
7676
assert session._state.session_id in ids
7777

@@ -99,14 +99,20 @@ async def test_pool_recreates_bad_sessions(self, pool: QuerySessionPoolAsync):
9999
async def test_acquire_from_closed_pool_raises(self, pool: QuerySessionPoolAsync):
100100
await pool.stop()
101101
with pytest.raises(RuntimeError):
102-
await pool.acquire(1)
102+
await pool.acquire()
103+
104+
@pytest.mark.asyncio
105+
async def test_acquire_with_timeout_from_closed_pool_raises(self, pool: QuerySessionPoolAsync):
106+
await pool.stop()
107+
with pytest.raises(RuntimeError):
108+
await pool.acquire_wih_timeout(timeout=0.5)
103109

104110
@pytest.mark.asyncio
105111
async def test_no_session_leak(self, driver, docker_project):
106112
pool = ydb.aio.QuerySessionPoolAsync(driver, 1)
107113
docker_project.stop()
108114
try:
109-
await pool.acquire(timeout=0.5)
115+
await pool.acquire_wih_timeout(timeout=0.5)
110116
except ydb.Error:
111117
pass
112118
assert pool._current_size == 0

ydb/aio/query/pool.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def _create_new_session(self):
4343
logger.debug(f"New session was created for pool. Session id: {session._state.session_id}")
4444
return session
4545

46-
async def acquire(self, timeout: float) -> QuerySessionAsync:
46+
async def acquire(self) -> QuerySessionAsync:
4747
if self._should_stop.is_set():
4848
logger.error("An attempt to take session from closed session pool.")
4949
raise RuntimeError("An attempt to take session from closed session pool.")
@@ -55,13 +55,7 @@ async def acquire(self, timeout: float) -> QuerySessionAsync:
5555
pass
5656

5757
if session is None and self._current_size == self._size:
58-
try:
59-
self._waiters += 1
60-
session = await self._get_session_with_timeout(timeout)
61-
except asyncio.TimeoutError:
62-
raise issues.SessionPoolEmpty("Timeout on acquire session")
63-
finally:
64-
self._waiters -= 1
58+
_, session = await self._queue.get()
6559

6660
if session is not None:
6761
if session._state.attached:
@@ -76,21 +70,28 @@ async def acquire(self, timeout: float) -> QuerySessionAsync:
7670
self._current_size += 1
7771
return session
7872

79-
async def _get_session_with_timeout(self, timeout: float):
80-
task_wait = asyncio.ensure_future(asyncio.wait_for(self._queue.get(), timeout=timeout))
81-
task_stop = asyncio.ensure_future(asyncio.ensure_future(self._should_stop.wait()))
82-
done, _ = await asyncio.wait((task_wait, task_stop), return_when=asyncio.FIRST_COMPLETED)
83-
if task_stop in done:
84-
task_wait.cancel()
85-
return await self._create_new_session() # TODO: not sure why
86-
_, session = task_wait.result()
87-
return session
73+
async def acquire_wih_timeout(self, timeout: float):
74+
if self._should_stop.is_set():
75+
logger.error("An attempt to take session from closed session pool.")
76+
raise RuntimeError("An attempt to take session from closed session pool.")
77+
78+
try:
79+
task_wait = asyncio.ensure_future(asyncio.wait_for(self.acquire(), timeout=timeout))
80+
task_stop = asyncio.ensure_future(asyncio.ensure_future(self._should_stop.wait()))
81+
done, _ = await asyncio.wait((task_wait, task_stop), return_when=asyncio.FIRST_COMPLETED)
82+
if task_stop in done:
83+
task_wait.cancel()
84+
return await self._create_new_session() # TODO: not sure why
85+
session = task_wait.result()
86+
return session
87+
except asyncio.TimeoutError:
88+
raise issues.SessionPoolEmpty("Timeout on acquire session")
8889

8990
async def release(self, session: QuerySessionAsync) -> None:
9091
self._queue.put_nowait((1, session))
9192
logger.debug("Session returned to queue: %s", session._state.session_id)
9293

93-
def checkout(self, timeout: float = 10) -> "SimpleQuerySessionCheckoutAsync":
94+
def checkout(self, timeout: Optional[float] = None) -> "SimpleQuerySessionCheckoutAsync":
9495
"""WARNING: This API is experimental and could be changed.
9596
Return a Session context manager, that opens session on enter and closes session on exit.
9697
"""
@@ -169,13 +170,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
169170

170171

171172
class SimpleQuerySessionCheckoutAsync:
172-
def __init__(self, pool: QuerySessionPoolAsync, timeout: float):
173+
def __init__(self, pool: QuerySessionPoolAsync, timeout: Optional[float]):
173174
self._pool = pool
174175
self._timeout = timeout
175176
self._session = None
176177

177178
async def __aenter__(self) -> QuerySessionAsync:
178-
self._session = await self._pool.acquire(self._timeout)
179+
if self._timeout and self._timeout > 0:
180+
self._session = await self._pool.acquire_wih_timeout(self._timeout)
181+
else:
182+
self._session = await self._pool.acquire()
179183
return self._session
180184

181185
async def __aexit__(self, exc_type, exc_val, exc_tb):

0 commit comments

Comments
 (0)