Skip to content

Add async methods to QuerySessionPool #689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion tests/query/test_query_session_pool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import ydb
import time
from concurrent import futures

from typing import Optional

Expand Down Expand Up @@ -132,7 +134,7 @@ def test_pool_recreates_bad_sessions(self, pool: QuerySessionPool):

def test_acquire_from_closed_pool_raises(self, pool: QuerySessionPool):
pool.stop()
with pytest.raises(RuntimeError):
with pytest.raises(ydb.SessionPoolClosed):
pool.acquire(1)

def test_no_session_leak(self, driver_sync, docker_project):
Expand All @@ -146,3 +148,55 @@ def test_no_session_leak(self, driver_sync, docker_project):

docker_project.start()
pool.stop()

def test_execute_with_retries_async(self, pool: QuerySessionPool):
fut = pool.execute_with_retries_async("select 1;")
res = fut.result()
assert len(res) == 1

def test_retry_operation_async(self, pool: QuerySessionPool):
def callee(session: QuerySession):
with session.transaction() as tx:
iterator = tx.execute("select 1;", commit_tx=True)
return [result_set for result_set in iterator]

fut = pool.retry_operation_async(callee)
res = fut.result()
assert len(res) == 1

def test_retry_tx_async(self, pool: QuerySessionPool):
retry_no = 0

def callee(tx: QueryTxContext):
nonlocal retry_no
if retry_no < 2:
retry_no += 1
raise ydb.Unavailable("Fake fast backoff error")
result_stream = tx.execute("SELECT 1")
return [result_set for result_set in result_stream]

result = pool.retry_tx_async(callee=callee).result()
assert len(result) == 1
assert retry_no == 2

def test_execute_with_retries_async_many_calls(self, pool: QuerySessionPool):
futs = [pool.execute_with_retries_async("select 1;") for _ in range(10)]
results = [f.result() for f in futures.as_completed(futs)]
assert all(len(r) == 1 for r in results)

def test_future_waits_on_stop(self, pool: QuerySessionPool):
def callee(session: QuerySession):
time.sleep(0.1)
with session.transaction() as tx:
it = tx.execute("select 1;", commit_tx=True)
return [rs for rs in it]

fut = pool.retry_operation_async(callee)
pool.stop()
assert fut.done()
assert len(fut.result()) == 1

def test_async_methods_after_stop_raise(self, pool: QuerySessionPool):
pool.stop()
with pytest.raises(ydb.SessionPoolClosed):
pool.execute_with_retries_async("select 1;")
8 changes: 8 additions & 0 deletions ydb/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class StatusCode(enum.IntEnum):

UNAUTHENTICATED = _CLIENT_STATUSES_FIRST + 30
SESSION_POOL_EMPTY = _CLIENT_STATUSES_FIRST + 40
SESSION_POOL_CLOSED = _CLIENT_STATUSES_FIRST + 50


# TODO: convert from proto IssueMessage
Expand Down Expand Up @@ -179,6 +180,13 @@ class SessionPoolEmpty(Error, queue.Empty):
status = StatusCode.SESSION_POOL_EMPTY


class SessionPoolClosed(Error):
status = StatusCode.SESSION_POOL_CLOSED

def __init__(self):
super().__init__("Session pool is closed.")


class ClientInternalError(Error):
status = StatusCode.CLIENT_INTERNAL_ERROR

Expand Down
70 changes: 69 additions & 1 deletion ydb/query/pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from concurrent import futures
from typing import (
Callable,
Optional,
Expand Down Expand Up @@ -36,14 +37,17 @@ def __init__(
size: int = 100,
*,
query_client_settings: Optional[QueryClientSettings] = None,
workers_threads_count: int = 4,
):
"""
:param driver: A driver instance.
:param size: Max size of Session Pool.
:param query_client_settings: ydb.QueryClientSettings object to configure QueryService behavior
:param workers_threads_count: A number of threads in executor used for *_async methods
"""

self._driver = driver
self._tp = futures.ThreadPoolExecutor(workers_threads_count)
self._queue = queue.Queue()
self._current_size = 0
self._size = size
Expand Down Expand Up @@ -72,7 +76,7 @@ def acquire(self, timeout: Optional[float] = None) -> QuerySession:
try:
if self._should_stop.is_set():
logger.error("An attempt to take session from closed session pool.")
raise RuntimeError("An attempt to take session from closed session pool.")
raise issues.SessionPoolClosed()

session = None
try:
Expand Down Expand Up @@ -132,6 +136,9 @@ def retry_operation_sync(self, callee: Callable, retry_settings: Optional[RetryS
:return: Result sets or exception in case of execution errors.
"""

if self._should_stop.is_set():
raise issues.SessionPoolClosed()

retry_settings = RetrySettings() if retry_settings is None else retry_settings

def wrapped_callee():
Expand All @@ -140,6 +147,38 @@ def wrapped_callee():

return retry_operation_sync(wrapped_callee, retry_settings)

def retry_tx_async(
self,
callee: Callable,
tx_mode: Optional[BaseQueryTxMode] = None,
retry_settings: Optional[RetrySettings] = None,
*args,
**kwargs,
) -> futures.Future:
"""Asynchronously execute a transaction in a retriable way."""

if self._should_stop.is_set():
raise issues.SessionPoolClosed()

return self._tp.submit(
self.retry_tx_sync,
callee,
tx_mode,
retry_settings,
*args,
**kwargs,
)

def retry_operation_async(
self, callee: Callable, retry_settings: Optional[RetrySettings] = None, *args, **kwargs
) -> futures.Future:
"""Asynchronously execute a retryable operation."""

if self._should_stop.is_set():
raise issues.SessionPoolClosed()

return self._tp.submit(self.retry_operation_sync, callee, retry_settings, *args, **kwargs)

def retry_tx_sync(
self,
callee: Callable,
Expand All @@ -161,6 +200,9 @@ def retry_tx_sync(
:return: Result sets or exception in case of execution errors.
"""

if self._should_stop.is_set():
raise issues.SessionPoolClosed()

tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite()
retry_settings = RetrySettings() if retry_settings is None else retry_settings

Expand Down Expand Up @@ -194,6 +236,9 @@ def execute_with_retries(
:return: Result sets or exception in case of execution errors.
"""

if self._should_stop.is_set():
raise issues.SessionPoolClosed()

retry_settings = RetrySettings() if retry_settings is None else retry_settings

def wrapped_callee():
Expand All @@ -203,11 +248,34 @@ def wrapped_callee():

return retry_operation_sync(wrapped_callee, retry_settings)

def execute_with_retries_async(
self,
query: str,
parameters: Optional[dict] = None,
retry_settings: Optional[RetrySettings] = None,
*args,
**kwargs,
) -> futures.Future:
"""Asynchronously execute a query with retries."""

if self._should_stop.is_set():
raise issues.SessionPoolClosed()

return self._tp.submit(
self.execute_with_retries,
query,
parameters,
retry_settings,
*args,
**kwargs,
)

def stop(self, timeout=None):
acquire_timeout = timeout if timeout is not None else -1
acquired = self._lock.acquire(timeout=acquire_timeout)
try:
self._should_stop.set()
Comment on lines +271 to 277
Copy link
Preview

Copilot AI Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider reviewing the thread pool shutdown behavior to determine if any pending async tasks should be canceled when the pool is stopped. Adding a mechanism to cancel or track pending tasks might help avoid unexpected execution after the pool is marked as closed.

Suggested change
)
def stop(self, timeout=None):
acquire_timeout = timeout if timeout is not None else -1
acquired = self._lock.acquire(timeout=acquire_timeout)
try:
self._should_stop.set()
)
self._pending_tasks.add(future)
future.add_done_callback(lambda f: self._pending_tasks.discard(f))
return future
def stop(self, timeout=None):
acquire_timeout = timeout if timeout is not None else -1
acquired = self._lock.acquire(timeout=acquire_timeout)
try:
self._should_stop.set()
for task in list(self._pending_tasks):
task.cancel()

Copilot uses AI. Check for mistakes.

self._tp.shutdown(wait=True)
while True:
try:
session = self._queue.get_nowait()
Expand Down
Loading