diff --git a/tests/query/test_query_session_pool.py b/tests/query/test_query_session_pool.py index 6b5a7bf4..5901c8c8 100644 --- a/tests/query/test_query_session_pool.py +++ b/tests/query/test_query_session_pool.py @@ -1,5 +1,7 @@ import pytest import ydb +import time +from concurrent import futures from typing import Optional @@ -132,7 +134,7 @@ def test_pool_recreates_bad_sessions(self, pool: QuerySessionPool): def test_acquire_from_closed_pool_raises(self, pool: QuerySessionPool): pool.stop() - with pytest.raises(RuntimeError): + with pytest.raises(ydb.SessionPoolClosed): pool.acquire(1) def test_no_session_leak(self, driver_sync, docker_project): @@ -146,3 +148,55 @@ def test_no_session_leak(self, driver_sync, docker_project): docker_project.start() pool.stop() + + def test_execute_with_retries_async(self, pool: QuerySessionPool): + fut = pool.execute_with_retries_async("select 1;") + res = fut.result() + assert len(res) == 1 + + def test_retry_operation_async(self, pool: QuerySessionPool): + def callee(session: QuerySession): + with session.transaction() as tx: + iterator = tx.execute("select 1;", commit_tx=True) + return [result_set for result_set in iterator] + + fut = pool.retry_operation_async(callee) + res = fut.result() + assert len(res) == 1 + + def test_retry_tx_async(self, pool: QuerySessionPool): + retry_no = 0 + + def callee(tx: QueryTxContext): + nonlocal retry_no + if retry_no < 2: + retry_no += 1 + raise ydb.Unavailable("Fake fast backoff error") + result_stream = tx.execute("SELECT 1") + return [result_set for result_set in result_stream] + + result = pool.retry_tx_async(callee=callee).result() + assert len(result) == 1 + assert retry_no == 2 + + def test_execute_with_retries_async_many_calls(self, pool: QuerySessionPool): + futs = [pool.execute_with_retries_async("select 1;") for _ in range(10)] + results = [f.result() for f in futures.as_completed(futs)] + assert all(len(r) == 1 for r in results) + + def test_future_waits_on_stop(self, pool: QuerySessionPool): + def callee(session: QuerySession): + time.sleep(0.1) + with session.transaction() as tx: + it = tx.execute("select 1;", commit_tx=True) + return [rs for rs in it] + + fut = pool.retry_operation_async(callee) + pool.stop() + assert fut.done() + assert len(fut.result()) == 1 + + def test_async_methods_after_stop_raise(self, pool: QuerySessionPool): + pool.stop() + with pytest.raises(ydb.SessionPoolClosed): + pool.execute_with_retries_async("select 1;") diff --git a/ydb/issues.py b/ydb/issues.py index 8b098667..7337c428 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -51,6 +51,7 @@ class StatusCode(enum.IntEnum): UNAUTHENTICATED = _CLIENT_STATUSES_FIRST + 30 SESSION_POOL_EMPTY = _CLIENT_STATUSES_FIRST + 40 + SESSION_POOL_CLOSED = _CLIENT_STATUSES_FIRST + 50 # TODO: convert from proto IssueMessage @@ -179,6 +180,13 @@ class SessionPoolEmpty(Error, queue.Empty): status = StatusCode.SESSION_POOL_EMPTY +class SessionPoolClosed(Error): + status = StatusCode.SESSION_POOL_CLOSED + + def __init__(self): + super().__init__("Session pool is closed.") + + class ClientInternalError(Error): status = StatusCode.CLIENT_INTERNAL_ERROR diff --git a/ydb/query/pool.py b/ydb/query/pool.py index 1cf95ac0..fc05950c 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -1,4 +1,5 @@ import logging +from concurrent import futures from typing import ( Callable, Optional, @@ -36,14 +37,17 @@ def __init__( size: int = 100, *, query_client_settings: Optional[QueryClientSettings] = None, + workers_threads_count: int = 4, ): """ :param driver: A driver instance. :param size: Max size of Session Pool. :param query_client_settings: ydb.QueryClientSettings object to configure QueryService behavior + :param workers_threads_count: A number of threads in executor used for *_async methods """ self._driver = driver + self._tp = futures.ThreadPoolExecutor(workers_threads_count) self._queue = queue.Queue() self._current_size = 0 self._size = size @@ -72,7 +76,7 @@ def acquire(self, timeout: Optional[float] = None) -> QuerySession: try: if self._should_stop.is_set(): logger.error("An attempt to take session from closed session pool.") - raise RuntimeError("An attempt to take session from closed session pool.") + raise issues.SessionPoolClosed() session = None try: @@ -132,6 +136,9 @@ def retry_operation_sync(self, callee: Callable, retry_settings: Optional[RetryS :return: Result sets or exception in case of execution errors. """ + if self._should_stop.is_set(): + raise issues.SessionPoolClosed() + retry_settings = RetrySettings() if retry_settings is None else retry_settings def wrapped_callee(): @@ -140,6 +147,38 @@ def wrapped_callee(): return retry_operation_sync(wrapped_callee, retry_settings) + def retry_tx_async( + self, + callee: Callable, + tx_mode: Optional[BaseQueryTxMode] = None, + retry_settings: Optional[RetrySettings] = None, + *args, + **kwargs, + ) -> futures.Future: + """Asynchronously execute a transaction in a retriable way.""" + + if self._should_stop.is_set(): + raise issues.SessionPoolClosed() + + return self._tp.submit( + self.retry_tx_sync, + callee, + tx_mode, + retry_settings, + *args, + **kwargs, + ) + + def retry_operation_async( + self, callee: Callable, retry_settings: Optional[RetrySettings] = None, *args, **kwargs + ) -> futures.Future: + """Asynchronously execute a retryable operation.""" + + if self._should_stop.is_set(): + raise issues.SessionPoolClosed() + + return self._tp.submit(self.retry_operation_sync, callee, retry_settings, *args, **kwargs) + def retry_tx_sync( self, callee: Callable, @@ -161,6 +200,9 @@ def retry_tx_sync( :return: Result sets or exception in case of execution errors. """ + if self._should_stop.is_set(): + raise issues.SessionPoolClosed() + tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite() retry_settings = RetrySettings() if retry_settings is None else retry_settings @@ -194,6 +236,9 @@ def execute_with_retries( :return: Result sets or exception in case of execution errors. """ + if self._should_stop.is_set(): + raise issues.SessionPoolClosed() + retry_settings = RetrySettings() if retry_settings is None else retry_settings def wrapped_callee(): @@ -203,11 +248,34 @@ def wrapped_callee(): return retry_operation_sync(wrapped_callee, retry_settings) + def execute_with_retries_async( + self, + query: str, + parameters: Optional[dict] = None, + retry_settings: Optional[RetrySettings] = None, + *args, + **kwargs, + ) -> futures.Future: + """Asynchronously execute a query with retries.""" + + if self._should_stop.is_set(): + raise issues.SessionPoolClosed() + + return self._tp.submit( + self.execute_with_retries, + query, + parameters, + retry_settings, + *args, + **kwargs, + ) + def stop(self, timeout=None): acquire_timeout = timeout if timeout is not None else -1 acquired = self._lock.acquire(timeout=acquire_timeout) try: self._should_stop.set() + self._tp.shutdown(wait=True) while True: try: session = self._queue.get_nowait()