diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py index a3f49cc4..6c1bc3e8 100644 --- a/tests/query/test_query_session.py +++ b/tests/query/test_query_session.py @@ -1,4 +1,9 @@ import pytest +import threading +import time +from concurrent.futures import _base as b +from unittest import mock + from ydb.query.session import QuerySession @@ -100,3 +105,38 @@ def test_two_results(self, session: QuerySession): res.append(list(result_set.rows[0].values())) assert res == [[1], [2]] + + def test_thread_leaks(self, session: QuerySession): + session.create() + thread_names = [t.name for t in threading.enumerate()] + assert "first response attach stream thread" not in thread_names + assert "attach stream thread" in thread_names + + def test_first_resp_timeout(self, session: QuerySession): + class FakeStream: + def __iter__(self): + return self + + def __next__(self): + time.sleep(10) + return 1 + + def cancel(self): + pass + + fake_stream = mock.Mock(spec=FakeStream) + + session._attach_call = mock.MagicMock(return_value=fake_stream) + assert session._attach_call() == fake_stream + + session._create_call() + with pytest.raises(b.TimeoutError): + session._attach(0.1) + + fake_stream.cancel.assert_called() + + thread_names = [t.name for t in threading.enumerate()] + assert "first response attach stream thread" not in thread_names + assert "attach stream thread" not in thread_names + + _check_session_state_empty(session) diff --git a/ydb/_utilities.py b/ydb/_utilities.py index e89b0af3..117c7407 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -182,3 +182,21 @@ def inc_and_get(self) -> int: with self._lock: self._value += 1 return self._value + + +def get_first_message_with_timeout(status_stream: SyncResponseIterator, timeout: int): + waiter = future() + + def get_first_response(waiter): + first_response = next(status_stream) + waiter.set_result(first_response) + + thread = threading.Thread( + target=get_first_response, + args=(waiter,), + name="first response attach stream thread", + daemon=True, + ) + thread.start() + + return waiter.result(timeout=timeout) diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 454378b0..5bd0f1a0 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -1,3 +1,6 @@ +import asyncio + + class AsyncResponseIterator(object): def __init__(self, it, wrapper): self.it = it.__aiter__() @@ -21,3 +24,10 @@ async def next(self): async def __anext__(self): return await self._next() + + +async def get_first_message_with_timeout(stream: AsyncResponseIterator, timeout: int): + async def get_first_response(): + return await stream.next() + + return await asyncio.wait_for(get_first_response(), timeout) diff --git a/ydb/aio/query/session.py b/ydb/aio/query/session.py index 779eb3f0..0561de8c 100644 --- a/ydb/aio/query/session.py +++ b/ydb/aio/query/session.py @@ -15,6 +15,7 @@ from ...query import base from ...query.session import ( BaseQuerySession, + DEFAULT_ATTACH_FIRST_RESP_TIMEOUT, QuerySessionStateEnum, ) @@ -43,9 +44,17 @@ async def _attach(self) -> None: lambda response: common_utils.ServerStatus.from_proto(response), ) - first_response = await self._status_stream.next() - if first_response.status != issues.StatusCode.SUCCESS: - pass + try: + first_response = await _utilities.get_first_message_with_timeout( + self._status_stream, + DEFAULT_ATTACH_FIRST_RESP_TIMEOUT, + ) + if first_response.status != issues.StatusCode.SUCCESS: + raise RuntimeError("Failed to attach session") + except Exception as e: + self._state.reset() + self._status_stream.cancel() + raise e self._state.set_attached(True) self._state._change_state(QuerySessionStateEnum.CREATED) diff --git a/ydb/query/session.py b/ydb/query/session.py index 0165f821..382c922d 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -22,6 +22,10 @@ logger = logging.getLogger(__name__) +DEFAULT_ATTACH_FIRST_RESP_TIMEOUT = 600 +DEFAULT_ATTACH_LONG_TIMEOUT = 31536000 # year + + class QuerySessionStateEnum(enum.Enum): NOT_INITIALIZED = "NOT_INITIALIZED" CREATED = "CREATED" @@ -136,6 +140,12 @@ def __init__(self, driver: common_utils.SupportedDriverType, settings: Optional[ self._driver = driver self._settings = self._get_client_settings(driver, settings) self._state = QuerySessionState(settings) + self._attach_settings: BaseRequestSettings = ( + BaseRequestSettings() + .with_operation_timeout(DEFAULT_ATTACH_LONG_TIMEOUT) + .with_cancel_after(DEFAULT_ATTACH_LONG_TIMEOUT) + .with_timeout(DEFAULT_ATTACH_LONG_TIMEOUT) + ) def _get_client_settings( self, @@ -168,12 +178,12 @@ def _delete_call(self, settings: Optional[BaseRequestSettings] = None) -> "BaseQ settings=settings, ) - def _attach_call(self, settings: Optional[BaseRequestSettings] = None) -> Iterable[_apis.ydb_query.SessionState]: + def _attach_call(self) -> Iterable[_apis.ydb_query.SessionState]: return self._driver( _apis.ydb_query.AttachSessionRequest(session_id=self._state.session_id), _apis.QueryService.Stub, _apis.QueryService.AttachSession, - settings=settings, + settings=self._attach_settings, ) def _execute_call( @@ -213,16 +223,24 @@ class QuerySession(BaseQuerySession): _stream = None - def _attach(self, settings: Optional[BaseRequestSettings] = None) -> None: - self._stream = self._attach_call(settings=settings) + def _attach(self, first_resp_timeout: int = DEFAULT_ATTACH_FIRST_RESP_TIMEOUT) -> None: + self._stream = self._attach_call() status_stream = _utilities.SyncResponseIterator( self._stream, lambda response: common_utils.ServerStatus.from_proto(response), ) - first_response = next(status_stream) - if first_response.status != issues.StatusCode.SUCCESS: - pass + try: + first_response = _utilities.get_first_message_with_timeout( + status_stream, + first_resp_timeout, + ) + if first_response.status != issues.StatusCode.SUCCESS: + raise RuntimeError("Failed to attach session") + except Exception as e: + self._state.reset() + status_stream.cancel() + raise e self._state.set_attached(True) self._state._change_state(QuerySessionStateEnum.CREATED) @@ -230,7 +248,7 @@ def _attach(self, settings: Optional[BaseRequestSettings] = None) -> None: threading.Thread( target=self._check_session_status_loop, args=(status_stream,), - name="check session status thread", + name="attach stream thread", daemon=True, ).start()