From f4ebefc932e59ec85b9732350ecf9fd099396f0f Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 01/57] Add query service to apis --- ydb/_apis.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/ydb/_apis.py b/ydb/_apis.py index 8c0b1164..2a9a14e8 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -10,6 +10,7 @@ ydb_table_v1_pb2_grpc, ydb_operation_v1_pb2_grpc, ydb_topic_v1_pb2_grpc, + ydb_query_v1_pb2_grpc, ) from ._grpc.v4.protos import ( @@ -20,6 +21,7 @@ ydb_value_pb2, ydb_operation_pb2, ydb_common_pb2, + ydb_query_pb2, ) else: @@ -30,6 +32,7 @@ ydb_table_v1_pb2_grpc, ydb_operation_v1_pb2_grpc, ydb_topic_v1_pb2_grpc, + ydb_query_v1_pb2_grpc, ) from ._grpc.common.protos import ( @@ -40,6 +43,7 @@ ydb_value_pb2, ydb_operation_pb2, ydb_common_pb2, + ydb_query_pb2, ) @@ -51,6 +55,7 @@ ydb_table = ydb_table_pb2 ydb_discovery = ydb_discovery_pb2 ydb_operation = ydb_operation_pb2 +ydb_query = ydb_query_pb2 class CmsService(object): @@ -111,3 +116,19 @@ class TopicService(object): DropTopic = "DropTopic" StreamRead = "StreamRead" StreamWrite = "StreamWrite" + + +class QueryService(object): + Stub = ydb_query_v1_pb2_grpc.QueryServiceStub + + CreateSession = "CreateSession" + DeleteSession = "DeleteSession" + AttachSession = "AttachSession" + + BeginTransaction = "BeginTransaction" + CommitTransaction = "CommitTransaction" + RollbackTransaction = "RollbackTransaction" + + ExecuteQuery = "ExecuteQuery" + ExecuteScript = "ExecuteScript" + FetchScriptResults = "FetchScriptResults" From 4a576b927118b919ea13f5db71a8f8bd0ae71102 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 02/57] session create --- ydb/_grpc/grpcwrapper/common_utils.py | 13 ++++ ydb/_grpc/grpcwrapper/ydb_query.py | 40 ++++++++++++ ydb/query/__init__.py | 0 ydb/query/session.py | 93 +++++++++++++++++++++++++++ 4 files changed, 146 insertions(+) create mode 100644 ydb/_grpc/grpcwrapper/ydb_query.py create mode 100644 ydb/query/__init__.py create mode 100644 ydb/query/session.py diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index a7febd5b..966a1ada 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -256,6 +256,19 @@ def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage): return res +ResultType = typing.TypeVar("ResultType", bound=IFromProtoWithProtoType) + + +def create_result_wrapper( + result_type: typing.Type[ResultType], +) -> typing.Callable[[typing.Any, typing.Any, typing.Any], ResultType]: + def wrapper(rpc_state, response_pb, driver=None): + # issues._process_response(response_pb.operation) + return result_type.from_proto(response_pb) + + return wrapper + + def callback_from_asyncio(callback: Union[Callable, Coroutine]) -> [asyncio.Future, asyncio.Task]: loop = asyncio.get_running_loop() diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py new file mode 100644 index 00000000..bd58e147 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +import typing +from typing import Optional + +from google.protobuf.message import Message + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_query_pb2 +else: + from ..common.protos import ydb_query_pb2 + +from .common_utils import ( + IFromProto, + IFromProtoWithProtoType, + IToProto, + IToPublic, + IFromPublic, + ServerStatus, + UnknownGrpcMessageError, + proto_duration_from_timedelta, + proto_timestamp_from_datetime, + datetime_from_proto_timestamp, + timedelta_from_proto_duration, +) + +@dataclass +class CreateSessionResponse(IFromProto): + status: Optional[ServerStatus] + session_id: str + node_id: int + + @staticmethod + def from_proto(msg: ydb_query_pb2.CreateSessionResponse) -> "CreateSessionResponse": + return CreateSessionResponse( + status=ServerStatus(msg.status, msg.issues), + session_id=msg.session_id, + node_id=msg.node_id, + ) + diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ydb/query/session.py b/ydb/query/session.py new file mode 100644 index 00000000..8bff8ff0 --- /dev/null +++ b/ydb/query/session.py @@ -0,0 +1,93 @@ +import abc +from abc import abstractmethod + +from .. import _apis, issues +from .._grpc.grpcwrapper import common_utils +from .._grpc.grpcwrapper import ydb_query as _ydb_query + + +class ISession(abc.ABC): + + @abc.abstractmethod + def create(self): + pass + + @abc.abstractmethod + def delete(self): + pass + + @property + @abstractmethod + def session_id(self): + pass + +class SessionState(object): + def __init__(self, settings=None): + self._settings = settings + self._session_id = None + self._node_id = None + self._is_closed = False + + @property + def session_id(self): + return self._session_id + + @property + def node_id(self): + return self._node_id + + def set_id(self, session_id): + self._session_id = session_id + return self + + def set_node_id(self, node_id): + self._node_id = node_id + return self + + + +class QuerySession(ISession): + def __init__(self, driver, settings=None): + self._driver = driver + self._state = SessionState(settings) + + @property + def session_id(self): + return self._state.session_id + + def create(self): + if self._state.session_id is not None: + return self + + # TODO: check what is settings + + res = self._driver( + _apis.ydb_query.CreateSessionRequest(), + _apis.QueryService.Stub, + _apis.QueryService.CreateSession, + common_utils.create_result_wrapper(_ydb_query.CreateSessionResponse), + ) + + self._state.set_id(res.session_id).set_node_id(res.node_id) + + return None + + def delete(self): + pass + + +if __name__ == "__main__": + + from ..driver import Driver + + endpoint = "grpc://localhost:2136" + database = "/local" + + with Driver(endpoint=endpoint, database=database) as driver: + driver.wait(timeout=5) + session = QuerySession(driver) + print(session.session_id) + + session.create() + + print(session.session_id) From 415c567132a5ae1ade8fa5602712af7fcf118ea2 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 03/57] session delete --- ydb/_grpc/grpcwrapper/ydb_query.py | 16 +++++++++++ ydb/query/session.py | 43 ++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index bd58e147..70008b9e 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -38,3 +38,19 @@ def from_proto(msg: ydb_query_pb2.CreateSessionResponse) -> "CreateSessionRespon node_id=msg.node_id, ) + +@dataclass +class DeleteSessionRequest(IToProto): + session_id: str + + def to_proto(self) -> ydb_query_pb2.DeleteSessionRequest: + return ydb_query_pb2.DeleteSessionRequest(session_id=self.session_id) + + +@dataclass +class DeleteSessionResponse(IFromProto): + status: Optional[ServerStatus] + + @staticmethod + def from_proto(msg: ydb_query_pb2.DeleteSessionResponse) -> "DeleteSessionResponse": + return DeleteSessionResponse(status=ServerStatus(msg.status, msg.issues)) diff --git a/ydb/query/session.py b/ydb/query/session.py index 8bff8ff0..1e8f3384 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -1,11 +1,15 @@ import abc from abc import abstractmethod +import logging from .. import _apis, issues from .._grpc.grpcwrapper import common_utils from .._grpc.grpcwrapper import ydb_query as _ydb_query +logger = logging.getLogger(__name__) + + class ISession(abc.ABC): @abc.abstractmethod @@ -24,9 +28,12 @@ def session_id(self): class SessionState(object): def __init__(self, settings=None): self._settings = settings + self.reset() + + def reset(self): self._session_id = None self._node_id = None - self._is_closed = False + self._is_attached = False @property def session_id(self): @@ -44,6 +51,9 @@ def set_node_id(self, node_id): self._node_id = node_id return self + def attached(self): + return self._is_attached + class QuerySession(ISession): @@ -55,6 +65,10 @@ def __init__(self, driver, settings=None): def session_id(self): return self._state.session_id + @property + def node_id(self): + return self._state.node_id + def create(self): if self._state.session_id is not None: return self @@ -68,12 +82,29 @@ def create(self): common_utils.create_result_wrapper(_ydb_query.CreateSessionResponse), ) + logging.info("session.create: success") + self._state.set_id(res.session_id).set_node_id(res.node_id) return None def delete(self): - pass + + if self._state.session_id is None: + return None + + res = self._driver( + _apis.ydb_query.DeleteSessionRequest(session_id=self._state.session_id), + _apis.QueryService.Stub, + _apis.QueryService.DeleteSession, + common_utils.create_result_wrapper(_ydb_query.DeleteSessionResponse), + ) + logging.info("session.delete: success") + + self._state.reset() + + return None + if __name__ == "__main__": @@ -87,7 +118,15 @@ def delete(self): driver.wait(timeout=5) session = QuerySession(driver) print(session.session_id) + print(session.node_id) session.create() print(session.session_id) + print(session.node_id) + + session.delete() + + print(session.session_id) + print(session.node_id) + From 8f345ac27473337ba340cebe4e14a9896cea1358 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 04/57] temp --- ydb/_grpc/grpcwrapper/ydb_query.py | 10 +++++----- ydb/query/session.py | 19 ++++++++++++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index 70008b9e..a7bf8f4c 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -39,12 +39,12 @@ def from_proto(msg: ydb_query_pb2.CreateSessionResponse) -> "CreateSessionRespon ) -@dataclass -class DeleteSessionRequest(IToProto): - session_id: str +# @dataclass +# class DeleteSessionRequest(IToProto): +# session_id: str - def to_proto(self) -> ydb_query_pb2.DeleteSessionRequest: - return ydb_query_pb2.DeleteSessionRequest(session_id=self.session_id) +# def to_proto(self) -> ydb_query_pb2.DeleteSessionRequest: +# return ydb_query_pb2.DeleteSessionRequest(session_id=self.session_id) @dataclass diff --git a/ydb/query/session.py b/ydb/query/session.py index 1e8f3384..f025f10b 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -2,7 +2,7 @@ from abc import abstractmethod import logging -from .. import _apis, issues +from .. import _apis, issues, _utilities from .._grpc.grpcwrapper import common_utils from .._grpc.grpcwrapper import ydb_query as _ydb_query @@ -105,6 +105,23 @@ def delete(self): return None + # def attach(self): + # if self._state.attached(): + # return self + + # stream_it = self._driver( + # _apis.ydb_query.AttachSessionRequest(session_id=self._state.session_id), + # _apis.QueryService.Stub, + # _apis.QueryService.AttachSession, + # common_utils.create_result_wrapper(_ydb_query.AttachSessionResponse), + # ) + + # it = _utilities.SyncResponseIterator( + + # ) + + # return None + if __name__ == "__main__": From 0ffe5453f9d2c92ab9134672e05d4bdd35aa5b65 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 05/57] Refactor wrappers to split UnaryStream and StreamStream wrappers --- ydb/_grpc/grpcwrapper/common_utils.py | 74 ++++++++++++++++------- ydb/_topic_common/common_test.py | 8 +-- ydb/_topic_reader/topic_reader_asyncio.py | 4 +- ydb/_topic_writer/topic_writer_asyncio.py | 4 +- 4 files changed, 59 insertions(+), 31 deletions(-) diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 966a1ada..5d71f4d0 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -145,7 +145,7 @@ def close(self): SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver] -class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): +class AbstractGrpcWrapperAsyncIO(IGrpcWrapperAsyncIO, abc.ABC): from_client_grpc: asyncio.Queue from_server_grpc: AsyncIterator convert_server_grpc_to_wrapper: Callable[[Any], Any] @@ -163,13 +163,6 @@ def __init__(self, convert_server_grpc_to_wrapper): def __del__(self): self._clean_executor(wait=False) - async def start(self, driver: SupportedDriverType, stub, method): - if asyncio.iscoroutinefunction(driver.__call__): - await self._start_asyncio_driver(driver, stub, method) - else: - await self._start_sync_driver(driver, stub, method) - self._connection_state = "started" - def close(self): self.from_client_grpc.put_nowait(_stop_grpc_connection_marker) if self._stream_call: @@ -181,6 +174,35 @@ def _clean_executor(self, wait: bool): if self._wait_executor: self._wait_executor.shutdown(wait) + async def receive(self) -> Any: + # todo handle grpc exceptions and convert it to internal exceptions + try: + grpc_message = await self.from_server_grpc.__anext__() + except (grpc.RpcError, grpc.aio.AioRpcError) as e: + raise connection._rpc_error_handler(self._connection_state, e) + + issues._process_response(grpc_message) + + if self._connection_state != "has_received_messages": + self._connection_state = "has_received_messages" + + # print("rekby, grpc, received", grpc_message) + return self.convert_server_grpc_to_wrapper(grpc_message) + + def write(self, wrap_message: IToProto): + grpc_message = wrap_message.to_proto() + # print("rekby, grpc, send", grpc_message) + self.from_client_grpc.put_nowait(grpc_message) + + +class GrpcWrapperStreamStreamAsyncIO(AbstractGrpcWrapperAsyncIO): + async def start(self, driver: SupportedDriverType, stub, method): + if asyncio.iscoroutinefunction(driver.__call__): + await self._start_asyncio_driver(driver, stub, method) + else: + await self._start_sync_driver(driver, stub, method) + self._connection_state = "started" + async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) stream_call = await driver( @@ -199,25 +221,31 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method): self._stream_call = stream_call self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor) - async def receive(self) -> Any: - # todo handle grpc exceptions and convert it to internal exceptions - try: - grpc_message = await self.from_server_grpc.__anext__() - except (grpc.RpcError, grpc.aio.AioRpcError) as e: - raise connection._rpc_error_handler(self._connection_state, e) - issues._process_response(grpc_message) +class GrpcWrapperUnaryStreamAsyncIO(AbstractGrpcWrapperAsyncIO): + async def start(self, driver: SupportedDriverType, request, stub, method): + if asyncio.iscoroutinefunction(driver.__call__): + await self._start_asyncio_driver(driver, request, stub, method) + else: + await self._start_sync_driver(driver, request, stub, method) + self._connection_state = "started" - if self._connection_state != "has_received_messages": - self._connection_state = "has_received_messages" + async def _start_asyncio_driver(self, driver: ydb.aio.Driver, request, stub, method): + stream_call = await driver( + request, + stub, + method, + ) + self._stream_call = stream_call + self.from_server_grpc = stream_call.__aiter__() - # print("rekby, grpc, received", grpc_message) - return self.convert_server_grpc_to_wrapper(grpc_message) + async def _start_sync_driver(self, driver: ydb.Driver, request, stub, method): + self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + stream_call = await to_thread(driver, request, stub, method, executor=self._wait_executor) + self._stream_call = stream_call + self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor) - def write(self, wrap_message: IToProto): - grpc_message = wrap_message.to_proto() - # print("rekby, grpc, send", grpc_message) - self.from_client_grpc.put_nowait(grpc_message) @dataclass(init=False) diff --git a/ydb/_topic_common/common_test.py b/ydb/_topic_common/common_test.py index b31f9af9..1dadaa04 100644 --- a/ydb/_topic_common/common_test.py +++ b/ydb/_topic_common/common_test.py @@ -8,7 +8,7 @@ from .common import CallFromSyncToAsync from .._grpc.grpcwrapper.common_utils import ( - GrpcWrapperAsyncIO, + GrpcWrapperStreamStreamAsyncIO, ServerStatus, callback_from_asyncio, ) @@ -77,7 +77,7 @@ async def async_failed(): @pytest.mark.asyncio -class TestGrpcWrapperAsyncIO: +class TestGrpcWrapperStreamStreamAsyncIO: async def test_convert_grpc_errors_to_ydb(self): class TestError(grpc.RpcError, grpc.Call): def __init__(self): @@ -93,7 +93,7 @@ class FromServerMock: async def __anext__(self): raise TestError() - wrapper = GrpcWrapperAsyncIO(lambda: None) + wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None) wrapper.from_server_grpc = FromServerMock() with pytest.raises(issues.Unauthenticated): @@ -107,7 +107,7 @@ async def __anext__(self): issues=[], ) - wrapper = GrpcWrapperAsyncIO(lambda: None) + wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None) wrapper.from_server_grpc = FromServerMock() with pytest.raises(issues.Overloaded): diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 81c6d9f4..8cc48a1d 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -18,7 +18,7 @@ from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, - GrpcWrapperAsyncIO, + GrpcWrapperStreamStreamAsyncIO, ) from .._grpc.grpcwrapper.ydb_topic import ( StreamReadMessage, @@ -308,7 +308,7 @@ async def create( driver: SupportedDriverType, settings: topic_reader.PublicReaderSettings, ) -> "ReaderStream": - stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto) + stream = GrpcWrapperStreamStreamAsyncIO(StreamReadMessage.FromServer.from_proto) await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 007c8a54..064f19ce 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -39,7 +39,7 @@ from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, - GrpcWrapperAsyncIO, + GrpcWrapperStreamStreamAsyncIO, ) logger = logging.getLogger(__name__) @@ -613,7 +613,7 @@ async def create( init_request: StreamWriteMessage.InitRequest, update_token_interval: Optional[Union[int, float]] = None, ) -> "WriterAsyncIOStream": - stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) + stream = GrpcWrapperStreamStreamAsyncIO(StreamWriteMessage.FromServer.from_proto) await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) From f266cbb58a73ce92db6b7eca545f6a5c000e7684 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 06/57] attach session --- ydb/_grpc/grpcwrapper/ydb_query.py | 16 ++++ ydb/query/session.py | 133 ++++++++++++++++++++++------- 2 files changed, 119 insertions(+), 30 deletions(-) diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index a7bf8f4c..36ddfeba 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -54,3 +54,19 @@ class DeleteSessionResponse(IFromProto): @staticmethod def from_proto(msg: ydb_query_pb2.DeleteSessionResponse) -> "DeleteSessionResponse": return DeleteSessionResponse(status=ServerStatus(msg.status, msg.issues)) + + +@dataclass +class AttachSessionRequest(IToProto): + session_id: str + + def to_proto(self) -> ydb_query_pb2.AttachSessionRequest: + return ydb_query_pb2.AttachSessionRequest(session_id=self.session_id) + +# @dataclass +# class SessionState(IFromProto): +# status: Optional[ServerStatus] + +# @staticmethod +# def from_proto(msg: ydb_query_pb2.SessionState) -> "SessionState": +# return SessionState(status=ServerStatus(msg.status, msg.issues)) \ No newline at end of file diff --git a/ydb/query/session.py b/ydb/query/session.py index f025f10b..1667a943 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -1,6 +1,10 @@ import abc from abc import abstractmethod +import asyncio import logging +from typing import ( + Set, +) from .. import _apis, issues, _utilities from .._grpc.grpcwrapper import common_utils @@ -25,7 +29,7 @@ def delete(self): def session_id(self): pass -class SessionState(object): +class SessionState: def __init__(self, settings=None): self._settings = settings self.reset() @@ -51,6 +55,10 @@ def set_node_id(self, node_id): self._node_id = node_id return self + def set_attached(self, is_attached): + self._is_attached = is_attached + + @property def attached(self): return self._is_attached @@ -69,13 +77,13 @@ def session_id(self): def node_id(self): return self._state.node_id - def create(self): + async def create(self): if self._state.session_id is not None: return self # TODO: check what is settings - res = self._driver( + res = await self._driver( _apis.ydb_query.CreateSessionRequest(), _apis.QueryService.Stub, _apis.QueryService.CreateSession, @@ -88,12 +96,12 @@ def create(self): return None - def delete(self): + async def delete(self): if self._state.session_id is None: return None - res = self._driver( + res = await self._driver( _apis.ydb_query.DeleteSessionRequest(session_id=self._state.session_id), _apis.QueryService.Stub, _apis.QueryService.DeleteSession, @@ -102,48 +110,113 @@ def delete(self): logging.info("session.delete: success") self._state.reset() + if self._stream is not None: + await self._stream.close() + self._stream = None return None - # def attach(self): - # if self._state.attached(): - # return self + async def attach(self): + self._stream = await SessionStateReaderStream.create(self._driver, self._state) - # stream_it = self._driver( - # _apis.ydb_query.AttachSessionRequest(session_id=self._state.session_id), - # _apis.QueryService.Stub, - # _apis.QueryService.AttachSession, - # common_utils.create_result_wrapper(_ydb_query.AttachSessionResponse), - # ) + print(self._state.attached) - # it = _utilities.SyncResponseIterator( - # ) - # return None +class SessionStateReaderStream: + _started: bool + _stream: common_utils.IGrpcWrapperAsyncIO + _session: QuerySession + _background_tasks: Set[asyncio.Task] + def __init__(self, session_state: SessionState): + self._session_state = session_state + self._background_tasks = set() + self._started = False -if __name__ == "__main__": - from ..driver import Driver + @staticmethod + async def create(driver: common_utils.SupportedDriverType, session_state: SessionState): + stream = common_utils.GrpcWrapperUnaryStreamAsyncIO(common_utils.ServerStatus.from_proto) + await stream.start( + driver, + _ydb_query.AttachSessionRequest(session_id=session_state.session_id).to_proto(), + _apis.QueryService.Stub, + _apis.QueryService.AttachSession + ) + + reader = SessionStateReaderStream(session_state) + + await reader._start(stream) + + return reader + + async def _start(self, stream: common_utils.IGrpcWrapperAsyncIO): + if self._started: + return # TODO: error + + self._started = True + self._stream = stream + + response = await self._stream.receive() + + if response.is_success(): + self._session_state.set_attached(True) + else: + raise common_utils.YdbError(response.error) + + self._background_tasks.add(asyncio.create_task(self._update_session_state_loop(), name="update_session_state_loop")) + + return response + + async def _update_session_state_loop(self): + while True: + response = await self._stream.receive() + + if response.is_success(): + pass + else: + self._session_state.set_attached(False) + + async def close(self): + self._stream.close() + for task in self._background_tasks: + task.cancel() + + if self._background_tasks: + await asyncio.wait(self._background_tasks) + + +async def main(): + from ..aio.driver import Driver endpoint = "grpc://localhost:2136" database = "/local" - with Driver(endpoint=endpoint, database=database) as driver: - driver.wait(timeout=5) - session = QuerySession(driver) - print(session.session_id) - print(session.node_id) + driver = Driver(endpoint=endpoint, database=database) # Creating new database driver to execute queries - session.create() + await driver.wait(timeout=10) # Wait until driver can execute calls - print(session.session_id) - print(session.node_id) + session = QuerySession(driver) - session.delete() + print(session.session_id) + print(session.node_id) - print(session.session_id) - print(session.node_id) + await session.create() + + print(session.session_id) + print(session.node_id) + + + await session.attach() + + await session.delete() + + print(session.session_id) + print(session.node_id) + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) From 1f9c7283532bf92af499f4b2c8de58015a1bf7f9 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 07/57] Added base module and wrappers --- ydb/_grpc/grpcwrapper/ydb_query.py | 64 +++++++++-- .../grpcwrapper/ydb_query_public_types.py | 68 +++++++++++ ydb/query/base.py | 106 ++++++++++++++++++ 3 files changed, 229 insertions(+), 9 deletions(-) create mode 100644 ydb/_grpc/grpcwrapper/ydb_query_public_types.py create mode 100644 ydb/query/base.py diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index 36ddfeba..49cae011 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -10,6 +10,8 @@ else: from ..common.protos import ydb_query_pb2 +from . import ydb_query_public_types as public_types + from .common_utils import ( IFromProto, IFromProtoWithProtoType, @@ -39,14 +41,6 @@ def from_proto(msg: ydb_query_pb2.CreateSessionResponse) -> "CreateSessionRespon ) -# @dataclass -# class DeleteSessionRequest(IToProto): -# session_id: str - -# def to_proto(self) -> ydb_query_pb2.DeleteSessionRequest: -# return ydb_query_pb2.DeleteSessionRequest(session_id=self.session_id) - - @dataclass class DeleteSessionResponse(IFromProto): status: Optional[ServerStatus] @@ -69,4 +63,56 @@ def to_proto(self) -> ydb_query_pb2.AttachSessionRequest: # @staticmethod # def from_proto(msg: ydb_query_pb2.SessionState) -> "SessionState": -# return SessionState(status=ServerStatus(msg.status, msg.issues)) \ No newline at end of file +# return SessionState(status=ServerStatus(msg.status, msg.issues)) + + +@dataclass +class TransactionMeta(IFromProto): + tx_id: str + + @staticmethod + def from_proto(msg: ydb_query_pb2.TransactionMeta) -> "TransactionMeta": + return TransactionMeta(tx_id=msg.id) + + +@dataclass +class TransactionSettings(IFromPublic, IToProto): + tx_mode: public_types.BaseQueryTxMode + + @staticmethod + def from_public(tx_mode: public_types.BaseQueryTxMode) -> "TransactionSettings": + return TransactionSettings(tx_mode=tx_mode) + + def to_proto(self) -> ydb_query_pb2.TransactionSettings: + if self.tx_mode.name == 'snapshot_read_only': + return ydb_query_pb2.TransactionSettings(snapshot_read_only=self.tx_mode.to_proto()) + if self.tx_mode.name == 'serializable_read_write': + return ydb_query_pb2.TransactionSettings(serializable_read_write=self.tx_mode.to_proto()) + if self.tx_mode.name == 'online_read_only': + return ydb_query_pb2.TransactionSettings(online_read_only=self.tx_mode.to_proto()) + if self.tx_mode.name == 'stale_read_only': + return ydb_query_pb2.TransactionSettings(stale_read_only=self.tx_mode.to_proto()) + # TODO: add exception + +@dataclass +class BeginTransactionRequest(IToProto): + session_id: str + tx_settings: TransactionSettings + + def to_proto(self) -> ydb_query_pb2.BeginTransactionRequest: + return ydb_query_pb2.BeginTransactionRequest( + session_id=self.session_id, + tx_settings=self.tx_settings + ) + +@dataclass +class BeginTransactionResponse(IFromProto): + status: Optional[ServerStatus] + tx_meta: TransactionMeta + + @staticmethod + def from_proto(msg: ydb_query_pb2.BeginTransactionResponse) -> "BeginTransactionResponse": + return BeginTransactionResponse( + status=ServerStatus(msg.status, msg.issues), + tx_meta=TransactionMeta.from_proto(msg.tx_meta), + ) \ No newline at end of file diff --git a/ydb/_grpc/grpcwrapper/ydb_query_public_types.py b/ydb/_grpc/grpcwrapper/ydb_query_public_types.py new file mode 100644 index 00000000..27d1e917 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_query_public_types.py @@ -0,0 +1,68 @@ +import abc +import typing + +from google.protobuf.message import Message + +from .common_utils import IToProto + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_query_pb2 +else: + from ..common.protos import ydb_query_pb2 + + +class BaseQueryTxMode(IToProto): + @property + @abc.abstractmethod + def name(self) -> str: + pass + + +class QuerySnapshotReadOnly(BaseQueryTxMode): + def __init__(self): + self._name = "snapshot_read_only" + + @property + def name(self) -> str: + return self._name + + def to_proto(self) -> ydb_query_pb2.SnapshotModeSettings: + return ydb_query_pb2.SnapshotModeSettings() + + +class QuerySerializableReadWrite(BaseQueryTxMode): + def __init__(self): + self._name = "serializable_read_write" + + @property + def name(self) -> str: + return self._name + + def to_proto(self) -> ydb_query_pb2.SerializableModeSettings: + return ydb_query_pb2.SerializableModeSettings() + + +class QueryOnlineReadOnly(BaseQueryTxMode): + def __init__(self, allow_inconsistent_reads: bool = False): + self.allow_inconsistent_reads = allow_inconsistent_reads + self._name = "online_read_only" + + @property + def name(self): + return self._name + + def to_proto(self) -> ydb_query_pb2.OnlineModeSettings: + return ydb_query_pb2.OnlineModeSettings(allow_inconsistent_reads=self.allow_inconsistent_reads) + + +class QueryStaleReadOnly(BaseQueryTxMode): + def __init__(self): + self._name = "stale_read_only" + + @property + def name(self): + return self._name + + def to_proto(self) -> ydb_query_pb2.StaleModeSettings: + return ydb_query_pb2.StaleModeSettings() diff --git a/ydb/query/base.py b/ydb/query/base.py new file mode 100644 index 00000000..db4f36fd --- /dev/null +++ b/ydb/query/base.py @@ -0,0 +1,106 @@ +import abc + +from typing import ( + Optional, +) + +from .._grpc.grpcwrapper.common_utils import ( + SupportedDriverType, +) + + +class QueryClientSettings: ... + + +class IQueryTxContext: ... + + +class QuerySessionState: + _session_id: Optional[str] + _node_id: Optional[int] + _attached: bool = False + _settings: Optional[QueryClientSettings] + + def __init__(self, settings: QueryClientSettings = None): + self._settings = settings + self.reset() + + def reset(self) -> None: + self._session_id = None + self._node_id = None + self._attached = False + + @property + def session_id(self) -> Optional[str]: + return self._session_id + + def set_session_id(self, session_id: str) -> "QuerySessionState": + self._session_id = session_id + return self + + @property + def node_id(self) -> Optional[int]: + return self._node_id + + def set_node_id(self, node_id: int) -> "QuerySessionState": + self._node_id = node_id + return self + + @property + def attached(self) -> bool: + return self._attached + + def set_attached(self, attached: bool) -> None: + self._attached = attached + + +class IQuerySession(abc.ABC): + def __init__(self, driver: SupportedDriverType, settings: QueryClientSettings = None): + pass + + @abc.abstractmethod + def create(self) -> None: + pass + + @abc.abstractmethod + def delete(self) -> None: + pass + + @abc.abstractmethod + def transaction(self) -> IQueryTxContext: + pass + + +# class BaseQuerySession(IQuerySession): +# _driver: SupportedDriverType +# _session_state: QuerySessionState +# _settings = QueryClientSettings + +# def __init__(self, driver: SupportedDriverType, settings: QueryClientSettings = None): +# self._driver = driver +# self._session_state = QuerySessionState(settings) +# self._settings = settings + +# @abc.abstractmethod +# def create(self) -> None: +# pass + +# @abc.abstractmethod +# def delete(self) -> None: +# pass + +# @abc.abstractmethod +# def transaction(self) -> IQueryTxContext: +# pass + + +class IQueryClient(abc.ABC): + def __init__(self, driver: SupportedDriverType, query_client_settings: QueryClientSettings = None): + pass + + @abc.abstractmethod + def session(self) -> IQuerySession: + pass + + + From 1b0d58ee79307fa17cb83b7b53902be768e2ad11 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 08/57] refactor session --- ydb/query/base.py | 27 +---- ydb/query/session.py | 261 ++++++++++++++----------------------------- 2 files changed, 85 insertions(+), 203 deletions(-) diff --git a/ydb/query/base.py b/ydb/query/base.py index db4f36fd..6fcd6b63 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -8,6 +8,8 @@ SupportedDriverType, ) +from .._grpc.grpcwrapper.ydb_query_public_types import BaseQueryTxMode + class QueryClientSettings: ... @@ -67,33 +69,10 @@ def delete(self) -> None: pass @abc.abstractmethod - def transaction(self) -> IQueryTxContext: + def transaction(self, tx_mode: BaseQueryTxMode) -> IQueryTxContext: pass -# class BaseQuerySession(IQuerySession): -# _driver: SupportedDriverType -# _session_state: QuerySessionState -# _settings = QueryClientSettings - -# def __init__(self, driver: SupportedDriverType, settings: QueryClientSettings = None): -# self._driver = driver -# self._session_state = QuerySessionState(settings) -# self._settings = settings - -# @abc.abstractmethod -# def create(self) -> None: -# pass - -# @abc.abstractmethod -# def delete(self) -> None: -# pass - -# @abc.abstractmethod -# def transaction(self) -> IQueryTxContext: -# pass - - class IQueryClient(abc.ABC): def __init__(self, driver: SupportedDriverType, query_client_settings: QueryClientSettings = None): pass diff --git a/ydb/query/session.py b/ydb/query/session.py index 1667a943..013f108a 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -1,222 +1,125 @@ import abc from abc import abstractmethod import asyncio +import concurrent import logging +import threading from typing import ( + Any, + Optional, Set, ) +from . import base + from .. import _apis, issues, _utilities from .._grpc.grpcwrapper import common_utils from .._grpc.grpcwrapper import ydb_query as _ydb_query +from .transaction import BaseTxContext -logger = logging.getLogger(__name__) - - -class ISession(abc.ABC): - - @abc.abstractmethod - def create(self): - pass - - @abc.abstractmethod - def delete(self): - pass - - @property - @abstractmethod - def session_id(self): - pass -class SessionState: - def __init__(self, settings=None): - self._settings = settings - self.reset() - - def reset(self): - self._session_id = None - self._node_id = None - self._is_attached = False - - @property - def session_id(self): - return self._session_id - - @property - def node_id(self): - return self._node_id +logger = logging.getLogger(__name__) - def set_id(self, session_id): - self._session_id = session_id - return self - def set_node_id(self, node_id): - self._node_id = node_id - return self +def wrapper_create_session(rpc_state, response_pb, session_state, session): + #TODO: process response + message = _ydb_query.CreateSessionResponse.from_proto(response_pb) + session_state.set_id(message.session_id).set_node_id(message.node_id) + return session - def set_attached(self, is_attached): - self._is_attached = is_attached - @property - def attached(self): - return self._is_attached +def wrapper_delete_session(rpc_state, response_pb, session_state, session): + #TODO: process response + message = _ydb_query.DeleteSessionResponse.from_proto(response_pb) + session_state.reset() + return session +class BaseQuerySession(base.IQuerySession): + _driver: base.SupportedDriverType + _settings: Optional[base.QueryClientSettings] + _state: base.QuerySessionState -class QuerySession(ISession): - def __init__(self, driver, settings=None): + def __init__(self, driver: base.SupportedDriverType, settings: base.QueryClientSettings = None): self._driver = driver - self._state = SessionState(settings) - - @property - def session_id(self): - return self._state.session_id - - @property - def node_id(self): - return self._state.node_id - - async def create(self): - if self._state.session_id is not None: - return self - - # TODO: check what is settings + self._settings = settings + self._state = base.QuerySessionState(settings) - res = await self._driver( + def _create_call(self): + return self._driver( _apis.ydb_query.CreateSessionRequest(), _apis.QueryService.Stub, _apis.QueryService.CreateSession, - common_utils.create_result_wrapper(_ydb_query.CreateSessionResponse), + wrap_result=wrapper_create_session, + wrap_args=(self._state, self), ) - logging.info("session.create: success") - - self._state.set_id(res.session_id).set_node_id(res.node_id) - - return None - - async def delete(self): - - if self._state.session_id is None: - return None - - res = await self._driver( + def _delete_call(self): + return self._driver( _apis.ydb_query.DeleteSessionRequest(session_id=self._state.session_id), _apis.QueryService.Stub, _apis.QueryService.DeleteSession, - common_utils.create_result_wrapper(_ydb_query.DeleteSessionResponse), + wrap_result=wrapper_delete_session, + wrap_args=(self._state, self), ) - logging.info("session.delete: success") - - self._state.reset() - if self._stream is not None: - await self._stream.close() - self._stream = None - - return None - - async def attach(self): - self._stream = await SessionStateReaderStream.create(self._driver, self._state) - - print(self._state.attached) - - - -class SessionStateReaderStream: - _started: bool - _stream: common_utils.IGrpcWrapperAsyncIO - _session: QuerySession - _background_tasks: Set[asyncio.Task] - - def __init__(self, session_state: SessionState): - self._session_state = session_state - self._background_tasks = set() - self._started = False - - - @staticmethod - async def create(driver: common_utils.SupportedDriverType, session_state: SessionState): - stream = common_utils.GrpcWrapperUnaryStreamAsyncIO(common_utils.ServerStatus.from_proto) - await stream.start( - driver, - _ydb_query.AttachSessionRequest(session_id=session_state.session_id).to_proto(), + def _attach_call(self): + return self._driver( + _apis.ydb_query.AttachSessionRequest(session_id=self._state.session_id), _apis.QueryService.Stub, - _apis.QueryService.AttachSession + _apis.QueryService.AttachSession, ) - reader = SessionStateReaderStream(session_state) - - await reader._start(stream) - - return reader +class QuerySessionSync(BaseQuerySession): + _stream = None - async def _start(self, stream: common_utils.IGrpcWrapperAsyncIO): - if self._started: - return # TODO: error - - self._started = True - self._stream = stream - - response = await self._stream.receive() - - if response.is_success(): - self._session_state.set_attached(True) - else: - raise common_utils.YdbError(response.error) - - self._background_tasks.add(asyncio.create_task(self._update_session_state_loop(), name="update_session_state_loop")) - - return response - - async def _update_session_state_loop(self): - while True: - response = await self._stream.receive() + def _attach(self): + self._stream = self._attach_call() + status_stream = _utilities.SyncResponseIterator( + self._stream, + lambda response: common_utils.ServerStatus.from_proto(response), + ) - if response.is_success(): + first_response = next(status_stream) + if first_response.status != issues.StatusCode.SUCCESS: + pass + # raise common_utils.YdbStatusError(first_response) + + self._state.set_attached(True) + + threading.Thread( + target=self._chech_session_status_loop, + args=(status_stream,), + name="check session status thread", + daemon=True, + ).start() + + def _chech_session_status_loop(self, status_stream): + print("CHECK STATUS") + try: + for status in status_stream: + if status.status != issues.StatusCode.SUCCESS: + print("STATUS NOT SUCCESS") + self._state.reset(False) + except Exception as e: pass - else: - self._session_state.set_attached(False) - - async def close(self): - self._stream.close() - for task in self._background_tasks: - task.cancel() - - if self._background_tasks: - await asyncio.wait(self._background_tasks) - - -async def main(): - from ..aio.driver import Driver - - endpoint = "grpc://localhost:2136" - database = "/local" - - driver = Driver(endpoint=endpoint, database=database) # Creating new database driver to execute queries - - await driver.wait(timeout=10) # Wait until driver can execute calls - - session = QuerySession(driver) - - print(session.session_id) - print(session.node_id) - - await session.create() - - print(session.session_id) - print(session.node_id) - - - await session.attach() + print("CHECK STATUS STOP") - await session.delete() - print(session.session_id) - print(session.node_id) + def delete(self) -> None: + if not self._state.session_id: + return + self._delete_call() + self._stream.cancel() -if __name__ == "__main__": - import asyncio - asyncio.run(main()) + def create(self) -> None: + if self._state.session_id: + return + self._create_call() + self._attach() + def transaction(self, tx_mode: base.BaseQueryTxMode) -> base.IQueryTxContext: + if not self._state.session_id: + return + return BaseTxContext(tx_mode) From 8e80cac78d7ea312f50b23c79862eb4a30a8f6d6 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 09/57] added basic test for session --- tests/query/__init__.py | 0 tests/query/test_query_session.py | 24 ++++++++++++++++++++++++ ydb/query/session.py | 6 +++--- 3 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 tests/query/__init__.py create mode 100644 tests/query/test_query_session.py diff --git a/tests/query/__init__.py b/tests/query/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py new file mode 100644 index 00000000..1bf713b9 --- /dev/null +++ b/tests/query/test_query_session.py @@ -0,0 +1,24 @@ +import pytest + +import ydb.query.session + +def _check_session_state_empty(session): + assert session._state.session_id == None + assert session._state.node_id == None + assert session._state.attached == False + +def _check_session_state_full(session): + assert session._state.session_id != None + assert session._state.node_id != None + assert session._state.attached == True + +class TestQuerySession: + def test_session_normal_lifecycle(self, driver_sync): + session = ydb.query.session.QuerySessionSync(driver_sync) + _check_session_state_empty(session) + + session.create() + _check_session_state_full(session) + + session.delete() + _check_session_state_empty(session) diff --git a/ydb/query/session.py b/ydb/query/session.py index 013f108a..6bd5bd39 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -22,14 +22,14 @@ logger = logging.getLogger(__name__) -def wrapper_create_session(rpc_state, response_pb, session_state, session): +def wrapper_create_session(rpc_state, response_pb, session_state: base.QuerySessionState, session): #TODO: process response message = _ydb_query.CreateSessionResponse.from_proto(response_pb) - session_state.set_id(message.session_id).set_node_id(message.node_id) + session_state.set_session_id(message.session_id).set_node_id(message.node_id) return session -def wrapper_delete_session(rpc_state, response_pb, session_state, session): +def wrapper_delete_session(rpc_state, response_pb, session_state: base.QuerySessionState, session): #TODO: process response message = _ydb_query.DeleteSessionResponse.from_proto(response_pb) session_state.reset() From c4fb819510681dec437e9da885a89194557d7901 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 10/57] added test to double session create --- tests/query/test_query_session.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py index 1bf713b9..ef0c26f5 100644 --- a/tests/query/test_query_session.py +++ b/tests/query/test_query_session.py @@ -22,3 +22,17 @@ def test_session_normal_lifecycle(self, driver_sync): session.delete() _check_session_state_empty(session) + + def test_second_create_do_nothing(self, driver_sync): + session = ydb.query.session.QuerySessionSync(driver_sync) + session.create() + _check_session_state_full(session) + + session_id_before = session._state.session_id + node_id_before = session._state.node_id + + session.create() + _check_session_state_full(session) + + assert session._state.session_id == session_id_before + assert session._state.node_id == node_id_before From ff59697c8455bb22522d211e389a47d535b89f0c Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:25 +0300 Subject: [PATCH 11/57] some more wrappers --- ydb/_grpc/grpcwrapper/ydb_query.py | 45 +++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index 49cae011..a5c4f78e 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -115,4 +115,47 @@ def from_proto(msg: ydb_query_pb2.BeginTransactionResponse) -> "BeginTransaction return BeginTransactionResponse( status=ServerStatus(msg.status, msg.issues), tx_meta=TransactionMeta.from_proto(msg.tx_meta), - ) \ No newline at end of file + ) + +@dataclass +class QueryContent(IFromPublic, IToProto): + text: str + syntax: Optional[str] + + @staticmethod + def from_public(query: str) -> "QueryContent": + return QueryContent(text=query) + + def to_proto(self) -> ydb_query_pb2.QueryContent: + return ydb_query_pb2.QueryContent(text=self.text, syntax=self.syntax) + + +@dataclass +class TransactionControl(IToProto): + begin_tx: Optional[TransactionSettings] + commit_tx: Optional[bool] + tx_id: Optional[str] + + def to_proto(self) -> ydb_query_pb2.TransactionControl: + if self.tx_id: + return ydb_query_pb2.TransactionControl( + tx_id=self.tx_id, + commit_tx=self.commit_tx, + ) + return ydb_query_pb2.TransactionControl( + begin_tx=self.begin_tx, + commit_tx=self.commit_tx + ) + + +@dataclass +class ExecuteQueryRequest: + exec_mode: Optional[str] + concurrent_result_sets: bool = False + parameters: Optional[dict] + query_content: QueryContent + session_id: str + stats_mode: Optional[str] + tx_control: TransactionControl + + From 7a4f4e73ec134da2009ced6209285054a487c68f Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:26 +0300 Subject: [PATCH 12/57] simple execute on session --- examples/query-service/basic_example.py | 30 +++++++++++++++++++++ ydb/_grpc/grpcwrapper/ydb_query.py | 36 +++++++++++++++---------- ydb/query/base.py | 27 ++++++++++++++++--- ydb/query/session.py | 32 +++++++++++++++++++--- 4 files changed, 104 insertions(+), 21 deletions(-) create mode 100644 examples/query-service/basic_example.py diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py new file mode 100644 index 00000000..584a5a05 --- /dev/null +++ b/examples/query-service/basic_example.py @@ -0,0 +1,30 @@ +import ydb + +from ydb.query.session import QuerySessionSync + + +def main(): + driver_config = ydb.DriverConfig( + endpoint="grpc://localhost:2136", + database="/local", + # credentials=ydb.credentials_from_env_variables(), + # root_certificates=ydb.load_ydb_root_certificate(), + ) + try: + driver = ydb.Driver(driver_config) + driver.wait(timeout=5) + except TimeoutError: + raise RuntimeError("Connect failed to YDB") + + session = QuerySessionSync(driver) + session.create() + + it = session.execute("select 1; select 2;") + for result_set in it: + print(f"columns: {str(result_set.columns)}") + print(f"rows: {str(result_set.rows)}") + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index a5c4f78e..542aa158 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -92,7 +92,6 @@ def to_proto(self) -> ydb_query_pb2.TransactionSettings: return ydb_query_pb2.TransactionSettings(online_read_only=self.tx_mode.to_proto()) if self.tx_mode.name == 'stale_read_only': return ydb_query_pb2.TransactionSettings(stale_read_only=self.tx_mode.to_proto()) - # TODO: add exception @dataclass class BeginTransactionRequest(IToProto): @@ -102,7 +101,7 @@ class BeginTransactionRequest(IToProto): def to_proto(self) -> ydb_query_pb2.BeginTransactionRequest: return ydb_query_pb2.BeginTransactionRequest( session_id=self.session_id, - tx_settings=self.tx_settings + tx_settings=self.tx_settings.to_proto(), ) @dataclass @@ -120,7 +119,7 @@ def from_proto(msg: ydb_query_pb2.BeginTransactionResponse) -> "BeginTransaction @dataclass class QueryContent(IFromPublic, IToProto): text: str - syntax: Optional[str] + syntax: Optional[str] = None @staticmethod def from_public(query: str) -> "QueryContent": @@ -132,9 +131,9 @@ def to_proto(self) -> ydb_query_pb2.QueryContent: @dataclass class TransactionControl(IToProto): - begin_tx: Optional[TransactionSettings] - commit_tx: Optional[bool] - tx_id: Optional[str] + begin_tx: Optional[TransactionSettings] = None + commit_tx: Optional[bool] = None + tx_id: Optional[str] = None def to_proto(self) -> ydb_query_pb2.TransactionControl: if self.tx_id: @@ -143,19 +142,28 @@ def to_proto(self) -> ydb_query_pb2.TransactionControl: commit_tx=self.commit_tx, ) return ydb_query_pb2.TransactionControl( - begin_tx=self.begin_tx, + begin_tx=self.begin_tx.to_proto(), commit_tx=self.commit_tx ) @dataclass -class ExecuteQueryRequest: - exec_mode: Optional[str] - concurrent_result_sets: bool = False - parameters: Optional[dict] - query_content: QueryContent +class ExecuteQueryRequest(IToProto): session_id: str - stats_mode: Optional[str] + query_content: QueryContent tx_control: TransactionControl + concurrent_result_sets: Optional[bool] = False + exec_mode: Optional[str] = None + parameters: Optional[dict] = None + stats_mode: Optional[str] = None - + def to_proto(self) -> ydb_query_pb2.ExecuteQueryRequest: + return ydb_query_pb2.ExecuteQueryRequest( + session_id=self.session_id, + tx_control=self.tx_control.to_proto(), + query_content=self.query_content.to_proto(), + exec_mode=ydb_query_pb2.EXEC_MODE_EXECUTE, + stats_mode=self.stats_mode, + concurrent_result_sets=self.concurrent_result_sets, + parameters=self.parameters, + ) diff --git a/ydb/query/base.py b/ydb/query/base.py index 6fcd6b63..a494ad9d 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -7,9 +7,12 @@ from .._grpc.grpcwrapper.common_utils import ( SupportedDriverType, ) - -from .._grpc.grpcwrapper.ydb_query_public_types import BaseQueryTxMode - +from .._grpc.grpcwrapper import ydb_query +from .._grpc.grpcwrapper.ydb_query_public_types import ( + BaseQueryTxMode, + QuerySerializableReadWrite +) +from .. import convert class QueryClientSettings: ... @@ -82,4 +85,22 @@ def session(self) -> IQuerySession: pass +def create_execute_query_request(query: str, session_id: str, commit_tx: bool): + req = ydb_query.ExecuteQueryRequest( + session_id=session_id, + query_content=ydb_query.QueryContent.from_public( + query=query, + ), + tx_control=ydb_query.TransactionControl( + begin_tx=ydb_query.TransactionSettings( + tx_mode=QuerySerializableReadWrite(), + ), + commit_tx=commit_tx + ), + ) + + return req.to_proto() + +def wrap_execute_query_response(rpc_state, response_pb): + return convert.ResultSet.from_message(response_pb.result_set) diff --git a/ydb/query/session.py b/ydb/query/session.py index 6bd5bd39..461180ef 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -71,6 +71,19 @@ def _attach_call(self): _apis.QueryService.AttachSession, ) + def _execute_call(self, query: str, commit_tx: bool): + request = base.create_execute_query_request( + query=query, + session_id=self._state.session_id, + commit_tx=commit_tx + ) + return self._driver( + request, + _apis.QueryService.Stub, + _apis.QueryService.ExecuteQuery, + # base.wrap_execute_query_response + ) + class QuerySessionSync(BaseQuerySession): _stream = None @@ -96,7 +109,6 @@ def _attach(self): ).start() def _chech_session_status_loop(self, status_stream): - print("CHECK STATUS") try: for status in status_stream: if status.status != issues.StatusCode.SUCCESS: @@ -104,7 +116,6 @@ def _chech_session_status_loop(self, status_stream): self._state.reset(False) except Exception as e: pass - print("CHECK STATUS STOP") def delete(self) -> None: @@ -119,7 +130,20 @@ def create(self) -> None: self._create_call() self._attach() - def transaction(self, tx_mode: base.BaseQueryTxMode) -> base.IQueryTxContext: + def transaction(self, tx_mode: base.BaseQueryTxMode = None) -> base.IQueryTxContext: if not self._state.session_id: return - return BaseTxContext(tx_mode) + return BaseTxContext( + self._driver, + self._state, + self, + tx_mode, + ) + + def execute(self, query: str, commit_tx: bool = True): + stream_it = self._execute_call(query, commit_tx) + + return _utilities.SyncResponseIterator( + stream_it, + lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp), + ) \ No newline at end of file From 5eaeddca60ce5326367788ab96b6ff596d8a320a Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:26 +0300 Subject: [PATCH 13/57] temp --- docker-compose.yml | 2 +- examples/query-service/basic_example.py | 7 +- tests/query/test_query_transaction.py | 17 ++ ydb/query/base.py | 2 + ydb/query/transaction.py | 212 ++++++++++++++++++++++++ 5 files changed, 236 insertions(+), 4 deletions(-) create mode 100644 tests/query/test_query_transaction.py create mode 100644 ydb/query/transaction.py diff --git a/docker-compose.yml b/docker-compose.yml index edbd56d1..50a31f12 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.3" services: ydb: - image: cr.yandex/yc/yandex-docker-local-ydb:latest + image: ydbplatform/local-ydb:latest restart: always ports: - 2136:2136 diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index 584a5a05..4c107c1a 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -19,10 +19,11 @@ def main(): session = QuerySessionSync(driver) session.create() - it = session.execute("select 1; select 2;") + it = session.execute("select 1; select 2;", commit_tx=False) for result_set in it: - print(f"columns: {str(result_set.columns)}") - print(f"rows: {str(result_set.rows)}") + pass + # print(f"columns: {str(result_set.columns)}") + # print(f"rows: {str(result_set.rows)}") diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py new file mode 100644 index 00000000..9877499c --- /dev/null +++ b/tests/query/test_query_transaction.py @@ -0,0 +1,17 @@ +import pytest + +import ydb.query.session + +class TestQuerySession: + def test_transaction_begin(self, driver_sync): + session = ydb.query.session.QuerySessionSync(driver_sync) + + session.create() + + tx = session.transaction() + + assert tx._tx_state.tx_id == None + + tx.begin() + + assert tx._tx_state.tx_id != None diff --git a/ydb/query/base.py b/ydb/query/base.py index a494ad9d..859738fe 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -103,4 +103,6 @@ def create_execute_query_request(query: str, session_id: str, commit_tx: bool): def wrap_execute_query_response(rpc_state, response_pb): + print(response_pb) + return convert.ResultSet.from_message(response_pb.result_set) diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py new file mode 100644 index 00000000..1fb8a74e --- /dev/null +++ b/ydb/query/transaction.py @@ -0,0 +1,212 @@ +import abc +import logging + +from .. import ( + _apis, + issues, + _utilities, +) +from .._grpc.grpcwrapper import ydb_query as _ydb_query +from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public + +from .._tx_ctx_impl import TxState, reset_tx_id_handler +from .._session_impl import bad_session_handler +from ..table import ( + AbstractTransactionModeBuilder, + ITxContext, + SerializableReadWrite +) + +logger = logging.getLogger(__name__) + +def patch_table_service_tx_mode_to_query_service(tx_mode: AbstractTransactionModeBuilder): + if tx_mode.name == 'snapshot_read_only': + tx_mode = _ydb_query_public.QuerySnapshotReadOnly() + elif tx_mode.name == 'serializable_read_write': + tx_mode = _ydb_query_public.QuerySerializableReadWrite() + elif tx_mode.name =='online_read_only': + tx_mode = _ydb_query_public.QueryOnlineReadOnly() + elif tx_mode.name == 'stale_read_only': + tx_mode = _ydb_query_public.QueryStaleReadOnly() + else: + raise issues.YDBInvalidArgumentError(f'Unknown transaction mode: {tx_mode.name}') + + return tx_mode + + +def _construct_tx_settings(tx_state): + tx_settings = _ydb_query.TransactionSettings.from_public(tx_state.tx_mode) + return tx_settings + + +def _create_begin_transaction_request(session_state, tx_state): + request = _ydb_query.BeginTransactionRequest( + session_id=session_state.session_id, + tx_settings=_construct_tx_settings(tx_state), + ).to_proto() + + print(request) + + return request + + +def _create_commit_transaction_request(session_state, tx_state): + request = _apis.ydb_query.CommitTransactionRequest() + request.tx_id = tx_state.tx_id + request.session_id = session_state.session_id + return request + +def _create_rollback_transaction_request(session_state, tx_state): + request = _apis.ydb_query.RollbackTransactionRequest() + request.tx_id = tx_state.tx_id + request.session_id = session_state.session_id + return request + + +@bad_session_handler +def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx): + # session_state.complete_query() + # issues._process_response(response_pb.operation) + print("wrap result") + message = _ydb_query.BeginTransactionResponse.from_proto(response_pb) + tx_state.tx_id = message.tx_meta.id + return tx + + +@bad_session_handler +@reset_tx_id_handler +def wrap_result_on_rollback_or_commit_tx(rpc_state, response_pb, session_state, tx_state, tx): + + # issues._process_response(response_pb.operation) + # transaction successfully committed or rolled back + tx_state.tx_id = None + return tx + + +class BaseTxContext(ITxContext): + + _COMMIT = "commit" + _ROLLBACK = "rollback" + + def __init__(self, driver, session_state, session, tx_mode=None): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() and .async_begin() methods; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: A transaction mode, which is a one from the following choices: + 1) SerializableReadWrite() which is default mode; + 2) OnlineReadOnly(); + 3) StaleReadOnly(). + """ + self._driver = driver + if tx_mode is None: + tx_mode = patch_table_service_tx_mode_to_query_service(SerializableReadWrite()) + else: + tx_mode = patch_table_service_tx_mode_to_query_service(tx_mode) + self._tx_state = TxState(tx_mode) + self._session_state = session_state + self.session = session + self._finished = "" + + def __enter__(self): + """ + Enters a context manager and returns a session + + :return: A session instance + """ + return self + + def __exit__(self, *args, **kwargs): + """ + Closes a transaction context manager and rollbacks transaction if + it is not rolled back explicitly + """ + if self._tx_state.tx_id is not None: + # It's strictly recommended to close transactions directly + # by using commit_tx=True flag while executing statement or by + # .commit() or .rollback() methods, but here we trying to do best + # effort to avoid useless open transactions + logger.warning("Potentially leaked tx: %s", self._tx_state.tx_id) + try: + self.rollback() + except issues.Error: + logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id) + + self._tx_state.tx_id = None + + @property + def session_id(self): + """ + A transaction's session id + + :return: A transaction's session id + """ + return self._session_state.session_id + + @property + def tx_id(self): + """ + Returns a id of open transaction or None otherwise + + :return: A id of open transaction or None otherwise + """ + return self._tx_state.tx_id + + def begin(self, settings=None): + """ + Explicitly begins a transaction + + :param settings: A request settings + + :return: An open transaction + """ + if self._tx_state.tx_id is not None: + return self + + print('try to begin tx') + + return self._driver( + _create_begin_transaction_request(self._session_state, self._tx_state), + _apis.QueryService.Stub, + _apis.QueryService.BeginTransaction, + wrap_result=wrap_tx_begin_response, + wrap_args=(self._session_state, self._tx_state, self), + ) + + def commit(self, settings=None): + """ + Calls commit on a transaction if it is open otherwise is no-op. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: A committed transaction or exception if commit is failed + """ + + self._set_finish(self._COMMIT) + + if self._tx_state.tx_id is None and not self._tx_state.dead: + return self + + return self._driver( + _create_commit_transaction_request(self._session_state, self._tx_state), + _apis.QueryService.Stub, + _apis.QueryService.CommitTransaction, + wrap_result_on_rollback_or_commit_tx, + settings, + (self._session_state, self._tx_state, self), + ) + + def rollback(self, settings=None): + pass + + def execute(self, query, parameters=None, commit_tx=False, settings=None): + pass \ No newline at end of file From 4c85d7613c04435b9338ebd28b8ba4df09845e69 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:26 +0300 Subject: [PATCH 14/57] wow tx begin works --- examples/query-service/basic_example.py | 6 +-- tests/query/test_query_transaction.py | 4 +- ydb/query/base.py | 52 ++++++++++++++++++++++--- ydb/query/session.py | 1 + ydb/query/transaction.py | 39 ++++++++----------- 5 files changed, 70 insertions(+), 32 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index 4c107c1a..e4af2c8e 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -21,9 +21,9 @@ def main(): it = session.execute("select 1; select 2;", commit_tx=False) for result_set in it: - pass - # print(f"columns: {str(result_set.columns)}") - # print(f"rows: {str(result_set.rows)}") + # pass + print(f"columns: {str(result_set.columns)}") + print(f"rows: {str(result_set.rows)}") diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 9877499c..71d9ac7f 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -10,8 +10,8 @@ def test_transaction_begin(self, driver_sync): tx = session.transaction() - assert tx._tx_state.tx_id == None + assert tx.tx_id == None tx.begin() - assert tx._tx_state.tx_id != None + assert tx.tx_id != None diff --git a/ydb/query/base.py b/ydb/query/base.py index 859738fe..92023eaf 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -17,9 +17,6 @@ class QueryClientSettings: ... -class IQueryTxContext: ... - - class QuerySessionState: _session_id: Optional[str] _node_id: Optional[int] @@ -60,6 +57,7 @@ def set_attached(self, attached: bool) -> None: class IQuerySession(abc.ABC): + @abc.abstractmethod def __init__(self, driver: SupportedDriverType, settings: QueryClientSettings = None): pass @@ -72,7 +70,48 @@ def delete(self) -> None: pass @abc.abstractmethod - def transaction(self, tx_mode: BaseQueryTxMode) -> IQueryTxContext: + def transaction(self, tx_mode: BaseQueryTxMode) -> "IQueryTxContext": + pass + + +class IQueryTxContext(abc.ABC): + + @abc.abstractmethod + def __init__(self, driver: SupportedDriverType, session_state: QuerySessionState, session: IQuerySession, tx_mode: BaseQueryTxMode = None): + pass + + @abc.abstractmethod + def __enter__(self): + pass + + @abc.abstractmethod + def __exit__(self, *args, **kwargs): + pass + + @property + @abc.abstractmethod + def session_id(self): + pass + + @property + @abc.abstractmethod + def tx_id(self): + pass + + @abc.abstractmethod + def begin(): + pass + + @abc.abstractmethod + def commit(): + pass + + @abc.abstractmethod + def rollback(): + pass + + @abc.abstractmethod + def execute(query: str): pass @@ -103,6 +142,9 @@ def create_execute_query_request(query: str, session_id: str, commit_tx: bool): def wrap_execute_query_response(rpc_state, response_pb): - print(response_pb) + # print("RESP:") + # print(f"meta: {response_pb.tx_meta}") + # print(response_pb) + return convert.ResultSet.from_message(response_pb.result_set) diff --git a/ydb/query/session.py b/ydb/query/session.py index 461180ef..8aee19b4 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -77,6 +77,7 @@ def _execute_call(self, query: str, commit_tx: bool): session_id=self._state.session_id, commit_tx=commit_tx ) + print(request) return self._driver( request, _apis.QueryService.Stub, diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 1fb8a74e..14c5a039 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -11,27 +11,23 @@ from .._tx_ctx_impl import TxState, reset_tx_id_handler from .._session_impl import bad_session_handler -from ..table import ( - AbstractTransactionModeBuilder, - ITxContext, - SerializableReadWrite -) +from . import base logger = logging.getLogger(__name__) -def patch_table_service_tx_mode_to_query_service(tx_mode: AbstractTransactionModeBuilder): - if tx_mode.name == 'snapshot_read_only': - tx_mode = _ydb_query_public.QuerySnapshotReadOnly() - elif tx_mode.name == 'serializable_read_write': - tx_mode = _ydb_query_public.QuerySerializableReadWrite() - elif tx_mode.name =='online_read_only': - tx_mode = _ydb_query_public.QueryOnlineReadOnly() - elif tx_mode.name == 'stale_read_only': - tx_mode = _ydb_query_public.QueryStaleReadOnly() - else: - raise issues.YDBInvalidArgumentError(f'Unknown transaction mode: {tx_mode.name}') +# def patch_table_service_tx_mode_to_query_service(tx_mode: AbstractTransactionModeBuilder): +# if tx_mode.name == 'snapshot_read_only': +# tx_mode = _ydb_query_public.QuerySnapshotReadOnly() +# elif tx_mode.name == 'serializable_read_write': +# tx_mode = _ydb_query_public.QuerySerializableReadWrite() +# elif tx_mode.name =='online_read_only': +# tx_mode = _ydb_query_public.QueryOnlineReadOnly() +# elif tx_mode.name == 'stale_read_only': +# tx_mode = _ydb_query_public.QueryStaleReadOnly() +# else: +# raise issues.YDBInvalidArgumentError(f'Unknown transaction mode: {tx_mode.name}') - return tx_mode +# return tx_mode def _construct_tx_settings(tx_state): @@ -69,7 +65,8 @@ def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx): # issues._process_response(response_pb.operation) print("wrap result") message = _ydb_query.BeginTransactionResponse.from_proto(response_pb) - tx_state.tx_id = message.tx_meta.id + + tx_state.tx_id = message.tx_meta.tx_id return tx @@ -83,7 +80,7 @@ def wrap_result_on_rollback_or_commit_tx(rpc_state, response_pb, session_state, return tx -class BaseTxContext(ITxContext): +class BaseTxContext(base.IQueryTxContext): _COMMIT = "commit" _ROLLBACK = "rollback" @@ -108,9 +105,7 @@ def __init__(self, driver, session_state, session, tx_mode=None): """ self._driver = driver if tx_mode is None: - tx_mode = patch_table_service_tx_mode_to_query_service(SerializableReadWrite()) - else: - tx_mode = patch_table_service_tx_mode_to_query_service(tx_mode) + tx_mode = _ydb_query_public.QuerySerializableReadWrite() self._tx_state = TxState(tx_mode) self._session_state = session_state self.session = session From 2ce3d6e4b46558a2c2860e39257e92237bdf17a1 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:26 +0300 Subject: [PATCH 15/57] tx state handler --- ydb/_grpc/grpcwrapper/ydb_query.py | 23 +++++++ ydb/query/base.py | 74 ++++++++++++++++++++--- ydb/query/transaction.py | 97 +++++++++++++++++++++--------- 3 files changed, 157 insertions(+), 37 deletions(-) diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index 542aa158..0bfdf792 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -116,6 +116,29 @@ def from_proto(msg: ydb_query_pb2.BeginTransactionResponse) -> "BeginTransaction tx_meta=TransactionMeta.from_proto(msg.tx_meta), ) + +@dataclass +class CommitTransactionResponse(IFromProto): + status: Optional[ServerStatus] + + @staticmethod + def from_proto(msg: ydb_query_pb2.CommitTransactionResponse) -> "CommitTransactionResponse": + return CommitTransactionResponse( + status=ServerStatus(msg.status, msg.issues), + ) + + +@dataclass +class RollbackTransactionResponse(IFromProto): + status: Optional[ServerStatus] + + @staticmethod + def from_proto(msg: ydb_query_pb2.RollbackTransactionResponse) -> "RollbackTransactionResponse": + return RollbackTransactionResponse( + status=ServerStatus(msg.status, msg.issues), + ) + + @dataclass class QueryContent(IFromPublic, IToProto): text: str diff --git a/ydb/query/base.py b/ydb/query/base.py index 92023eaf..40e562b3 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -1,4 +1,6 @@ import abc +import enum +import functools from typing import ( Optional, @@ -13,10 +15,51 @@ QuerySerializableReadWrite ) from .. import convert +from .. import issues class QueryClientSettings: ... +class QuerySessionStateEnum(enum.Enum): + NOT_INITIALIZED = "NOT_INITIALIZED" + CREATED = "CREATED" + CLOSED = "CLOSED" + + +class QuerySessionStateHelper(abc.ABC): + _VALID_TRANSITIONS = { + QuerySessionStateEnum.NOT_INITIALIZED: [QuerySessionStateEnum.CREATED], + QuerySessionStateEnum.CREATED: [QuerySessionStateEnum.CLOSED], + QuerySessionStateEnum.CLOSED: [] + } + + @classmethod + def valid_transition(cls, before: QuerySessionStateEnum, after: QuerySessionStateEnum) -> bool: + return after in cls._VALID_TRANSITIONS[before] + + +class QueryTxStateEnum(enum.Enum): + NOT_INITIALIZED = "NOT_INITIALIZED" + BEGINED = "BEGINED" + COMMITTED = "COMMITTED" + ROLLBACKED = "ROLLBACKED" + DEAD = "DEAD" + + +class QueryTxStateHelper(abc.ABC): + _VALID_TRANSITIONS = { + QueryTxStateEnum.NOT_INITIALIZED: [QueryTxStateEnum.BEGINED, QueryTxStateEnum.DEAD], + QueryTxStateEnum.BEGINED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED, QueryTxStateEnum.DEAD], + QueryTxStateEnum.COMMITTED: [], + QueryTxStateEnum.ROLLBACKED: [], + QueryTxStateEnum.DEAD: [], + } + + @classmethod + def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: + return after in cls._VALID_TRANSITIONS[before] + + class QuerySessionState: _session_id: Optional[str] _node_id: Optional[int] @@ -99,15 +142,15 @@ def tx_id(self): pass @abc.abstractmethod - def begin(): + def begin(settings: QueryClientSettings = None): pass @abc.abstractmethod - def commit(): + def commit(settings: QueryClientSettings = None): pass @abc.abstractmethod - def rollback(): + def rollback(settings: QueryClientSettings = None): pass @abc.abstractmethod @@ -142,9 +185,26 @@ def create_execute_query_request(query: str, session_id: str, commit_tx: bool): def wrap_execute_query_response(rpc_state, response_pb): - # print("RESP:") - # print(f"meta: {response_pb.tx_meta}") - # print(response_pb) + return convert.ResultSet.from_message(response_pb.result_set) +X_YDB_SERVER_HINTS = "x-ydb-server-hints" +X_YDB_SESSION_CLOSE = "session-close" - return convert.ResultSet.from_message(response_pb.result_set) + +def _check_session_is_closing(rpc_state, session_state): + metadata = rpc_state.trailing_metadata() + if X_YDB_SESSION_CLOSE in metadata.get(X_YDB_SERVER_HINTS, []): + session_state.set_closing() + + +def bad_session_handler(func): + @functools.wraps(func) + def decorator(rpc_state, response_pb, session_state, *args, **kwargs): + try: + _check_session_is_closing(rpc_state, session_state) + return func(rpc_state, response_pb, session_state, *args, **kwargs) + except issues.BadSession: + session_state.reset() + raise + + return decorator \ No newline at end of file diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 14c5a039..fd3e26e8 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -1,5 +1,7 @@ import abc import logging +import enum +import functools from .. import ( _apis, @@ -9,8 +11,6 @@ from .._grpc.grpcwrapper import ydb_query as _ydb_query from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public -from .._tx_ctx_impl import TxState, reset_tx_id_handler -from .._session_impl import bad_session_handler from . import base logger = logging.getLogger(__name__) @@ -30,6 +30,38 @@ # return tx_mode +def reset_tx_id_handler(func): + @functools.wraps(func) + def decorator(rpc_state, response_pb, session_state, tx_state, *args, **kwargs): + try: + return func(rpc_state, response_pb, session_state, tx_state, *args, **kwargs) + except issues.Error: + tx_state.change_state(base.QueryTxStateEnum.DEAD) + tx_state.tx_id = None + raise + + return decorator + + +class QueryTxState: + def __init__(self, tx_mode: base.BaseQueryTxMode): + """ + Holds transaction context manager info + :param tx_mode: A mode of transaction + """ + self.tx_id = None + self.tx_mode = tx_mode + self._state = base.QueryTxStateEnum.NOT_INITIALIZED + + def check_invalid_transition(self, target: base.QueryTxStateEnum): + if not base.QueryTxStateHelper.is_valid_transition(self._state, target): + raise RuntimeError(f"Transaction could not be moved from {self._state.value} to {target.value}") + + def change_state(self, target: base.QueryTxStateEnum): + self.check_invalid_transition(target) + self._state = target + + def _construct_tx_settings(tx_state): tx_settings = _ydb_query.TransactionSettings.from_public(tx_state.tx_mode) return tx_settings @@ -41,8 +73,6 @@ def _create_begin_transaction_request(session_state, tx_state): tx_settings=_construct_tx_settings(tx_state), ).to_proto() - print(request) - return request @@ -59,32 +89,35 @@ def _create_rollback_transaction_request(session_state, tx_state): return request -@bad_session_handler +@base.bad_session_handler def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx): - # session_state.complete_query() - # issues._process_response(response_pb.operation) - print("wrap result") message = _ydb_query.BeginTransactionResponse.from_proto(response_pb) - + issues._process_response(message.status) + tx_state.change_state(base.QueryTxStateEnum.BEGINED) tx_state.tx_id = message.tx_meta.tx_id return tx -@bad_session_handler +@base.bad_session_handler @reset_tx_id_handler -def wrap_result_on_rollback_or_commit_tx(rpc_state, response_pb, session_state, tx_state, tx): +def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx): + message = _ydb_query.CommitTransactionResponse(response_pb) + issues._process_response(message.status) + tx_state.tx_id = None + tx_state.change_state(base.QueryTxStateEnum.COMMITTED) + return tx - # issues._process_response(response_pb.operation) - # transaction successfully committed or rolled back +@base.bad_session_handler +@reset_tx_id_handler +def wrap_tx_rollback_response(rpc_state, response_pb, session_state, tx_state, tx): + message = _ydb_query.RollbackTransactionResponse(response_pb) + issues._process_response(message.status) tx_state.tx_id = None + tx_state.change_state(base.QueryTxStateEnum.ROLLBACKED) return tx class BaseTxContext(base.IQueryTxContext): - - _COMMIT = "commit" - _ROLLBACK = "rollback" - def __init__(self, driver, session_state, session, tx_mode=None): """ An object that provides a simple transaction context manager that allows statements execution @@ -106,7 +139,7 @@ def __init__(self, driver, session_state, session, tx_mode=None): self._driver = driver if tx_mode is None: tx_mode = _ydb_query_public.QuerySerializableReadWrite() - self._tx_state = TxState(tx_mode) + self._tx_state = QueryTxState(tx_mode) self._session_state = session_state self.session = session self._finished = "" @@ -163,17 +196,15 @@ def begin(self, settings=None): :return: An open transaction """ - if self._tx_state.tx_id is not None: - return self - - print('try to begin tx') + self._tx_state.check_invalid_transition(base.QueryTxStateEnum.BEGINED) return self._driver( _create_begin_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, _apis.QueryService.BeginTransaction, - wrap_result=wrap_tx_begin_response, - wrap_args=(self._session_state, self._tx_state, self), + wrap_tx_begin_response, + settings, + (self._session_state, self._tx_state, self), ) def commit(self, settings=None): @@ -186,22 +217,28 @@ def commit(self, settings=None): :return: A committed transaction or exception if commit is failed """ - self._set_finish(self._COMMIT) - - if self._tx_state.tx_id is None and not self._tx_state.dead: - return self + self._tx_state.check_invalid_transition(base.QueryTxStateEnum.COMMITTED) return self._driver( _create_commit_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, _apis.QueryService.CommitTransaction, - wrap_result_on_rollback_or_commit_tx, + wrap_tx_commit_response, settings, (self._session_state, self._tx_state, self), ) def rollback(self, settings=None): - pass + self._tx_state.check_invalid_transition(base.QueryTxStateEnum.ROLLBACKED) + + return self._driver( + _create_rollback_transaction_request(self._session_state, self._tx_state), + _apis.QueryService.Stub, + _apis.QueryService.RollbackTransaction, + wrap_tx_rollback_response, + settings, + (self._session_state, self._tx_state, self), + ) def execute(self, query, parameters=None, commit_tx=False, settings=None): pass \ No newline at end of file From 75ea82d6e7f30217ed0873a48e817414f12816aa Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:26 +0300 Subject: [PATCH 16/57] refactor session --- tests/query/conftest.py | 27 +++++++ tests/query/test_query_session.py | 60 +++++++++++++-- ydb/query/base.py | 87 ++++++--------------- ydb/query/session.py | 123 +++++++++++++++++++++++++----- 4 files changed, 209 insertions(+), 88 deletions(-) create mode 100644 tests/query/conftest.py diff --git a/tests/query/conftest.py b/tests/query/conftest.py new file mode 100644 index 00000000..c098c631 --- /dev/null +++ b/tests/query/conftest.py @@ -0,0 +1,27 @@ +import pytest +from ydb.query.session import QuerySessionSync + + +@pytest.fixture +def session(driver_sync): + session = QuerySessionSync(driver_sync) + + yield session + + try: + session.delete() + except BaseException: + pass + +@pytest.fixture +def transaction(session): + session.create() + transaction = session.transaction() + + yield transaction + + try: + transaction.rollback() + except BaseException: + pass + diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py index ef0c26f5..71e06b37 100644 --- a/tests/query/test_query_session.py +++ b/tests/query/test_query_session.py @@ -1,7 +1,5 @@ import pytest -import ydb.query.session - def _check_session_state_empty(session): assert session._state.session_id == None assert session._state.node_id == None @@ -13,8 +11,7 @@ def _check_session_state_full(session): assert session._state.attached == True class TestQuerySession: - def test_session_normal_lifecycle(self, driver_sync): - session = ydb.query.session.QuerySessionSync(driver_sync) + def test_session_normal_lifecycle(self, session): _check_session_state_empty(session) session.create() @@ -23,8 +20,7 @@ def test_session_normal_lifecycle(self, driver_sync): session.delete() _check_session_state_empty(session) - def test_second_create_do_nothing(self, driver_sync): - session = ydb.query.session.QuerySessionSync(driver_sync) + def test_second_create_do_nothing(self, session): session.create() _check_session_state_full(session) @@ -36,3 +32,55 @@ def test_second_create_do_nothing(self, driver_sync): assert session._state.session_id == session_id_before assert session._state.node_id == node_id_before + + def test_second_delete_do_nothing(self, session): + session.create() + + session.delete() + session.delete() + + def test_delete_before_create_not_possible(self, session): + with pytest.raises(RuntimeError): + session.delete() + + def test_create_after_delete_not_possible(self, session): + session.create() + session.delete() + with pytest.raises(RuntimeError): + session.create() + + def test_transaction_before_create_raises(self, session): + with pytest.raises(RuntimeError): + session.transaction() + + def test_transaction_after_delete_raises(self, session): + session.create() + + session.delete() + + with pytest.raises(RuntimeError): + session.transaction() + + def test_transaction_after_create_not_raises(self, session): + session.create() + session.transaction() + + def test_execute_before_create_raises(self, session): + with pytest.raises(RuntimeError): + session.execute("select 1;") + + def test_execute_after_delete_raises(self, session): + session.create() + session.delete() + with pytest.raises(RuntimeError): + session.execute("select 1;") + + def test_basic_execute(self, session): + session.create() + it = session.execute("select 1;") + result_sets = [result_set for result_set in it] + + assert len(result_sets) == 1 + assert len(result_sets[0].rows) == 1 + assert len(result_sets[0].columns) == 1 + assert list(result_sets[0].rows[0].values()) == [1] \ No newline at end of file diff --git a/ydb/query/base.py b/ydb/query/base.py index 40e562b3..7db6a5de 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -20,83 +20,40 @@ class QueryClientSettings: ... -class QuerySessionStateEnum(enum.Enum): - NOT_INITIALIZED = "NOT_INITIALIZED" - CREATED = "CREATED" - CLOSED = "CLOSED" - - -class QuerySessionStateHelper(abc.ABC): - _VALID_TRANSITIONS = { - QuerySessionStateEnum.NOT_INITIALIZED: [QuerySessionStateEnum.CREATED], - QuerySessionStateEnum.CREATED: [QuerySessionStateEnum.CLOSED], - QuerySessionStateEnum.CLOSED: [] - } - - @classmethod - def valid_transition(cls, before: QuerySessionStateEnum, after: QuerySessionStateEnum) -> bool: - return after in cls._VALID_TRANSITIONS[before] - - -class QueryTxStateEnum(enum.Enum): - NOT_INITIALIZED = "NOT_INITIALIZED" - BEGINED = "BEGINED" - COMMITTED = "COMMITTED" - ROLLBACKED = "ROLLBACKED" - DEAD = "DEAD" - - -class QueryTxStateHelper(abc.ABC): - _VALID_TRANSITIONS = { - QueryTxStateEnum.NOT_INITIALIZED: [QueryTxStateEnum.BEGINED, QueryTxStateEnum.DEAD], - QueryTxStateEnum.BEGINED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED, QueryTxStateEnum.DEAD], - QueryTxStateEnum.COMMITTED: [], - QueryTxStateEnum.ROLLBACKED: [], - QueryTxStateEnum.DEAD: [], - } - - @classmethod - def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: - return after in cls._VALID_TRANSITIONS[before] - - -class QuerySessionState: - _session_id: Optional[str] - _node_id: Optional[int] - _attached: bool = False - _settings: Optional[QueryClientSettings] - +class IQuerySessionState(abc.ABC): def __init__(self, settings: QueryClientSettings = None): - self._settings = settings - self.reset() + pass + @abc.abstractmethod def reset(self) -> None: - self._session_id = None - self._node_id = None - self._attached = False + pass @property + @abc.abstractmethod def session_id(self) -> Optional[str]: - return self._session_id + pass - def set_session_id(self, session_id: str) -> "QuerySessionState": - self._session_id = session_id - return self + @abc.abstractmethod + def set_session_id(self, session_id: str) -> "IQuerySessionState": + pass @property + @abc.abstractmethod def node_id(self) -> Optional[int]: - return self._node_id + pass - def set_node_id(self, node_id: int) -> "QuerySessionState": - self._node_id = node_id - return self + @abc.abstractmethod + def set_node_id(self, node_id: int) -> "IQuerySessionState": + pass @property + @abc.abstractmethod def attached(self) -> bool: - return self._attached + pass - def set_attached(self, attached: bool) -> None: - self._attached = attached + @abc.abstractmethod + def set_attached(self, attached: bool) -> "IQuerySessionState": + pass class IQuerySession(abc.ABC): @@ -120,7 +77,7 @@ def transaction(self, tx_mode: BaseQueryTxMode) -> "IQueryTxContext": class IQueryTxContext(abc.ABC): @abc.abstractmethod - def __init__(self, driver: SupportedDriverType, session_state: QuerySessionState, session: IQuerySession, tx_mode: BaseQueryTxMode = None): + def __init__(self, driver: SupportedDriverType, session_state: IQuerySessionState, session: IQuerySession, tx_mode: BaseQueryTxMode = None): pass @abc.abstractmethod @@ -184,9 +141,9 @@ def create_execute_query_request(query: str, session_id: str, commit_tx: bool): return req.to_proto() def wrap_execute_query_response(rpc_state, response_pb): - return convert.ResultSet.from_message(response_pb.result_set) + X_YDB_SERVER_HINTS = "x-ydb-server-hints" X_YDB_SESSION_CLOSE = "session-close" @@ -194,7 +151,7 @@ def wrap_execute_query_response(rpc_state, response_pb): def _check_session_is_closing(rpc_state, session_state): metadata = rpc_state.trailing_metadata() if X_YDB_SESSION_CLOSE in metadata.get(X_YDB_SERVER_HINTS, []): - session_state.set_closing() + session_state.set_closing() # TODO: clarify & implement def bad_session_handler(func): diff --git a/ydb/query/session.py b/ydb/query/session.py index 8aee19b4..e59a8acd 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -2,6 +2,7 @@ from abc import abstractmethod import asyncio import concurrent +import enum import logging import threading from typing import ( @@ -22,29 +23,112 @@ logger = logging.getLogger(__name__) -def wrapper_create_session(rpc_state, response_pb, session_state: base.QuerySessionState, session): - #TODO: process response +class QuerySessionStateEnum(enum.Enum): + NOT_INITIALIZED = "NOT_INITIALIZED" + CREATED = "CREATED" + CLOSED = "CLOSED" + + +class QuerySessionStateHelper(abc.ABC): + _VALID_TRANSITIONS = { + QuerySessionStateEnum.NOT_INITIALIZED: [QuerySessionStateEnum.CREATED], + QuerySessionStateEnum.CREATED: [QuerySessionStateEnum.CLOSED], + QuerySessionStateEnum.CLOSED: [] + } + + _READY_TO_USE = [ + QuerySessionStateEnum.CREATED + ] + + @classmethod + def valid_transition(cls, before: QuerySessionStateEnum, after: QuerySessionStateEnum) -> bool: + return after in cls._VALID_TRANSITIONS[before] + + @classmethod + def ready_to_use(cls, state: QuerySessionStateEnum) -> bool: + return state in cls._READY_TO_USE + + +class QuerySessionState(base.IQuerySessionState): + _session_id: Optional[str] + _node_id: Optional[int] + _attached: bool = False + _settings: Optional[base.QueryClientSettings] + _state: QuerySessionStateEnum + + def __init__(self, settings: base.QueryClientSettings = None): + self._settings = settings + self._state = QuerySessionStateEnum.NOT_INITIALIZED + self.reset() + + def reset(self) -> None: + self._session_id = None + self._node_id = None + self._attached = False + + @property + def session_id(self) -> Optional[str]: + return self._session_id + + def set_session_id(self, session_id: str) -> "QuerySessionState": + self._session_id = session_id + return self + + @property + def node_id(self) -> Optional[int]: + return self._node_id + + def set_node_id(self, node_id: int) -> "QuerySessionState": + self._node_id = node_id + return self + + @property + def attached(self) -> bool: + return self._attached + + def set_attached(self, attached: bool) -> "QuerySessionState": + self._attached = attached + + def _check_invalid_transition(self, target: QuerySessionStateEnum): + if not QuerySessionStateHelper.valid_transition(self._state, target): + raise RuntimeError(f"Session could not be moved from {self._state.value} to {target.value}") + + def _change_state(self, target: QuerySessionStateEnum): + self._check_invalid_transition(target) + self._state = target + + def _check_session_ready_to_use(self): + if not QuerySessionStateHelper.ready_to_use(self._state): + raise RuntimeError(f"Session is not ready to use, current state: {self._state.value}") + + def _already_in(self, target): + return self._state == target + + +def wrapper_create_session(rpc_state, response_pb, session_state: QuerySessionState, session): message = _ydb_query.CreateSessionResponse.from_proto(response_pb) + issues._process_response(message.status) session_state.set_session_id(message.session_id).set_node_id(message.node_id) return session -def wrapper_delete_session(rpc_state, response_pb, session_state: base.QuerySessionState, session): - #TODO: process response +def wrapper_delete_session(rpc_state, response_pb, session_state: QuerySessionState, session): message = _ydb_query.DeleteSessionResponse.from_proto(response_pb) + issues._process_response(message.status) session_state.reset() + session_state._change_state(QuerySessionStateEnum.CLOSED) return session class BaseQuerySession(base.IQuerySession): _driver: base.SupportedDriverType _settings: Optional[base.QueryClientSettings] - _state: base.QuerySessionState + _state: QuerySessionState def __init__(self, driver: base.SupportedDriverType, settings: base.QueryClientSettings = None): self._driver = driver self._settings = settings - self._state = base.QuerySessionState(settings) + self._state = QuerySessionState(settings) def _create_call(self): return self._driver( @@ -77,12 +161,11 @@ def _execute_call(self, query: str, commit_tx: bool): session_id=self._state.session_id, commit_tx=commit_tx ) - print(request) + return self._driver( request, _apis.QueryService.Stub, _apis.QueryService.ExecuteQuery, - # base.wrap_execute_query_response ) class QuerySessionSync(BaseQuerySession): @@ -98,9 +181,9 @@ def _attach(self): first_response = next(status_stream) if first_response.status != issues.StatusCode.SUCCESS: pass - # raise common_utils.YdbStatusError(first_response) self._state.set_attached(True) + self._state._change_state(QuerySessionStateEnum.CREATED) threading.Thread( target=self._chech_session_status_loop, @@ -113,27 +196,31 @@ def _chech_session_status_loop(self, status_stream): try: for status in status_stream: if status.status != issues.StatusCode.SUCCESS: - print("STATUS NOT SUCCESS") - self._state.reset(False) + self._state.reset() + self._state._change_state(QuerySessionStateEnum.CLOSED) except Exception as e: pass def delete(self) -> None: - if not self._state.session_id: + if self._state._already_in(QuerySessionStateEnum.CLOSED): return + + self._state._check_invalid_transition(QuerySessionStateEnum.CLOSED) self._delete_call() self._stream.cancel() def create(self) -> None: - if self._state.session_id: + if self._state._already_in(QuerySessionStateEnum.CREATED): return + + self._state._check_invalid_transition(QuerySessionStateEnum.CREATED) self._create_call() self._attach() def transaction(self, tx_mode: base.BaseQueryTxMode = None) -> base.IQueryTxContext: - if not self._state.session_id: - return + self._state._check_session_ready_to_use() + return BaseTxContext( self._driver, self._state, @@ -141,8 +228,10 @@ def transaction(self, tx_mode: base.BaseQueryTxMode = None) -> base.IQueryTxCont tx_mode, ) - def execute(self, query: str, commit_tx: bool = True): - stream_it = self._execute_call(query, commit_tx) + def execute(self, query: str, parameters=None): + self._state._check_session_ready_to_use() + + stream_it = self._execute_call(query, commit_tx=True) return _utilities.SyncResponseIterator( stream_it, From d9a0424d7a26dc3b9ea3e9397db09da42fbbd817 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 15:46:26 +0300 Subject: [PATCH 17/57] refactor transaction --- docker-compose.yml | 2 +- tests/query/conftest.py | 2 +- tests/query/test_query_transaction.py | 43 ++++++++--- ydb/query/base.py | 47 +++++++----- ydb/query/transaction.py | 100 +++++++++++++++++++------- 5 files changed, 141 insertions(+), 53 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 50a31f12..cb37a377 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.3" services: ydb: - image: ydbplatform/local-ydb:latest + image: ydbplatform/local-ydb:trunk restart: always ports: - 2136:2136 diff --git a/tests/query/conftest.py b/tests/query/conftest.py index c098c631..1c1cf333 100644 --- a/tests/query/conftest.py +++ b/tests/query/conftest.py @@ -14,7 +14,7 @@ def session(driver_sync): pass @pytest.fixture -def transaction(session): +def tx(session): session.create() transaction = session.transaction() diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 71d9ac7f..0731882a 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -1,17 +1,44 @@ import pytest -import ydb.query.session +class TestQueryTransaction: + def test_tx_begin(self, tx): + assert tx.tx_id == None -class TestQuerySession: - def test_transaction_begin(self, driver_sync): - session = ydb.query.session.QuerySessionSync(driver_sync) + tx.begin() + assert tx.tx_id != None - session.create() + def test_tx_allow_double_commit(self, tx): + tx.begin() + tx.commit() + tx.commit() - tx = session.transaction() + def test_tx_allow_double_rollback(self, tx): + tx.begin() + tx.rollback() + tx.rollback() - assert tx.tx_id == None + def test_tx_commit_raises_before_begin(self, tx): + with pytest.raises(RuntimeError): + tx.commit() + def test_tx_rollback_raises_before_begin(self, tx): + with pytest.raises(RuntimeError): + tx.rollback() + + # def test_tx_execute_raises_before_begin(self, tx): + # with pytest.raises(RuntimeError): + # tx.execute("select 1;") + + def text_tx_execute_raises_after_commit(self, tx): tx.begin() + tx.commit() + with pytest.raises(RuntimeError): + tx.execute("select 1;") + + def text_tx_execute_raises_after_rollback(self, tx): + tx.begin() + tx.rollback() + with pytest.raises(RuntimeError): + tx.execute("select 1;") + - assert tx.tx_id != None diff --git a/ydb/query/base.py b/ydb/query/base.py index 7db6a5de..44cc94fd 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -124,19 +124,32 @@ def session(self) -> IQuerySession: pass -def create_execute_query_request(query: str, session_id: str, commit_tx: bool): - req = ydb_query.ExecuteQueryRequest( - session_id=session_id, - query_content=ydb_query.QueryContent.from_public( - query=query, - ), - tx_control=ydb_query.TransactionControl( - begin_tx=ydb_query.TransactionSettings( - tx_mode=QuerySerializableReadWrite(), +def create_execute_query_request(query: str, session_id: str, tx_id: str = None, commit_tx: bool = False, tx_mode: BaseQueryTxMode = None): + if tx_id: + req = ydb_query.ExecuteQueryRequest( + session_id=session_id, + query_content=ydb_query.QueryContent.from_public( + query=query, ), - commit_tx=commit_tx - ), - ) + tx_control=ydb_query.TransactionControl( + tx_id=tx_id, + commit_tx=commit_tx + ), + ) + else: + tx_mode = tx_mode if tx_mode is not None else QuerySerializableReadWrite() + req = ydb_query.ExecuteQueryRequest( + session_id=session_id, + query_content=ydb_query.QueryContent.from_public( + query=query, + ), + tx_control=ydb_query.TransactionControl( + begin_tx=ydb_query.TransactionSettings( + tx_mode=tx_mode, + ), + commit_tx=commit_tx + ), + ) return req.to_proto() @@ -148,17 +161,17 @@ def wrap_execute_query_response(rpc_state, response_pb): X_YDB_SESSION_CLOSE = "session-close" -def _check_session_is_closing(rpc_state, session_state): - metadata = rpc_state.trailing_metadata() - if X_YDB_SESSION_CLOSE in metadata.get(X_YDB_SERVER_HINTS, []): - session_state.set_closing() # TODO: clarify & implement +# def _check_session_is_closing(rpc_state, session_state): +# metadata = rpc_state.trailing_metadata() +# if X_YDB_SESSION_CLOSE in metadata.get(X_YDB_SERVER_HINTS, []): +# session_state.set_closing() # TODO: clarify & implement def bad_session_handler(func): @functools.wraps(func) def decorator(rpc_state, response_pb, session_state, *args, **kwargs): try: - _check_session_is_closing(rpc_state, session_state) + # _check_session_is_closing(rpc_state, session_state) return func(rpc_state, response_pb, session_state, *args, **kwargs) except issues.BadSession: session_state.reset() diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index fd3e26e8..90f5f681 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -15,19 +15,37 @@ logger = logging.getLogger(__name__) -# def patch_table_service_tx_mode_to_query_service(tx_mode: AbstractTransactionModeBuilder): -# if tx_mode.name == 'snapshot_read_only': -# tx_mode = _ydb_query_public.QuerySnapshotReadOnly() -# elif tx_mode.name == 'serializable_read_write': -# tx_mode = _ydb_query_public.QuerySerializableReadWrite() -# elif tx_mode.name =='online_read_only': -# tx_mode = _ydb_query_public.QueryOnlineReadOnly() -# elif tx_mode.name == 'stale_read_only': -# tx_mode = _ydb_query_public.QueryStaleReadOnly() -# else: -# raise issues.YDBInvalidArgumentError(f'Unknown transaction mode: {tx_mode.name}') -# return tx_mode +class QueryTxStateEnum(enum.Enum): + NOT_INITIALIZED = "NOT_INITIALIZED" + BEGINED = "BEGINED" + COMMITTED = "COMMITTED" + ROLLBACKED = "ROLLBACKED" + DEAD = "DEAD" + + +class QueryTxStateHelper(abc.ABC): + _VALID_TRANSITIONS = { + QueryTxStateEnum.NOT_INITIALIZED: [QueryTxStateEnum.BEGINED, QueryTxStateEnum.DEAD], + QueryTxStateEnum.BEGINED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED, QueryTxStateEnum.DEAD], + QueryTxStateEnum.COMMITTED: [], + QueryTxStateEnum.ROLLBACKED: [], + QueryTxStateEnum.DEAD: [], + } + + _TERMINAL_STATES = [ + QueryTxStateEnum.COMMITTED, + QueryTxStateEnum.ROLLBACKED, + QueryTxStateEnum.DEAD, + ] + + @classmethod + def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: + return after in cls._VALID_TRANSITIONS[before] + + @classmethod + def terminal(cls, state: QueryTxStateEnum) -> bool: + return state in cls._TERMINAL_STATES def reset_tx_id_handler(func): @@ -36,7 +54,7 @@ def decorator(rpc_state, response_pb, session_state, tx_state, *args, **kwargs): try: return func(rpc_state, response_pb, session_state, tx_state, *args, **kwargs) except issues.Error: - tx_state.change_state(base.QueryTxStateEnum.DEAD) + tx_state._change_state(QueryTxStateEnum.DEAD) tx_state.tx_id = None raise @@ -51,16 +69,23 @@ def __init__(self, tx_mode: base.BaseQueryTxMode): """ self.tx_id = None self.tx_mode = tx_mode - self._state = base.QueryTxStateEnum.NOT_INITIALIZED + self._state = QueryTxStateEnum.NOT_INITIALIZED - def check_invalid_transition(self, target: base.QueryTxStateEnum): - if not base.QueryTxStateHelper.is_valid_transition(self._state, target): + def _check_invalid_transition(self, target: QueryTxStateEnum): + if not QueryTxStateHelper.valid_transition(self._state, target): raise RuntimeError(f"Transaction could not be moved from {self._state.value} to {target.value}") - def change_state(self, target: base.QueryTxStateEnum): - self.check_invalid_transition(target) + def _change_state(self, target: QueryTxStateEnum): + self._check_invalid_transition(target) self._state = target + def _check_tx_not_terminal(self): + if QueryTxStateHelper.terminal(self._state): + raise RuntimeError(f"Transaction is in terminal state: {self._state.value}") + + def _already_in(self, target: QueryTxStateEnum): + return self._state == target + def _construct_tx_settings(tx_state): tx_settings = _ydb_query.TransactionSettings.from_public(tx_state.tx_mode) @@ -93,7 +118,7 @@ def _create_rollback_transaction_request(session_state, tx_state): def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx): message = _ydb_query.BeginTransactionResponse.from_proto(response_pb) issues._process_response(message.status) - tx_state.change_state(base.QueryTxStateEnum.BEGINED) + tx_state._change_state(QueryTxStateEnum.BEGINED) tx_state.tx_id = message.tx_meta.tx_id return tx @@ -104,7 +129,7 @@ def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx) message = _ydb_query.CommitTransactionResponse(response_pb) issues._process_response(message.status) tx_state.tx_id = None - tx_state.change_state(base.QueryTxStateEnum.COMMITTED) + tx_state._change_state(QueryTxStateEnum.COMMITTED) return tx @base.bad_session_handler @@ -113,7 +138,7 @@ def wrap_tx_rollback_response(rpc_state, response_pb, session_state, tx_state, t message = _ydb_query.RollbackTransactionResponse(response_pb) issues._process_response(message.status) tx_state.tx_id = None - tx_state.change_state(base.QueryTxStateEnum.ROLLBACKED) + tx_state._change_state(QueryTxStateEnum.ROLLBACKED) return tx @@ -196,7 +221,7 @@ def begin(self, settings=None): :return: An open transaction """ - self._tx_state.check_invalid_transition(base.QueryTxStateEnum.BEGINED) + self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) return self._driver( _create_begin_transaction_request(self._session_state, self._tx_state), @@ -216,8 +241,9 @@ def commit(self, settings=None): :return: A committed transaction or exception if commit is failed """ - - self._tx_state.check_invalid_transition(base.QueryTxStateEnum.COMMITTED) + if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + return + self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) return self._driver( _create_commit_transaction_request(self._session_state, self._tx_state), @@ -229,7 +255,10 @@ def commit(self, settings=None): ) def rollback(self, settings=None): - self._tx_state.check_invalid_transition(base.QueryTxStateEnum.ROLLBACKED) + if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + return + + self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) return self._driver( _create_rollback_transaction_request(self._session_state, self._tx_state), @@ -240,5 +269,24 @@ def rollback(self, settings=None): (self._session_state, self._tx_state, self), ) + def _execute_call(self, query: str, commit_tx: bool): + request = base.create_execute_query_request( + query=query, + session_id=self._session_state.session_id, + commit_tx=commit_tx + ) + return self._driver( + request, + _apis.QueryService.Stub, + _apis.QueryService.ExecuteQuery, + ) + def execute(self, query, parameters=None, commit_tx=False, settings=None): - pass \ No newline at end of file + self._tx_state._check_tx_not_terminal() + + stream_it = self._execute_call(query, commit_tx) + + return _utilities.SyncResponseIterator( + stream_it, + lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp), + ) From e82076426b0c69c26015d531a57680f62945fd91 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 16:05:43 +0300 Subject: [PATCH 18/57] codestyles fixes --- examples/query-service/basic_example.py | 3 +-- tests/query/conftest.py | 2 +- tests/query/test_query_session.py | 17 ++++++++++------- tests/query/test_query_transaction.py | 6 ++---- ydb/_grpc/grpcwrapper/common_utils.py | 1 - ydb/_grpc/grpcwrapper/ydb_query.py | 18 ++---------------- .../grpcwrapper/ydb_query_public_types.py | 2 -- ydb/query/base.py | 8 +++++--- ydb/query/session.py | 8 ++------ ydb/query/transaction.py | 2 ++ 10 files changed, 25 insertions(+), 42 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index e4af2c8e..d0987cd9 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -26,6 +26,5 @@ def main(): print(f"rows: {str(result_set.rows)}") - if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/tests/query/conftest.py b/tests/query/conftest.py index 1c1cf333..a7f0c34c 100644 --- a/tests/query/conftest.py +++ b/tests/query/conftest.py @@ -13,6 +13,7 @@ def session(driver_sync): except BaseException: pass + @pytest.fixture def tx(session): session.create() @@ -24,4 +25,3 @@ def tx(session): transaction.rollback() except BaseException: pass - diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py index 71e06b37..dc0b7664 100644 --- a/tests/query/test_query_session.py +++ b/tests/query/test_query_session.py @@ -1,14 +1,17 @@ import pytest + def _check_session_state_empty(session): - assert session._state.session_id == None - assert session._state.node_id == None - assert session._state.attached == False + assert session._state.session_id is None + assert session._state.node_id is None + assert not session._state.attached + def _check_session_state_full(session): - assert session._state.session_id != None - assert session._state.node_id != None - assert session._state.attached == True + assert session._state.session_id is not None + assert session._state.node_id is not None + assert session._state.attached + class TestQuerySession: def test_session_normal_lifecycle(self, session): @@ -83,4 +86,4 @@ def test_basic_execute(self, session): assert len(result_sets) == 1 assert len(result_sets[0].rows) == 1 assert len(result_sets[0].columns) == 1 - assert list(result_sets[0].rows[0].values()) == [1] \ No newline at end of file + assert list(result_sets[0].rows[0].values()) == [1] diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 0731882a..0cde7d47 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -2,10 +2,10 @@ class TestQueryTransaction: def test_tx_begin(self, tx): - assert tx.tx_id == None + assert tx.tx_id is None tx.begin() - assert tx.tx_id != None + assert tx.tx_id is not None def test_tx_allow_double_commit(self, tx): tx.begin() @@ -40,5 +40,3 @@ def text_tx_execute_raises_after_rollback(self, tx): tx.rollback() with pytest.raises(RuntimeError): tx.execute("select 1;") - - diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 5d71f4d0..895d4036 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -247,7 +247,6 @@ async def _start_sync_driver(self, driver: ydb.Driver, request, stub, method): self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor) - @dataclass(init=False) class ServerStatus(IFromProto): __slots__ = ("_grpc_status_code", "_issues") diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index 0bfdf792..0026f050 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -2,7 +2,6 @@ import typing from typing import Optional -from google.protobuf.message import Message # Workaround for good IDE and universal for runtime if typing.TYPE_CHECKING: @@ -14,16 +13,9 @@ from .common_utils import ( IFromProto, - IFromProtoWithProtoType, IToProto, - IToPublic, IFromPublic, ServerStatus, - UnknownGrpcMessageError, - proto_duration_from_timedelta, - proto_timestamp_from_datetime, - datetime_from_proto_timestamp, - timedelta_from_proto_duration, ) @dataclass @@ -57,14 +49,6 @@ class AttachSessionRequest(IToProto): def to_proto(self) -> ydb_query_pb2.AttachSessionRequest: return ydb_query_pb2.AttachSessionRequest(session_id=self.session_id) -# @dataclass -# class SessionState(IFromProto): -# status: Optional[ServerStatus] - -# @staticmethod -# def from_proto(msg: ydb_query_pb2.SessionState) -> "SessionState": -# return SessionState(status=ServerStatus(msg.status, msg.issues)) - @dataclass class TransactionMeta(IFromProto): @@ -93,6 +77,7 @@ def to_proto(self) -> ydb_query_pb2.TransactionSettings: if self.tx_mode.name == 'stale_read_only': return ydb_query_pb2.TransactionSettings(stale_read_only=self.tx_mode.to_proto()) + @dataclass class BeginTransactionRequest(IToProto): session_id: str @@ -104,6 +89,7 @@ def to_proto(self) -> ydb_query_pb2.BeginTransactionRequest: tx_settings=self.tx_settings.to_proto(), ) + @dataclass class BeginTransactionResponse(IFromProto): status: Optional[ServerStatus] diff --git a/ydb/_grpc/grpcwrapper/ydb_query_public_types.py b/ydb/_grpc/grpcwrapper/ydb_query_public_types.py index 27d1e917..d79a2967 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_query_public_types.py @@ -1,8 +1,6 @@ import abc import typing -from google.protobuf.message import Message - from .common_utils import IToProto # Workaround for good IDE and universal for runtime diff --git a/ydb/query/base.py b/ydb/query/base.py index 44cc94fd..5895db73 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -1,5 +1,4 @@ import abc -import enum import functools from typing import ( @@ -17,7 +16,9 @@ from .. import convert from .. import issues -class QueryClientSettings: ... + +class QueryClientSettings: + pass class IQuerySessionState(abc.ABC): @@ -153,6 +154,7 @@ def create_execute_query_request(query: str, session_id: str, tx_id: str = None, return req.to_proto() + def wrap_execute_query_response(rpc_state, response_pb): return convert.ResultSet.from_message(response_pb.result_set) @@ -177,4 +179,4 @@ def decorator(rpc_state, response_pb, session_state, *args, **kwargs): session_state.reset() raise - return decorator \ No newline at end of file + return decorator diff --git a/ydb/query/session.py b/ydb/query/session.py index e59a8acd..9c8ed1ce 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -1,14 +1,9 @@ import abc -from abc import abstractmethod -import asyncio -import concurrent import enum import logging import threading from typing import ( - Any, Optional, - Set, ) from . import base @@ -168,6 +163,7 @@ def _execute_call(self, query: str, commit_tx: bool): _apis.QueryService.ExecuteQuery, ) + class QuerySessionSync(BaseQuerySession): _stream = None @@ -236,4 +232,4 @@ def execute(self, query: str, parameters=None): return _utilities.SyncResponseIterator( stream_it, lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp), - ) \ No newline at end of file + ) diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 90f5f681..067707bc 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -107,6 +107,7 @@ def _create_commit_transaction_request(session_state, tx_state): request.session_id = session_state.session_id return request + def _create_rollback_transaction_request(session_state, tx_state): request = _apis.ydb_query.RollbackTransactionRequest() request.tx_id = tx_state.tx_id @@ -132,6 +133,7 @@ def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx) tx_state._change_state(QueryTxStateEnum.COMMITTED) return tx + @base.bad_session_handler @reset_tx_id_handler def wrap_tx_rollback_response(rpc_state, response_pb, session_state, tx_state, tx): From 779950fa38bdf2c0ae40a04456e448f63f4f3907 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 16:17:23 +0300 Subject: [PATCH 19/57] codestyle fixes --- examples/query-service/basic_example.py | 2 +- tests/query/test_query_transaction.py | 1 + ydb/_grpc/grpcwrapper/ydb_query.py | 21 ++++++++------------- ydb/query/base.py | 18 +++++++++++++----- ydb/query/session.py | 18 +++++++++--------- ydb/query/transaction.py | 2 +- 6 files changed, 33 insertions(+), 29 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index d0987cd9..cc93332a 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -26,5 +26,5 @@ def main(): print(f"rows: {str(result_set.rows)}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 0cde7d47..1b0d865a 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -1,5 +1,6 @@ import pytest + class TestQueryTransaction: def test_tx_begin(self, tx): assert tx.tx_id is None diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index 0026f050..f5e16664 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -18,6 +18,7 @@ ServerStatus, ) + @dataclass class CreateSessionResponse(IFromProto): status: Optional[ServerStatus] @@ -68,13 +69,13 @@ def from_public(tx_mode: public_types.BaseQueryTxMode) -> "TransactionSettings": return TransactionSettings(tx_mode=tx_mode) def to_proto(self) -> ydb_query_pb2.TransactionSettings: - if self.tx_mode.name == 'snapshot_read_only': + if self.tx_mode.name == "snapshot_read_only": return ydb_query_pb2.TransactionSettings(snapshot_read_only=self.tx_mode.to_proto()) - if self.tx_mode.name == 'serializable_read_write': + if self.tx_mode.name == "serializable_read_write": return ydb_query_pb2.TransactionSettings(serializable_read_write=self.tx_mode.to_proto()) - if self.tx_mode.name == 'online_read_only': + if self.tx_mode.name == "online_read_only": return ydb_query_pb2.TransactionSettings(online_read_only=self.tx_mode.to_proto()) - if self.tx_mode.name == 'stale_read_only': + if self.tx_mode.name == "stale_read_only": return ydb_query_pb2.TransactionSettings(stale_read_only=self.tx_mode.to_proto()) @@ -87,7 +88,7 @@ def to_proto(self) -> ydb_query_pb2.BeginTransactionRequest: return ydb_query_pb2.BeginTransactionRequest( session_id=self.session_id, tx_settings=self.tx_settings.to_proto(), - ) + ) @dataclass @@ -146,14 +147,8 @@ class TransactionControl(IToProto): def to_proto(self) -> ydb_query_pb2.TransactionControl: if self.tx_id: - return ydb_query_pb2.TransactionControl( - tx_id=self.tx_id, - commit_tx=self.commit_tx, - ) - return ydb_query_pb2.TransactionControl( - begin_tx=self.begin_tx.to_proto(), - commit_tx=self.commit_tx - ) + return ydb_query_pb2.TransactionControl(tx_id=self.tx_id,commit_tx=self.commit_tx) + return ydb_query_pb2.TransactionControl(begin_tx=self.begin_tx.to_proto(), commit_tx=self.commit_tx) @dataclass diff --git a/ydb/query/base.py b/ydb/query/base.py index 5895db73..d3bfddae 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -11,7 +11,7 @@ from .._grpc.grpcwrapper import ydb_query from .._grpc.grpcwrapper.ydb_query_public_types import ( BaseQueryTxMode, - QuerySerializableReadWrite + QuerySerializableReadWrite, ) from .. import convert from .. import issues @@ -78,7 +78,13 @@ def transaction(self, tx_mode: BaseQueryTxMode) -> "IQueryTxContext": class IQueryTxContext(abc.ABC): @abc.abstractmethod - def __init__(self, driver: SupportedDriverType, session_state: IQuerySessionState, session: IQuerySession, tx_mode: BaseQueryTxMode = None): + def __init__( + self, + driver: SupportedDriverType, + session_state: IQuerySessionState, + session: IQuerySession, + tx_mode: BaseQueryTxMode = None + ): pass @abc.abstractmethod @@ -125,7 +131,9 @@ def session(self) -> IQuerySession: pass -def create_execute_query_request(query: str, session_id: str, tx_id: str = None, commit_tx: bool = False, tx_mode: BaseQueryTxMode = None): +def create_execute_query_request( + query: str, session_id: str, tx_id: str = None, commit_tx: bool = False, tx_mode: BaseQueryTxMode = None +): if tx_id: req = ydb_query.ExecuteQueryRequest( session_id=session_id, @@ -134,7 +142,7 @@ def create_execute_query_request(query: str, session_id: str, tx_id: str = None, ), tx_control=ydb_query.TransactionControl( tx_id=tx_id, - commit_tx=commit_tx + commit_tx=commit_tx, ), ) else: @@ -148,7 +156,7 @@ def create_execute_query_request(query: str, session_id: str, tx_id: str = None, begin_tx=ydb_query.TransactionSettings( tx_mode=tx_mode, ), - commit_tx=commit_tx + commit_tx=commit_tx, ), ) diff --git a/ydb/query/session.py b/ydb/query/session.py index 9c8ed1ce..d0ebeebc 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -32,7 +32,7 @@ class QuerySessionStateHelper(abc.ABC): } _READY_TO_USE = [ - QuerySessionStateEnum.CREATED + QuerySessionStateEnum.CREATED, ] @classmethod @@ -154,7 +154,7 @@ def _execute_call(self, query: str, commit_tx: bool): request = base.create_execute_query_request( query=query, session_id=self._state.session_id, - commit_tx=commit_tx + commit_tx=commit_tx, ) return self._driver( @@ -189,13 +189,13 @@ def _attach(self): ).start() def _chech_session_status_loop(self, status_stream): - try: - for status in status_stream: - if status.status != issues.StatusCode.SUCCESS: - self._state.reset() - self._state._change_state(QuerySessionStateEnum.CLOSED) - except Exception as e: - pass + try: + for status in status_stream: + if status.status != issues.StatusCode.SUCCESS: + self._state.reset() + self._state._change_state(QuerySessionStateEnum.CLOSED) + except Exception as e: + pass def delete(self) -> None: diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 067707bc..77ff467f 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -275,7 +275,7 @@ def _execute_call(self, query: str, commit_tx: bool): request = base.create_execute_query_request( query=query, session_id=self._session_state.session_id, - commit_tx=commit_tx + commit_tx=commit_tx, ) return self._driver( request, From 66bfe0975aba95bcd5c7def4027049574d3401d6 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 16:21:04 +0300 Subject: [PATCH 20/57] more style fixes --- ydb/_grpc/grpcwrapper/ydb_query.py | 10 ++++++++-- ydb/query/base.py | 4 ++-- ydb/query/session.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index f5e16664..0a49ac74 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -147,8 +147,14 @@ class TransactionControl(IToProto): def to_proto(self) -> ydb_query_pb2.TransactionControl: if self.tx_id: - return ydb_query_pb2.TransactionControl(tx_id=self.tx_id,commit_tx=self.commit_tx) - return ydb_query_pb2.TransactionControl(begin_tx=self.begin_tx.to_proto(), commit_tx=self.commit_tx) + return ydb_query_pb2.TransactionControl( + tx_id=self.tx_id, + commit_tx=self.commit_tx, + ) + return ydb_query_pb2.TransactionControl( + begin_tx=self.begin_tx.to_proto(), + commit_tx=self.commit_tx, + ) @dataclass diff --git a/ydb/query/base.py b/ydb/query/base.py index d3bfddae..368d15c2 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -83,7 +83,7 @@ def __init__( driver: SupportedDriverType, session_state: IQuerySessionState, session: IQuerySession, - tx_mode: BaseQueryTxMode = None + tx_mode: BaseQueryTxMode = None, ): pass @@ -132,7 +132,7 @@ def session(self) -> IQuerySession: def create_execute_query_request( - query: str, session_id: str, tx_id: str = None, commit_tx: bool = False, tx_mode: BaseQueryTxMode = None + query: str, session_id: str, tx_id: str = None, commit_tx: bool = False, tx_mode: BaseQueryTxMode = None ): if tx_id: req = ydb_query.ExecuteQueryRequest( diff --git a/ydb/query/session.py b/ydb/query/session.py index d0ebeebc..1638ab4f 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -28,7 +28,7 @@ class QuerySessionStateHelper(abc.ABC): _VALID_TRANSITIONS = { QuerySessionStateEnum.NOT_INITIALIZED: [QuerySessionStateEnum.CREATED], QuerySessionStateEnum.CREATED: [QuerySessionStateEnum.CLOSED], - QuerySessionStateEnum.CLOSED: [] + QuerySessionStateEnum.CLOSED: [], } _READY_TO_USE = [ From 1195c4f4ca103ae85beb0b4eb9ddacca9a9fc8d6 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 17 Jul 2024 16:23:18 +0300 Subject: [PATCH 21/57] please tox im tired --- ydb/query/base.py | 1 - ydb/query/session.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/ydb/query/base.py b/ydb/query/base.py index 368d15c2..dea5ff66 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -76,7 +76,6 @@ def transaction(self, tx_mode: BaseQueryTxMode) -> "IQueryTxContext": class IQueryTxContext(abc.ABC): - @abc.abstractmethod def __init__( self, diff --git a/ydb/query/session.py b/ydb/query/session.py index 1638ab4f..32c0b79e 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -194,10 +194,9 @@ def _chech_session_status_loop(self, status_stream): if status.status != issues.StatusCode.SUCCESS: self._state.reset() self._state._change_state(QuerySessionStateEnum.CLOSED) - except Exception as e: + except Exception: pass - def delete(self) -> None: if self._state._already_in(QuerySessionStateEnum.CLOSED): return From 802c8414b73447350944f869f75b56d6154ceded Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 19 Jul 2024 14:29:19 +0300 Subject: [PATCH 22/57] basic pool & retries & example --- examples/query-service/basic_example.py | 59 +++++++++-- ydb/__init__.py | 1 + ydb/_grpc/grpcwrapper/ydb_query.py | 12 +-- ydb/query/__init__.py | 25 +++++ ydb/query/base.py | 41 +++++++- ydb/query/pool.py | 132 ++++++++++++++++++++++++ ydb/query/session.py | 4 +- ydb/query/transaction.py | 60 +++++++++-- 8 files changed, 307 insertions(+), 27 deletions(-) create mode 100644 ydb/query/pool.py diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index cc93332a..d0140494 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -1,7 +1,5 @@ import ydb -from ydb.query.session import QuerySessionSync - def main(): driver_config = ydb.DriverConfig( @@ -16,14 +14,57 @@ def main(): except TimeoutError: raise RuntimeError("Connect failed to YDB") - session = QuerySessionSync(driver) - session.create() + # client = ydb.QueryClientSync(driver) + # session = client.session().create() + pool = ydb.QuerySessionPool(driver) + # with pool.checkout() as session: + def callee(session): + print("="*50) + print("BEFORE ACTION") + it = session.execute("""SELECT COUNT(*) FROM example;""") + for result_set in it: + print(f"rows: {str(result_set.rows)}") + + print("="*50) + print("INSERT WITH COMMIT TX") + tx = session.transaction() + + tx.begin() + + tx.execute("""INSERT INTO example (key, value) VALUES (0033, "onepieceisreal");""") + + for result_set in tx.execute("""SELECT COUNT(*) FROM example;"""): + print(f"rows: {str(result_set.rows)}") + + tx.commit() + + print("="*50) + print("AFTER COMMIT TX") + + for result_set in session.execute("""SELECT COUNT(*) FROM example;"""): + print(f"rows: {str(result_set.rows)}") + + print("="*50) + print("INSERT WITH ROLLBACK TX") + + tx = session.transaction() + + tx.begin() + + tx.execute("""INSERT INTO example (key, value) VALUES (0044, "onepieceisreal");""") + + for result_set in tx.execute("""SELECT COUNT(*) FROM example;"""): + print(f"rows: {str(result_set.rows)}") + + tx.rollback() + + print("="*50) + print("AFTER ROLLBACK TX") + + for result_set in session.execute("""SELECT COUNT(*) FROM example;"""): + print(f"rows: {str(result_set.rows)}") - it = session.execute("select 1; select 2;", commit_tx=False) - for result_set in it: - # pass - print(f"columns: {str(result_set.columns)}") - print(f"rows: {str(result_set.rows)}") + pool.retry_operation_sync(callee) if __name__ == "__main__": diff --git a/ydb/__init__.py b/ydb/__init__.py index 0a09834b..fc911a44 100644 --- a/ydb/__init__.py +++ b/ydb/__init__.py @@ -19,6 +19,7 @@ from .tracing import * # noqa from .topic import * # noqa from .draft import * # noqa +from .query import * # noqa try: import ydb.aio as aio # noqa diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index 0a49ac74..7196ec04 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -129,11 +129,11 @@ def from_proto(msg: ydb_query_pb2.RollbackTransactionResponse) -> "RollbackTrans @dataclass class QueryContent(IFromPublic, IToProto): text: str - syntax: Optional[str] = None + syntax: Optional[int] = None @staticmethod - def from_public(query: str) -> "QueryContent": - return QueryContent(text=query) + def from_public(query: str, syntax: int = None) -> "QueryContent": + return QueryContent(text=query, syntax=syntax) def to_proto(self) -> ydb_query_pb2.QueryContent: return ydb_query_pb2.QueryContent(text=self.text, syntax=self.syntax) @@ -163,16 +163,16 @@ class ExecuteQueryRequest(IToProto): query_content: QueryContent tx_control: TransactionControl concurrent_result_sets: Optional[bool] = False - exec_mode: Optional[str] = None + exec_mode: Optional[int] = None parameters: Optional[dict] = None - stats_mode: Optional[str] = None + stats_mode: Optional[int] = None def to_proto(self) -> ydb_query_pb2.ExecuteQueryRequest: return ydb_query_pb2.ExecuteQueryRequest( session_id=self.session_id, tx_control=self.tx_control.to_proto(), query_content=self.query_content.to_proto(), - exec_mode=ydb_query_pb2.EXEC_MODE_EXECUTE, + exec_mode=self.exec_mode, stats_mode=self.stats_mode, concurrent_result_sets=self.concurrent_result_sets, parameters=self.parameters, diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py index e69de29b..e7e33d66 100644 --- a/ydb/query/__init__.py +++ b/ydb/query/__init__.py @@ -0,0 +1,25 @@ +from .base import ( + IQueryClient, + SupportedDriverType, + QueryClientSettings, +) + +from .session import QuerySessionSync + +from .._grpc.grpcwrapper.ydb_query_public_types import ( + QueryOnlineReadOnly, + QuerySerializableReadWrite, + QuerySnapshotReadOnly, + QueryStaleReadOnly, +) + +from .pool import QuerySessionPool + + +class QueryClientSync(IQueryClient): + def __init__(self, driver: SupportedDriverType, query_client_settings: QueryClientSettings = None): + self._driver = driver + self._settings = query_client_settings + + def session(self) -> QuerySessionSync: + return QuerySessionSync(self._driver, self._settings) diff --git a/ydb/query/base.py b/ydb/query/base.py index dea5ff66..e19b1826 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -1,4 +1,5 @@ import abc +import enum import functools from typing import ( @@ -63,7 +64,7 @@ def __init__(self, driver: SupportedDriverType, settings: QueryClientSettings = pass @abc.abstractmethod - def create(self) -> None: + def create(self) -> "IQuerySession": pass @abc.abstractmethod @@ -117,7 +118,7 @@ def rollback(settings: QueryClientSettings = None): pass @abc.abstractmethod - def execute(query: str): + def execute(query: str, commit_tx=False): pass @@ -130,19 +131,48 @@ def session(self) -> IQuerySession: pass +class QuerySyntax(enum.IntEnum): + UNSPECIFIED = 0 + YQL_V1 = 1 + PG = 2 + + +class QueryExecMode(enum.IntEnum): + UNSPECIFIED = 0 + PARSE = 10 + VALIDATE = 20 + EXPLAIN = 30 + EXECUTE = 50 + + def create_execute_query_request( - query: str, session_id: str, tx_id: str = None, commit_tx: bool = False, tx_mode: BaseQueryTxMode = None + query: str, + session_id: str, + tx_id: str = None, + commit_tx: bool = False, + tx_mode: BaseQueryTxMode = None, + syntax: QuerySyntax = None, + exec_mode: QueryExecMode = None, + parameters: dict = None, + concurrent_result_sets: bool = False, + ): + syntax = QuerySyntax.YQL_V1 if not syntax else syntax + exec_mode = QueryExecMode.EXECUTE if not exec_mode else exec_mode if tx_id: req = ydb_query.ExecuteQueryRequest( session_id=session_id, query_content=ydb_query.QueryContent.from_public( query=query, + syntax=syntax, ), tx_control=ydb_query.TransactionControl( tx_id=tx_id, commit_tx=commit_tx, ), + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, ) else: tx_mode = tx_mode if tx_mode is not None else QuerySerializableReadWrite() @@ -150,6 +180,7 @@ def create_execute_query_request( session_id=session_id, query_content=ydb_query.QueryContent.from_public( query=query, + syntax=syntax, ), tx_control=ydb_query.TransactionControl( begin_tx=ydb_query.TransactionSettings( @@ -157,12 +188,16 @@ def create_execute_query_request( ), commit_tx=commit_tx, ), + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, ) return req.to_proto() def wrap_execute_query_response(rpc_state, response_pb): + issues._process_response(response_pb) return convert.ResultSet.from_message(response_pb.result_set) diff --git a/ydb/query/pool.py b/ydb/query/pool.py new file mode 100644 index 00000000..229158ff --- /dev/null +++ b/ydb/query/pool.py @@ -0,0 +1,132 @@ +import abc +import time +from typing import Callable + +from . import base +from .session import ( + QuerySessionSync, + BaseQuerySession, +) +from .. import issues +from .._errors import check_retriable_error + + +class RetrySettings(object): + def __init__( + self, + max_retries: int = 10, + max_session_acquire_timeout: int = None, + on_ydb_error_callback: Callable = None, + idempotent: bool = False, + ): + self.max_retries = max_retries + self.max_session_acquire_timeout = max_session_acquire_timeout + self.on_ydb_error_callback = (lambda e: None) if on_ydb_error_callback is None else on_ydb_error_callback + self.retry_not_found = True + self.idempotent = idempotent + self.retry_internal_error = True + self.unknown_error_handler = lambda e: None + + +class YdbRetryOperationSleepOpt: + def __init__(self, timeout): + self.timeout = timeout + + def __eq__(self, other): + return type(self) == type(other) and self.timeout == other.timeout + + def __repr__(self): + return "YdbRetryOperationSleepOpt(%s)" % self.timeout + + +class YdbRetryOperationFinalResult: + def __init__(self, result): + self.result = result + self.exc = None + + def __eq__(self, other): + return type(self) == type(other) and self.result == other.result and self.exc == other.exc + + def __repr__(self): + return "YdbRetryOperationFinalResult(%s, exc=%s)" % (self.result, self.exc) + + def set_exception(self, exc): + self.exc = exc + + +def retry_operation_impl(callee: Callable, retry_settings: RetrySettings = None, *args, **kwargs): + retry_settings = RetrySettings() if retry_settings is None else retry_settings + status = None + + for attempt in range(retry_settings.max_retries + 1): + try: + result = YdbRetryOperationFinalResult(callee(*args, **kwargs)) + yield result + + if result.exc is not None: + raise result.exc + + except issues.Error as e: + status = e + retry_settings.on_ydb_error_callback(e) + + retriable_info = check_retriable_error(e, retry_settings, attempt) + if not retriable_info.is_retriable: + raise + + skip_yield_error_types = [ + issues.Aborted, + issues.BadSession, + issues.NotFound, + issues.InternalError, + ] + + yield_sleep = True + for t in skip_yield_error_types: + if isinstance(e, t): + yield_sleep = False + + if yield_sleep: + yield YdbRetryOperationSleepOpt(retriable_info.sleep_timeout_seconds) + + except Exception as e: + # you should provide your own handler you want + retry_settings.unknown_error_handler(e) + raise + + raise status + + +class QuerySessionPool: + def __init__(self, driver: base.SupportedDriverType): + self._driver = driver + + def checkout(self): + return SimpleQuerySessionCheckout(self) + + def retry_operation_sync(self, callee: Callable, retry_settings: RetrySettings = None, *args, **kwargs): + retry_settings = RetrySettings() if retry_settings is None else retry_settings + + def wrapped_callee(): + with self.checkout() as session: + return callee(session, *args, **kwargs) + + opt_generator = retry_operation_impl(wrapped_callee, retry_settings, *args, **kwargs) + for next_opt in opt_generator: + if isinstance(next_opt, YdbRetryOperationSleepOpt): + time.sleep(next_opt.timeout) + else: + return next_opt.result + + +class SimpleQuerySessionCheckout: + def __init__(self, pool: QuerySessionPool): + self._pool = pool + self._session = QuerySessionSync(pool._driver) + + def __enter__(self): + self._session.create() + return self._session + + def __exit__(self, exc_type, exc_val, exc_tb): + self._session.delete() diff --git a/ydb/query/session.py b/ydb/query/session.py index 32c0b79e..4926b0bb 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -205,7 +205,7 @@ def delete(self) -> None: self._delete_call() self._stream.cancel() - def create(self) -> None: + def create(self) -> "QuerySessionSync": if self._state._already_in(QuerySessionStateEnum.CREATED): return @@ -213,6 +213,8 @@ def create(self) -> None: self._create_call() self._attach() + return self + def transaction(self, tx_mode: base.BaseQueryTxMode = None) -> base.IQueryTxContext: self._state._check_session_ready_to_use() diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 77ff467f..df71d62c 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -97,7 +97,6 @@ def _create_begin_transaction_request(session_state, tx_state): session_id=session_state.session_id, tx_settings=_construct_tx_settings(tx_state), ).to_proto() - return request @@ -127,7 +126,7 @@ def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx): @base.bad_session_handler @reset_tx_id_handler def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx): - message = _ydb_query.CommitTransactionResponse(response_pb) + message = _ydb_query.CommitTransactionResponse.from_proto(response_pb) issues._process_response(message.status) tx_state.tx_id = None tx_state._change_state(QueryTxStateEnum.COMMITTED) @@ -137,7 +136,7 @@ def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx) @base.bad_session_handler @reset_tx_id_handler def wrap_tx_rollback_response(rpc_state, response_pb, session_state, tx_state, tx): - message = _ydb_query.RollbackTransactionResponse(response_pb) + message = _ydb_query.RollbackTransactionResponse.from_proto(response_pb) issues._process_response(message.status) tx_state.tx_id = None tx_state._change_state(QueryTxStateEnum.ROLLBACKED) @@ -170,6 +169,7 @@ def __init__(self, driver, session_state, session, tx_mode=None): self._session_state = session_state self.session = session self._finished = "" + self._prev_stream = None def __enter__(self): """ @@ -247,6 +247,8 @@ def commit(self, settings=None): return self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) + self._ensure_prev_stream_finished() + return self._driver( _create_commit_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -262,6 +264,8 @@ def rollback(self, settings=None): self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) + self._ensure_prev_stream_finished() + return self._driver( _create_rollback_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -271,24 +275,64 @@ def rollback(self, settings=None): (self._session_state, self._tx_state, self), ) - def _execute_call(self, query: str, commit_tx: bool): + def _execute_call( + self, + query: str, + commit_tx: bool = False, + tx_mode: base.BaseQueryTxMode = None, + syntax: base.QuerySyntax = None, + exec_mode: base.QueryExecMode = None, + parameters: dict = None, + concurrent_result_sets: bool = False, + ): request = base.create_execute_query_request( query=query, session_id=self._session_state.session_id, commit_tx=commit_tx, + tx_id=self._tx_state.tx_id, + tx_mode=tx_mode, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, ) + return self._driver( request, _apis.QueryService.Stub, _apis.QueryService.ExecuteQuery, ) - def execute(self, query, parameters=None, commit_tx=False, settings=None): + def _ensure_prev_stream_finished(self): + if self._prev_stream is not None: + for _ in self._prev_stream: + pass + self._prev_stream = None + + def execute( + self, + query: str, + commit_tx: bool = False, + tx_mode: base.BaseQueryTxMode = None, + syntax: base.QuerySyntax = None, + exec_mode: base.QueryExecMode = None, + parameters: dict = None, + concurrent_result_sets: bool = False, + ): self._tx_state._check_tx_not_terminal() + self._ensure_prev_stream_finished() - stream_it = self._execute_call(query, commit_tx) - - return _utilities.SyncResponseIterator( + stream_it = self._execute_call( + query=query, + commit_tx=commit_tx, + tx_mode=tx_mode, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + ) + self._prev_stream = _utilities.SyncResponseIterator( stream_it, lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp), ) + return self._prev_stream From 16c1c0656fe2b422755115064a696980aa3ccd90 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 19 Jul 2024 14:32:23 +0300 Subject: [PATCH 23/57] style fixes --- examples/query-service/basic_example.py | 10 +++++----- ydb/__init__.py | 2 +- ydb/query/base.py | 1 - 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index d0140494..9947a3ff 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -19,13 +19,13 @@ def main(): pool = ydb.QuerySessionPool(driver) # with pool.checkout() as session: def callee(session): - print("="*50) + print("=" * 50) print("BEFORE ACTION") it = session.execute("""SELECT COUNT(*) FROM example;""") for result_set in it: print(f"rows: {str(result_set.rows)}") - print("="*50) + print("=" * 50) print("INSERT WITH COMMIT TX") tx = session.transaction() @@ -38,13 +38,13 @@ def callee(session): tx.commit() - print("="*50) + print("=" * 50) print("AFTER COMMIT TX") for result_set in session.execute("""SELECT COUNT(*) FROM example;"""): print(f"rows: {str(result_set.rows)}") - print("="*50) + print("=" * 50) print("INSERT WITH ROLLBACK TX") tx = session.transaction() @@ -58,7 +58,7 @@ def callee(session): tx.rollback() - print("="*50) + print("=" * 50) print("AFTER ROLLBACK TX") for result_set in session.execute("""SELECT COUNT(*) FROM example;"""): diff --git a/ydb/__init__.py b/ydb/__init__.py index fc911a44..1caaaa02 100644 --- a/ydb/__init__.py +++ b/ydb/__init__.py @@ -19,7 +19,7 @@ from .tracing import * # noqa from .topic import * # noqa from .draft import * # noqa -from .query import * # noqa +from .query import * # noqa try: import ydb.aio as aio # noqa diff --git a/ydb/query/base.py b/ydb/query/base.py index e19b1826..97a8d56c 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -155,7 +155,6 @@ def create_execute_query_request( exec_mode: QueryExecMode = None, parameters: dict = None, concurrent_result_sets: bool = False, - ): syntax = QuerySyntax.YQL_V1 if not syntax else syntax exec_mode = QueryExecMode.EXECUTE if not exec_mode else exec_mode From a11e42af16ac8776f60e4311d981be488c78d3ee Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 19 Jul 2024 14:36:50 +0300 Subject: [PATCH 24/57] style fixes --- examples/query-service/basic_example.py | 1 + ydb/query/__init__.py | 10 +++++----- ydb/query/pool.py | 2 -- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index 9947a3ff..1ff264f1 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -18,6 +18,7 @@ def main(): # session = client.session().create() pool = ydb.QuerySessionPool(driver) # with pool.checkout() as session: + def callee(session): print("=" * 50) print("BEFORE ACTION") diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py index e7e33d66..11660a33 100644 --- a/ydb/query/__init__.py +++ b/ydb/query/__init__.py @@ -7,13 +7,13 @@ from .session import QuerySessionSync from .._grpc.grpcwrapper.ydb_query_public_types import ( - QueryOnlineReadOnly, - QuerySerializableReadWrite, - QuerySnapshotReadOnly, - QueryStaleReadOnly, + QueryOnlineReadOnly, # noqa + QuerySerializableReadWrite, # noqa + QuerySnapshotReadOnly, # noqa + QueryStaleReadOnly, # noqa ) -from .pool import QuerySessionPool +from .pool import QuerySessionPool # noqa class QueryClientSync(IQueryClient): diff --git a/ydb/query/pool.py b/ydb/query/pool.py index 229158ff..010b5faa 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -1,11 +1,9 @@ -import abc import time from typing import Callable from . import base from .session import ( QuerySessionSync, - BaseQuerySession, ) from .. import issues from .._errors import check_retriable_error From f27658f7fcc12f56416d247ad3ee149c53ff2093 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 19 Jul 2024 14:38:29 +0300 Subject: [PATCH 25/57] style fixes --- examples/query-service/basic_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index 1ff264f1..8b0ea193 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -18,7 +18,7 @@ def main(): # session = client.session().create() pool = ydb.QuerySessionPool(driver) # with pool.checkout() as session: - + def callee(session): print("=" * 50) print("BEFORE ACTION") From e580dc242f53cd8ff38937b276e8f0909510a206 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 19 Jul 2024 14:47:10 +0300 Subject: [PATCH 26/57] add dunder all to query module --- examples/query-service/basic_example.py | 4 ++-- ydb/query/__init__.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index 8b0ea193..eb21fa99 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -32,7 +32,7 @@ def callee(session): tx.begin() - tx.execute("""INSERT INTO example (key, value) VALUES (0033, "onepieceisreal");""") + tx.execute("""INSERT INTO example (key, value) VALUES (0055, "onepieceisreal");""") for result_set in tx.execute("""SELECT COUNT(*) FROM example;"""): print(f"rows: {str(result_set.rows)}") @@ -52,7 +52,7 @@ def callee(session): tx.begin() - tx.execute("""INSERT INTO example (key, value) VALUES (0044, "onepieceisreal");""") + tx.execute("""INSERT INTO example (key, value) VALUES (0066, "onepieceisreal");""") for result_set in tx.execute("""SELECT COUNT(*) FROM example;"""): print(f"rows: {str(result_set.rows)}") diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py index 11660a33..3b504b0f 100644 --- a/ydb/query/__init__.py +++ b/ydb/query/__init__.py @@ -1,3 +1,12 @@ +__all__ = [ + "QueryOnlineReadOnly", + "QuerySerializableReadWrite", + "QuerySnapshotReadOnly", + "QueryStaleReadOnly", + "QuerySessionPool", + "QueryClientSync" +] + from .base import ( IQueryClient, SupportedDriverType, @@ -7,13 +16,13 @@ from .session import QuerySessionSync from .._grpc.grpcwrapper.ydb_query_public_types import ( - QueryOnlineReadOnly, # noqa - QuerySerializableReadWrite, # noqa - QuerySnapshotReadOnly, # noqa - QueryStaleReadOnly, # noqa + QueryOnlineReadOnly, + QuerySerializableReadWrite, + QuerySnapshotReadOnly, + QueryStaleReadOnly, ) -from .pool import QuerySessionPool # noqa +from .pool import QuerySessionPool class QueryClientSync(IQueryClient): From 424f8319419d439032439ee8ad2828b996c5572e Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 19 Jul 2024 14:48:49 +0300 Subject: [PATCH 27/57] style fix --- ydb/query/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py index 3b504b0f..fbdfa7dd 100644 --- a/ydb/query/__init__.py +++ b/ydb/query/__init__.py @@ -4,7 +4,7 @@ "QuerySnapshotReadOnly", "QueryStaleReadOnly", "QuerySessionPool", - "QueryClientSync" + "QueryClientSync", ] from .base import ( From c90ccf48643e4925d3dd10c41c4923005eadea17 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 22 Jul 2024 15:56:08 +0300 Subject: [PATCH 28/57] ability to execute ddl queries --- examples/query-service/basic_example.py | 23 +++++++---- ydb/_grpc/grpcwrapper/ydb_query.py | 5 ++- ydb/query/base.py | 53 ++++++++++++------------- ydb/query/pool.py | 18 ++++++++- ydb/query/session.py | 40 +++++++++++++++++-- 5 files changed, 96 insertions(+), 43 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index eb21fa99..e589eba6 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -14,15 +14,22 @@ def main(): except TimeoutError: raise RuntimeError("Connect failed to YDB") - # client = ydb.QueryClientSync(driver) - # session = client.session().create() pool = ydb.QuerySessionPool(driver) - # with pool.checkout() as session: + + print("=" * 50) + print("DELETE TABLE IF EXISTS") + pool.execute_with_retries("drop table if exists example;", ddl=True) + + print("=" * 50) + print("CREATE TABLE") + pool.execute_with_retries("CREATE TABLE example(key UInt64, value String, PRIMARY KEY (key));", ddl=True) def callee(session): print("=" * 50) + session.execute("""delete from example;""") + print("BEFORE ACTION") - it = session.execute("""SELECT COUNT(*) FROM example;""") + it = session.execute("""SELECT COUNT(*) as rows_count FROM example;""") for result_set in it: print(f"rows: {str(result_set.rows)}") @@ -34,7 +41,7 @@ def callee(session): tx.execute("""INSERT INTO example (key, value) VALUES (0055, "onepieceisreal");""") - for result_set in tx.execute("""SELECT COUNT(*) FROM example;"""): + for result_set in tx.execute("""SELECT COUNT(*) as rows_count FROM example;"""): print(f"rows: {str(result_set.rows)}") tx.commit() @@ -42,7 +49,7 @@ def callee(session): print("=" * 50) print("AFTER COMMIT TX") - for result_set in session.execute("""SELECT COUNT(*) FROM example;"""): + for result_set in session.execute("""SELECT COUNT(*) as rows_count FROM example;"""): print(f"rows: {str(result_set.rows)}") print("=" * 50) @@ -54,7 +61,7 @@ def callee(session): tx.execute("""INSERT INTO example (key, value) VALUES (0066, "onepieceisreal");""") - for result_set in tx.execute("""SELECT COUNT(*) FROM example;"""): + for result_set in tx.execute("""SELECT COUNT(*) as rows_count FROM example;"""): print(f"rows: {str(result_set.rows)}") tx.rollback() @@ -62,7 +69,7 @@ def callee(session): print("=" * 50) print("AFTER ROLLBACK TX") - for result_set in session.execute("""SELECT COUNT(*) FROM example;"""): + for result_set in session.execute("""SELECT COUNT(*) as rows_count FROM example;"""): print(f"rows: {str(result_set.rows)}") pool.retry_operation_sync(callee) diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index 7196ec04..343ef8b8 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -161,16 +161,17 @@ def to_proto(self) -> ydb_query_pb2.TransactionControl: class ExecuteQueryRequest(IToProto): session_id: str query_content: QueryContent - tx_control: TransactionControl + tx_control: Optional[TransactionControl] = None concurrent_result_sets: Optional[bool] = False exec_mode: Optional[int] = None parameters: Optional[dict] = None stats_mode: Optional[int] = None def to_proto(self) -> ydb_query_pb2.ExecuteQueryRequest: + tx_control = self.tx_control.to_proto() if self.tx_control is not None else self.tx_control return ydb_query_pb2.ExecuteQueryRequest( session_id=self.session_id, - tx_control=self.tx_control.to_proto(), + tx_control=tx_control, query_content=self.query_content.to_proto(), exec_mode=self.exec_mode, stats_mode=self.stats_mode, diff --git a/ydb/query/base.py b/ydb/query/base.py index 97a8d56c..98a02a42 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -155,43 +155,40 @@ def create_execute_query_request( exec_mode: QueryExecMode = None, parameters: dict = None, concurrent_result_sets: bool = False, + empty_tx_control: bool = False, ): syntax = QuerySyntax.YQL_V1 if not syntax else syntax exec_mode = QueryExecMode.EXECUTE if not exec_mode else exec_mode - if tx_id: - req = ydb_query.ExecuteQueryRequest( - session_id=session_id, - query_content=ydb_query.QueryContent.from_public( - query=query, - syntax=syntax, - ), - tx_control=ydb_query.TransactionControl( - tx_id=tx_id, - commit_tx=commit_tx, - ), - exec_mode=exec_mode, - parameters=parameters, - concurrent_result_sets=concurrent_result_sets, + + tx_control = None + if empty_tx_control: + tx_control = None + elif tx_id: + tx_control = ydb_query.TransactionControl( + tx_id=tx_id, + commit_tx=commit_tx, ) else: tx_mode = tx_mode if tx_mode is not None else QuerySerializableReadWrite() - req = ydb_query.ExecuteQueryRequest( - session_id=session_id, - query_content=ydb_query.QueryContent.from_public( - query=query, - syntax=syntax, + tx_control = ydb_query.TransactionControl( + begin_tx=ydb_query.TransactionSettings( + tx_mode=tx_mode, ), - tx_control=ydb_query.TransactionControl( - begin_tx=ydb_query.TransactionSettings( - tx_mode=tx_mode, - ), - commit_tx=commit_tx, - ), - exec_mode=exec_mode, - parameters=parameters, - concurrent_result_sets=concurrent_result_sets, + commit_tx=commit_tx, ) + req = ydb_query.ExecuteQueryRequest( + session_id=session_id, + query_content=ydb_query.QueryContent.from_public( + query=query, + syntax=syntax, + ), + tx_control=tx_control, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + ) + return req.to_proto() diff --git a/ydb/query/pool.py b/ydb/query/pool.py index 010b5faa..d227a0b0 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -91,8 +91,8 @@ def retry_operation_impl(callee: Callable, retry_settings: RetrySettings = None, # you should provide your own handler you want retry_settings.unknown_error_handler(e) raise - - raise status + if status: + raise status class QuerySessionPool: @@ -116,6 +116,20 @@ def wrapped_callee(): else: return next_opt.result + def execute_with_retries(self, query: str, ddl: bool = False, retry_settings: RetrySettings = None, *args, **kwargs): + retry_settings = RetrySettings() if retry_settings is None else retry_settings + with self.checkout() as session: + def wrapped_callee(): + it = session.execute(query, empty_tx_control=ddl) + return [result_set for result_set in it] + + opt_generator = retry_operation_impl(wrapped_callee, retry_settings, *args, **kwargs) + for next_opt in opt_generator: + if isinstance(next_opt, YdbRetryOperationSleepOpt): + time.sleep(next_opt.timeout) + else: + return next_opt.result + class SimpleQuerySessionCheckout: def __init__(self, pool: QuerySessionPool): diff --git a/ydb/query/session.py b/ydb/query/session.py index 4926b0bb..25d5658f 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -150,11 +150,27 @@ def _attach_call(self): _apis.QueryService.AttachSession, ) - def _execute_call(self, query: str, commit_tx: bool): + def _execute_call( + self, + query: str, + commit_tx: bool = False, + tx_mode: base.BaseQueryTxMode = None, + syntax: base.QuerySyntax = None, + exec_mode: base.QueryExecMode = None, + parameters: dict = None, + concurrent_result_sets: bool = False, + empty_tx_control: bool = False, + ): request = base.create_execute_query_request( query=query, session_id=self._state.session_id, commit_tx=commit_tx, + tx_mode=tx_mode, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + empty_tx_control=empty_tx_control, ) return self._driver( @@ -225,10 +241,28 @@ def transaction(self, tx_mode: base.BaseQueryTxMode = None) -> base.IQueryTxCont tx_mode, ) - def execute(self, query: str, parameters=None): + def execute( + self, + query: str, + tx_mode: base.BaseQueryTxMode = None, + syntax: base.QuerySyntax = None, + exec_mode: base.QueryExecMode = None, + parameters: dict = None, + concurrent_result_sets: bool = False, + empty_tx_control: bool = False + ): self._state._check_session_ready_to_use() - stream_it = self._execute_call(query, commit_tx=True) + stream_it = self._execute_call( + query=query, + commit_tx=True, + tx_mode=tx_mode, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + empty_tx_control=empty_tx_control, + ) return _utilities.SyncResponseIterator( stream_it, From 6a740e4eb8ab9340a53bbb03fcc2d30bd127ccab Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 22 Jul 2024 15:57:56 +0300 Subject: [PATCH 29/57] Revert "Refactor wrappers to split UnaryStream and StreamStream wrappers" This reverts commit 0ffe5453f9d2c92ab9134672e05d4bdd35aa5b65. --- ydb/_grpc/grpcwrapper/common_utils.py | 86 ++++++----------------- ydb/_topic_common/common_test.py | 8 +-- ydb/_topic_reader/topic_reader_asyncio.py | 4 +- ydb/_topic_writer/topic_writer_asyncio.py | 4 +- 4 files changed, 31 insertions(+), 71 deletions(-) diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 895d4036..a7febd5b 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -145,7 +145,7 @@ def close(self): SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver] -class AbstractGrpcWrapperAsyncIO(IGrpcWrapperAsyncIO, abc.ABC): +class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): from_client_grpc: asyncio.Queue from_server_grpc: AsyncIterator convert_server_grpc_to_wrapper: Callable[[Any], Any] @@ -163,6 +163,13 @@ def __init__(self, convert_server_grpc_to_wrapper): def __del__(self): self._clean_executor(wait=False) + async def start(self, driver: SupportedDriverType, stub, method): + if asyncio.iscoroutinefunction(driver.__call__): + await self._start_asyncio_driver(driver, stub, method) + else: + await self._start_sync_driver(driver, stub, method) + self._connection_state = "started" + def close(self): self.from_client_grpc.put_nowait(_stop_grpc_connection_marker) if self._stream_call: @@ -174,35 +181,6 @@ def _clean_executor(self, wait: bool): if self._wait_executor: self._wait_executor.shutdown(wait) - async def receive(self) -> Any: - # todo handle grpc exceptions and convert it to internal exceptions - try: - grpc_message = await self.from_server_grpc.__anext__() - except (grpc.RpcError, grpc.aio.AioRpcError) as e: - raise connection._rpc_error_handler(self._connection_state, e) - - issues._process_response(grpc_message) - - if self._connection_state != "has_received_messages": - self._connection_state = "has_received_messages" - - # print("rekby, grpc, received", grpc_message) - return self.convert_server_grpc_to_wrapper(grpc_message) - - def write(self, wrap_message: IToProto): - grpc_message = wrap_message.to_proto() - # print("rekby, grpc, send", grpc_message) - self.from_client_grpc.put_nowait(grpc_message) - - -class GrpcWrapperStreamStreamAsyncIO(AbstractGrpcWrapperAsyncIO): - async def start(self, driver: SupportedDriverType, stub, method): - if asyncio.iscoroutinefunction(driver.__call__): - await self._start_asyncio_driver(driver, stub, method) - else: - await self._start_sync_driver(driver, stub, method) - self._connection_state = "started" - async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) stream_call = await driver( @@ -221,30 +199,25 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method): self._stream_call = stream_call self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor) + async def receive(self) -> Any: + # todo handle grpc exceptions and convert it to internal exceptions + try: + grpc_message = await self.from_server_grpc.__anext__() + except (grpc.RpcError, grpc.aio.AioRpcError) as e: + raise connection._rpc_error_handler(self._connection_state, e) -class GrpcWrapperUnaryStreamAsyncIO(AbstractGrpcWrapperAsyncIO): - async def start(self, driver: SupportedDriverType, request, stub, method): - if asyncio.iscoroutinefunction(driver.__call__): - await self._start_asyncio_driver(driver, request, stub, method) - else: - await self._start_sync_driver(driver, request, stub, method) - self._connection_state = "started" + issues._process_response(grpc_message) - async def _start_asyncio_driver(self, driver: ydb.aio.Driver, request, stub, method): - stream_call = await driver( - request, - stub, - method, - ) - self._stream_call = stream_call - self.from_server_grpc = stream_call.__aiter__() + if self._connection_state != "has_received_messages": + self._connection_state = "has_received_messages" - async def _start_sync_driver(self, driver: ydb.Driver, request, stub, method): - self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # print("rekby, grpc, received", grpc_message) + return self.convert_server_grpc_to_wrapper(grpc_message) - stream_call = await to_thread(driver, request, stub, method, executor=self._wait_executor) - self._stream_call = stream_call - self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor) + def write(self, wrap_message: IToProto): + grpc_message = wrap_message.to_proto() + # print("rekby, grpc, send", grpc_message) + self.from_client_grpc.put_nowait(grpc_message) @dataclass(init=False) @@ -283,19 +256,6 @@ def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage): return res -ResultType = typing.TypeVar("ResultType", bound=IFromProtoWithProtoType) - - -def create_result_wrapper( - result_type: typing.Type[ResultType], -) -> typing.Callable[[typing.Any, typing.Any, typing.Any], ResultType]: - def wrapper(rpc_state, response_pb, driver=None): - # issues._process_response(response_pb.operation) - return result_type.from_proto(response_pb) - - return wrapper - - def callback_from_asyncio(callback: Union[Callable, Coroutine]) -> [asyncio.Future, asyncio.Task]: loop = asyncio.get_running_loop() diff --git a/ydb/_topic_common/common_test.py b/ydb/_topic_common/common_test.py index 1dadaa04..b31f9af9 100644 --- a/ydb/_topic_common/common_test.py +++ b/ydb/_topic_common/common_test.py @@ -8,7 +8,7 @@ from .common import CallFromSyncToAsync from .._grpc.grpcwrapper.common_utils import ( - GrpcWrapperStreamStreamAsyncIO, + GrpcWrapperAsyncIO, ServerStatus, callback_from_asyncio, ) @@ -77,7 +77,7 @@ async def async_failed(): @pytest.mark.asyncio -class TestGrpcWrapperStreamStreamAsyncIO: +class TestGrpcWrapperAsyncIO: async def test_convert_grpc_errors_to_ydb(self): class TestError(grpc.RpcError, grpc.Call): def __init__(self): @@ -93,7 +93,7 @@ class FromServerMock: async def __anext__(self): raise TestError() - wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None) + wrapper = GrpcWrapperAsyncIO(lambda: None) wrapper.from_server_grpc = FromServerMock() with pytest.raises(issues.Unauthenticated): @@ -107,7 +107,7 @@ async def __anext__(self): issues=[], ) - wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None) + wrapper = GrpcWrapperAsyncIO(lambda: None) wrapper.from_server_grpc = FromServerMock() with pytest.raises(issues.Overloaded): diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 8cc48a1d..81c6d9f4 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -18,7 +18,7 @@ from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, - GrpcWrapperStreamStreamAsyncIO, + GrpcWrapperAsyncIO, ) from .._grpc.grpcwrapper.ydb_topic import ( StreamReadMessage, @@ -308,7 +308,7 @@ async def create( driver: SupportedDriverType, settings: topic_reader.PublicReaderSettings, ) -> "ReaderStream": - stream = GrpcWrapperStreamStreamAsyncIO(StreamReadMessage.FromServer.from_proto) + stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto) await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 064f19ce..007c8a54 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -39,7 +39,7 @@ from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, - GrpcWrapperStreamStreamAsyncIO, + GrpcWrapperAsyncIO, ) logger = logging.getLogger(__name__) @@ -613,7 +613,7 @@ async def create( init_request: StreamWriteMessage.InitRequest, update_token_interval: Optional[Union[int, float]] = None, ) -> "WriterAsyncIOStream": - stream = GrpcWrapperStreamStreamAsyncIO(StreamWriteMessage.FromServer.from_proto) + stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) From 825c4626ef855167df2a91f9177656e352456bc9 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 22 Jul 2024 16:04:45 +0300 Subject: [PATCH 30/57] new details to example --- examples/query-service/basic_example.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index e589eba6..a43bfa9a 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -24,9 +24,13 @@ def main(): print("CREATE TABLE") pool.execute_with_retries("CREATE TABLE example(key UInt64, value String, PRIMARY KEY (key));", ddl=True) + pool.execute_with_retries("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal');") + + def callee(session): print("=" * 50) - session.execute("""delete from example;""") + for _ in session.execute("""delete from example;"""): + pass print("BEFORE ACTION") it = session.execute("""SELECT COUNT(*) as rows_count FROM example;""") @@ -39,7 +43,7 @@ def callee(session): tx.begin() - tx.execute("""INSERT INTO example (key, value) VALUES (0055, "onepieceisreal");""") + tx.execute("""INSERT INTO example (key, value) VALUES (1, "onepieceisreal");""") for result_set in tx.execute("""SELECT COUNT(*) as rows_count FROM example;"""): print(f"rows: {str(result_set.rows)}") @@ -59,7 +63,7 @@ def callee(session): tx.begin() - tx.execute("""INSERT INTO example (key, value) VALUES (0066, "onepieceisreal");""") + tx.execute("""INSERT INTO example (key, value) VALUES (2, "onepieceisreal");""") for result_set in tx.execute("""SELECT COUNT(*) as rows_count FROM example;"""): print(f"rows: {str(result_set.rows)}") From 49cdce00fb94ef991881cd2a6f3972a37bb4fb5c Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 22 Jul 2024 16:08:42 +0300 Subject: [PATCH 31/57] style fixes --- ydb/query/pool.py | 5 ++++- ydb/query/session.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ydb/query/pool.py b/ydb/query/pool.py index d227a0b0..6c2fc60b 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -116,9 +116,12 @@ def wrapped_callee(): else: return next_opt.result - def execute_with_retries(self, query: str, ddl: bool = False, retry_settings: RetrySettings = None, *args, **kwargs): + def execute_with_retries( + self, query: str, ddl: bool = False, retry_settings: RetrySettings = None, *args, **kwargs + ): retry_settings = RetrySettings() if retry_settings is None else retry_settings with self.checkout() as session: + def wrapped_callee(): it = session.execute(query, empty_tx_control=ddl) return [result_set for result_set in it] diff --git a/ydb/query/session.py b/ydb/query/session.py index 25d5658f..8288aa57 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -249,7 +249,7 @@ def execute( exec_mode: base.QueryExecMode = None, parameters: dict = None, concurrent_result_sets: bool = False, - empty_tx_control: bool = False + empty_tx_control: bool = False, ): self._state._check_session_ready_to_use() From daea8a91192a6b58af8fc12505f0558a7454604b Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 22 Jul 2024 16:15:17 +0300 Subject: [PATCH 32/57] style fixes --- examples/query-service/basic_example.py | 1 - ydb/query/pool.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index a43bfa9a..47a8a0d9 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -26,7 +26,6 @@ def main(): pool.execute_with_retries("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal');") - def callee(session): print("=" * 50) for _ in session.execute("""delete from example;"""): diff --git a/ydb/query/pool.py b/ydb/query/pool.py index 6c2fc60b..bc3007d4 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -117,7 +117,7 @@ def wrapped_callee(): return next_opt.result def execute_with_retries( - self, query: str, ddl: bool = False, retry_settings: RetrySettings = None, *args, **kwargs + self, query: str, ddl: bool = False, retry_settings: RetrySettings = None, *args, **kwargs ): retry_settings = RetrySettings() if retry_settings is None else retry_settings with self.checkout() as session: From 89846fd82cc442ed85baa427c68e5bbece6cfa29 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 22 Jul 2024 16:30:28 +0300 Subject: [PATCH 33/57] omit flag in client side --- examples/query-service/basic_example.py | 4 ++-- ydb/query/pool.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index 47a8a0d9..f5c91d4a 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -18,11 +18,11 @@ def main(): print("=" * 50) print("DELETE TABLE IF EXISTS") - pool.execute_with_retries("drop table if exists example;", ddl=True) + pool.execute_with_retries("drop table if exists example;") print("=" * 50) print("CREATE TABLE") - pool.execute_with_retries("CREATE TABLE example(key UInt64, value String, PRIMARY KEY (key));", ddl=True) + pool.execute_with_retries("CREATE TABLE example(key UInt64, value String, PRIMARY KEY (key));") pool.execute_with_retries("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal');") diff --git a/ydb/query/pool.py b/ydb/query/pool.py index bc3007d4..f1c52757 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -117,13 +117,13 @@ def wrapped_callee(): return next_opt.result def execute_with_retries( - self, query: str, ddl: bool = False, retry_settings: RetrySettings = None, *args, **kwargs + self, query: str, retry_settings: RetrySettings = None, *args, **kwargs ): retry_settings = RetrySettings() if retry_settings is None else retry_settings with self.checkout() as session: def wrapped_callee(): - it = session.execute(query, empty_tx_control=ddl) + it = session.execute(query, empty_tx_control=True) return [result_set for result_set in it] opt_generator = retry_operation_impl(wrapped_callee, retry_settings, *args, **kwargs) From 199e8d7900399082518ffd0cb89180f81aa3d693 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 22 Jul 2024 16:32:29 +0300 Subject: [PATCH 34/57] pass args to session execute --- ydb/query/pool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ydb/query/pool.py b/ydb/query/pool.py index f1c52757..4552e319 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -123,10 +123,10 @@ def execute_with_retries( with self.checkout() as session: def wrapped_callee(): - it = session.execute(query, empty_tx_control=True) + it = session.execute(query, empty_tx_control=True, *args, **kwargs) return [result_set for result_set in it] - opt_generator = retry_operation_impl(wrapped_callee, retry_settings, *args, **kwargs) + opt_generator = retry_operation_impl(wrapped_callee, retry_settings) for next_opt in opt_generator: if isinstance(next_opt, YdbRetryOperationSleepOpt): time.sleep(next_opt.timeout) From c3d1d2bf0d2bbbcb0e9029bd3099c52219007e35 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 22 Jul 2024 16:37:12 +0300 Subject: [PATCH 35/57] style fixes --- ydb/query/pool.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ydb/query/pool.py b/ydb/query/pool.py index 4552e319..d5871d4c 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -116,9 +116,7 @@ def wrapped_callee(): else: return next_opt.result - def execute_with_retries( - self, query: str, retry_settings: RetrySettings = None, *args, **kwargs - ): + def execute_with_retries(self, query: str, retry_settings: RetrySettings = None, *args, **kwargs): retry_settings = RetrySettings() if retry_settings is None else retry_settings with self.checkout() as session: From ec7274f6c6fca7e1a211b340e2ed24bec3cbdc60 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 23 Jul 2024 12:16:49 +0300 Subject: [PATCH 36/57] interactive tx support --- tests/query/test_query_transaction.py | 11 ++++++++--- ydb/query/base.py | 13 +++++-------- ydb/query/transaction.py | 25 +++++++++++++++++++------ 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 1b0d865a..0224a613 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -26,9 +26,14 @@ def test_tx_rollback_raises_before_begin(self, tx): with pytest.raises(RuntimeError): tx.rollback() - # def test_tx_execute_raises_before_begin(self, tx): - # with pytest.raises(RuntimeError): - # tx.execute("select 1;") + def test_tx_first_execute_begins_tx(self, tx): + tx.execute("select 1;") + tx.commit() + + def test_interactive_tx_commit(self, tx): + tx.execute("select 1;", commit_tx=True) + with pytest.raises(RuntimeError): + tx.execute("select 1;") def text_tx_execute_raises_after_commit(self, tx): tx.begin() diff --git a/ydb/query/base.py b/ydb/query/base.py index 98a02a42..e1936291 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -192,8 +192,12 @@ def create_execute_query_request( return req.to_proto() -def wrap_execute_query_response(rpc_state, response_pb): +def wrap_execute_query_response(rpc_state, response_pb, tx, commit_tx=False): issues._process_response(response_pb) + if response_pb.tx_meta and not tx.tx_id: + tx._handle_tx_meta(response_pb.tx_meta) + if commit_tx: + tx._move_to_commited() return convert.ResultSet.from_message(response_pb.result_set) @@ -201,17 +205,10 @@ def wrap_execute_query_response(rpc_state, response_pb): X_YDB_SESSION_CLOSE = "session-close" -# def _check_session_is_closing(rpc_state, session_state): -# metadata = rpc_state.trailing_metadata() -# if X_YDB_SESSION_CLOSE in metadata.get(X_YDB_SERVER_HINTS, []): -# session_state.set_closing() # TODO: clarify & implement - - def bad_session_handler(func): @functools.wraps(func) def decorator(rpc_state, response_pb, session_state, *args, **kwargs): try: - # _check_session_is_closing(rpc_state, session_state) return func(rpc_state, response_pb, session_state, *args, **kwargs) except issues.BadSession: session_state.reset() diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index df71d62c..973526e3 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -245,9 +245,8 @@ def commit(self, settings=None): """ if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): return - self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) - self._ensure_prev_stream_finished() + self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) return self._driver( _create_commit_transaction_request(self._session_state, self._tx_state), @@ -262,9 +261,8 @@ def rollback(self, settings=None): if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): return - self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) - self._ensure_prev_stream_finished() + self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) return self._driver( _create_rollback_transaction_request(self._session_state, self._tx_state), @@ -309,6 +307,16 @@ def _ensure_prev_stream_finished(self): pass self._prev_stream = None + def _handle_tx_meta(self, tx_meta=None): + if not self.tx_id: + self._tx_state._change_state(QueryTxStateEnum.BEGINED) + self._tx_state.tx_id = tx_meta.id + + def _move_to_commited(self): + if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + return + self._tx_state._change_state(QueryTxStateEnum.COMMITTED) + def execute( self, query: str, @@ -319,8 +327,8 @@ def execute( parameters: dict = None, concurrent_result_sets: bool = False, ): - self._tx_state._check_tx_not_terminal() self._ensure_prev_stream_finished() + self._tx_state._check_tx_not_terminal() stream_it = self._execute_call( query=query, @@ -333,6 +341,11 @@ def execute( ) self._prev_stream = _utilities.SyncResponseIterator( stream_it, - lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp), + lambda resp: base.wrap_execute_query_response( + rpc_state=None, + response_pb=resp, + tx=self, + commit_tx=commit_tx, + ), ) return self._prev_stream From 6b07d63d248cb249168f81514e149947b3c925a7 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 23 Jul 2024 12:41:36 +0300 Subject: [PATCH 37/57] fix test errors --- ydb/query/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ydb/query/base.py b/ydb/query/base.py index e1936291..2f270d92 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -192,11 +192,11 @@ def create_execute_query_request( return req.to_proto() -def wrap_execute_query_response(rpc_state, response_pb, tx, commit_tx=False): +def wrap_execute_query_response(rpc_state, response_pb, tx=None, commit_tx=False): issues._process_response(response_pb) - if response_pb.tx_meta and not tx.tx_id: + if tx and response_pb.tx_meta and not tx.tx_id: tx._handle_tx_meta(response_pb.tx_meta) - if commit_tx: + if tx and commit_tx: tx._move_to_commited() return convert.ResultSet.from_message(response_pb.result_set) From 1d1e312f267af7d6011af379be50ef135c6862b1 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 23 Jul 2024 14:00:43 +0300 Subject: [PATCH 38/57] QuerySessionPool docstring --- ydb/query/pool.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ydb/query/pool.py b/ydb/query/pool.py index d5871d4c..2781e0ee 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -96,13 +96,26 @@ def retry_operation_impl(callee: Callable, retry_settings: RetrySettings = None, class QuerySessionPool: + """QuerySessionPool is an object to simplify operations with sessions of Query Service.""" + def __init__(self, driver: base.SupportedDriverType): + """ + :param driver: A driver instance + """ self._driver = driver def checkout(self): + """Return a Session context manager, that opens session on enter and closes session on exit.""" return SimpleQuerySessionCheckout(self) def retry_operation_sync(self, callee: Callable, retry_settings: RetrySettings = None, *args, **kwargs): + """Special interface to execute a bunch of commands with session in a safe, retriable way. + + :param callee: A function, that works with session. + :param retry_settings: RetrySettings object. + + :return: Result sets or exception in case of execution errors. + """ retry_settings = RetrySettings() if retry_settings is None else retry_settings def wrapped_callee(): @@ -117,6 +130,13 @@ def wrapped_callee(): return next_opt.result def execute_with_retries(self, query: str, retry_settings: RetrySettings = None, *args, **kwargs): + """Special interface to execute a one-shot queries in a safe, retriable way. + + :param query: A query, yql or sql text. + :param retry_settings: RetrySettings object. + + :return: Result sets or exception in case of execution errors. + """ retry_settings = RetrySettings() if retry_settings is None else retry_settings with self.checkout() as session: From bbd47dad06971c47e762e8afbf58738021a01790 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 23 Jul 2024 14:52:29 +0300 Subject: [PATCH 39/57] Fix tests naming --- tests/query/test_query_transaction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 0224a613..fd7fb6bb 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -35,13 +35,13 @@ def test_interactive_tx_commit(self, tx): with pytest.raises(RuntimeError): tx.execute("select 1;") - def text_tx_execute_raises_after_commit(self, tx): + def test_tx_execute_raises_after_commit(self, tx): tx.begin() tx.commit() with pytest.raises(RuntimeError): tx.execute("select 1;") - def text_tx_execute_raises_after_rollback(self, tx): + def test_tx_execute_raises_after_rollback(self, tx): tx.begin() tx.rollback() with pytest.raises(RuntimeError): From a1f513a513f7ad01125d221b8e8dede9e701d76f Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 24 Jul 2024 10:17:34 +0300 Subject: [PATCH 40/57] tx context manager tests --- tests/query/test_query_transaction.py | 20 ++++++++++++++++++++ ydb/query/transaction.py | 1 + 2 files changed, 21 insertions(+) diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index fd7fb6bb..ee0c3a07 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -1,5 +1,6 @@ import pytest +from ydb.query.transaction import QueryTxStateEnum class TestQueryTransaction: def test_tx_begin(self, tx): @@ -46,3 +47,22 @@ def test_tx_execute_raises_after_rollback(self, tx): tx.rollback() with pytest.raises(RuntimeError): tx.execute("select 1;") + + def test_context_manager_rollbacks_tx(self, tx): + with tx: + tx.begin() + + assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED + + def test_context_manager_normal_flow(self, tx): + with tx: + tx.begin() + tx.execute("select 1;") + tx.commit() + + assert tx._tx_state._state == QueryTxStateEnum.COMMITTED + + def test_context_manager_does_not_hide_exceptions(self, tx): + with pytest.raises(RuntimeError): + with tx: + tx.commit() diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 973526e3..2bf76513 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -184,6 +184,7 @@ def __exit__(self, *args, **kwargs): Closes a transaction context manager and rollbacks transaction if it is not rolled back explicitly """ + self._ensure_prev_stream_finished() if self._tx_state.tx_id is not None: # It's strictly recommended to close transactions directly # by using commit_tx=True flag while executing statement or by From 68ce180ca728fa405418d2904391d9ba2ba845ec Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 24 Jul 2024 10:24:36 +0300 Subject: [PATCH 41/57] add typing to tests --- tests/query/test_query_session.py | 28 ++++++++++++++------------- tests/query/test_query_transaction.py | 25 ++++++++++++------------ 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py index dc0b7664..402fc8bb 100644 --- a/tests/query/test_query_session.py +++ b/tests/query/test_query_session.py @@ -1,20 +1,22 @@ import pytest +from ydb.query.session import QuerySessionSync -def _check_session_state_empty(session): + +def _check_session_state_empty(session: QuerySessionSync): assert session._state.session_id is None assert session._state.node_id is None assert not session._state.attached -def _check_session_state_full(session): +def _check_session_state_full(session: QuerySessionSync): assert session._state.session_id is not None assert session._state.node_id is not None assert session._state.attached class TestQuerySession: - def test_session_normal_lifecycle(self, session): + def test_session_normal_lifecycle(self, session: QuerySessionSync): _check_session_state_empty(session) session.create() @@ -23,7 +25,7 @@ def test_session_normal_lifecycle(self, session): session.delete() _check_session_state_empty(session) - def test_second_create_do_nothing(self, session): + def test_second_create_do_nothing(self, session: QuerySessionSync): session.create() _check_session_state_full(session) @@ -36,27 +38,27 @@ def test_second_create_do_nothing(self, session): assert session._state.session_id == session_id_before assert session._state.node_id == node_id_before - def test_second_delete_do_nothing(self, session): + def test_second_delete_do_nothing(self, session: QuerySessionSync): session.create() session.delete() session.delete() - def test_delete_before_create_not_possible(self, session): + def test_delete_before_create_not_possible(self, session: QuerySessionSync): with pytest.raises(RuntimeError): session.delete() - def test_create_after_delete_not_possible(self, session): + def test_create_after_delete_not_possible(self, session: QuerySessionSync): session.create() session.delete() with pytest.raises(RuntimeError): session.create() - def test_transaction_before_create_raises(self, session): + def test_transaction_before_create_raises(self, session: QuerySessionSync): with pytest.raises(RuntimeError): session.transaction() - def test_transaction_after_delete_raises(self, session): + def test_transaction_after_delete_raises(self, session: QuerySessionSync): session.create() session.delete() @@ -64,21 +66,21 @@ def test_transaction_after_delete_raises(self, session): with pytest.raises(RuntimeError): session.transaction() - def test_transaction_after_create_not_raises(self, session): + def test_transaction_after_create_not_raises(self, session: QuerySessionSync): session.create() session.transaction() - def test_execute_before_create_raises(self, session): + def test_execute_before_create_raises(self, session: QuerySessionSync): with pytest.raises(RuntimeError): session.execute("select 1;") - def test_execute_after_delete_raises(self, session): + def test_execute_after_delete_raises(self, session: QuerySessionSync): session.create() session.delete() with pytest.raises(RuntimeError): session.execute("select 1;") - def test_basic_execute(self, session): + def test_basic_execute(self, session: QuerySessionSync): session.create() it = session.execute("select 1;") result_sets = [result_set for result_set in it] diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index ee0c3a07..0cfd55e8 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -1,60 +1,61 @@ import pytest +from ydb.query.transaction import BaseTxContext from ydb.query.transaction import QueryTxStateEnum class TestQueryTransaction: - def test_tx_begin(self, tx): + def test_tx_begin(self, tx: BaseTxContext): assert tx.tx_id is None tx.begin() assert tx.tx_id is not None - def test_tx_allow_double_commit(self, tx): + def test_tx_allow_double_commit(self, tx: BaseTxContext): tx.begin() tx.commit() tx.commit() - def test_tx_allow_double_rollback(self, tx): + def test_tx_allow_double_rollback(self, tx: BaseTxContext): tx.begin() tx.rollback() tx.rollback() - def test_tx_commit_raises_before_begin(self, tx): + def test_tx_commit_raises_before_begin(self, tx: BaseTxContext): with pytest.raises(RuntimeError): tx.commit() - def test_tx_rollback_raises_before_begin(self, tx): + def test_tx_rollback_raises_before_begin(self, tx: BaseTxContext): with pytest.raises(RuntimeError): tx.rollback() - def test_tx_first_execute_begins_tx(self, tx): + def test_tx_first_execute_begins_tx(self, tx: BaseTxContext): tx.execute("select 1;") tx.commit() - def test_interactive_tx_commit(self, tx): + def test_interactive_tx_commit(self, tx: BaseTxContext): tx.execute("select 1;", commit_tx=True) with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_tx_execute_raises_after_commit(self, tx): + def test_tx_execute_raises_after_commit(self, tx: BaseTxContext): tx.begin() tx.commit() with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_tx_execute_raises_after_rollback(self, tx): + def test_tx_execute_raises_after_rollback(self, tx: BaseTxContext): tx.begin() tx.rollback() with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_context_manager_rollbacks_tx(self, tx): + def test_context_manager_rollbacks_tx(self, tx: BaseTxContext): with tx: tx.begin() assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED - def test_context_manager_normal_flow(self, tx): + def test_context_manager_normal_flow(self, tx: BaseTxContext): with tx: tx.begin() tx.execute("select 1;") @@ -62,7 +63,7 @@ def test_context_manager_normal_flow(self, tx): assert tx._tx_state._state == QueryTxStateEnum.COMMITTED - def test_context_manager_does_not_hide_exceptions(self, tx): + def test_context_manager_does_not_hide_exceptions(self, tx: BaseTxContext): with pytest.raises(RuntimeError): with tx: tx.commit() From 690c974c3463edb15392d10634877f1146b4496e Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 24 Jul 2024 10:29:04 +0300 Subject: [PATCH 42/57] style fix --- tests/query/test_query_transaction.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 0cfd55e8..287c151c 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -3,6 +3,7 @@ from ydb.query.transaction import BaseTxContext from ydb.query.transaction import QueryTxStateEnum + class TestQueryTransaction: def test_tx_begin(self, tx: BaseTxContext): assert tx.tx_id is None From c8e74a8fb6999f87ba8929cad09675216bc55322 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 24 Jul 2024 13:16:55 +0300 Subject: [PATCH 43/57] move retry logic to standalone module --- tests/query/test_query_transaction.py | 4 +- ydb/_retries.py | 136 +++++++++++++++++++++ ydb/_topic_reader/topic_reader.py | 2 +- ydb/_topic_writer/topic_writer_asyncio.py | 4 +- ydb/aio/table.py | 6 +- ydb/query/pool.py | 107 +---------------- ydb/table.py | 139 +--------------------- ydb/table_test.py | 5 +- 8 files changed, 158 insertions(+), 245 deletions(-) create mode 100644 ydb/_retries.py diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 287c151c..6b185720 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -65,6 +65,6 @@ def test_context_manager_normal_flow(self, tx: BaseTxContext): assert tx._tx_state._state == QueryTxStateEnum.COMMITTED def test_context_manager_does_not_hide_exceptions(self, tx: BaseTxContext): - with pytest.raises(RuntimeError): + with pytest.raises(Exception): with tx: - tx.commit() + raise Exception() diff --git a/ydb/_retries.py b/ydb/_retries.py new file mode 100644 index 00000000..5d4f6e6a --- /dev/null +++ b/ydb/_retries.py @@ -0,0 +1,136 @@ +import random +import time + +from . import issues +from ._errors import check_retriable_error + + +class BackoffSettings(object): + def __init__(self, ceiling=6, slot_duration=0.001, uncertain_ratio=0.5): + self.ceiling = ceiling + self.slot_duration = slot_duration + self.uncertain_ratio = uncertain_ratio + + def calc_timeout(self, retry_number): + slots_count = 1 << min(retry_number, self.ceiling) + max_duration_ms = slots_count * self.slot_duration * 1000.0 + # duration_ms = random.random() * max_duration_ms * uncertain_ratio) + max_duration_ms * (1 - uncertain_ratio) + duration_ms = max_duration_ms * (random.random() * self.uncertain_ratio + 1.0 - self.uncertain_ratio) + return duration_ms / 1000.0 + + +class RetrySettings(object): + def __init__( + self, + max_retries=10, + max_session_acquire_timeout=None, + on_ydb_error_callback=None, + backoff_ceiling=6, + backoff_slot_duration=1, + get_session_client_timeout=5, + fast_backoff_settings=None, + slow_backoff_settings=None, + idempotent=False, + ): + self.max_retries = max_retries + self.max_session_acquire_timeout = max_session_acquire_timeout + self.on_ydb_error_callback = (lambda e: None) if on_ydb_error_callback is None else on_ydb_error_callback + self.fast_backoff = BackoffSettings(10, 0.005) if fast_backoff_settings is None else fast_backoff_settings + self.slow_backoff = ( + BackoffSettings(backoff_ceiling, backoff_slot_duration) + if slow_backoff_settings is None + else slow_backoff_settings + ) + self.retry_not_found = True + self.idempotent = idempotent + self.retry_internal_error = True + self.unknown_error_handler = lambda e: None + self.get_session_client_timeout = get_session_client_timeout + if max_session_acquire_timeout is not None: + self.get_session_client_timeout = min(self.max_session_acquire_timeout, self.get_session_client_timeout) + + def with_fast_backoff(self, backoff_settings): + self.fast_backoff = backoff_settings + return self + + def with_slow_backoff(self, backoff_settings): + self.slow_backoff = backoff_settings + return self + + +class YdbRetryOperationSleepOpt(object): + def __init__(self, timeout): + self.timeout = timeout + + def __eq__(self, other): + return type(self) == type(other) and self.timeout == other.timeout + + def __repr__(self): + return "YdbRetryOperationSleepOpt(%s)" % self.timeout + + +class YdbRetryOperationFinalResult(object): + def __init__(self, result): + self.result = result + self.exc = None + + def __eq__(self, other): + return type(self) == type(other) and self.result == other.result and self.exc == other.exc + + def __repr__(self): + return "YdbRetryOperationFinalResult(%s, exc=%s)" % (self.result, self.exc) + + def set_exception(self, exc): + self.exc = exc + + +def retry_operation_impl(callee, retry_settings=None, *args, **kwargs): + retry_settings = RetrySettings() if retry_settings is None else retry_settings + status = None + + for attempt in range(retry_settings.max_retries + 1): + try: + result = YdbRetryOperationFinalResult(callee(*args, **kwargs)) + yield result + + if result.exc is not None: + raise result.exc + + except issues.Error as e: + status = e + retry_settings.on_ydb_error_callback(e) + + retriable_info = check_retriable_error(e, retry_settings, attempt) + if not retriable_info.is_retriable: + raise + + skip_yield_error_types = [ + issues.Aborted, + issues.BadSession, + issues.NotFound, + issues.InternalError, + ] + + yield_sleep = True + for t in skip_yield_error_types: + if isinstance(e, t): + yield_sleep = False + + if yield_sleep: + yield YdbRetryOperationSleepOpt(retriable_info.sleep_timeout_seconds) + + except Exception as e: + # you should provide your own handler you want + retry_settings.unknown_error_handler(e) + raise + + raise status + + +def retry_operation_sync(callee, retry_settings=None, *args, **kwargs): + opt_generator = retry_operation_impl(callee, retry_settings, *args, **kwargs) + for next_opt in opt_generator: + if isinstance(next_opt, YdbRetryOperationSleepOpt): + time.sleep(next_opt.timeout) + else: + return next_opt.result diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 17fb2885..4ac6c441 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -10,7 +10,7 @@ Callable, ) -from ..table import RetrySettings +from .._retries import RetrySettings from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, OffsetsRange diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 007c8a54..04b174b0 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -26,9 +26,9 @@ from .. import ( _apis, issues, - check_retriable_error, - RetrySettings, ) +from .._errors import check_retriable_error +from .._retries import RetrySettings from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._grpc.grpcwrapper.ydb_topic import ( UpdateTokenRequest, diff --git a/ydb/aio/table.py b/ydb/aio/table.py index 3c25f7d2..228581ae 100644 --- a/ydb/aio/table.py +++ b/ydb/aio/table.py @@ -15,7 +15,7 @@ BaseTxContext, ) from . import _utilities -from ydb import _apis, _session_impl +from ydb import _apis, _session_impl, _retries logger = logging.getLogger(__name__) @@ -214,9 +214,9 @@ async def retry_operation(callee, retry_settings=None, *args, **kwargs): # pyli Returns awaitable result of coroutine. If retries are not succussful exception is raised. """ - opt_generator = ydb.retry_operation_impl(callee, retry_settings, *args, **kwargs) + opt_generator = _retries.retry_operation_impl(callee, retry_settings, *args, **kwargs) for next_opt in opt_generator: - if isinstance(next_opt, ydb.YdbRetryOperationSleepOpt): + if isinstance(next_opt, _retries.YdbRetryOperationSleepOpt): await asyncio.sleep(next_opt.timeout) else: try: diff --git a/ydb/query/pool.py b/ydb/query/pool.py index 2781e0ee..ce1207de 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -1,98 +1,13 @@ -import time from typing import Callable from . import base from .session import ( QuerySessionSync, ) -from .. import issues -from .._errors import check_retriable_error - - -class RetrySettings(object): - def __init__( - self, - max_retries: int = 10, - max_session_acquire_timeout: int = None, - on_ydb_error_callback: Callable = None, - idempotent: bool = False, - ): - self.max_retries = max_retries - self.max_session_acquire_timeout = max_session_acquire_timeout - self.on_ydb_error_callback = (lambda e: None) if on_ydb_error_callback is None else on_ydb_error_callback - self.retry_not_found = True - self.idempotent = idempotent - self.retry_internal_error = True - self.unknown_error_handler = lambda e: None - - -class YdbRetryOperationSleepOpt: - def __init__(self, timeout): - self.timeout = timeout - - def __eq__(self, other): - return type(self) == type(other) and self.timeout == other.timeout - - def __repr__(self): - return "YdbRetryOperationSleepOpt(%s)" % self.timeout - - -class YdbRetryOperationFinalResult: - def __init__(self, result): - self.result = result - self.exc = None - - def __eq__(self, other): - return type(self) == type(other) and self.result == other.result and self.exc == other.exc - - def __repr__(self): - return "YdbRetryOperationFinalResult(%s, exc=%s)" % (self.result, self.exc) - - def set_exception(self, exc): - self.exc = exc - - -def retry_operation_impl(callee: Callable, retry_settings: RetrySettings = None, *args, **kwargs): - retry_settings = RetrySettings() if retry_settings is None else retry_settings - status = None - - for attempt in range(retry_settings.max_retries + 1): - try: - result = YdbRetryOperationFinalResult(callee(*args, **kwargs)) - yield result - - if result.exc is not None: - raise result.exc - - except issues.Error as e: - status = e - retry_settings.on_ydb_error_callback(e) - - retriable_info = check_retriable_error(e, retry_settings, attempt) - if not retriable_info.is_retriable: - raise - - skip_yield_error_types = [ - issues.Aborted, - issues.BadSession, - issues.NotFound, - issues.InternalError, - ] - - yield_sleep = True - for t in skip_yield_error_types: - if isinstance(e, t): - yield_sleep = False - - if yield_sleep: - yield YdbRetryOperationSleepOpt(retriable_info.sleep_timeout_seconds) - - except Exception as e: - # you should provide your own handler you want - retry_settings.unknown_error_handler(e) - raise - if status: - raise status +from .._retries import ( + RetrySettings, + retry_operation_sync, +) class QuerySessionPool: @@ -122,12 +37,7 @@ def wrapped_callee(): with self.checkout() as session: return callee(session, *args, **kwargs) - opt_generator = retry_operation_impl(wrapped_callee, retry_settings, *args, **kwargs) - for next_opt in opt_generator: - if isinstance(next_opt, YdbRetryOperationSleepOpt): - time.sleep(next_opt.timeout) - else: - return next_opt.result + return retry_operation_sync(wrapped_callee, retry_settings) def execute_with_retries(self, query: str, retry_settings: RetrySettings = None, *args, **kwargs): """Special interface to execute a one-shot queries in a safe, retriable way. @@ -144,12 +54,7 @@ def wrapped_callee(): it = session.execute(query, empty_tx_control=True, *args, **kwargs) return [result_set for result_set in it] - opt_generator = retry_operation_impl(wrapped_callee, retry_settings) - for next_opt in opt_generator: - if isinstance(next_opt, YdbRetryOperationSleepOpt): - time.sleep(next_opt.timeout) - else: - return next_opt.result + return retry_operation_sync(wrapped_callee, retry_settings) class SimpleQuerySessionCheckout: diff --git a/ydb/table.py b/ydb/table.py index c21392bb..b2a4c569 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -3,8 +3,6 @@ import ydb from abc import abstractmethod import logging -import time -import random import enum from . import ( @@ -20,7 +18,11 @@ _tx_ctx_impl, tracing, ) -from ._errors import check_retriable_error + +from ._retries import ( + retry_operation_sync, + RetrySettings, +) try: from . import interceptor @@ -840,137 +842,6 @@ def name(self): return self._name -class BackoffSettings(object): - def __init__(self, ceiling=6, slot_duration=0.001, uncertain_ratio=0.5): - self.ceiling = ceiling - self.slot_duration = slot_duration - self.uncertain_ratio = uncertain_ratio - - def calc_timeout(self, retry_number): - slots_count = 1 << min(retry_number, self.ceiling) - max_duration_ms = slots_count * self.slot_duration * 1000.0 - # duration_ms = random.random() * max_duration_ms * uncertain_ratio) + max_duration_ms * (1 - uncertain_ratio) - duration_ms = max_duration_ms * (random.random() * self.uncertain_ratio + 1.0 - self.uncertain_ratio) - return duration_ms / 1000.0 - - -class RetrySettings(object): - def __init__( - self, - max_retries=10, - max_session_acquire_timeout=None, - on_ydb_error_callback=None, - backoff_ceiling=6, - backoff_slot_duration=1, - get_session_client_timeout=5, - fast_backoff_settings=None, - slow_backoff_settings=None, - idempotent=False, - ): - self.max_retries = max_retries - self.max_session_acquire_timeout = max_session_acquire_timeout - self.on_ydb_error_callback = (lambda e: None) if on_ydb_error_callback is None else on_ydb_error_callback - self.fast_backoff = BackoffSettings(10, 0.005) if fast_backoff_settings is None else fast_backoff_settings - self.slow_backoff = ( - BackoffSettings(backoff_ceiling, backoff_slot_duration) - if slow_backoff_settings is None - else slow_backoff_settings - ) - self.retry_not_found = True - self.idempotent = idempotent - self.retry_internal_error = True - self.unknown_error_handler = lambda e: None - self.get_session_client_timeout = get_session_client_timeout - if max_session_acquire_timeout is not None: - self.get_session_client_timeout = min(self.max_session_acquire_timeout, self.get_session_client_timeout) - - def with_fast_backoff(self, backoff_settings): - self.fast_backoff = backoff_settings - return self - - def with_slow_backoff(self, backoff_settings): - self.slow_backoff = backoff_settings - return self - - -class YdbRetryOperationSleepOpt(object): - def __init__(self, timeout): - self.timeout = timeout - - def __eq__(self, other): - return type(self) == type(other) and self.timeout == other.timeout - - def __repr__(self): - return "YdbRetryOperationSleepOpt(%s)" % self.timeout - - -class YdbRetryOperationFinalResult(object): - def __init__(self, result): - self.result = result - self.exc = None - - def __eq__(self, other): - return type(self) == type(other) and self.result == other.result and self.exc == other.exc - - def __repr__(self): - return "YdbRetryOperationFinalResult(%s, exc=%s)" % (self.result, self.exc) - - def set_exception(self, exc): - self.exc = exc - - -def retry_operation_impl(callee, retry_settings=None, *args, **kwargs): - retry_settings = RetrySettings() if retry_settings is None else retry_settings - status = None - - for attempt in range(retry_settings.max_retries + 1): - try: - result = YdbRetryOperationFinalResult(callee(*args, **kwargs)) - yield result - - if result.exc is not None: - raise result.exc - - except issues.Error as e: - status = e - retry_settings.on_ydb_error_callback(e) - - retriable_info = check_retriable_error(e, retry_settings, attempt) - if not retriable_info.is_retriable: - raise - - skip_yield_error_types = [ - issues.Aborted, - issues.BadSession, - issues.NotFound, - issues.InternalError, - ] - - yield_sleep = True - for t in skip_yield_error_types: - if isinstance(e, t): - yield_sleep = False - - if yield_sleep: - yield YdbRetryOperationSleepOpt(retriable_info.sleep_timeout_seconds) - - except Exception as e: - # you should provide your own handler you want - retry_settings.unknown_error_handler(e) - raise - - raise status - - -def retry_operation_sync(callee, retry_settings=None, *args, **kwargs): - opt_generator = retry_operation_impl(callee, retry_settings, *args, **kwargs) - for next_opt in opt_generator: - if isinstance(next_opt, YdbRetryOperationSleepOpt): - time.sleep(next_opt.timeout) - else: - return next_opt.result - - class TableClientSettings(object): def __init__(self): self._client_query_cache_enabled = False diff --git a/ydb/table_test.py b/ydb/table_test.py index 640662c4..aa45e41e 100644 --- a/ydb/table_test.py +++ b/ydb/table_test.py @@ -1,8 +1,9 @@ from unittest import mock -from . import ( +from . import issues + +from ._retries import ( retry_operation_impl, YdbRetryOperationFinalResult, - issues, YdbRetryOperationSleepOpt, RetrySettings, ) From 249e5696a8ac5be8f5ccf5695688d8011c7cc431 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 24 Jul 2024 14:47:50 +0300 Subject: [PATCH 44/57] response iterator as context manager --- examples/query-service/basic_example.py | 39 +++++++++++++------------ tests/query/test_query_transaction.py | 8 +++++ ydb/query/base.py | 10 +++++++ ydb/query/transaction.py | 3 +- 4 files changed, 39 insertions(+), 21 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index f5c91d4a..d78f5c71 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -38,39 +38,40 @@ def callee(session): print("=" * 50) print("INSERT WITH COMMIT TX") - tx = session.transaction() - tx.begin() + with session.transaction() as tx: + tx.begin() - tx.execute("""INSERT INTO example (key, value) VALUES (1, "onepieceisreal");""") + with tx.execute("""INSERT INTO example (key, value) VALUES (1, "onepieceisreal");""") as results: + pass - for result_set in tx.execute("""SELECT COUNT(*) as rows_count FROM example;"""): - print(f"rows: {str(result_set.rows)}") + with tx.execute("""SELECT COUNT(*) as rows_count FROM example;""") as results: + for result_set in results: + print(f"rows: {str(result_set.rows)}") - tx.commit() + tx.commit() - print("=" * 50) - print("AFTER COMMIT TX") + print("=" * 50) + print("AFTER COMMIT TX") - for result_set in session.execute("""SELECT COUNT(*) as rows_count FROM example;"""): - print(f"rows: {str(result_set.rows)}") + for result_set in session.execute("""SELECT COUNT(*) as rows_count FROM example;"""): + print(f"rows: {str(result_set.rows)}") print("=" * 50) print("INSERT WITH ROLLBACK TX") - tx = session.transaction() + with session.transaction() as tx: + tx.begin() - tx.begin() + tx.execute("""INSERT INTO example (key, value) VALUES (2, "onepieceisreal");""") - tx.execute("""INSERT INTO example (key, value) VALUES (2, "onepieceisreal");""") + for result_set in tx.execute("""SELECT COUNT(*) as rows_count FROM example;"""): + print(f"rows: {str(result_set.rows)}") - for result_set in tx.execute("""SELECT COUNT(*) as rows_count FROM example;"""): - print(f"rows: {str(result_set.rows)}") - - tx.rollback() + tx.rollback() - print("=" * 50) - print("AFTER ROLLBACK TX") + print("=" * 50) + print("AFTER ROLLBACK TX") for result_set in session.execute("""SELECT COUNT(*) as rows_count FROM example;"""): print(f"rows: {str(result_set.rows)}") diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 6b185720..0f09f9b1 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -68,3 +68,11 @@ def test_context_manager_does_not_hide_exceptions(self, tx: BaseTxContext): with pytest.raises(Exception): with tx: raise Exception() + + def test_execute_as_context_manager(self, tx: BaseTxContext): + tx.begin() + + with tx.execute("select 1;") as results: + res = [result_set for result_set in results] + + assert len(res) == 1 diff --git a/ydb/query/base.py b/ydb/query/base.py index 2f270d92..9221da8e 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -16,6 +16,7 @@ ) from .. import convert from .. import issues +from .. import _utilities class QueryClientSettings: @@ -215,3 +216,12 @@ def decorator(rpc_state, response_pb, session_state, *args, **kwargs): raise return decorator + + +class SyncResponseContextIterator(_utilities.SyncResponseIterator): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + for _ in self.it: + pass diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 2bf76513..a859326f 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -6,7 +6,6 @@ from .. import ( _apis, issues, - _utilities, ) from .._grpc.grpcwrapper import ydb_query as _ydb_query from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public @@ -340,7 +339,7 @@ def execute( parameters=parameters, concurrent_result_sets=concurrent_result_sets, ) - self._prev_stream = _utilities.SyncResponseIterator( + self._prev_stream = base.SyncResponseContextIterator( stream_it, lambda resp: base.wrap_execute_query_response( rpc_state=None, From 8bb8a6ff96572cbffc624b5c358506c278e537d6 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 24 Jul 2024 15:01:32 +0300 Subject: [PATCH 45/57] make retries module public --- ydb/__init__.py | 1 + ydb/_topic_reader/topic_reader.py | 2 +- ydb/_topic_writer/topic_writer_asyncio.py | 2 +- ydb/aio/table.py | 6 +++--- ydb/query/pool.py | 2 +- ydb/{_retries.py => retries.py} | 0 ydb/table.py | 2 +- ydb/table_test.py | 2 +- 8 files changed, 9 insertions(+), 8 deletions(-) rename ydb/{_retries.py => retries.py} (100%) diff --git a/ydb/__init__.py b/ydb/__init__.py index 1caaaa02..375f2f54 100644 --- a/ydb/__init__.py +++ b/ydb/__init__.py @@ -20,6 +20,7 @@ from .topic import * # noqa from .draft import * # noqa from .query import * # noqa +from .retries import * # noqa try: import ydb.aio as aio # noqa diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 4ac6c441..b907ee27 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -10,7 +10,7 @@ Callable, ) -from .._retries import RetrySettings +from ..retries import RetrySettings from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, OffsetsRange diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 04b174b0..585e88ab 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -28,7 +28,7 @@ issues, ) from .._errors import check_retriable_error -from .._retries import RetrySettings +from ..retries import RetrySettings from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._grpc.grpcwrapper.ydb_topic import ( UpdateTokenRequest, diff --git a/ydb/aio/table.py b/ydb/aio/table.py index 228581ae..3c25f7d2 100644 --- a/ydb/aio/table.py +++ b/ydb/aio/table.py @@ -15,7 +15,7 @@ BaseTxContext, ) from . import _utilities -from ydb import _apis, _session_impl, _retries +from ydb import _apis, _session_impl logger = logging.getLogger(__name__) @@ -214,9 +214,9 @@ async def retry_operation(callee, retry_settings=None, *args, **kwargs): # pyli Returns awaitable result of coroutine. If retries are not succussful exception is raised. """ - opt_generator = _retries.retry_operation_impl(callee, retry_settings, *args, **kwargs) + opt_generator = ydb.retry_operation_impl(callee, retry_settings, *args, **kwargs) for next_opt in opt_generator: - if isinstance(next_opt, _retries.YdbRetryOperationSleepOpt): + if isinstance(next_opt, ydb.YdbRetryOperationSleepOpt): await asyncio.sleep(next_opt.timeout) else: try: diff --git a/ydb/query/pool.py b/ydb/query/pool.py index ce1207de..7124c55e 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -4,7 +4,7 @@ from .session import ( QuerySessionSync, ) -from .._retries import ( +from ..retries import ( RetrySettings, retry_operation_sync, ) diff --git a/ydb/_retries.py b/ydb/retries.py similarity index 100% rename from ydb/_retries.py rename to ydb/retries.py diff --git a/ydb/table.py b/ydb/table.py index b2a4c569..12856d61 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -19,7 +19,7 @@ tracing, ) -from ._retries import ( +from .retries import ( retry_operation_sync, RetrySettings, ) diff --git a/ydb/table_test.py b/ydb/table_test.py index aa45e41e..d5d86e05 100644 --- a/ydb/table_test.py +++ b/ydb/table_test.py @@ -1,7 +1,7 @@ from unittest import mock from . import issues -from ._retries import ( +from .retries import ( retry_operation_impl, YdbRetryOperationFinalResult, YdbRetryOperationSleepOpt, From 56f7b24ae425de984229b51216da9ccf1ca03a21 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 24 Jul 2024 18:39:16 +0300 Subject: [PATCH 46/57] query session pool tests --- tests/query/conftest.py | 7 ++++ tests/query/test_query_session_pool.py | 47 ++++++++++++++++++++++++++ ydb/query/transaction.py | 1 + 3 files changed, 55 insertions(+) create mode 100644 tests/query/test_query_session_pool.py diff --git a/tests/query/conftest.py b/tests/query/conftest.py index a7f0c34c..113c1126 100644 --- a/tests/query/conftest.py +++ b/tests/query/conftest.py @@ -1,5 +1,6 @@ import pytest from ydb.query.session import QuerySessionSync +from ydb.query.pool import QuerySessionPool @pytest.fixture @@ -25,3 +26,9 @@ def tx(session): transaction.rollback() except BaseException: pass + + +@pytest.fixture +def pool(driver_sync): + pool = QuerySessionPool(driver_sync) + yield pool \ No newline at end of file diff --git a/tests/query/test_query_session_pool.py b/tests/query/test_query_session_pool.py new file mode 100644 index 00000000..aee95446 --- /dev/null +++ b/tests/query/test_query_session_pool.py @@ -0,0 +1,47 @@ +import pytest + +from ydb.query.pool import QuerySessionPool +from ydb.query.session import QuerySessionSync, QuerySessionStateEnum + + +class TestQuerySessionPool: + def test_checkout_provides_created_session(self, pool: QuerySessionPool): + with pool.checkout() as session: + assert session._state._state == QuerySessionStateEnum.CREATED + + assert session._state._state == QuerySessionStateEnum.CLOSED + + def test_oneshot_query_normal(self, pool: QuerySessionPool): + res = pool.execute_with_retries("select 1;") + assert len(res) == 1 + + def test_oneshot_ddl_query(self, pool: QuerySessionPool): + pool.execute_with_retries("create table Queen(key UInt64, PRIMARY KEY (key));") + pool.execute_with_retries("drop table Queen;") + + def test_oneshot_query_raises(self, pool: QuerySessionPool): + with pytest.raises(Exception): + pool.execute_with_retries("Is this the real life? Is this just fantasy?") + + def test_retry_op_uses_created_session(self, pool: QuerySessionPool): + def callee(session: QuerySessionSync): + assert session._state._state == QuerySessionStateEnum.CREATED + pool.retry_operation_sync(callee) + + def test_retry_op_normal(self, pool: QuerySessionPool): + def callee(session: QuerySessionSync): + with session.transaction() as tx: + iterator = tx.execute("select 1;", commit_tx=True) + return [result_set for result_set in iterator] + + res = pool.retry_operation_sync(callee) + assert len(res) == 1 + + def test_retry_op_raises(self, pool: QuerySessionPool): + def callee(session: QuerySessionSync): + res = session.execute("Caught in a landslide, no escape from reality") + for _ in res: + pass + + with pytest.raises(Exception): + pool.retry_operation_sync(callee) diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index a859326f..83724245 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -316,6 +316,7 @@ def _move_to_commited(self): if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): return self._tx_state._change_state(QueryTxStateEnum.COMMITTED) + self._tx_state.tx_id = None def execute( self, From aa76736f13daf7ae4ca5b7af9693d8e92e238219 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 24 Jul 2024 18:42:27 +0300 Subject: [PATCH 47/57] style fixes --- tests/query/conftest.py | 2 +- tests/query/test_query_session_pool.py | 1 + ydb/query/base.py | 4 ---- ydb/query/session.py | 4 ++-- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/query/conftest.py b/tests/query/conftest.py index 113c1126..277aaeba 100644 --- a/tests/query/conftest.py +++ b/tests/query/conftest.py @@ -31,4 +31,4 @@ def tx(session): @pytest.fixture def pool(driver_sync): pool = QuerySessionPool(driver_sync) - yield pool \ No newline at end of file + yield pool diff --git a/tests/query/test_query_session_pool.py b/tests/query/test_query_session_pool.py index aee95446..069de7f1 100644 --- a/tests/query/test_query_session_pool.py +++ b/tests/query/test_query_session_pool.py @@ -26,6 +26,7 @@ def test_oneshot_query_raises(self, pool: QuerySessionPool): def test_retry_op_uses_created_session(self, pool: QuerySessionPool): def callee(session: QuerySessionSync): assert session._state._state == QuerySessionStateEnum.CREATED + pool.retry_operation_sync(callee) def test_retry_op_normal(self, pool: QuerySessionPool): diff --git a/ydb/query/base.py b/ydb/query/base.py index 9221da8e..e64871d5 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -202,10 +202,6 @@ def wrap_execute_query_response(rpc_state, response_pb, tx=None, commit_tx=False return convert.ResultSet.from_message(response_pb.result_set) -X_YDB_SERVER_HINTS = "x-ydb-server-hints" -X_YDB_SESSION_CLOSE = "session-close" - - def bad_session_handler(func): @functools.wraps(func) def decorator(rpc_state, response_pb, session_state, *args, **kwargs): diff --git a/ydb/query/session.py b/ydb/query/session.py index 8288aa57..83fd1795 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -198,13 +198,13 @@ def _attach(self): self._state._change_state(QuerySessionStateEnum.CREATED) threading.Thread( - target=self._chech_session_status_loop, + target=self._check_session_status_loop, args=(status_stream,), name="check session status thread", daemon=True, ).start() - def _chech_session_status_loop(self, status_stream): + def _check_session_status_loop(self, status_stream): try: for status in status_stream: if status.status != issues.StatusCode.SUCCESS: From 8f3b8afc303f2218c109e187a8776ac2ab755a65 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Thu, 25 Jul 2024 11:41:39 +0300 Subject: [PATCH 48/57] docstrings for base interfaces --- ydb/query/base.py | 139 ++++++++++++++++++++++++++++++++++++++----- ydb/query/pool.py | 4 ++ ydb/query/session.py | 41 +++++++++++++ 3 files changed, 169 insertions(+), 15 deletions(-) diff --git a/ydb/query/base.py b/ydb/query/base.py index e64871d5..8ed823d4 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -19,6 +19,20 @@ from .. import _utilities +class QuerySyntax(enum.IntEnum): + UNSPECIFIED = 0 + YQL_V1 = 1 + PG = 2 + + +class QueryExecMode(enum.IntEnum): + UNSPECIFIED = 0 + PARSE = 10 + VALIDATE = 20 + EXPLAIN = 30 + EXECUTE = 50 + + class QueryClientSettings: pass @@ -60,24 +74,60 @@ def set_attached(self, attached: bool) -> "IQuerySessionState": class IQuerySession(abc.ABC): + """Session object for Query Service. It is not recommended to control + session's lifecycle manually - use a QuerySessionPool is always a better choise. + """ + @abc.abstractmethod def __init__(self, driver: SupportedDriverType, settings: QueryClientSettings = None): pass @abc.abstractmethod def create(self) -> "IQuerySession": + """ + Creates a Session of Query Service on server side and attaches it. + + :return: Session object. + """ pass @abc.abstractmethod def delete(self) -> None: + """ + Deletes a Session of Query Service on server side and releases resources. + + :return: None + """ pass @abc.abstractmethod def transaction(self, tx_mode: BaseQueryTxMode) -> "IQueryTxContext": + """ + Creates a transaction context manager with specified transaction mode. + + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + + :return: transaction context manager. + """ pass class IQueryTxContext(abc.ABC): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + 1) By explicit .begin(); + 2) On execution of a first statement, which is strictly recommended method, because that avoids + useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + """ + @abc.abstractmethod def __init__( self, @@ -90,36 +140,109 @@ def __init__( @abc.abstractmethod def __enter__(self): + """ + Enters a context manager and returns a transaction + + :return: A transaction instance + """ pass @abc.abstractmethod def __exit__(self, *args, **kwargs): + """ + Closes a transaction context manager and rollbacks transaction if + it is not rolled back explicitly + """ pass @property @abc.abstractmethod def session_id(self): + """ + A transaction's session id + + :return: A transaction's session id + """ pass @property @abc.abstractmethod def tx_id(self): + """ + Returns a id of open transaction or None otherwise + + :return: A id of open transaction or None otherwise + """ pass @abc.abstractmethod def begin(settings: QueryClientSettings = None): + """ + Explicitly begins a transaction + + :param settings: A request settings + + :return: None + """ pass @abc.abstractmethod def commit(settings: QueryClientSettings = None): + """ + Calls commit on a transaction if it is open. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: A committed transaction or exception if commit is failed + """ pass @abc.abstractmethod def rollback(settings: QueryClientSettings = None): + """ + Calls rollback on a transaction if it is open. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: A rolled back transaction or exception if rollback is failed + """ pass @abc.abstractmethod - def execute(query: str, commit_tx=False): + def execute( + self, + query: str, + commit_tx: bool = False, + tx_mode: BaseQueryTxMode = None, + syntax: QuerySyntax = None, + exec_mode: QueryExecMode = None, + parameters: dict = None, + concurrent_result_sets: bool = False, + ): + """ + Sends a query to Query Service + :param query: (YQL or SQL text) to be executed. + :param commit_tx: A special flag that allows transaction commit. + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + :param syntax: Syntax of the query, which is a one from the following choises: + 1) QuerySyntax.YQL_V1, which is default; + 2) QuerySyntax.PG. + :param exec_mode: Exec mode of the query, which is a one from the following choises: + 1) QueryExecMode.EXECUTE, which is default; + 2) QueryExecMode.EXPLAIN; + 3) QueryExecMode.VALIDATE; + 4) QueryExecMode.PARSE. + :param parameters: dict with parameters and YDB types; + :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; + + :return: Iterator with result sets + """ pass @@ -132,20 +255,6 @@ def session(self) -> IQuerySession: pass -class QuerySyntax(enum.IntEnum): - UNSPECIFIED = 0 - YQL_V1 = 1 - PG = 2 - - -class QueryExecMode(enum.IntEnum): - UNSPECIFIED = 0 - PARSE = 10 - VALIDATE = 20 - EXPLAIN = 30 - EXECUTE = 50 - - def create_execute_query_request( query: str, session_id: str, diff --git a/ydb/query/pool.py b/ydb/query/pool.py index 7124c55e..0a3d6d0a 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -17,10 +17,12 @@ def __init__(self, driver: base.SupportedDriverType): """ :param driver: A driver instance """ + self._driver = driver def checkout(self): """Return a Session context manager, that opens session on enter and closes session on exit.""" + return SimpleQuerySessionCheckout(self) def retry_operation_sync(self, callee: Callable, retry_settings: RetrySettings = None, *args, **kwargs): @@ -31,6 +33,7 @@ def retry_operation_sync(self, callee: Callable, retry_settings: RetrySettings = :return: Result sets or exception in case of execution errors. """ + retry_settings = RetrySettings() if retry_settings is None else retry_settings def wrapped_callee(): @@ -47,6 +50,7 @@ def execute_with_retries(self, query: str, retry_settings: RetrySettings = None, :return: Result sets or exception in case of execution errors. """ + retry_settings = RetrySettings() if retry_settings is None else retry_settings with self.checkout() as session: diff --git a/ydb/query/session.py b/ydb/query/session.py index 83fd1795..5251595f 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -181,6 +181,10 @@ def _execute_call( class QuerySessionSync(BaseQuerySession): + """Session object for Query Service. It is not recommended to control + session's lifecycle manually - use a QuerySessionPool is always a better choise. + """ + _stream = None def _attach(self): @@ -214,6 +218,11 @@ def _check_session_status_loop(self, status_stream): pass def delete(self) -> None: + """ + Deletes a Session of Query Service on server side and releases resources. + + :return: None + """ if self._state._already_in(QuerySessionStateEnum.CLOSED): return @@ -222,6 +231,11 @@ def delete(self) -> None: self._stream.cancel() def create(self) -> "QuerySessionSync": + """ + Creates a Session of Query Service on server side and attaches it. + + :return: QuerySessionSync object. + """ if self._state._already_in(QuerySessionStateEnum.CREATED): return @@ -232,6 +246,17 @@ def create(self) -> "QuerySessionSync": return self def transaction(self, tx_mode: base.BaseQueryTxMode = None) -> base.IQueryTxContext: + """ + Creates a transaction context manager with specified transaction mode. + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + + :return transaction context manager. + + """ self._state._check_session_ready_to_use() return BaseTxContext( @@ -251,6 +276,22 @@ def execute( concurrent_result_sets: bool = False, empty_tx_control: bool = False, ): + """ + Sends a query to Query Service + :param query: (YQL or SQL text) to be executed. + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + :param syntax: Syntax of the query, which is a one from the following choises: + 1) QuerySyntax.YQL_V1, which is default; + 2) QuerySyntax.PG. + :param parameters: dict with parameters and YDB types; + :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; + + :return: Iterator with result sets + """ self._state._check_session_ready_to_use() stream_it = self._execute_call( From 24852f568e2de3219308ca1fe163b60a3644a7c7 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Thu, 25 Jul 2024 11:54:41 +0300 Subject: [PATCH 49/57] add logs about experimental api --- ydb/query/__init__.py | 5 +++++ ydb/query/pool.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py index fbdfa7dd..923202db 100644 --- a/ydb/query/__init__.py +++ b/ydb/query/__init__.py @@ -7,6 +7,8 @@ "QueryClientSync", ] +import logging + from .base import ( IQueryClient, SupportedDriverType, @@ -24,9 +26,12 @@ from .pool import QuerySessionPool +logger = logging.getLogger(__name__) + class QueryClientSync(IQueryClient): def __init__(self, driver: SupportedDriverType, query_client_settings: QueryClientSettings = None): + logger.warning("QueryClientSync is an experimental API, which could be changed.") self._driver = driver self._settings = query_client_settings diff --git a/ydb/query/pool.py b/ydb/query/pool.py index 0a3d6d0a..f76a60aa 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -1,3 +1,4 @@ +import logging from typing import Callable from . import base @@ -9,6 +10,8 @@ retry_operation_sync, ) +logger = logging.getLogger(__name__) + class QuerySessionPool: """QuerySessionPool is an object to simplify operations with sessions of Query Service.""" @@ -18,6 +21,7 @@ def __init__(self, driver: base.SupportedDriverType): :param driver: A driver instance """ + logger.warning("QuerySessionPool is an experimental API, which could be changed.") self._driver = driver def checkout(self): From b4eeb02f3f8f8bdd74f8b0e13d7226a3417c6efe Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Thu, 25 Jul 2024 12:51:44 +0300 Subject: [PATCH 50/57] fixes after review --- examples/query-service/basic_example.py | 42 ++++++++++++++----------- ydb/query/session.py | 2 +- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index d78f5c71..de2920c3 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -18,23 +18,23 @@ def main(): print("=" * 50) print("DELETE TABLE IF EXISTS") - pool.execute_with_retries("drop table if exists example;") + pool.execute_with_retries("drop table if exists example") print("=" * 50) print("CREATE TABLE") - pool.execute_with_retries("CREATE TABLE example(key UInt64, value String, PRIMARY KEY (key));") + pool.execute_with_retries("CREATE TABLE example(key UInt64, value String, PRIMARY KEY (key))") - pool.execute_with_retries("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal');") + pool.execute_with_retries("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal')") def callee(session): print("=" * 50) - for _ in session.execute("""delete from example;"""): + with session.execute("delete from example"): pass print("BEFORE ACTION") - it = session.execute("""SELECT COUNT(*) as rows_count FROM example;""") - for result_set in it: - print(f"rows: {str(result_set.rows)}") + with session.execute("SELECT COUNT(*) as rows_count FROM example") as results: + for result_set in results: + print(f"rows: {str(result_set.rows)}") print("=" * 50) print("INSERT WITH COMMIT TX") @@ -42,19 +42,20 @@ def callee(session): with session.transaction() as tx: tx.begin() - with tx.execute("""INSERT INTO example (key, value) VALUES (1, "onepieceisreal");""") as results: + with tx.execute("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal')"): pass - with tx.execute("""SELECT COUNT(*) as rows_count FROM example;""") as results: + with tx.execute("SELECT COUNT(*) as rows_count FROM example") as results: for result_set in results: print(f"rows: {str(result_set.rows)}") tx.commit() - print("=" * 50) - print("AFTER COMMIT TX") + print("=" * 50) + print("AFTER COMMIT TX") - for result_set in session.execute("""SELECT COUNT(*) as rows_count FROM example;"""): + with session.execute("SELECT COUNT(*) as rows_count FROM example") as results: + for result_set in results: print(f"rows: {str(result_set.rows)}") print("=" * 50) @@ -63,18 +64,21 @@ def callee(session): with session.transaction() as tx: tx.begin() - tx.execute("""INSERT INTO example (key, value) VALUES (2, "onepieceisreal");""") + with tx.execute("INSERT INTO example (key, value) VALUES (2, 'onepieceisreal')"): + pass - for result_set in tx.execute("""SELECT COUNT(*) as rows_count FROM example;"""): - print(f"rows: {str(result_set.rows)}") + with tx.execute("SELECT COUNT(*) as rows_count FROM example") as results: + for result_set in results: + print(f"rows: {str(result_set.rows)}") tx.rollback() - print("=" * 50) - print("AFTER ROLLBACK TX") + print("=" * 50) + print("AFTER ROLLBACK TX") - for result_set in session.execute("""SELECT COUNT(*) as rows_count FROM example;"""): - print(f"rows: {str(result_set.rows)}") + with session.execute("SELECT COUNT(*) as rows_count FROM example") as results: + for result_set in results: + print(f"rows: {str(result_set.rows)}") pool.retry_operation_sync(callee) diff --git a/ydb/query/session.py b/ydb/query/session.py index 5251595f..aea69ddf 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -305,7 +305,7 @@ def execute( empty_tx_control=empty_tx_control, ) - return _utilities.SyncResponseIterator( + return base.SyncResponseContextIterator( stream_it, lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp), ) From 6f980fee4485c2cd0df6c25b4fe6518e43dc8650 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Thu, 25 Jul 2024 13:50:23 +0300 Subject: [PATCH 51/57] fix review comments --- tests/query/test_query_session_pool.py | 9 ++++----- tests/query/test_query_transaction.py | 2 ++ ydb/query/base.py | 2 +- ydb/query/transaction.py | 21 ++++++++++++++++++++- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/query/test_query_session_pool.py b/tests/query/test_query_session_pool.py index 069de7f1..eb237870 100644 --- a/tests/query/test_query_session_pool.py +++ b/tests/query/test_query_session_pool.py @@ -1,5 +1,5 @@ import pytest - +import ydb from ydb.query.pool import QuerySessionPool from ydb.query.session import QuerySessionSync, QuerySessionStateEnum @@ -20,7 +20,7 @@ def test_oneshot_ddl_query(self, pool: QuerySessionPool): pool.execute_with_retries("drop table Queen;") def test_oneshot_query_raises(self, pool: QuerySessionPool): - with pytest.raises(Exception): + with pytest.raises(ydb.GenericError): pool.execute_with_retries("Is this the real life? Is this just fantasy?") def test_retry_op_uses_created_session(self, pool: QuerySessionPool): @@ -40,9 +40,8 @@ def callee(session: QuerySessionSync): def test_retry_op_raises(self, pool: QuerySessionPool): def callee(session: QuerySessionSync): - res = session.execute("Caught in a landslide, no escape from reality") - for _ in res: + with session.execute("Caught in a landslide, no escape from reality"): pass - with pytest.raises(Exception): + with pytest.raises(ydb.GenericError): pool.retry_operation_sync(callee) diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 0f09f9b1..eafa412d 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -21,10 +21,12 @@ def test_tx_allow_double_rollback(self, tx: BaseTxContext): tx.rollback() tx.rollback() + @pytest.mark.skip(reason="Not sure should we keep this behavior or not") def test_tx_commit_raises_before_begin(self, tx: BaseTxContext): with pytest.raises(RuntimeError): tx.commit() + @pytest.mark.skip(reason="Not sure should we keep this behavior or not") def test_tx_rollback_raises_before_begin(self, tx: BaseTxContext): with pytest.raises(RuntimeError): tx.rollback() diff --git a/ydb/query/base.py b/ydb/query/base.py index 8ed823d4..e93aa56e 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -328,5 +328,5 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - for _ in self.it: + for _ in self: pass diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 83724245..293b1d9c 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -25,7 +25,12 @@ class QueryTxStateEnum(enum.Enum): class QueryTxStateHelper(abc.ABC): _VALID_TRANSITIONS = { - QueryTxStateEnum.NOT_INITIALIZED: [QueryTxStateEnum.BEGINED, QueryTxStateEnum.DEAD], + QueryTxStateEnum.NOT_INITIALIZED: [ + QueryTxStateEnum.BEGINED, + QueryTxStateEnum.DEAD, + QueryTxStateEnum.COMMITTED, + QueryTxStateEnum.ROLLBACKED, + ], QueryTxStateEnum.BEGINED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED, QueryTxStateEnum.DEAD], QueryTxStateEnum.COMMITTED: [], QueryTxStateEnum.ROLLBACKED: [], @@ -246,6 +251,13 @@ def commit(self, settings=None): if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): return self._ensure_prev_stream_finished() + + if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: + # TODO(vgvoleg): Discuss should we raise before begin or not + self._tx_state._change_state(QueryTxStateEnum.COMMITTED) + self._tx_state.tx_id = None + return + self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) return self._driver( @@ -262,6 +274,13 @@ def rollback(self, settings=None): return self._ensure_prev_stream_finished() + + if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: + # TODO(vgvoleg): Discuss should we raise before begin or not + self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED) + self._tx_state.tx_id = None + return + self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) return self._driver( From c0b125a2fa406682f0ea6d0e119b46e3025de5e3 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Thu, 25 Jul 2024 17:25:56 +0300 Subject: [PATCH 52/57] review fixes --- tests/query/test_query_session_pool.py | 8 +++++--- tests/query/test_query_transaction.py | 21 +++++++++++---------- ydb/query/transaction.py | 2 -- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/query/test_query_session_pool.py b/tests/query/test_query_session_pool.py index eb237870..3c66c613 100644 --- a/tests/query/test_query_session_pool.py +++ b/tests/query/test_query_session_pool.py @@ -39,9 +39,11 @@ def callee(session: QuerySessionSync): assert len(res) == 1 def test_retry_op_raises(self, pool: QuerySessionPool): + class CustomException(Exception): + pass + def callee(session: QuerySessionSync): - with session.execute("Caught in a landslide, no escape from reality"): - pass + raise CustomException() - with pytest.raises(ydb.GenericError): + with pytest.raises(CustomException): pool.retry_operation_sync(callee) diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index eafa412d..41f0f5b3 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -21,15 +21,13 @@ def test_tx_allow_double_rollback(self, tx: BaseTxContext): tx.rollback() tx.rollback() - @pytest.mark.skip(reason="Not sure should we keep this behavior or not") - def test_tx_commit_raises_before_begin(self, tx: BaseTxContext): - with pytest.raises(RuntimeError): - tx.commit() + def test_tx_commit_before_begin(self, tx: BaseTxContext): + tx.commit() + assert tx._tx_state._state == QueryTxStateEnum.COMMITTED - @pytest.mark.skip(reason="Not sure should we keep this behavior or not") - def test_tx_rollback_raises_before_begin(self, tx: BaseTxContext): - with pytest.raises(RuntimeError): - tx.rollback() + def test_tx_rollback_before_begin(self, tx: BaseTxContext): + tx.rollback() + assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED def test_tx_first_execute_begins_tx(self, tx: BaseTxContext): tx.execute("select 1;") @@ -67,9 +65,12 @@ def test_context_manager_normal_flow(self, tx: BaseTxContext): assert tx._tx_state._state == QueryTxStateEnum.COMMITTED def test_context_manager_does_not_hide_exceptions(self, tx: BaseTxContext): - with pytest.raises(Exception): + class CustomException(Exception): + pass + + with pytest.raises(CustomException): with tx: - raise Exception() + raise CustomException() def test_execute_as_context_manager(self, tx: BaseTxContext): tx.begin() diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 293b1d9c..970ac791 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -253,7 +253,6 @@ def commit(self, settings=None): self._ensure_prev_stream_finished() if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: - # TODO(vgvoleg): Discuss should we raise before begin or not self._tx_state._change_state(QueryTxStateEnum.COMMITTED) self._tx_state.tx_id = None return @@ -276,7 +275,6 @@ def rollback(self, settings=None): self._ensure_prev_stream_finished() if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: - # TODO(vgvoleg): Discuss should we raise before begin or not self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED) self._tx_state.tx_id = None return From d0c9388ad8ce18d3b350c3cefa6cfff9cdd07546 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 26 Jul 2024 15:54:43 +0300 Subject: [PATCH 53/57] fix review --- tests/query/test_query_session.py | 10 ++ tests/query/test_query_transaction.py | 28 ++--- tox.ini | 1 + ydb/_grpc/grpcwrapper/ydb_query.py | 24 ++-- ydb/query/__init__.py | 1 + ydb/query/base.py | 137 ++++++++++++++------- ydb/query/pool.py | 21 ++-- ydb/query/session.py | 50 +++----- ydb/query/transaction.py | 171 +++++++++++++------------- ydb/table.py | 6 +- 10 files changed, 248 insertions(+), 201 deletions(-) diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py index 402fc8bb..585c4c2b 100644 --- a/tests/query/test_query_session.py +++ b/tests/query/test_query_session.py @@ -89,3 +89,13 @@ def test_basic_execute(self, session: QuerySessionSync): assert len(result_sets[0].rows) == 1 assert len(result_sets[0].columns) == 1 assert list(result_sets[0].rows[0].values()) == [1] + + def test_two_results(self, session: QuerySessionSync): + session.create() + res = [] + with session.execute("select 1; select 2") as results: + for result_set in results: + if len(result_set.rows) > 0: + res.extend(list(result_set.rows[0].values())) + + assert res == [1, 2] diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 41f0f5b3..1c3fdda2 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -1,62 +1,62 @@ import pytest -from ydb.query.transaction import BaseTxContext +from ydb.query.transaction import BaseQueryTxContext from ydb.query.transaction import QueryTxStateEnum class TestQueryTransaction: - def test_tx_begin(self, tx: BaseTxContext): + def test_tx_begin(self, tx: BaseQueryTxContext): assert tx.tx_id is None tx.begin() assert tx.tx_id is not None - def test_tx_allow_double_commit(self, tx: BaseTxContext): + def test_tx_allow_double_commit(self, tx: BaseQueryTxContext): tx.begin() tx.commit() tx.commit() - def test_tx_allow_double_rollback(self, tx: BaseTxContext): + def test_tx_allow_double_rollback(self, tx: BaseQueryTxContext): tx.begin() tx.rollback() tx.rollback() - def test_tx_commit_before_begin(self, tx: BaseTxContext): + def test_tx_commit_before_begin(self, tx: BaseQueryTxContext): tx.commit() assert tx._tx_state._state == QueryTxStateEnum.COMMITTED - def test_tx_rollback_before_begin(self, tx: BaseTxContext): + def test_tx_rollback_before_begin(self, tx: BaseQueryTxContext): tx.rollback() assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED - def test_tx_first_execute_begins_tx(self, tx: BaseTxContext): + def test_tx_first_execute_begins_tx(self, tx: BaseQueryTxContext): tx.execute("select 1;") tx.commit() - def test_interactive_tx_commit(self, tx: BaseTxContext): + def test_interactive_tx_commit(self, tx: BaseQueryTxContext): tx.execute("select 1;", commit_tx=True) with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_tx_execute_raises_after_commit(self, tx: BaseTxContext): + def test_tx_execute_raises_after_commit(self, tx: BaseQueryTxContext): tx.begin() tx.commit() with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_tx_execute_raises_after_rollback(self, tx: BaseTxContext): + def test_tx_execute_raises_after_rollback(self, tx: BaseQueryTxContext): tx.begin() tx.rollback() with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_context_manager_rollbacks_tx(self, tx: BaseTxContext): + def test_context_manager_rollbacks_tx(self, tx: BaseQueryTxContext): with tx: tx.begin() assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED - def test_context_manager_normal_flow(self, tx: BaseTxContext): + def test_context_manager_normal_flow(self, tx: BaseQueryTxContext): with tx: tx.begin() tx.execute("select 1;") @@ -64,7 +64,7 @@ def test_context_manager_normal_flow(self, tx: BaseTxContext): assert tx._tx_state._state == QueryTxStateEnum.COMMITTED - def test_context_manager_does_not_hide_exceptions(self, tx: BaseTxContext): + def test_context_manager_does_not_hide_exceptions(self, tx: BaseQueryTxContext): class CustomException(Exception): pass @@ -72,7 +72,7 @@ class CustomException(Exception): with tx: raise CustomException() - def test_execute_as_context_manager(self, tx: BaseTxContext): + def test_execute_as_context_manager(self, tx: BaseQueryTxContext): tx.begin() with tx.execute("select 1;") as results: diff --git a/tox.ini b/tox.ini index 7aca13db..8b5c06ae 100644 --- a/tox.ini +++ b/tox.ini @@ -75,6 +75,7 @@ builtins = _ max-line-length = 160 ignore=E203,W503 exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,ydb/public/api/protos/*,docs/*,ydb/public/api/grpc/*,persqueue/*,client/*,dbapi/*,ydb/default_pem.py,*docs/conf.py +per-file-ignores = ydb/table.py:F401 [pytest] asyncio_mode = auto diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py index 343ef8b8..befb02c7 100644 --- a/ydb/_grpc/grpcwrapper/ydb_query.py +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -21,7 +21,7 @@ @dataclass class CreateSessionResponse(IFromProto): - status: Optional[ServerStatus] + status: ServerStatus session_id: str node_id: int @@ -36,7 +36,7 @@ def from_proto(msg: ydb_query_pb2.CreateSessionResponse) -> "CreateSessionRespon @dataclass class DeleteSessionResponse(IFromProto): - status: Optional[ServerStatus] + status: ServerStatus @staticmethod def from_proto(msg: ydb_query_pb2.DeleteSessionResponse) -> "DeleteSessionResponse": @@ -129,10 +129,10 @@ def from_proto(msg: ydb_query_pb2.RollbackTransactionResponse) -> "RollbackTrans @dataclass class QueryContent(IFromPublic, IToProto): text: str - syntax: Optional[int] = None + syntax: int @staticmethod - def from_public(query: str, syntax: int = None) -> "QueryContent": + def from_public(query: str, syntax: int) -> "QueryContent": return QueryContent(text=query, syntax=syntax) def to_proto(self) -> ydb_query_pb2.QueryContent: @@ -141,9 +141,9 @@ def to_proto(self) -> ydb_query_pb2.QueryContent: @dataclass class TransactionControl(IToProto): - begin_tx: Optional[TransactionSettings] = None - commit_tx: Optional[bool] = None - tx_id: Optional[str] = None + begin_tx: Optional[TransactionSettings] + commit_tx: Optional[bool] + tx_id: Optional[str] def to_proto(self) -> ydb_query_pb2.TransactionControl: if self.tx_id: @@ -161,11 +161,11 @@ def to_proto(self) -> ydb_query_pb2.TransactionControl: class ExecuteQueryRequest(IToProto): session_id: str query_content: QueryContent - tx_control: Optional[TransactionControl] = None - concurrent_result_sets: Optional[bool] = False - exec_mode: Optional[int] = None - parameters: Optional[dict] = None - stats_mode: Optional[int] = None + tx_control: TransactionControl + concurrent_result_sets: bool + exec_mode: int + parameters: dict + stats_mode: int def to_proto(self) -> ydb_query_pb2.ExecuteQueryRequest: tx_control = self.tx_control.to_proto() if self.tx_control is not None else self.tx_control diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py index 923202db..eb967abc 100644 --- a/ydb/query/__init__.py +++ b/ydb/query/__init__.py @@ -5,6 +5,7 @@ "QueryStaleReadOnly", "QuerySessionPool", "QueryClientSync", + "QuerySessionSync", ] import logging diff --git a/ydb/query/base.py b/ydb/query/base.py index e93aa56e..9947e5f2 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -3,6 +3,7 @@ import functools from typing import ( + Iterator, Optional, ) @@ -12,7 +13,6 @@ from .._grpc.grpcwrapper import ydb_query from .._grpc.grpcwrapper.ydb_query_public_types import ( BaseQueryTxMode, - QuerySerializableReadWrite, ) from .. import convert from .. import issues @@ -33,12 +33,29 @@ class QueryExecMode(enum.IntEnum): EXECUTE = 50 +class StatsMode(enum.IntEnum): + UNSPECIFIED = 0 + NONE = 10 + BASIC = 20 + FULL = 30 + PROFILE = 40 + + +class SyncResponseContextIterator(_utilities.SyncResponseIterator): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + for _ in self: + pass + + class QueryClientSettings: pass class IQuerySessionState(abc.ABC): - def __init__(self, settings: QueryClientSettings = None): + def __init__(self, settings: Optional[QueryClientSettings] = None): pass @abc.abstractmethod @@ -79,7 +96,7 @@ class IQuerySession(abc.ABC): """ @abc.abstractmethod - def __init__(self, driver: SupportedDriverType, settings: QueryClientSettings = None): + def __init__(self, driver: SupportedDriverType, settings: Optional[QueryClientSettings] = None): pass @abc.abstractmethod @@ -101,7 +118,7 @@ def delete(self) -> None: pass @abc.abstractmethod - def transaction(self, tx_mode: BaseQueryTxMode) -> "IQueryTxContext": + def transaction(self, tx_mode: Optional[BaseQueryTxMode] = None) -> "IQueryTxContext": """ Creates a transaction context manager with specified transaction mode. @@ -115,6 +132,27 @@ def transaction(self, tx_mode: BaseQueryTxMode) -> "IQueryTxContext": """ pass + @abc.abstractmethod + def execute( + self, + query: str, + syntax: Optional[QuerySyntax] = None, + exec_mode: Optional[QueryExecMode] = None, + parameters: Optional[dict] = None, + concurrent_result_sets: Optional[bool] = False, + ) -> Iterator: + """ + Sends a query to Query Service + :param query: (YQL or SQL text) to be executed. + :param syntax: Syntax of the query, which is a one from the following choises: + 1) QuerySyntax.YQL_V1, which is default; + 2) QuerySyntax.PG. + :param parameters: dict with parameters and YDB types; + :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; + + :return: Iterator with result sets + """ + class IQueryTxContext(abc.ABC): """ @@ -134,12 +172,30 @@ def __init__( driver: SupportedDriverType, session_state: IQuerySessionState, session: IQuerySession, - tx_mode: BaseQueryTxMode = None, + tx_mode: BaseQueryTxMode, ): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() method; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + """ pass @abc.abstractmethod - def __enter__(self): + def __enter__(self) -> "IQueryTxContext": """ Enters a context manager and returns a transaction @@ -151,13 +207,13 @@ def __enter__(self): def __exit__(self, *args, **kwargs): """ Closes a transaction context manager and rollbacks transaction if - it is not rolled back explicitly + it is not finished explicitly """ pass @property @abc.abstractmethod - def session_id(self): + def session_id(self) -> str: """ A transaction's session id @@ -167,7 +223,7 @@ def session_id(self): @property @abc.abstractmethod - def tx_id(self): + def tx_id(self) -> Optional[str]: """ Returns a id of open transaction or None otherwise @@ -176,37 +232,37 @@ def tx_id(self): pass @abc.abstractmethod - def begin(settings: QueryClientSettings = None): + def begin(self, settings: Optional[QueryClientSettings] = None) -> None: """ Explicitly begins a transaction :param settings: A request settings - :return: None + :return: None or exception if begin is failed """ pass @abc.abstractmethod - def commit(settings: QueryClientSettings = None): + def commit(self, settings: Optional[QueryClientSettings] = None) -> None: """ Calls commit on a transaction if it is open. If transaction execution failed then this method raises PreconditionFailed. :param settings: A request settings - :return: A committed transaction or exception if commit is failed + :return: None or exception if commit is failed """ pass @abc.abstractmethod - def rollback(settings: QueryClientSettings = None): + def rollback(self, settings: Optional[QueryClientSettings] = None) -> None: """ Calls rollback on a transaction if it is open. If transaction execution failed then this method raises PreconditionFailed. :param settings: A request settings - :return: A rolled back transaction or exception if rollback is failed + :return: None or exception if rollback is failed """ pass @@ -214,22 +270,16 @@ def rollback(settings: QueryClientSettings = None): def execute( self, query: str, - commit_tx: bool = False, - tx_mode: BaseQueryTxMode = None, - syntax: QuerySyntax = None, - exec_mode: QueryExecMode = None, - parameters: dict = None, - concurrent_result_sets: bool = False, - ): + commit_tx: Optional[bool] = False, + syntax: Optional[QuerySyntax] = None, + exec_mode: Optional[QueryExecMode] = None, + parameters: Optional[dict] = None, + concurrent_result_sets: Optional[bool] = False, + ) -> Iterator: """ Sends a query to Query Service :param query: (YQL or SQL text) to be executed. :param commit_tx: A special flag that allows transaction commit. - :param tx_mode: Transaction mode, which is a one from the following choises: - 1) QuerySerializableReadWrite() which is default mode; - 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); - 3) QuerySnapshotReadOnly(); - 4) QueryStaleReadOnly(). :param syntax: Syntax of the query, which is a one from the following choises: 1) QuerySyntax.YQL_V1, which is default; 2) QuerySyntax.PG. @@ -247,7 +297,7 @@ def execute( class IQueryClient(abc.ABC): - def __init__(self, driver: SupportedDriverType, query_client_settings: QueryClientSettings = None): + def __init__(self, driver: SupportedDriverType, query_client_settings: Optional[QueryClientSettings] = None): pass @abc.abstractmethod @@ -258,33 +308,34 @@ def session(self) -> IQuerySession: def create_execute_query_request( query: str, session_id: str, - tx_id: str = None, - commit_tx: bool = False, - tx_mode: BaseQueryTxMode = None, - syntax: QuerySyntax = None, - exec_mode: QueryExecMode = None, - parameters: dict = None, - concurrent_result_sets: bool = False, - empty_tx_control: bool = False, + tx_id: Optional[str], + commit_tx: Optional[bool], + tx_mode: Optional[BaseQueryTxMode], + syntax: Optional[QuerySyntax], + exec_mode: Optional[QueryExecMode], + parameters: Optional[dict], + concurrent_result_sets: Optional[bool], ): syntax = QuerySyntax.YQL_V1 if not syntax else syntax exec_mode = QueryExecMode.EXECUTE if not exec_mode else exec_mode + stats_mode = StatsMode.NONE # TODO: choise is not supported yet tx_control = None - if empty_tx_control: + if not tx_id and not tx_mode: tx_control = None elif tx_id: tx_control = ydb_query.TransactionControl( tx_id=tx_id, commit_tx=commit_tx, + begin_tx=None, ) else: - tx_mode = tx_mode if tx_mode is not None else QuerySerializableReadWrite() tx_control = ydb_query.TransactionControl( begin_tx=ydb_query.TransactionSettings( tx_mode=tx_mode, ), commit_tx=commit_tx, + tx_id=None, ) req = ydb_query.ExecuteQueryRequest( @@ -297,6 +348,7 @@ def create_execute_query_request( exec_mode=exec_mode, parameters=parameters, concurrent_result_sets=concurrent_result_sets, + stats_mode=stats_mode, ) return req.to_proto() @@ -321,12 +373,3 @@ def decorator(rpc_state, response_pb, session_state, *args, **kwargs): raise return decorator - - -class SyncResponseContextIterator(_utilities.SyncResponseIterator): - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - for _ in self: - pass diff --git a/ydb/query/pool.py b/ydb/query/pool.py index f76a60aa..eef084f9 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -1,5 +1,8 @@ import logging -from typing import Callable +from typing import ( + Callable, + Optional, +) from . import base from .session import ( @@ -24,12 +27,12 @@ def __init__(self, driver: base.SupportedDriverType): logger.warning("QuerySessionPool is an experimental API, which could be changed.") self._driver = driver - def checkout(self): + def checkout(self) -> "SimpleQuerySessionCheckout": """Return a Session context manager, that opens session on enter and closes session on exit.""" return SimpleQuerySessionCheckout(self) - def retry_operation_sync(self, callee: Callable, retry_settings: RetrySettings = None, *args, **kwargs): + def retry_operation_sync(self, callee: Callable, retry_settings: Optional[RetrySettings] = None, *args, **kwargs): """Special interface to execute a bunch of commands with session in a safe, retriable way. :param callee: A function, that works with session. @@ -46,7 +49,7 @@ def wrapped_callee(): return retry_operation_sync(wrapped_callee, retry_settings) - def execute_with_retries(self, query: str, retry_settings: RetrySettings = None, *args, **kwargs): + def execute_with_retries(self, query: str, retry_settings: Optional[RetrySettings] = None, *args, **kwargs): """Special interface to execute a one-shot queries in a safe, retriable way. :param query: A query, yql or sql text. @@ -56,13 +59,13 @@ def execute_with_retries(self, query: str, retry_settings: RetrySettings = None, """ retry_settings = RetrySettings() if retry_settings is None else retry_settings - with self.checkout() as session: - def wrapped_callee(): - it = session.execute(query, empty_tx_control=True, *args, **kwargs) + def wrapped_callee(): + with self.checkout() as session: + it = session.execute(query, *args, **kwargs) return [result_set for result_set in it] - return retry_operation_sync(wrapped_callee, retry_settings) + return retry_operation_sync(wrapped_callee, retry_settings) class SimpleQuerySessionCheckout: @@ -70,7 +73,7 @@ def __init__(self, pool: QuerySessionPool): self._pool = pool self._session = QuerySessionSync(pool._driver) - def __enter__(self): + def __enter__(self) -> base.IQuerySession: self._session.create() return self._session diff --git a/ydb/query/session.py b/ydb/query/session.py index aea69ddf..f5e74331 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -11,8 +11,9 @@ from .. import _apis, issues, _utilities from .._grpc.grpcwrapper import common_utils from .._grpc.grpcwrapper import ydb_query as _ydb_query +from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public -from .transaction import BaseTxContext +from .transaction import BaseQueryTxContext logger = logging.getLogger(__name__) @@ -45,16 +46,14 @@ def ready_to_use(cls, state: QuerySessionStateEnum) -> bool: class QuerySessionState(base.IQuerySessionState): - _session_id: Optional[str] - _node_id: Optional[int] + _session_id: Optional[str] = None + _node_id: Optional[int] = None _attached: bool = False - _settings: Optional[base.QueryClientSettings] - _state: QuerySessionStateEnum + _settings: Optional[base.QueryClientSettings] = None + _state: QuerySessionStateEnum = QuerySessionStateEnum.NOT_INITIALIZED def __init__(self, settings: base.QueryClientSettings = None): self._settings = settings - self._state = QuerySessionStateEnum.NOT_INITIALIZED - self.reset() def reset(self) -> None: self._session_id = None @@ -84,19 +83,19 @@ def attached(self) -> bool: def set_attached(self, attached: bool) -> "QuerySessionState": self._attached = attached - def _check_invalid_transition(self, target: QuerySessionStateEnum): + def _check_invalid_transition(self, target: QuerySessionStateEnum) -> None: if not QuerySessionStateHelper.valid_transition(self._state, target): raise RuntimeError(f"Session could not be moved from {self._state.value} to {target.value}") - def _change_state(self, target: QuerySessionStateEnum): + def _change_state(self, target: QuerySessionStateEnum) -> None: self._check_invalid_transition(target) self._state = target - def _check_session_ready_to_use(self): + def _check_session_ready_to_use(self) -> None: if not QuerySessionStateHelper.ready_to_use(self._state): raise RuntimeError(f"Session is not ready to use, current state: {self._state.value}") - def _already_in(self, target): + def _already_in(self, target) -> bool: return self._state == target @@ -117,12 +116,12 @@ def wrapper_delete_session(rpc_state, response_pb, session_state: QuerySessionSt class BaseQuerySession(base.IQuerySession): _driver: base.SupportedDriverType - _settings: Optional[base.QueryClientSettings] + _settings: base.QueryClientSettings _state: QuerySessionState - def __init__(self, driver: base.SupportedDriverType, settings: base.QueryClientSettings = None): + def __init__(self, driver: base.SupportedDriverType, settings: Optional[base.QueryClientSettings] = None): self._driver = driver - self._settings = settings + self._settings = settings if settings is not None else base.QueryClientSettings() self._state = QuerySessionState(settings) def _create_call(self): @@ -154,23 +153,21 @@ def _execute_call( self, query: str, commit_tx: bool = False, - tx_mode: base.BaseQueryTxMode = None, syntax: base.QuerySyntax = None, exec_mode: base.QueryExecMode = None, parameters: dict = None, concurrent_result_sets: bool = False, - empty_tx_control: bool = False, ): request = base.create_execute_query_request( query=query, session_id=self._state.session_id, commit_tx=commit_tx, - tx_mode=tx_mode, + tx_mode=None, + tx_id=None, syntax=syntax, exec_mode=exec_mode, parameters=parameters, concurrent_result_sets=concurrent_result_sets, - empty_tx_control=empty_tx_control, ) return self._driver( @@ -245,7 +242,7 @@ def create(self) -> "QuerySessionSync": return self - def transaction(self, tx_mode: base.BaseQueryTxMode = None) -> base.IQueryTxContext: + def transaction(self, tx_mode: Optional[base.BaseQueryTxMode] = None) -> base.IQueryTxContext: """ Creates a transaction context manager with specified transaction mode. :param tx_mode: Transaction mode, which is a one from the following choises: @@ -259,7 +256,9 @@ def transaction(self, tx_mode: base.BaseQueryTxMode = None) -> base.IQueryTxCont """ self._state._check_session_ready_to_use() - return BaseTxContext( + tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite() + + return BaseQueryTxContext( self._driver, self._state, self, @@ -269,21 +268,14 @@ def transaction(self, tx_mode: base.BaseQueryTxMode = None) -> base.IQueryTxCont def execute( self, query: str, - tx_mode: base.BaseQueryTxMode = None, syntax: base.QuerySyntax = None, exec_mode: base.QueryExecMode = None, parameters: dict = None, concurrent_result_sets: bool = False, - empty_tx_control: bool = False, - ): + ) -> base.SyncResponseContextIterator: """ Sends a query to Query Service :param query: (YQL or SQL text) to be executed. - :param tx_mode: Transaction mode, which is a one from the following choises: - 1) QuerySerializableReadWrite() which is default mode; - 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); - 3) QuerySnapshotReadOnly(); - 4) QueryStaleReadOnly(). :param syntax: Syntax of the query, which is a one from the following choises: 1) QuerySyntax.YQL_V1, which is default; 2) QuerySyntax.PG. @@ -297,12 +289,10 @@ def execute( stream_it = self._execute_call( query=query, commit_tx=True, - tx_mode=tx_mode, syntax=syntax, exec_mode=exec_mode, parameters=parameters, concurrent_result_sets=concurrent_result_sets, - empty_tx_control=empty_tx_control, ) return base.SyncResponseContextIterator( diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 970ac791..0ae770be 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -2,13 +2,15 @@ import logging import enum import functools +from typing import ( + Optional, +) from .. import ( _apis, issues, ) from .._grpc.grpcwrapper import ydb_query as _ydb_query -from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public from . import base @@ -37,19 +39,13 @@ class QueryTxStateHelper(abc.ABC): QueryTxStateEnum.DEAD: [], } - _TERMINAL_STATES = [ - QueryTxStateEnum.COMMITTED, - QueryTxStateEnum.ROLLBACKED, - QueryTxStateEnum.DEAD, - ] - @classmethod def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: return after in cls._VALID_TRANSITIONS[before] @classmethod def terminal(cls, state: QueryTxStateEnum) -> bool: - return state in cls._TERMINAL_STATES + return len(cls._VALID_TRANSITIONS[state]) == 0 def reset_tx_id_handler(func): @@ -75,19 +71,19 @@ def __init__(self, tx_mode: base.BaseQueryTxMode): self.tx_mode = tx_mode self._state = QueryTxStateEnum.NOT_INITIALIZED - def _check_invalid_transition(self, target: QueryTxStateEnum): + def _check_invalid_transition(self, target: QueryTxStateEnum) -> None: if not QueryTxStateHelper.valid_transition(self._state, target): raise RuntimeError(f"Transaction could not be moved from {self._state.value} to {target.value}") - def _change_state(self, target: QueryTxStateEnum): + def _change_state(self, target: QueryTxStateEnum) -> None: self._check_invalid_transition(target) self._state = target - def _check_tx_not_terminal(self): + def _check_tx_not_terminal(self) -> None: if QueryTxStateHelper.terminal(self._state): raise RuntimeError(f"Transaction is in terminal state: {self._state.value}") - def _already_in(self, target: QueryTxStateEnum): + def _already_in(self, target: QueryTxStateEnum) -> bool: return self._state == target @@ -132,7 +128,6 @@ def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx): def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx): message = _ydb_query.CommitTransactionResponse.from_proto(response_pb) issues._process_response(message.status) - tx_state.tx_id = None tx_state._change_state(QueryTxStateEnum.COMMITTED) return tx @@ -142,54 +137,52 @@ def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx) def wrap_tx_rollback_response(rpc_state, response_pb, session_state, tx_state, tx): message = _ydb_query.RollbackTransactionResponse.from_proto(response_pb) issues._process_response(message.status) - tx_state.tx_id = None tx_state._change_state(QueryTxStateEnum.ROLLBACKED) return tx -class BaseTxContext(base.IQueryTxContext): - def __init__(self, driver, session_state, session, tx_mode=None): +class BaseQueryTxContext(base.IQueryTxContext): + def __init__(self, driver, session_state, session, tx_mode): """ An object that provides a simple transaction context manager that allows statements execution in a transaction. You don't have to open transaction explicitly, because context manager encapsulates transaction control logic, and opens new transaction if: - 1) By explicit .begin() and .async_begin() methods; + 1) By explicit .begin() method; 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip This context manager is not thread-safe, so you should not manipulate on it concurrently. :param driver: A driver instance :param session_state: A state of session - :param tx_mode: A transaction mode, which is a one from the following choices: - 1) SerializableReadWrite() which is default mode; - 2) OnlineReadOnly(); - 3) StaleReadOnly(). + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). """ + self._driver = driver - if tx_mode is None: - tx_mode = _ydb_query_public.QuerySerializableReadWrite() self._tx_state = QueryTxState(tx_mode) self._session_state = session_state self.session = session - self._finished = "" self._prev_stream = None - def __enter__(self): + def __enter__(self) -> "BaseQueryTxContext": """ - Enters a context manager and returns a session + Enters a context manager and returns a transaction - :return: A session instance + :return: A transaction instance """ return self def __exit__(self, *args, **kwargs): """ Closes a transaction context manager and rollbacks transaction if - it is not rolled back explicitly + it is not finished explicitly """ self._ensure_prev_stream_finished() - if self._tx_state.tx_id is not None: + if self._tx_state._state == QueryTxStateEnum.BEGINED: # It's strictly recommended to close transactions directly # by using commit_tx=True flag while executing statement or by # .commit() or .rollback() methods, but here we trying to do best @@ -200,10 +193,8 @@ def __exit__(self, *args, **kwargs): except issues.Error: logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id) - self._tx_state.tx_id = None - @property - def session_id(self): + def session_id(self) -> str: """ A transaction's session id @@ -212,7 +203,7 @@ def session_id(self): return self._session_state.session_id @property - def tx_id(self): + def tx_id(self) -> Optional[str]: """ Returns a id of open transaction or None otherwise @@ -220,16 +211,7 @@ def tx_id(self): """ return self._tx_state.tx_id - def begin(self, settings=None): - """ - Explicitly begins a transaction - - :param settings: A request settings - - :return: An open transaction - """ - self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) - + def _begin_call(self, settings: Optional[base.QueryClientSettings]): return self._driver( _create_begin_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -239,26 +221,7 @@ def begin(self, settings=None): (self._session_state, self._tx_state, self), ) - def commit(self, settings=None): - """ - Calls commit on a transaction if it is open otherwise is no-op. If transaction execution - failed then this method raises PreconditionFailed. - - :param settings: A request settings - - :return: A committed transaction or exception if commit is failed - """ - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): - return - self._ensure_prev_stream_finished() - - if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: - self._tx_state._change_state(QueryTxStateEnum.COMMITTED) - self._tx_state.tx_id = None - return - - self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) - + def _commit_call(self, settings: Optional[base.QueryClientSettings]): return self._driver( _create_commit_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -268,19 +231,7 @@ def commit(self, settings=None): (self._session_state, self._tx_state, self), ) - def rollback(self, settings=None): - if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): - return - - self._ensure_prev_stream_finished() - - if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: - self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED) - self._tx_state.tx_id = None - return - - self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) - + def _rollback_call(self, settings: Optional[base.QueryClientSettings]): return self._driver( _create_rollback_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -294,7 +245,6 @@ def _execute_call( self, query: str, commit_tx: bool = False, - tx_mode: base.BaseQueryTxMode = None, syntax: base.QuerySyntax = None, exec_mode: base.QueryExecMode = None, parameters: dict = None, @@ -305,7 +255,7 @@ def _execute_call( session_id=self._session_state.session_id, commit_tx=commit_tx, tx_id=self._tx_state.tx_id, - tx_mode=tx_mode, + tx_mode=self._tx_state.tx_mode, syntax=syntax, exec_mode=exec_mode, parameters=parameters, @@ -325,7 +275,7 @@ def _ensure_prev_stream_finished(self): self._prev_stream = None def _handle_tx_meta(self, tx_meta=None): - if not self.tx_id: + if not self.tx_id and tx_meta: self._tx_state._change_state(QueryTxStateEnum.BEGINED) self._tx_state.tx_id = tx_meta.id @@ -333,25 +283,70 @@ def _move_to_commited(self): if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): return self._tx_state._change_state(QueryTxStateEnum.COMMITTED) - self._tx_state.tx_id = None + + def begin(self, settings: Optional[base.QueryClientSettings] = None) -> None: + """ + Explicitly begins a transaction + + :param settings: A request settings + + :return: None or exception if begin is failed + """ + self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) + + self._begin_call(settings) + + def commit(self, settings: Optional[base.QueryClientSettings] = None) -> None: + """ + Calls commit on a transaction if it is open otherwise is no-op. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: A committed transaction or exception if commit is failed + """ + if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + return + self._ensure_prev_stream_finished() + + if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: + self._tx_state._change_state(QueryTxStateEnum.COMMITTED) + return + + self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) + + self._commit_call(settings) + + def rollback(self, settings: Optional[base.QueryClientSettings] = None) -> None: + if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + return + + self._ensure_prev_stream_finished() + + if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: + self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED) + return + + self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) + + self._rollback_call(settings) def execute( self, query: str, - commit_tx: bool = False, - tx_mode: base.BaseQueryTxMode = None, - syntax: base.QuerySyntax = None, - exec_mode: base.QueryExecMode = None, - parameters: dict = None, - concurrent_result_sets: bool = False, - ): + commit_tx: Optional[bool] = False, + syntax: Optional[base.QuerySyntax] = None, + exec_mode: Optional[base.QueryExecMode] = None, + parameters: Optional[dict] = None, + concurrent_result_sets: Optional[bool] = False, + ) -> base.SyncResponseContextIterator: + self._ensure_prev_stream_finished() self._tx_state._check_tx_not_terminal() stream_it = self._execute_call( query=query, commit_tx=commit_tx, - tx_mode=tx_mode, syntax=syntax, exec_mode=exec_mode, parameters=parameters, diff --git a/ydb/table.py b/ydb/table.py index 12856d61..ac9f9304 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -20,8 +20,12 @@ ) from .retries import ( - retry_operation_sync, + YdbRetryOperationFinalResult, # noqa + YdbRetryOperationSleepOpt, # noqa + BackoffSettings, # noqa + retry_operation_impl, # noqa RetrySettings, + retry_operation_sync, ) try: From 2d8a2ff5b1b352076aab2e31ee984a90483994fe Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 26 Jul 2024 17:09:37 +0300 Subject: [PATCH 54/57] style fixes --- ydb/query/base.py | 21 ++++++++----- ydb/query/session.py | 30 +++++++++++++------ ydb/query/transaction.py | 65 ++++++++++++++++++++++++++++------------ 3 files changed, 80 insertions(+), 36 deletions(-) diff --git a/ydb/query/base.py b/ydb/query/base.py index 9947e5f2..9fe6c21b 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -14,9 +14,11 @@ from .._grpc.grpcwrapper.ydb_query_public_types import ( BaseQueryTxMode, ) +from ..connection import _RpcState as RpcState from .. import convert from .. import issues from .. import _utilities +from .. import _apis class QuerySyntax(enum.IntEnum): @@ -42,7 +44,7 @@ class StatsMode(enum.IntEnum): class SyncResponseContextIterator(_utilities.SyncResponseIterator): - def __enter__(self): + def __enter__(self) -> "SyncResponseContextIterator": return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -315,7 +317,7 @@ def create_execute_query_request( exec_mode: Optional[QueryExecMode], parameters: Optional[dict], concurrent_result_sets: Optional[bool], -): +) -> ydb_query.ExecuteQueryRequest: syntax = QuerySyntax.YQL_V1 if not syntax else syntax exec_mode = QueryExecMode.EXECUTE if not exec_mode else exec_mode stats_mode = StatsMode.NONE # TODO: choise is not supported yet @@ -338,7 +340,7 @@ def create_execute_query_request( tx_id=None, ) - req = ydb_query.ExecuteQueryRequest( + return ydb_query.ExecuteQueryRequest( session_id=session_id, query_content=ydb_query.QueryContent.from_public( query=query, @@ -351,13 +353,16 @@ def create_execute_query_request( stats_mode=stats_mode, ) - return req.to_proto() - -def wrap_execute_query_response(rpc_state, response_pb, tx=None, commit_tx=False): +def wrap_execute_query_response( + rpc_state: RpcState, + response_pb: _apis.ydb_query.ExecuteQueryResponsePart, + tx: Optional[IQueryTxContext] = None, + commit_tx: Optional[bool] = False, +) -> convert.ResultSet: issues._process_response(response_pb) if tx and response_pb.tx_meta and not tx.tx_id: - tx._handle_tx_meta(response_pb.tx_meta) + tx._move_to_beginned(response_pb.tx_meta.id) if tx and commit_tx: tx._move_to_commited() return convert.ResultSet.from_message(response_pb.result_set) @@ -365,7 +370,7 @@ def wrap_execute_query_response(rpc_state, response_pb, tx=None, commit_tx=False def bad_session_handler(func): @functools.wraps(func) - def decorator(rpc_state, response_pb, session_state, *args, **kwargs): + def decorator(rpc_state, response_pb, session_state: IQuerySessionState, *args, **kwargs): try: return func(rpc_state, response_pb, session_state, *args, **kwargs) except issues.BadSession: diff --git a/ydb/query/session.py b/ydb/query/session.py index f5e74331..a6a02d92 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -3,12 +3,14 @@ import logging import threading from typing import ( + Iterable, Optional, ) from . import base from .. import _apis, issues, _utilities +from ..connection import _RpcState as RpcState from .._grpc.grpcwrapper import common_utils from .._grpc.grpcwrapper import ydb_query as _ydb_query from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public @@ -99,14 +101,24 @@ def _already_in(self, target) -> bool: return self._state == target -def wrapper_create_session(rpc_state, response_pb, session_state: QuerySessionState, session): +def wrapper_create_session( + rpc_state: RpcState, + response_pb: _apis.ydb_query.CreateSessionResponse, + session_state: QuerySessionState, + session: "BaseQuerySession" +) -> "BaseQuerySession": message = _ydb_query.CreateSessionResponse.from_proto(response_pb) issues._process_response(message.status) session_state.set_session_id(message.session_id).set_node_id(message.node_id) return session -def wrapper_delete_session(rpc_state, response_pb, session_state: QuerySessionState, session): +def wrapper_delete_session( + rpc_state: RpcState, + response_pb: _apis.ydb_query.DeleteSessionResponse, + session_state: QuerySessionState, + session: "BaseQuerySession" +) -> "BaseQuerySession": message = _ydb_query.DeleteSessionResponse.from_proto(response_pb) issues._process_response(message.status) session_state.reset() @@ -124,7 +136,7 @@ def __init__(self, driver: base.SupportedDriverType, settings: Optional[base.Que self._settings = settings if settings is not None else base.QueryClientSettings() self._state = QuerySessionState(settings) - def _create_call(self): + def _create_call(self) -> "BaseQuerySession": return self._driver( _apis.ydb_query.CreateSessionRequest(), _apis.QueryService.Stub, @@ -133,7 +145,7 @@ def _create_call(self): wrap_args=(self._state, self), ) - def _delete_call(self): + def _delete_call(self) -> "BaseQuerySession": return self._driver( _apis.ydb_query.DeleteSessionRequest(session_id=self._state.session_id), _apis.QueryService.Stub, @@ -142,7 +154,7 @@ def _delete_call(self): wrap_args=(self._state, self), ) - def _attach_call(self): + def _attach_call(self) -> Iterable[_apis.ydb_query.SessionState]: return self._driver( _apis.ydb_query.AttachSessionRequest(session_id=self._state.session_id), _apis.QueryService.Stub, @@ -157,7 +169,7 @@ def _execute_call( exec_mode: base.QueryExecMode = None, parameters: dict = None, concurrent_result_sets: bool = False, - ): + ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: request = base.create_execute_query_request( query=query, session_id=self._state.session_id, @@ -171,7 +183,7 @@ def _execute_call( ) return self._driver( - request, + request.to_proto(), _apis.QueryService.Stub, _apis.QueryService.ExecuteQuery, ) @@ -184,7 +196,7 @@ class QuerySessionSync(BaseQuerySession): _stream = None - def _attach(self): + def _attach(self) -> None: self._stream = self._attach_call() status_stream = _utilities.SyncResponseIterator( self._stream, @@ -205,7 +217,7 @@ def _attach(self): daemon=True, ).start() - def _check_session_status_loop(self, status_stream): + def _check_session_status_loop(self, status_stream: _utilities.SyncResponseIterator) -> None: try: for status in status_stream: if status.status != issues.StatusCode.SUCCESS: diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 0ae770be..154a893e 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -3,6 +3,7 @@ import enum import functools from typing import ( + Iterable, Optional, ) @@ -11,6 +12,7 @@ issues, ) from .._grpc.grpcwrapper import ydb_query as _ydb_query +from ..connection import _RpcState as RpcState from . import base @@ -50,7 +52,7 @@ def terminal(cls, state: QueryTxStateEnum) -> bool: def reset_tx_id_handler(func): @functools.wraps(func) - def decorator(rpc_state, response_pb, session_state, tx_state, *args, **kwargs): + def decorator(rpc_state, response_pb, session_state: base.IQuerySessionState, tx_state: QueryTxState, *args, **kwargs): try: return func(rpc_state, response_pb, session_state, tx_state, *args, **kwargs) except issues.Error: @@ -87,12 +89,14 @@ def _already_in(self, target: QueryTxStateEnum) -> bool: return self._state == target -def _construct_tx_settings(tx_state): +def _construct_tx_settings(tx_state: QueryTxState) -> _ydb_query.TransactionSettings: tx_settings = _ydb_query.TransactionSettings.from_public(tx_state.tx_mode) return tx_settings -def _create_begin_transaction_request(session_state, tx_state): +def _create_begin_transaction_request( + session_state: base.IQuerySessionState, tx_state: QueryTxState +) -> _apis.ydb_query.BeginTransactionRequest: request = _ydb_query.BeginTransactionRequest( session_id=session_state.session_id, tx_settings=_construct_tx_settings(tx_state), @@ -100,14 +104,18 @@ def _create_begin_transaction_request(session_state, tx_state): return request -def _create_commit_transaction_request(session_state, tx_state): +def _create_commit_transaction_request( + session_state: base.IQuerySessionState, tx_state: QueryTxState +) -> _apis.ydb_query.CommitTransactionRequest: request = _apis.ydb_query.CommitTransactionRequest() request.tx_id = tx_state.tx_id request.session_id = session_state.session_id return request -def _create_rollback_transaction_request(session_state, tx_state): +def _create_rollback_transaction_request( + session_state: base.IQuerySessionState, tx_state: QueryTxState +) -> _apis.ydb_query.RollbackTransactionRequest: request = _apis.ydb_query.RollbackTransactionRequest() request.tx_id = tx_state.tx_id request.session_id = session_state.session_id @@ -115,7 +123,13 @@ def _create_rollback_transaction_request(session_state, tx_state): @base.bad_session_handler -def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx): +def wrap_tx_begin_response( + rpc_state: RpcState, + response_pb: _apis.ydb_query.BeginTransactionResponse, + session_state: base.IQuerySessionState, + tx_state: QueryTxState, + tx: "BaseQueryTxContext", +) -> "BaseQueryTxContext": message = _ydb_query.BeginTransactionResponse.from_proto(response_pb) issues._process_response(message.status) tx_state._change_state(QueryTxStateEnum.BEGINED) @@ -125,7 +139,13 @@ def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx): @base.bad_session_handler @reset_tx_id_handler -def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx): +def wrap_tx_commit_response( + rpc_state: RpcState, + response_pb: _apis.ydb_query.CommitTransactionResponse, + session_state: base.IQuerySessionState, + tx_state: QueryTxState, + tx: "BaseQueryTxContext", +) -> "BaseQueryTxContext": message = _ydb_query.CommitTransactionResponse.from_proto(response_pb) issues._process_response(message.status) tx_state._change_state(QueryTxStateEnum.COMMITTED) @@ -134,7 +154,13 @@ def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx) @base.bad_session_handler @reset_tx_id_handler -def wrap_tx_rollback_response(rpc_state, response_pb, session_state, tx_state, tx): +def wrap_tx_rollback_response( + rpc_state: RpcState, + response_pb: _apis.ydb_query.RollbackTransactionResponse, + session_state: base.IQuerySessionState, + tx_state: QueryTxState, + tx: "BaseQueryTxContext", +) -> "BaseQueryTxContext": message = _ydb_query.RollbackTransactionResponse.from_proto(response_pb) issues._process_response(message.status) tx_state._change_state(QueryTxStateEnum.ROLLBACKED) @@ -211,7 +237,7 @@ def tx_id(self) -> Optional[str]: """ return self._tx_state.tx_id - def _begin_call(self, settings: Optional[base.QueryClientSettings]): + def _begin_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext": return self._driver( _create_begin_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -221,7 +247,7 @@ def _begin_call(self, settings: Optional[base.QueryClientSettings]): (self._session_state, self._tx_state, self), ) - def _commit_call(self, settings: Optional[base.QueryClientSettings]): + def _commit_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext": return self._driver( _create_commit_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -231,7 +257,7 @@ def _commit_call(self, settings: Optional[base.QueryClientSettings]): (self._session_state, self._tx_state, self), ) - def _rollback_call(self, settings: Optional[base.QueryClientSettings]): + def _rollback_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext": return self._driver( _create_rollback_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -249,7 +275,7 @@ def _execute_call( exec_mode: base.QueryExecMode = None, parameters: dict = None, concurrent_result_sets: bool = False, - ): + ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: request = base.create_execute_query_request( query=query, session_id=self._session_state.session_id, @@ -263,23 +289,24 @@ def _execute_call( ) return self._driver( - request, + request.to_proto(), _apis.QueryService.Stub, _apis.QueryService.ExecuteQuery, ) - def _ensure_prev_stream_finished(self): + def _ensure_prev_stream_finished(self) -> None: if self._prev_stream is not None: for _ in self._prev_stream: pass self._prev_stream = None - def _handle_tx_meta(self, tx_meta=None): - if not self.tx_id and tx_meta: - self._tx_state._change_state(QueryTxStateEnum.BEGINED) - self._tx_state.tx_id = tx_meta.id + def _move_to_beginned(self, tx_id: str) -> None: + if self._tx_state._already_in(QueryTxStateEnum.BEGINED): + return + self._tx_state._change_state(QueryTxStateEnum.BEGINED) + self._tx_state.tx_id = tx_id - def _move_to_commited(self): + def _move_to_commited(self) -> None: if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): return self._tx_state._change_state(QueryTxStateEnum.COMMITTED) From 630c1883742d088deb53a9acd349c518633aa549 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 26 Jul 2024 17:12:03 +0300 Subject: [PATCH 55/57] fixes --- ydb/query/session.py | 4 ++-- ydb/query/transaction.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ydb/query/session.py b/ydb/query/session.py index a6a02d92..5b9c00ed 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -105,7 +105,7 @@ def wrapper_create_session( rpc_state: RpcState, response_pb: _apis.ydb_query.CreateSessionResponse, session_state: QuerySessionState, - session: "BaseQuerySession" + session: "BaseQuerySession", ) -> "BaseQuerySession": message = _ydb_query.CreateSessionResponse.from_proto(response_pb) issues._process_response(message.status) @@ -117,7 +117,7 @@ def wrapper_delete_session( rpc_state: RpcState, response_pb: _apis.ydb_query.DeleteSessionResponse, session_state: QuerySessionState, - session: "BaseQuerySession" + session: "BaseQuerySession", ) -> "BaseQuerySession": message = _ydb_query.DeleteSessionResponse.from_proto(response_pb) issues._process_response(message.status) diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 154a893e..cfbbf89b 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -52,7 +52,9 @@ def terminal(cls, state: QueryTxStateEnum) -> bool: def reset_tx_id_handler(func): @functools.wraps(func) - def decorator(rpc_state, response_pb, session_state: base.IQuerySessionState, tx_state: QueryTxState, *args, **kwargs): + def decorator( + rpc_state, response_pb, session_state: base.IQuerySessionState, tx_state: QueryTxState, *args, **kwargs + ): try: return func(rpc_state, response_pb, session_state, tx_state, *args, **kwargs) except issues.Error: From 23366241908bbdc7e18df4249a1f54c171bffc29 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Sun, 28 Jul 2024 20:08:58 +0300 Subject: [PATCH 56/57] review fixes --- ydb/query/base.py | 4 ++-- ydb/query/transaction.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ydb/query/base.py b/ydb/query/base.py index 9fe6c21b..dcb56ca6 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -227,9 +227,9 @@ def session_id(self) -> str: @abc.abstractmethod def tx_id(self) -> Optional[str]: """ - Returns a id of open transaction or None otherwise + Returns an id of open transaction or None otherwise - :return: A id of open transaction or None otherwise + :return: An id of open transaction or None otherwise """ pass diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index cfbbf89b..6846b5a5 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -83,7 +83,7 @@ def _change_state(self, target: QueryTxStateEnum) -> None: self._check_invalid_transition(target) self._state = target - def _check_tx_not_terminal(self) -> None: + def _check_tx_ready_to_use(self) -> None: if QueryTxStateHelper.terminal(self._state): raise RuntimeError(f"Transaction is in terminal state: {self._state.value}") @@ -233,9 +233,9 @@ def session_id(self) -> str: @property def tx_id(self) -> Optional[str]: """ - Returns a id of open transaction or None otherwise + Returns an id of open transaction or None otherwise - :return: A id of open transaction or None otherwise + :return: An id of open transaction or None otherwise """ return self._tx_state.tx_id @@ -371,7 +371,7 @@ def execute( ) -> base.SyncResponseContextIterator: self._ensure_prev_stream_finished() - self._tx_state._check_tx_not_terminal() + self._tx_state._check_tx_ready_to_use() stream_it = self._execute_call( query=query, From 49e747127961a0f83b0539635987b84e8277ab1b Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 30 Jul 2024 15:27:14 +0300 Subject: [PATCH 57/57] review fixes --- examples/query-service/basic_example.py | 14 +++++++------- tests/query/test_query_session.py | 5 +++-- ydb/query/pool.py | 8 +++++++- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index de2920c3..b355e10c 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -18,7 +18,7 @@ def main(): print("=" * 50) print("DELETE TABLE IF EXISTS") - pool.execute_with_retries("drop table if exists example") + pool.execute_with_retries("DROP TABLE IF EXISTS example") print("=" * 50) print("CREATE TABLE") @@ -28,11 +28,11 @@ def main(): def callee(session): print("=" * 50) - with session.execute("delete from example"): + with session.execute("DELETE FROM example"): pass print("BEFORE ACTION") - with session.execute("SELECT COUNT(*) as rows_count FROM example") as results: + with session.execute("SELECT COUNT(*) AS rows_count FROM example") as results: for result_set in results: print(f"rows: {str(result_set.rows)}") @@ -45,7 +45,7 @@ def callee(session): with tx.execute("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal')"): pass - with tx.execute("SELECT COUNT(*) as rows_count FROM example") as results: + with tx.execute("SELECT COUNT(*) AS rows_count FROM example") as results: for result_set in results: print(f"rows: {str(result_set.rows)}") @@ -54,7 +54,7 @@ def callee(session): print("=" * 50) print("AFTER COMMIT TX") - with session.execute("SELECT COUNT(*) as rows_count FROM example") as results: + with session.execute("SELECT COUNT(*) AS rows_count FROM example") as results: for result_set in results: print(f"rows: {str(result_set.rows)}") @@ -67,7 +67,7 @@ def callee(session): with tx.execute("INSERT INTO example (key, value) VALUES (2, 'onepieceisreal')"): pass - with tx.execute("SELECT COUNT(*) as rows_count FROM example") as results: + with tx.execute("SELECT COUNT(*) AS rows_count FROM example") as results: for result_set in results: print(f"rows: {str(result_set.rows)}") @@ -76,7 +76,7 @@ def callee(session): print("=" * 50) print("AFTER ROLLBACK TX") - with session.execute("SELECT COUNT(*) as rows_count FROM example") as results: + with session.execute("SELECT COUNT(*) AS rows_count FROM example") as results: for result_set in results: print(f"rows: {str(result_set.rows)}") diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py index 585c4c2b..89b899bd 100644 --- a/tests/query/test_query_session.py +++ b/tests/query/test_query_session.py @@ -93,9 +93,10 @@ def test_basic_execute(self, session: QuerySessionSync): def test_two_results(self, session: QuerySessionSync): session.create() res = [] + with session.execute("select 1; select 2") as results: for result_set in results: if len(result_set.rows) > 0: - res.extend(list(result_set.rows[0].values())) + res.append(list(result_set.rows[0].values())) - assert res == [1, 2] + assert res == [[1], [2]] diff --git a/ydb/query/pool.py b/ydb/query/pool.py index eef084f9..bddd666a 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -2,6 +2,7 @@ from typing import ( Callable, Optional, + List, ) from . import base @@ -12,6 +13,7 @@ RetrySettings, retry_operation_sync, ) +from .. import convert logger = logging.getLogger(__name__) @@ -49,8 +51,12 @@ def wrapped_callee(): return retry_operation_sync(wrapped_callee, retry_settings) - def execute_with_retries(self, query: str, retry_settings: Optional[RetrySettings] = None, *args, **kwargs): + def execute_with_retries( + self, query: str, retry_settings: Optional[RetrySettings] = None, *args, **kwargs + ) -> List[convert.ResultSet]: """Special interface to execute a one-shot queries in a safe, retriable way. + Note: this method loads all data from stream before return, do not use this + method with huge read queries. :param query: A query, yql or sql text. :param retry_settings: RetrySettings object.