diff --git a/tests/aio/query/test_query_session_pool.py b/tests/aio/query/test_query_session_pool.py index f86ff3ed..2cd0d4b9 100644 --- a/tests/aio/query/test_query_session_pool.py +++ b/tests/aio/query/test_query_session_pool.py @@ -162,3 +162,20 @@ async def test_no_session_leak(self, driver, docker_project): docker_project.start() await pool.stop() + + @pytest.mark.asyncio + async def test_acquire_no_race_condition(self, driver): + ids = set() + async with ydb.aio.QuerySessionPool(driver, 1) as pool: + + async def acquire_session(): + session = await pool.acquire() + ids.add(session._state.session_id) + await pool.release(session) + + tasks = [acquire_session() for _ in range(10)] + + await asyncio.gather(*tasks) + + assert len(ids) == 1 + assert pool._current_size == 1 diff --git a/tox.ini b/tox.ini index b7d712f9..df029d2a 100644 --- a/tox.ini +++ b/tox.ini @@ -32,14 +32,14 @@ deps = [testenv:py-proto5] commands = - pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} --ignore=tests/topics + pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} deps = -r{toxinidir}/test-requirements.txt protobuf<6.0.0 [testenv:py-proto4] commands = - pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} --ignore=tests/topics + pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} deps = -r{toxinidir}/test-requirements.txt protobuf<5.0.0 @@ -55,7 +55,7 @@ deps = [testenv:py-proto3] commands = - pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} --ignore=tests/topics + pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} deps = -r{toxinidir}/test-requirements.txt protobuf<4.0.0 diff --git a/ydb/aio/query/pool.py b/ydb/aio/query/pool.py index f6a84eb0..947db658 100644 --- a/ydb/aio/query/pool.py +++ b/ydb/aio/query/pool.py @@ -90,8 +90,15 @@ async def acquire(self) -> QuerySession: logger.debug(f"Acquired dead session from queue: {session._state.session_id}") logger.debug(f"Session pool is not large enough: {self._current_size} < {self._size}, will create new one.") - session = await self._create_new_session() + self._current_size += 1 + try: + session = await self._create_new_session() + except Exception as e: + logger.error("Failed to create new session") + self._current_size -= 1 + raise e + return session async def release(self, session: QuerySession) -> None: