Skip to content

Commit 5103dba

Browse files
committed
aio query session pool redesign
1 parent c5cec05 commit 5103dba

File tree

3 files changed

+16
-27
lines changed

3 files changed

+16
-27
lines changed

tests/aio/query/test_query_session_pool.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import pytest
23
import ydb
34
from ydb.aio.query.pool import QuerySessionPoolAsync
@@ -61,17 +62,17 @@ async def test_pool_size_limit_logic(self, pool: QuerySessionPoolAsync):
6162
ids = set()
6263

6364
for i in range(1, target_size + 1):
64-
session = await pool.acquire_wih_timeout(timeout=0.5)
65+
session = await pool.acquire()
6566
assert pool._current_size == i
6667
assert session._state.session_id not in ids
6768
ids.add(session._state.session_id)
6869

69-
with pytest.raises(ydb.SessionPoolEmpty):
70-
await pool.acquire_wih_timeout(timeout=0.5)
70+
with pytest.raises(asyncio.TimeoutError):
71+
await asyncio.wait_for(pool.acquire(), timeout=0.5)
7172

7273
await pool.release(session)
7374

74-
session = await pool.acquire_wih_timeout(timeout=0.5)
75+
session = await pool.acquire()
7576
assert pool._current_size == target_size
7677
assert session._state.session_id in ids
7778

@@ -105,14 +106,14 @@ async def test_acquire_from_closed_pool_raises(self, pool: QuerySessionPoolAsync
105106
async def test_acquire_with_timeout_from_closed_pool_raises(self, pool: QuerySessionPoolAsync):
106107
await pool.stop()
107108
with pytest.raises(RuntimeError):
108-
await pool.acquire_wih_timeout(timeout=0.5)
109+
await asyncio.wait_for(pool.acquire(), timeout=0.5)
109110

110111
@pytest.mark.asyncio
111112
async def test_no_session_leak(self, driver, docker_project):
112113
pool = ydb.aio.QuerySessionPoolAsync(driver, 1)
113114
docker_project.stop()
114115
try:
115-
await pool.acquire_wih_timeout(timeout=0.5)
116+
await asyncio.wait_for(pool.acquire(), timeout=0.5)
116117
except ydb.Error:
117118
pass
118119
assert pool._current_size == 0

ydb/aio/query/pool.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ async def acquire(self) -> QuerySessionAsync:
5555
pass
5656

5757
if session is None and self._current_size == self._size:
58-
_, session = await self._queue.get()
58+
queue_get = asyncio.ensure_future(self._queue.get())
59+
task_stop = asyncio.ensure_future(asyncio.ensure_future(self._should_stop.wait()))
60+
done, _ = await asyncio.wait((queue_get, task_stop), return_when=asyncio.FIRST_COMPLETED)
61+
if task_stop in done:
62+
queue_get.cancel()
63+
return await self._create_new_session() # TODO: not sure why
64+
_, session = queue_get.result()
5965

6066
if session is not None:
6167
if session._state.attached:
@@ -70,23 +76,6 @@ async def acquire(self) -> QuerySessionAsync:
7076
self._current_size += 1
7177
return session
7278

73-
async def acquire_wih_timeout(self, timeout: float):
74-
if self._should_stop.is_set():
75-
logger.error("An attempt to take session from closed session pool.")
76-
raise RuntimeError("An attempt to take session from closed session pool.")
77-
78-
try:
79-
task_wait = asyncio.ensure_future(asyncio.wait_for(self.acquire(), timeout=timeout))
80-
task_stop = asyncio.ensure_future(asyncio.ensure_future(self._should_stop.wait()))
81-
done, _ = await asyncio.wait((task_wait, task_stop), return_when=asyncio.FIRST_COMPLETED)
82-
if task_stop in done:
83-
task_wait.cancel()
84-
return await self._create_new_session() # TODO: not sure why
85-
session = task_wait.result()
86-
return session
87-
except asyncio.TimeoutError:
88-
raise issues.SessionPoolEmpty("Timeout on acquire session")
89-
9079
async def release(self, session: QuerySessionAsync) -> None:
9180
self._queue.put_nowait((1, session))
9281
logger.debug("Session returned to queue: %s", session._state.session_id)
@@ -170,14 +159,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
170159

171160

172161
class SimpleQuerySessionCheckoutAsync:
173-
def __init__(self, pool: QuerySessionPoolAsync, timeout: Optional[float]):
162+
def __init__(self, pool: QuerySessionPoolAsync, timeout: Optional[float] = None):
174163
self._pool = pool
175164
self._timeout = timeout
176165
self._session = None
177166

178167
async def __aenter__(self) -> QuerySessionAsync:
179168
if self._timeout and self._timeout > 0:
180-
self._session = await self._pool.acquire_wih_timeout(self._timeout)
169+
self._session = await self._pool.acquire_with_timeout(self._timeout)
181170
else:
182171
self._session = await self._pool.acquire()
183172
return self._session

ydb/query/session.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import abc
22
import enum
33
import logging
4-
import time
54
import threading
65
from typing import (
76
Iterable,

0 commit comments

Comments
 (0)