diff --git a/docker-compose.yml b/docker-compose.yml index edbd56d1..cb37a377 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.3" services: ydb: - image: cr.yandex/yc/yandex-docker-local-ydb:latest + image: ydbplatform/local-ydb:trunk restart: always ports: - 2136:2136 diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py new file mode 100644 index 00000000..b355e10c --- /dev/null +++ b/examples/query-service/basic_example.py @@ -0,0 +1,87 @@ +import ydb + + +def main(): + driver_config = ydb.DriverConfig( + endpoint="grpc://localhost:2136", + database="/local", + # credentials=ydb.credentials_from_env_variables(), + # root_certificates=ydb.load_ydb_root_certificate(), + ) + try: + driver = ydb.Driver(driver_config) + driver.wait(timeout=5) + except TimeoutError: + raise RuntimeError("Connect failed to YDB") + + pool = ydb.QuerySessionPool(driver) + + print("=" * 50) + print("DELETE TABLE IF EXISTS") + pool.execute_with_retries("DROP TABLE IF EXISTS example") + + print("=" * 50) + print("CREATE TABLE") + pool.execute_with_retries("CREATE TABLE example(key UInt64, value String, PRIMARY KEY (key))") + + pool.execute_with_retries("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal')") + + def callee(session): + print("=" * 50) + with session.execute("DELETE FROM example"): + pass + + print("BEFORE ACTION") + with session.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + for result_set in results: + print(f"rows: {str(result_set.rows)}") + + print("=" * 50) + print("INSERT WITH COMMIT TX") + + with session.transaction() as tx: + tx.begin() + + with tx.execute("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal')"): + pass + + with tx.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + for result_set in results: + print(f"rows: {str(result_set.rows)}") + + tx.commit() + + print("=" * 50) + print("AFTER COMMIT TX") + + with session.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + for result_set in results: + print(f"rows: {str(result_set.rows)}") + + print("=" * 50) + print("INSERT WITH ROLLBACK TX") + + with session.transaction() as tx: + tx.begin() + + with tx.execute("INSERT INTO example (key, value) VALUES (2, 'onepieceisreal')"): + pass + + with tx.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + for result_set in results: + print(f"rows: {str(result_set.rows)}") + + tx.rollback() + + print("=" * 50) + print("AFTER ROLLBACK TX") + + with session.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + for result_set in results: + print(f"rows: {str(result_set.rows)}") + + pool.retry_operation_sync(callee) + + +if __name__ == "__main__": + main() diff --git a/tests/query/__init__.py b/tests/query/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/query/conftest.py b/tests/query/conftest.py new file mode 100644 index 00000000..277aaeba --- /dev/null +++ b/tests/query/conftest.py @@ -0,0 +1,34 @@ +import pytest +from ydb.query.session import QuerySessionSync +from ydb.query.pool import QuerySessionPool + + +@pytest.fixture +def session(driver_sync): + session = QuerySessionSync(driver_sync) + + yield session + + try: + session.delete() + except BaseException: + pass + + +@pytest.fixture +def tx(session): + session.create() + transaction = session.transaction() + + yield transaction + + try: + transaction.rollback() + except BaseException: + pass + + +@pytest.fixture +def pool(driver_sync): + pool = QuerySessionPool(driver_sync) + yield pool diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py new file mode 100644 index 00000000..89b899bd --- /dev/null +++ b/tests/query/test_query_session.py @@ -0,0 +1,102 @@ +import pytest + +from ydb.query.session import QuerySessionSync + + +def _check_session_state_empty(session: QuerySessionSync): + assert session._state.session_id is None + assert session._state.node_id is None + assert not session._state.attached + + +def _check_session_state_full(session: QuerySessionSync): + assert session._state.session_id is not None + assert session._state.node_id is not None + assert session._state.attached + + +class TestQuerySession: + def test_session_normal_lifecycle(self, session: QuerySessionSync): + _check_session_state_empty(session) + + session.create() + _check_session_state_full(session) + + session.delete() + _check_session_state_empty(session) + + def test_second_create_do_nothing(self, session: QuerySessionSync): + session.create() + _check_session_state_full(session) + + session_id_before = session._state.session_id + node_id_before = session._state.node_id + + session.create() + _check_session_state_full(session) + + assert session._state.session_id == session_id_before + assert session._state.node_id == node_id_before + + def test_second_delete_do_nothing(self, session: QuerySessionSync): + session.create() + + session.delete() + session.delete() + + def test_delete_before_create_not_possible(self, session: QuerySessionSync): + with pytest.raises(RuntimeError): + session.delete() + + def test_create_after_delete_not_possible(self, session: QuerySessionSync): + session.create() + session.delete() + with pytest.raises(RuntimeError): + session.create() + + def test_transaction_before_create_raises(self, session: QuerySessionSync): + with pytest.raises(RuntimeError): + session.transaction() + + def test_transaction_after_delete_raises(self, session: QuerySessionSync): + session.create() + + session.delete() + + with pytest.raises(RuntimeError): + session.transaction() + + def test_transaction_after_create_not_raises(self, session: QuerySessionSync): + session.create() + session.transaction() + + def test_execute_before_create_raises(self, session: QuerySessionSync): + with pytest.raises(RuntimeError): + session.execute("select 1;") + + def test_execute_after_delete_raises(self, session: QuerySessionSync): + session.create() + session.delete() + with pytest.raises(RuntimeError): + session.execute("select 1;") + + def test_basic_execute(self, session: QuerySessionSync): + session.create() + it = session.execute("select 1;") + result_sets = [result_set for result_set in it] + + assert len(result_sets) == 1 + assert len(result_sets[0].rows) == 1 + assert len(result_sets[0].columns) == 1 + assert list(result_sets[0].rows[0].values()) == [1] + + def test_two_results(self, session: QuerySessionSync): + session.create() + res = [] + + with session.execute("select 1; select 2") as results: + for result_set in results: + if len(result_set.rows) > 0: + res.append(list(result_set.rows[0].values())) + + assert res == [[1], [2]] diff --git a/tests/query/test_query_session_pool.py b/tests/query/test_query_session_pool.py new file mode 100644 index 00000000..3c66c613 --- /dev/null +++ b/tests/query/test_query_session_pool.py @@ -0,0 +1,49 @@ +import pytest +import ydb +from ydb.query.pool import QuerySessionPool +from ydb.query.session import QuerySessionSync, QuerySessionStateEnum + + +class TestQuerySessionPool: + def test_checkout_provides_created_session(self, pool: QuerySessionPool): + with pool.checkout() as session: + assert session._state._state == QuerySessionStateEnum.CREATED + + assert session._state._state == QuerySessionStateEnum.CLOSED + + def test_oneshot_query_normal(self, pool: QuerySessionPool): + res = pool.execute_with_retries("select 1;") + assert len(res) == 1 + + def test_oneshot_ddl_query(self, pool: QuerySessionPool): + pool.execute_with_retries("create table Queen(key UInt64, PRIMARY KEY (key));") + pool.execute_with_retries("drop table Queen;") + + def test_oneshot_query_raises(self, pool: QuerySessionPool): + with pytest.raises(ydb.GenericError): + pool.execute_with_retries("Is this the real life? Is this just fantasy?") + + def test_retry_op_uses_created_session(self, pool: QuerySessionPool): + def callee(session: QuerySessionSync): + assert session._state._state == QuerySessionStateEnum.CREATED + + pool.retry_operation_sync(callee) + + def test_retry_op_normal(self, pool: QuerySessionPool): + def callee(session: QuerySessionSync): + with session.transaction() as tx: + iterator = tx.execute("select 1;", commit_tx=True) + return [result_set for result_set in iterator] + + res = pool.retry_operation_sync(callee) + assert len(res) == 1 + + def test_retry_op_raises(self, pool: QuerySessionPool): + class CustomException(Exception): + pass + + def callee(session: QuerySessionSync): + raise CustomException() + + with pytest.raises(CustomException): + pool.retry_operation_sync(callee) diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py new file mode 100644 index 00000000..1c3fdda2 --- /dev/null +++ b/tests/query/test_query_transaction.py @@ -0,0 +1,81 @@ +import pytest + +from ydb.query.transaction import BaseQueryTxContext +from ydb.query.transaction import QueryTxStateEnum + + +class TestQueryTransaction: + def test_tx_begin(self, tx: BaseQueryTxContext): + assert tx.tx_id is None + + tx.begin() + assert tx.tx_id is not None + + def test_tx_allow_double_commit(self, tx: BaseQueryTxContext): + tx.begin() + tx.commit() + tx.commit() + + def test_tx_allow_double_rollback(self, tx: BaseQueryTxContext): + tx.begin() + tx.rollback() + tx.rollback() + + def test_tx_commit_before_begin(self, tx: BaseQueryTxContext): + tx.commit() + assert tx._tx_state._state == QueryTxStateEnum.COMMITTED + + def test_tx_rollback_before_begin(self, tx: BaseQueryTxContext): + tx.rollback() + assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED + + def test_tx_first_execute_begins_tx(self, tx: BaseQueryTxContext): + tx.execute("select 1;") + tx.commit() + + def test_interactive_tx_commit(self, tx: BaseQueryTxContext): + tx.execute("select 1;", commit_tx=True) + with pytest.raises(RuntimeError): + tx.execute("select 1;") + + def test_tx_execute_raises_after_commit(self, tx: BaseQueryTxContext): + tx.begin() + tx.commit() + with pytest.raises(RuntimeError): + tx.execute("select 1;") + + def test_tx_execute_raises_after_rollback(self, tx: BaseQueryTxContext): + tx.begin() + tx.rollback() + with pytest.raises(RuntimeError): + tx.execute("select 1;") + + def test_context_manager_rollbacks_tx(self, tx: BaseQueryTxContext): + with tx: + tx.begin() + + assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED + + def test_context_manager_normal_flow(self, tx: BaseQueryTxContext): + with tx: + tx.begin() + tx.execute("select 1;") + tx.commit() + + assert tx._tx_state._state == QueryTxStateEnum.COMMITTED + + def test_context_manager_does_not_hide_exceptions(self, tx: BaseQueryTxContext): + class CustomException(Exception): + pass + + with pytest.raises(CustomException): + with tx: + raise CustomException() + + def test_execute_as_context_manager(self, tx: BaseQueryTxContext): + tx.begin() + + with tx.execute("select 1;") as results: + res = [result_set for result_set in results] + + assert len(res) == 1 diff --git a/tox.ini b/tox.ini index 7aca13db..8b5c06ae 100644 --- a/tox.ini +++ b/tox.ini @@ -75,6 +75,7 @@ builtins = _ max-line-length = 160 ignore=E203,W503 exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,ydb/public/api/protos/*,docs/*,ydb/public/api/grpc/*,persqueue/*,client/*,dbapi/*,ydb/default_pem.py,*docs/conf.py +per-file-ignores = ydb/table.py:F401 [pytest] asyncio_mode = auto diff --git a/ydb/__init__.py b/ydb/__init__.py index 0a09834b..375f2f54 100644 --- a/ydb/__init__.py +++ b/ydb/__init__.py @@ -19,6 +19,8 @@ from .tracing import * # noqa from .topic import * # noqa from .draft import * # noqa +from .query import * # noqa +from .retries import * # noqa try: import ydb.aio as aio # noqa diff --git a/ydb/_apis.py b/ydb/_apis.py index 8c0b1164..2a9a14e8 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -10,6 +10,7 @@ ydb_table_v1_pb2_grpc, ydb_operation_v1_pb2_grpc, ydb_topic_v1_pb2_grpc, + ydb_query_v1_pb2_grpc, ) from ._grpc.v4.protos import ( @@ -20,6 +21,7 @@ ydb_value_pb2, ydb_operation_pb2, ydb_common_pb2, + ydb_query_pb2, ) else: @@ -30,6 +32,7 @@ ydb_table_v1_pb2_grpc, ydb_operation_v1_pb2_grpc, ydb_topic_v1_pb2_grpc, + ydb_query_v1_pb2_grpc, ) from ._grpc.common.protos import ( @@ -40,6 +43,7 @@ ydb_value_pb2, ydb_operation_pb2, ydb_common_pb2, + ydb_query_pb2, ) @@ -51,6 +55,7 @@ ydb_table = ydb_table_pb2 ydb_discovery = ydb_discovery_pb2 ydb_operation = ydb_operation_pb2 +ydb_query = ydb_query_pb2 class CmsService(object): @@ -111,3 +116,19 @@ class TopicService(object): DropTopic = "DropTopic" StreamRead = "StreamRead" StreamWrite = "StreamWrite" + + +class QueryService(object): + Stub = ydb_query_v1_pb2_grpc.QueryServiceStub + + CreateSession = "CreateSession" + DeleteSession = "DeleteSession" + AttachSession = "AttachSession" + + BeginTransaction = "BeginTransaction" + CommitTransaction = "CommitTransaction" + RollbackTransaction = "RollbackTransaction" + + ExecuteQuery = "ExecuteQuery" + ExecuteScript = "ExecuteScript" + FetchScriptResults = "FetchScriptResults" diff --git a/ydb/_grpc/grpcwrapper/ydb_query.py b/ydb/_grpc/grpcwrapper/ydb_query.py new file mode 100644 index 00000000..befb02c7 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_query.py @@ -0,0 +1,180 @@ +from dataclasses import dataclass +import typing +from typing import Optional + + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_query_pb2 +else: + from ..common.protos import ydb_query_pb2 + +from . import ydb_query_public_types as public_types + +from .common_utils import ( + IFromProto, + IToProto, + IFromPublic, + ServerStatus, +) + + +@dataclass +class CreateSessionResponse(IFromProto): + status: ServerStatus + session_id: str + node_id: int + + @staticmethod + def from_proto(msg: ydb_query_pb2.CreateSessionResponse) -> "CreateSessionResponse": + return CreateSessionResponse( + status=ServerStatus(msg.status, msg.issues), + session_id=msg.session_id, + node_id=msg.node_id, + ) + + +@dataclass +class DeleteSessionResponse(IFromProto): + status: ServerStatus + + @staticmethod + def from_proto(msg: ydb_query_pb2.DeleteSessionResponse) -> "DeleteSessionResponse": + return DeleteSessionResponse(status=ServerStatus(msg.status, msg.issues)) + + +@dataclass +class AttachSessionRequest(IToProto): + session_id: str + + def to_proto(self) -> ydb_query_pb2.AttachSessionRequest: + return ydb_query_pb2.AttachSessionRequest(session_id=self.session_id) + + +@dataclass +class TransactionMeta(IFromProto): + tx_id: str + + @staticmethod + def from_proto(msg: ydb_query_pb2.TransactionMeta) -> "TransactionMeta": + return TransactionMeta(tx_id=msg.id) + + +@dataclass +class TransactionSettings(IFromPublic, IToProto): + tx_mode: public_types.BaseQueryTxMode + + @staticmethod + def from_public(tx_mode: public_types.BaseQueryTxMode) -> "TransactionSettings": + return TransactionSettings(tx_mode=tx_mode) + + def to_proto(self) -> ydb_query_pb2.TransactionSettings: + if self.tx_mode.name == "snapshot_read_only": + return ydb_query_pb2.TransactionSettings(snapshot_read_only=self.tx_mode.to_proto()) + if self.tx_mode.name == "serializable_read_write": + return ydb_query_pb2.TransactionSettings(serializable_read_write=self.tx_mode.to_proto()) + if self.tx_mode.name == "online_read_only": + return ydb_query_pb2.TransactionSettings(online_read_only=self.tx_mode.to_proto()) + if self.tx_mode.name == "stale_read_only": + return ydb_query_pb2.TransactionSettings(stale_read_only=self.tx_mode.to_proto()) + + +@dataclass +class BeginTransactionRequest(IToProto): + session_id: str + tx_settings: TransactionSettings + + def to_proto(self) -> ydb_query_pb2.BeginTransactionRequest: + return ydb_query_pb2.BeginTransactionRequest( + session_id=self.session_id, + tx_settings=self.tx_settings.to_proto(), + ) + + +@dataclass +class BeginTransactionResponse(IFromProto): + status: Optional[ServerStatus] + tx_meta: TransactionMeta + + @staticmethod + def from_proto(msg: ydb_query_pb2.BeginTransactionResponse) -> "BeginTransactionResponse": + return BeginTransactionResponse( + status=ServerStatus(msg.status, msg.issues), + tx_meta=TransactionMeta.from_proto(msg.tx_meta), + ) + + +@dataclass +class CommitTransactionResponse(IFromProto): + status: Optional[ServerStatus] + + @staticmethod + def from_proto(msg: ydb_query_pb2.CommitTransactionResponse) -> "CommitTransactionResponse": + return CommitTransactionResponse( + status=ServerStatus(msg.status, msg.issues), + ) + + +@dataclass +class RollbackTransactionResponse(IFromProto): + status: Optional[ServerStatus] + + @staticmethod + def from_proto(msg: ydb_query_pb2.RollbackTransactionResponse) -> "RollbackTransactionResponse": + return RollbackTransactionResponse( + status=ServerStatus(msg.status, msg.issues), + ) + + +@dataclass +class QueryContent(IFromPublic, IToProto): + text: str + syntax: int + + @staticmethod + def from_public(query: str, syntax: int) -> "QueryContent": + return QueryContent(text=query, syntax=syntax) + + def to_proto(self) -> ydb_query_pb2.QueryContent: + return ydb_query_pb2.QueryContent(text=self.text, syntax=self.syntax) + + +@dataclass +class TransactionControl(IToProto): + begin_tx: Optional[TransactionSettings] + commit_tx: Optional[bool] + tx_id: Optional[str] + + def to_proto(self) -> ydb_query_pb2.TransactionControl: + if self.tx_id: + return ydb_query_pb2.TransactionControl( + tx_id=self.tx_id, + commit_tx=self.commit_tx, + ) + return ydb_query_pb2.TransactionControl( + begin_tx=self.begin_tx.to_proto(), + commit_tx=self.commit_tx, + ) + + +@dataclass +class ExecuteQueryRequest(IToProto): + session_id: str + query_content: QueryContent + tx_control: TransactionControl + concurrent_result_sets: bool + exec_mode: int + parameters: dict + stats_mode: int + + def to_proto(self) -> ydb_query_pb2.ExecuteQueryRequest: + tx_control = self.tx_control.to_proto() if self.tx_control is not None else self.tx_control + return ydb_query_pb2.ExecuteQueryRequest( + session_id=self.session_id, + tx_control=tx_control, + query_content=self.query_content.to_proto(), + exec_mode=self.exec_mode, + stats_mode=self.stats_mode, + concurrent_result_sets=self.concurrent_result_sets, + parameters=self.parameters, + ) diff --git a/ydb/_grpc/grpcwrapper/ydb_query_public_types.py b/ydb/_grpc/grpcwrapper/ydb_query_public_types.py new file mode 100644 index 00000000..d79a2967 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_query_public_types.py @@ -0,0 +1,66 @@ +import abc +import typing + +from .common_utils import IToProto + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_query_pb2 +else: + from ..common.protos import ydb_query_pb2 + + +class BaseQueryTxMode(IToProto): + @property + @abc.abstractmethod + def name(self) -> str: + pass + + +class QuerySnapshotReadOnly(BaseQueryTxMode): + def __init__(self): + self._name = "snapshot_read_only" + + @property + def name(self) -> str: + return self._name + + def to_proto(self) -> ydb_query_pb2.SnapshotModeSettings: + return ydb_query_pb2.SnapshotModeSettings() + + +class QuerySerializableReadWrite(BaseQueryTxMode): + def __init__(self): + self._name = "serializable_read_write" + + @property + def name(self) -> str: + return self._name + + def to_proto(self) -> ydb_query_pb2.SerializableModeSettings: + return ydb_query_pb2.SerializableModeSettings() + + +class QueryOnlineReadOnly(BaseQueryTxMode): + def __init__(self, allow_inconsistent_reads: bool = False): + self.allow_inconsistent_reads = allow_inconsistent_reads + self._name = "online_read_only" + + @property + def name(self): + return self._name + + def to_proto(self) -> ydb_query_pb2.OnlineModeSettings: + return ydb_query_pb2.OnlineModeSettings(allow_inconsistent_reads=self.allow_inconsistent_reads) + + +class QueryStaleReadOnly(BaseQueryTxMode): + def __init__(self): + self._name = "stale_read_only" + + @property + def name(self): + return self._name + + def to_proto(self) -> ydb_query_pb2.StaleModeSettings: + return ydb_query_pb2.StaleModeSettings() diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 17fb2885..b907ee27 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -10,7 +10,7 @@ Callable, ) -from ..table import RetrySettings +from ..retries import RetrySettings from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, OffsetsRange diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 007c8a54..585e88ab 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -26,9 +26,9 @@ from .. import ( _apis, issues, - check_retriable_error, - RetrySettings, ) +from .._errors import check_retriable_error +from ..retries import RetrySettings from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._grpc.grpcwrapper.ydb_topic import ( UpdateTokenRequest, diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py new file mode 100644 index 00000000..eb967abc --- /dev/null +++ b/ydb/query/__init__.py @@ -0,0 +1,40 @@ +__all__ = [ + "QueryOnlineReadOnly", + "QuerySerializableReadWrite", + "QuerySnapshotReadOnly", + "QueryStaleReadOnly", + "QuerySessionPool", + "QueryClientSync", + "QuerySessionSync", +] + +import logging + +from .base import ( + IQueryClient, + SupportedDriverType, + QueryClientSettings, +) + +from .session import QuerySessionSync + +from .._grpc.grpcwrapper.ydb_query_public_types import ( + QueryOnlineReadOnly, + QuerySerializableReadWrite, + QuerySnapshotReadOnly, + QueryStaleReadOnly, +) + +from .pool import QuerySessionPool + +logger = logging.getLogger(__name__) + + +class QueryClientSync(IQueryClient): + def __init__(self, driver: SupportedDriverType, query_client_settings: QueryClientSettings = None): + logger.warning("QueryClientSync is an experimental API, which could be changed.") + self._driver = driver + self._settings = query_client_settings + + def session(self) -> QuerySessionSync: + return QuerySessionSync(self._driver, self._settings) diff --git a/ydb/query/base.py b/ydb/query/base.py new file mode 100644 index 00000000..dcb56ca6 --- /dev/null +++ b/ydb/query/base.py @@ -0,0 +1,380 @@ +import abc +import enum +import functools + +from typing import ( + Iterator, + Optional, +) + +from .._grpc.grpcwrapper.common_utils import ( + SupportedDriverType, +) +from .._grpc.grpcwrapper import ydb_query +from .._grpc.grpcwrapper.ydb_query_public_types import ( + BaseQueryTxMode, +) +from ..connection import _RpcState as RpcState +from .. import convert +from .. import issues +from .. import _utilities +from .. import _apis + + +class QuerySyntax(enum.IntEnum): + UNSPECIFIED = 0 + YQL_V1 = 1 + PG = 2 + + +class QueryExecMode(enum.IntEnum): + UNSPECIFIED = 0 + PARSE = 10 + VALIDATE = 20 + EXPLAIN = 30 + EXECUTE = 50 + + +class StatsMode(enum.IntEnum): + UNSPECIFIED = 0 + NONE = 10 + BASIC = 20 + FULL = 30 + PROFILE = 40 + + +class SyncResponseContextIterator(_utilities.SyncResponseIterator): + def __enter__(self) -> "SyncResponseContextIterator": + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + for _ in self: + pass + + +class QueryClientSettings: + pass + + +class IQuerySessionState(abc.ABC): + def __init__(self, settings: Optional[QueryClientSettings] = None): + pass + + @abc.abstractmethod + def reset(self) -> None: + pass + + @property + @abc.abstractmethod + def session_id(self) -> Optional[str]: + pass + + @abc.abstractmethod + def set_session_id(self, session_id: str) -> "IQuerySessionState": + pass + + @property + @abc.abstractmethod + def node_id(self) -> Optional[int]: + pass + + @abc.abstractmethod + def set_node_id(self, node_id: int) -> "IQuerySessionState": + pass + + @property + @abc.abstractmethod + def attached(self) -> bool: + pass + + @abc.abstractmethod + def set_attached(self, attached: bool) -> "IQuerySessionState": + pass + + +class IQuerySession(abc.ABC): + """Session object for Query Service. It is not recommended to control + session's lifecycle manually - use a QuerySessionPool is always a better choise. + """ + + @abc.abstractmethod + def __init__(self, driver: SupportedDriverType, settings: Optional[QueryClientSettings] = None): + pass + + @abc.abstractmethod + def create(self) -> "IQuerySession": + """ + Creates a Session of Query Service on server side and attaches it. + + :return: Session object. + """ + pass + + @abc.abstractmethod + def delete(self) -> None: + """ + Deletes a Session of Query Service on server side and releases resources. + + :return: None + """ + pass + + @abc.abstractmethod + def transaction(self, tx_mode: Optional[BaseQueryTxMode] = None) -> "IQueryTxContext": + """ + Creates a transaction context manager with specified transaction mode. + + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + + :return: transaction context manager. + """ + pass + + @abc.abstractmethod + def execute( + self, + query: str, + syntax: Optional[QuerySyntax] = None, + exec_mode: Optional[QueryExecMode] = None, + parameters: Optional[dict] = None, + concurrent_result_sets: Optional[bool] = False, + ) -> Iterator: + """ + Sends a query to Query Service + :param query: (YQL or SQL text) to be executed. + :param syntax: Syntax of the query, which is a one from the following choises: + 1) QuerySyntax.YQL_V1, which is default; + 2) QuerySyntax.PG. + :param parameters: dict with parameters and YDB types; + :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; + + :return: Iterator with result sets + """ + + +class IQueryTxContext(abc.ABC): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + 1) By explicit .begin(); + 2) On execution of a first statement, which is strictly recommended method, because that avoids + useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + """ + + @abc.abstractmethod + def __init__( + self, + driver: SupportedDriverType, + session_state: IQuerySessionState, + session: IQuerySession, + tx_mode: BaseQueryTxMode, + ): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() method; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + """ + pass + + @abc.abstractmethod + def __enter__(self) -> "IQueryTxContext": + """ + Enters a context manager and returns a transaction + + :return: A transaction instance + """ + pass + + @abc.abstractmethod + def __exit__(self, *args, **kwargs): + """ + Closes a transaction context manager and rollbacks transaction if + it is not finished explicitly + """ + pass + + @property + @abc.abstractmethod + def session_id(self) -> str: + """ + A transaction's session id + + :return: A transaction's session id + """ + pass + + @property + @abc.abstractmethod + def tx_id(self) -> Optional[str]: + """ + Returns an id of open transaction or None otherwise + + :return: An id of open transaction or None otherwise + """ + pass + + @abc.abstractmethod + def begin(self, settings: Optional[QueryClientSettings] = None) -> None: + """ + Explicitly begins a transaction + + :param settings: A request settings + + :return: None or exception if begin is failed + """ + pass + + @abc.abstractmethod + def commit(self, settings: Optional[QueryClientSettings] = None) -> None: + """ + Calls commit on a transaction if it is open. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: None or exception if commit is failed + """ + pass + + @abc.abstractmethod + def rollback(self, settings: Optional[QueryClientSettings] = None) -> None: + """ + Calls rollback on a transaction if it is open. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: None or exception if rollback is failed + """ + pass + + @abc.abstractmethod + def execute( + self, + query: str, + commit_tx: Optional[bool] = False, + syntax: Optional[QuerySyntax] = None, + exec_mode: Optional[QueryExecMode] = None, + parameters: Optional[dict] = None, + concurrent_result_sets: Optional[bool] = False, + ) -> Iterator: + """ + Sends a query to Query Service + :param query: (YQL or SQL text) to be executed. + :param commit_tx: A special flag that allows transaction commit. + :param syntax: Syntax of the query, which is a one from the following choises: + 1) QuerySyntax.YQL_V1, which is default; + 2) QuerySyntax.PG. + :param exec_mode: Exec mode of the query, which is a one from the following choises: + 1) QueryExecMode.EXECUTE, which is default; + 2) QueryExecMode.EXPLAIN; + 3) QueryExecMode.VALIDATE; + 4) QueryExecMode.PARSE. + :param parameters: dict with parameters and YDB types; + :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; + + :return: Iterator with result sets + """ + pass + + +class IQueryClient(abc.ABC): + def __init__(self, driver: SupportedDriverType, query_client_settings: Optional[QueryClientSettings] = None): + pass + + @abc.abstractmethod + def session(self) -> IQuerySession: + pass + + +def create_execute_query_request( + query: str, + session_id: str, + tx_id: Optional[str], + commit_tx: Optional[bool], + tx_mode: Optional[BaseQueryTxMode], + syntax: Optional[QuerySyntax], + exec_mode: Optional[QueryExecMode], + parameters: Optional[dict], + concurrent_result_sets: Optional[bool], +) -> ydb_query.ExecuteQueryRequest: + syntax = QuerySyntax.YQL_V1 if not syntax else syntax + exec_mode = QueryExecMode.EXECUTE if not exec_mode else exec_mode + stats_mode = StatsMode.NONE # TODO: choise is not supported yet + + tx_control = None + if not tx_id and not tx_mode: + tx_control = None + elif tx_id: + tx_control = ydb_query.TransactionControl( + tx_id=tx_id, + commit_tx=commit_tx, + begin_tx=None, + ) + else: + tx_control = ydb_query.TransactionControl( + begin_tx=ydb_query.TransactionSettings( + tx_mode=tx_mode, + ), + commit_tx=commit_tx, + tx_id=None, + ) + + return ydb_query.ExecuteQueryRequest( + session_id=session_id, + query_content=ydb_query.QueryContent.from_public( + query=query, + syntax=syntax, + ), + tx_control=tx_control, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + stats_mode=stats_mode, + ) + + +def wrap_execute_query_response( + rpc_state: RpcState, + response_pb: _apis.ydb_query.ExecuteQueryResponsePart, + tx: Optional[IQueryTxContext] = None, + commit_tx: Optional[bool] = False, +) -> convert.ResultSet: + issues._process_response(response_pb) + if tx and response_pb.tx_meta and not tx.tx_id: + tx._move_to_beginned(response_pb.tx_meta.id) + if tx and commit_tx: + tx._move_to_commited() + return convert.ResultSet.from_message(response_pb.result_set) + + +def bad_session_handler(func): + @functools.wraps(func) + def decorator(rpc_state, response_pb, session_state: IQuerySessionState, *args, **kwargs): + try: + return func(rpc_state, response_pb, session_state, *args, **kwargs) + except issues.BadSession: + session_state.reset() + raise + + return decorator diff --git a/ydb/query/pool.py b/ydb/query/pool.py new file mode 100644 index 00000000..bddd666a --- /dev/null +++ b/ydb/query/pool.py @@ -0,0 +1,87 @@ +import logging +from typing import ( + Callable, + Optional, + List, +) + +from . import base +from .session import ( + QuerySessionSync, +) +from ..retries import ( + RetrySettings, + retry_operation_sync, +) +from .. import convert + +logger = logging.getLogger(__name__) + + +class QuerySessionPool: + """QuerySessionPool is an object to simplify operations with sessions of Query Service.""" + + def __init__(self, driver: base.SupportedDriverType): + """ + :param driver: A driver instance + """ + + logger.warning("QuerySessionPool is an experimental API, which could be changed.") + self._driver = driver + + def checkout(self) -> "SimpleQuerySessionCheckout": + """Return a Session context manager, that opens session on enter and closes session on exit.""" + + return SimpleQuerySessionCheckout(self) + + def retry_operation_sync(self, callee: Callable, retry_settings: Optional[RetrySettings] = None, *args, **kwargs): + """Special interface to execute a bunch of commands with session in a safe, retriable way. + + :param callee: A function, that works with session. + :param retry_settings: RetrySettings object. + + :return: Result sets or exception in case of execution errors. + """ + + retry_settings = RetrySettings() if retry_settings is None else retry_settings + + def wrapped_callee(): + with self.checkout() as session: + return callee(session, *args, **kwargs) + + return retry_operation_sync(wrapped_callee, retry_settings) + + def execute_with_retries( + self, query: str, retry_settings: Optional[RetrySettings] = None, *args, **kwargs + ) -> List[convert.ResultSet]: + """Special interface to execute a one-shot queries in a safe, retriable way. + Note: this method loads all data from stream before return, do not use this + method with huge read queries. + + :param query: A query, yql or sql text. + :param retry_settings: RetrySettings object. + + :return: Result sets or exception in case of execution errors. + """ + + retry_settings = RetrySettings() if retry_settings is None else retry_settings + + def wrapped_callee(): + with self.checkout() as session: + it = session.execute(query, *args, **kwargs) + return [result_set for result_set in it] + + return retry_operation_sync(wrapped_callee, retry_settings) + + +class SimpleQuerySessionCheckout: + def __init__(self, pool: QuerySessionPool): + self._pool = pool + self._session = QuerySessionSync(pool._driver) + + def __enter__(self) -> base.IQuerySession: + self._session.create() + return self._session + + def __exit__(self, exc_type, exc_val, exc_tb): + self._session.delete() diff --git a/ydb/query/session.py b/ydb/query/session.py new file mode 100644 index 00000000..5b9c00ed --- /dev/null +++ b/ydb/query/session.py @@ -0,0 +1,313 @@ +import abc +import enum +import logging +import threading +from typing import ( + Iterable, + Optional, +) + +from . import base + +from .. import _apis, issues, _utilities +from ..connection import _RpcState as RpcState +from .._grpc.grpcwrapper import common_utils +from .._grpc.grpcwrapper import ydb_query as _ydb_query +from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public + +from .transaction import BaseQueryTxContext + + +logger = logging.getLogger(__name__) + + +class QuerySessionStateEnum(enum.Enum): + NOT_INITIALIZED = "NOT_INITIALIZED" + CREATED = "CREATED" + CLOSED = "CLOSED" + + +class QuerySessionStateHelper(abc.ABC): + _VALID_TRANSITIONS = { + QuerySessionStateEnum.NOT_INITIALIZED: [QuerySessionStateEnum.CREATED], + QuerySessionStateEnum.CREATED: [QuerySessionStateEnum.CLOSED], + QuerySessionStateEnum.CLOSED: [], + } + + _READY_TO_USE = [ + QuerySessionStateEnum.CREATED, + ] + + @classmethod + def valid_transition(cls, before: QuerySessionStateEnum, after: QuerySessionStateEnum) -> bool: + return after in cls._VALID_TRANSITIONS[before] + + @classmethod + def ready_to_use(cls, state: QuerySessionStateEnum) -> bool: + return state in cls._READY_TO_USE + + +class QuerySessionState(base.IQuerySessionState): + _session_id: Optional[str] = None + _node_id: Optional[int] = None + _attached: bool = False + _settings: Optional[base.QueryClientSettings] = None + _state: QuerySessionStateEnum = QuerySessionStateEnum.NOT_INITIALIZED + + def __init__(self, settings: base.QueryClientSettings = None): + self._settings = settings + + def reset(self) -> None: + self._session_id = None + self._node_id = None + self._attached = False + + @property + def session_id(self) -> Optional[str]: + return self._session_id + + def set_session_id(self, session_id: str) -> "QuerySessionState": + self._session_id = session_id + return self + + @property + def node_id(self) -> Optional[int]: + return self._node_id + + def set_node_id(self, node_id: int) -> "QuerySessionState": + self._node_id = node_id + return self + + @property + def attached(self) -> bool: + return self._attached + + def set_attached(self, attached: bool) -> "QuerySessionState": + self._attached = attached + + def _check_invalid_transition(self, target: QuerySessionStateEnum) -> None: + if not QuerySessionStateHelper.valid_transition(self._state, target): + raise RuntimeError(f"Session could not be moved from {self._state.value} to {target.value}") + + def _change_state(self, target: QuerySessionStateEnum) -> None: + self._check_invalid_transition(target) + self._state = target + + def _check_session_ready_to_use(self) -> None: + if not QuerySessionStateHelper.ready_to_use(self._state): + raise RuntimeError(f"Session is not ready to use, current state: {self._state.value}") + + def _already_in(self, target) -> bool: + return self._state == target + + +def wrapper_create_session( + rpc_state: RpcState, + response_pb: _apis.ydb_query.CreateSessionResponse, + session_state: QuerySessionState, + session: "BaseQuerySession", +) -> "BaseQuerySession": + message = _ydb_query.CreateSessionResponse.from_proto(response_pb) + issues._process_response(message.status) + session_state.set_session_id(message.session_id).set_node_id(message.node_id) + return session + + +def wrapper_delete_session( + rpc_state: RpcState, + response_pb: _apis.ydb_query.DeleteSessionResponse, + session_state: QuerySessionState, + session: "BaseQuerySession", +) -> "BaseQuerySession": + message = _ydb_query.DeleteSessionResponse.from_proto(response_pb) + issues._process_response(message.status) + session_state.reset() + session_state._change_state(QuerySessionStateEnum.CLOSED) + return session + + +class BaseQuerySession(base.IQuerySession): + _driver: base.SupportedDriverType + _settings: base.QueryClientSettings + _state: QuerySessionState + + def __init__(self, driver: base.SupportedDriverType, settings: Optional[base.QueryClientSettings] = None): + self._driver = driver + self._settings = settings if settings is not None else base.QueryClientSettings() + self._state = QuerySessionState(settings) + + def _create_call(self) -> "BaseQuerySession": + return self._driver( + _apis.ydb_query.CreateSessionRequest(), + _apis.QueryService.Stub, + _apis.QueryService.CreateSession, + wrap_result=wrapper_create_session, + wrap_args=(self._state, self), + ) + + def _delete_call(self) -> "BaseQuerySession": + return self._driver( + _apis.ydb_query.DeleteSessionRequest(session_id=self._state.session_id), + _apis.QueryService.Stub, + _apis.QueryService.DeleteSession, + wrap_result=wrapper_delete_session, + wrap_args=(self._state, self), + ) + + def _attach_call(self) -> Iterable[_apis.ydb_query.SessionState]: + return self._driver( + _apis.ydb_query.AttachSessionRequest(session_id=self._state.session_id), + _apis.QueryService.Stub, + _apis.QueryService.AttachSession, + ) + + def _execute_call( + self, + query: str, + commit_tx: bool = False, + syntax: base.QuerySyntax = None, + exec_mode: base.QueryExecMode = None, + parameters: dict = None, + concurrent_result_sets: bool = False, + ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: + request = base.create_execute_query_request( + query=query, + session_id=self._state.session_id, + commit_tx=commit_tx, + tx_mode=None, + tx_id=None, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + ) + + return self._driver( + request.to_proto(), + _apis.QueryService.Stub, + _apis.QueryService.ExecuteQuery, + ) + + +class QuerySessionSync(BaseQuerySession): + """Session object for Query Service. It is not recommended to control + session's lifecycle manually - use a QuerySessionPool is always a better choise. + """ + + _stream = None + + def _attach(self) -> None: + self._stream = self._attach_call() + status_stream = _utilities.SyncResponseIterator( + self._stream, + lambda response: common_utils.ServerStatus.from_proto(response), + ) + + first_response = next(status_stream) + if first_response.status != issues.StatusCode.SUCCESS: + pass + + self._state.set_attached(True) + self._state._change_state(QuerySessionStateEnum.CREATED) + + threading.Thread( + target=self._check_session_status_loop, + args=(status_stream,), + name="check session status thread", + daemon=True, + ).start() + + def _check_session_status_loop(self, status_stream: _utilities.SyncResponseIterator) -> None: + try: + for status in status_stream: + if status.status != issues.StatusCode.SUCCESS: + self._state.reset() + self._state._change_state(QuerySessionStateEnum.CLOSED) + except Exception: + pass + + def delete(self) -> None: + """ + Deletes a Session of Query Service on server side and releases resources. + + :return: None + """ + if self._state._already_in(QuerySessionStateEnum.CLOSED): + return + + self._state._check_invalid_transition(QuerySessionStateEnum.CLOSED) + self._delete_call() + self._stream.cancel() + + def create(self) -> "QuerySessionSync": + """ + Creates a Session of Query Service on server side and attaches it. + + :return: QuerySessionSync object. + """ + if self._state._already_in(QuerySessionStateEnum.CREATED): + return + + self._state._check_invalid_transition(QuerySessionStateEnum.CREATED) + self._create_call() + self._attach() + + return self + + def transaction(self, tx_mode: Optional[base.BaseQueryTxMode] = None) -> base.IQueryTxContext: + """ + Creates a transaction context manager with specified transaction mode. + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + + :return transaction context manager. + + """ + self._state._check_session_ready_to_use() + + tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite() + + return BaseQueryTxContext( + self._driver, + self._state, + self, + tx_mode, + ) + + def execute( + self, + query: str, + syntax: base.QuerySyntax = None, + exec_mode: base.QueryExecMode = None, + parameters: dict = None, + concurrent_result_sets: bool = False, + ) -> base.SyncResponseContextIterator: + """ + Sends a query to Query Service + :param query: (YQL or SQL text) to be executed. + :param syntax: Syntax of the query, which is a one from the following choises: + 1) QuerySyntax.YQL_V1, which is default; + 2) QuerySyntax.PG. + :param parameters: dict with parameters and YDB types; + :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; + + :return: Iterator with result sets + """ + self._state._check_session_ready_to_use() + + stream_it = self._execute_call( + query=query, + commit_tx=True, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + ) + + return base.SyncResponseContextIterator( + stream_it, + lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp), + ) diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py new file mode 100644 index 00000000..6846b5a5 --- /dev/null +++ b/ydb/query/transaction.py @@ -0,0 +1,393 @@ +import abc +import logging +import enum +import functools +from typing import ( + Iterable, + Optional, +) + +from .. import ( + _apis, + issues, +) +from .._grpc.grpcwrapper import ydb_query as _ydb_query +from ..connection import _RpcState as RpcState + +from . import base + +logger = logging.getLogger(__name__) + + +class QueryTxStateEnum(enum.Enum): + NOT_INITIALIZED = "NOT_INITIALIZED" + BEGINED = "BEGINED" + COMMITTED = "COMMITTED" + ROLLBACKED = "ROLLBACKED" + DEAD = "DEAD" + + +class QueryTxStateHelper(abc.ABC): + _VALID_TRANSITIONS = { + QueryTxStateEnum.NOT_INITIALIZED: [ + QueryTxStateEnum.BEGINED, + QueryTxStateEnum.DEAD, + QueryTxStateEnum.COMMITTED, + QueryTxStateEnum.ROLLBACKED, + ], + QueryTxStateEnum.BEGINED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED, QueryTxStateEnum.DEAD], + QueryTxStateEnum.COMMITTED: [], + QueryTxStateEnum.ROLLBACKED: [], + QueryTxStateEnum.DEAD: [], + } + + @classmethod + def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: + return after in cls._VALID_TRANSITIONS[before] + + @classmethod + def terminal(cls, state: QueryTxStateEnum) -> bool: + return len(cls._VALID_TRANSITIONS[state]) == 0 + + +def reset_tx_id_handler(func): + @functools.wraps(func) + def decorator( + rpc_state, response_pb, session_state: base.IQuerySessionState, tx_state: QueryTxState, *args, **kwargs + ): + try: + return func(rpc_state, response_pb, session_state, tx_state, *args, **kwargs) + except issues.Error: + tx_state._change_state(QueryTxStateEnum.DEAD) + tx_state.tx_id = None + raise + + return decorator + + +class QueryTxState: + def __init__(self, tx_mode: base.BaseQueryTxMode): + """ + Holds transaction context manager info + :param tx_mode: A mode of transaction + """ + self.tx_id = None + self.tx_mode = tx_mode + self._state = QueryTxStateEnum.NOT_INITIALIZED + + def _check_invalid_transition(self, target: QueryTxStateEnum) -> None: + if not QueryTxStateHelper.valid_transition(self._state, target): + raise RuntimeError(f"Transaction could not be moved from {self._state.value} to {target.value}") + + def _change_state(self, target: QueryTxStateEnum) -> None: + self._check_invalid_transition(target) + self._state = target + + def _check_tx_ready_to_use(self) -> None: + if QueryTxStateHelper.terminal(self._state): + raise RuntimeError(f"Transaction is in terminal state: {self._state.value}") + + def _already_in(self, target: QueryTxStateEnum) -> bool: + return self._state == target + + +def _construct_tx_settings(tx_state: QueryTxState) -> _ydb_query.TransactionSettings: + tx_settings = _ydb_query.TransactionSettings.from_public(tx_state.tx_mode) + return tx_settings + + +def _create_begin_transaction_request( + session_state: base.IQuerySessionState, tx_state: QueryTxState +) -> _apis.ydb_query.BeginTransactionRequest: + request = _ydb_query.BeginTransactionRequest( + session_id=session_state.session_id, + tx_settings=_construct_tx_settings(tx_state), + ).to_proto() + return request + + +def _create_commit_transaction_request( + session_state: base.IQuerySessionState, tx_state: QueryTxState +) -> _apis.ydb_query.CommitTransactionRequest: + request = _apis.ydb_query.CommitTransactionRequest() + request.tx_id = tx_state.tx_id + request.session_id = session_state.session_id + return request + + +def _create_rollback_transaction_request( + session_state: base.IQuerySessionState, tx_state: QueryTxState +) -> _apis.ydb_query.RollbackTransactionRequest: + request = _apis.ydb_query.RollbackTransactionRequest() + request.tx_id = tx_state.tx_id + request.session_id = session_state.session_id + return request + + +@base.bad_session_handler +def wrap_tx_begin_response( + rpc_state: RpcState, + response_pb: _apis.ydb_query.BeginTransactionResponse, + session_state: base.IQuerySessionState, + tx_state: QueryTxState, + tx: "BaseQueryTxContext", +) -> "BaseQueryTxContext": + message = _ydb_query.BeginTransactionResponse.from_proto(response_pb) + issues._process_response(message.status) + tx_state._change_state(QueryTxStateEnum.BEGINED) + tx_state.tx_id = message.tx_meta.tx_id + return tx + + +@base.bad_session_handler +@reset_tx_id_handler +def wrap_tx_commit_response( + rpc_state: RpcState, + response_pb: _apis.ydb_query.CommitTransactionResponse, + session_state: base.IQuerySessionState, + tx_state: QueryTxState, + tx: "BaseQueryTxContext", +) -> "BaseQueryTxContext": + message = _ydb_query.CommitTransactionResponse.from_proto(response_pb) + issues._process_response(message.status) + tx_state._change_state(QueryTxStateEnum.COMMITTED) + return tx + + +@base.bad_session_handler +@reset_tx_id_handler +def wrap_tx_rollback_response( + rpc_state: RpcState, + response_pb: _apis.ydb_query.RollbackTransactionResponse, + session_state: base.IQuerySessionState, + tx_state: QueryTxState, + tx: "BaseQueryTxContext", +) -> "BaseQueryTxContext": + message = _ydb_query.RollbackTransactionResponse.from_proto(response_pb) + issues._process_response(message.status) + tx_state._change_state(QueryTxStateEnum.ROLLBACKED) + return tx + + +class BaseQueryTxContext(base.IQueryTxContext): + def __init__(self, driver, session_state, session, tx_mode): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() method; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + """ + + self._driver = driver + self._tx_state = QueryTxState(tx_mode) + self._session_state = session_state + self.session = session + self._prev_stream = None + + def __enter__(self) -> "BaseQueryTxContext": + """ + Enters a context manager and returns a transaction + + :return: A transaction instance + """ + return self + + def __exit__(self, *args, **kwargs): + """ + Closes a transaction context manager and rollbacks transaction if + it is not finished explicitly + """ + self._ensure_prev_stream_finished() + if self._tx_state._state == QueryTxStateEnum.BEGINED: + # It's strictly recommended to close transactions directly + # by using commit_tx=True flag while executing statement or by + # .commit() or .rollback() methods, but here we trying to do best + # effort to avoid useless open transactions + logger.warning("Potentially leaked tx: %s", self._tx_state.tx_id) + try: + self.rollback() + except issues.Error: + logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id) + + @property + def session_id(self) -> str: + """ + A transaction's session id + + :return: A transaction's session id + """ + return self._session_state.session_id + + @property + def tx_id(self) -> Optional[str]: + """ + Returns an id of open transaction or None otherwise + + :return: An id of open transaction or None otherwise + """ + return self._tx_state.tx_id + + def _begin_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext": + return self._driver( + _create_begin_transaction_request(self._session_state, self._tx_state), + _apis.QueryService.Stub, + _apis.QueryService.BeginTransaction, + wrap_tx_begin_response, + settings, + (self._session_state, self._tx_state, self), + ) + + def _commit_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext": + return self._driver( + _create_commit_transaction_request(self._session_state, self._tx_state), + _apis.QueryService.Stub, + _apis.QueryService.CommitTransaction, + wrap_tx_commit_response, + settings, + (self._session_state, self._tx_state, self), + ) + + def _rollback_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext": + return self._driver( + _create_rollback_transaction_request(self._session_state, self._tx_state), + _apis.QueryService.Stub, + _apis.QueryService.RollbackTransaction, + wrap_tx_rollback_response, + settings, + (self._session_state, self._tx_state, self), + ) + + def _execute_call( + self, + query: str, + commit_tx: bool = False, + syntax: base.QuerySyntax = None, + exec_mode: base.QueryExecMode = None, + parameters: dict = None, + concurrent_result_sets: bool = False, + ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: + request = base.create_execute_query_request( + query=query, + session_id=self._session_state.session_id, + commit_tx=commit_tx, + tx_id=self._tx_state.tx_id, + tx_mode=self._tx_state.tx_mode, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + ) + + return self._driver( + request.to_proto(), + _apis.QueryService.Stub, + _apis.QueryService.ExecuteQuery, + ) + + def _ensure_prev_stream_finished(self) -> None: + if self._prev_stream is not None: + for _ in self._prev_stream: + pass + self._prev_stream = None + + def _move_to_beginned(self, tx_id: str) -> None: + if self._tx_state._already_in(QueryTxStateEnum.BEGINED): + return + self._tx_state._change_state(QueryTxStateEnum.BEGINED) + self._tx_state.tx_id = tx_id + + def _move_to_commited(self) -> None: + if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + return + self._tx_state._change_state(QueryTxStateEnum.COMMITTED) + + def begin(self, settings: Optional[base.QueryClientSettings] = None) -> None: + """ + Explicitly begins a transaction + + :param settings: A request settings + + :return: None or exception if begin is failed + """ + self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) + + self._begin_call(settings) + + def commit(self, settings: Optional[base.QueryClientSettings] = None) -> None: + """ + Calls commit on a transaction if it is open otherwise is no-op. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: A committed transaction or exception if commit is failed + """ + if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + return + self._ensure_prev_stream_finished() + + if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: + self._tx_state._change_state(QueryTxStateEnum.COMMITTED) + return + + self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) + + self._commit_call(settings) + + def rollback(self, settings: Optional[base.QueryClientSettings] = None) -> None: + if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + return + + self._ensure_prev_stream_finished() + + if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: + self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED) + return + + self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) + + self._rollback_call(settings) + + def execute( + self, + query: str, + commit_tx: Optional[bool] = False, + syntax: Optional[base.QuerySyntax] = None, + exec_mode: Optional[base.QueryExecMode] = None, + parameters: Optional[dict] = None, + concurrent_result_sets: Optional[bool] = False, + ) -> base.SyncResponseContextIterator: + + self._ensure_prev_stream_finished() + self._tx_state._check_tx_ready_to_use() + + stream_it = self._execute_call( + query=query, + commit_tx=commit_tx, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + ) + self._prev_stream = base.SyncResponseContextIterator( + stream_it, + lambda resp: base.wrap_execute_query_response( + rpc_state=None, + response_pb=resp, + tx=self, + commit_tx=commit_tx, + ), + ) + return self._prev_stream diff --git a/ydb/retries.py b/ydb/retries.py new file mode 100644 index 00000000..5d4f6e6a --- /dev/null +++ b/ydb/retries.py @@ -0,0 +1,136 @@ +import random +import time + +from . import issues +from ._errors import check_retriable_error + + +class BackoffSettings(object): + def __init__(self, ceiling=6, slot_duration=0.001, uncertain_ratio=0.5): + self.ceiling = ceiling + self.slot_duration = slot_duration + self.uncertain_ratio = uncertain_ratio + + def calc_timeout(self, retry_number): + slots_count = 1 << min(retry_number, self.ceiling) + max_duration_ms = slots_count * self.slot_duration * 1000.0 + # duration_ms = random.random() * max_duration_ms * uncertain_ratio) + max_duration_ms * (1 - uncertain_ratio) + duration_ms = max_duration_ms * (random.random() * self.uncertain_ratio + 1.0 - self.uncertain_ratio) + return duration_ms / 1000.0 + + +class RetrySettings(object): + def __init__( + self, + max_retries=10, + max_session_acquire_timeout=None, + on_ydb_error_callback=None, + backoff_ceiling=6, + backoff_slot_duration=1, + get_session_client_timeout=5, + fast_backoff_settings=None, + slow_backoff_settings=None, + idempotent=False, + ): + self.max_retries = max_retries + self.max_session_acquire_timeout = max_session_acquire_timeout + self.on_ydb_error_callback = (lambda e: None) if on_ydb_error_callback is None else on_ydb_error_callback + self.fast_backoff = BackoffSettings(10, 0.005) if fast_backoff_settings is None else fast_backoff_settings + self.slow_backoff = ( + BackoffSettings(backoff_ceiling, backoff_slot_duration) + if slow_backoff_settings is None + else slow_backoff_settings + ) + self.retry_not_found = True + self.idempotent = idempotent + self.retry_internal_error = True + self.unknown_error_handler = lambda e: None + self.get_session_client_timeout = get_session_client_timeout + if max_session_acquire_timeout is not None: + self.get_session_client_timeout = min(self.max_session_acquire_timeout, self.get_session_client_timeout) + + def with_fast_backoff(self, backoff_settings): + self.fast_backoff = backoff_settings + return self + + def with_slow_backoff(self, backoff_settings): + self.slow_backoff = backoff_settings + return self + + +class YdbRetryOperationSleepOpt(object): + def __init__(self, timeout): + self.timeout = timeout + + def __eq__(self, other): + return type(self) == type(other) and self.timeout == other.timeout + + def __repr__(self): + return "YdbRetryOperationSleepOpt(%s)" % self.timeout + + +class YdbRetryOperationFinalResult(object): + def __init__(self, result): + self.result = result + self.exc = None + + def __eq__(self, other): + return type(self) == type(other) and self.result == other.result and self.exc == other.exc + + def __repr__(self): + return "YdbRetryOperationFinalResult(%s, exc=%s)" % (self.result, self.exc) + + def set_exception(self, exc): + self.exc = exc + + +def retry_operation_impl(callee, retry_settings=None, *args, **kwargs): + retry_settings = RetrySettings() if retry_settings is None else retry_settings + status = None + + for attempt in range(retry_settings.max_retries + 1): + try: + result = YdbRetryOperationFinalResult(callee(*args, **kwargs)) + yield result + + if result.exc is not None: + raise result.exc + + except issues.Error as e: + status = e + retry_settings.on_ydb_error_callback(e) + + retriable_info = check_retriable_error(e, retry_settings, attempt) + if not retriable_info.is_retriable: + raise + + skip_yield_error_types = [ + issues.Aborted, + issues.BadSession, + issues.NotFound, + issues.InternalError, + ] + + yield_sleep = True + for t in skip_yield_error_types: + if isinstance(e, t): + yield_sleep = False + + if yield_sleep: + yield YdbRetryOperationSleepOpt(retriable_info.sleep_timeout_seconds) + + except Exception as e: + # you should provide your own handler you want + retry_settings.unknown_error_handler(e) + raise + + raise status + + +def retry_operation_sync(callee, retry_settings=None, *args, **kwargs): + opt_generator = retry_operation_impl(callee, retry_settings, *args, **kwargs) + for next_opt in opt_generator: + if isinstance(next_opt, YdbRetryOperationSleepOpt): + time.sleep(next_opt.timeout) + else: + return next_opt.result diff --git a/ydb/table.py b/ydb/table.py index c21392bb..ac9f9304 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -3,8 +3,6 @@ import ydb from abc import abstractmethod import logging -import time -import random import enum from . import ( @@ -20,7 +18,15 @@ _tx_ctx_impl, tracing, ) -from ._errors import check_retriable_error + +from .retries import ( + YdbRetryOperationFinalResult, # noqa + YdbRetryOperationSleepOpt, # noqa + BackoffSettings, # noqa + retry_operation_impl, # noqa + RetrySettings, + retry_operation_sync, +) try: from . import interceptor @@ -840,137 +846,6 @@ def name(self): return self._name -class BackoffSettings(object): - def __init__(self, ceiling=6, slot_duration=0.001, uncertain_ratio=0.5): - self.ceiling = ceiling - self.slot_duration = slot_duration - self.uncertain_ratio = uncertain_ratio - - def calc_timeout(self, retry_number): - slots_count = 1 << min(retry_number, self.ceiling) - max_duration_ms = slots_count * self.slot_duration * 1000.0 - # duration_ms = random.random() * max_duration_ms * uncertain_ratio) + max_duration_ms * (1 - uncertain_ratio) - duration_ms = max_duration_ms * (random.random() * self.uncertain_ratio + 1.0 - self.uncertain_ratio) - return duration_ms / 1000.0 - - -class RetrySettings(object): - def __init__( - self, - max_retries=10, - max_session_acquire_timeout=None, - on_ydb_error_callback=None, - backoff_ceiling=6, - backoff_slot_duration=1, - get_session_client_timeout=5, - fast_backoff_settings=None, - slow_backoff_settings=None, - idempotent=False, - ): - self.max_retries = max_retries - self.max_session_acquire_timeout = max_session_acquire_timeout - self.on_ydb_error_callback = (lambda e: None) if on_ydb_error_callback is None else on_ydb_error_callback - self.fast_backoff = BackoffSettings(10, 0.005) if fast_backoff_settings is None else fast_backoff_settings - self.slow_backoff = ( - BackoffSettings(backoff_ceiling, backoff_slot_duration) - if slow_backoff_settings is None - else slow_backoff_settings - ) - self.retry_not_found = True - self.idempotent = idempotent - self.retry_internal_error = True - self.unknown_error_handler = lambda e: None - self.get_session_client_timeout = get_session_client_timeout - if max_session_acquire_timeout is not None: - self.get_session_client_timeout = min(self.max_session_acquire_timeout, self.get_session_client_timeout) - - def with_fast_backoff(self, backoff_settings): - self.fast_backoff = backoff_settings - return self - - def with_slow_backoff(self, backoff_settings): - self.slow_backoff = backoff_settings - return self - - -class YdbRetryOperationSleepOpt(object): - def __init__(self, timeout): - self.timeout = timeout - - def __eq__(self, other): - return type(self) == type(other) and self.timeout == other.timeout - - def __repr__(self): - return "YdbRetryOperationSleepOpt(%s)" % self.timeout - - -class YdbRetryOperationFinalResult(object): - def __init__(self, result): - self.result = result - self.exc = None - - def __eq__(self, other): - return type(self) == type(other) and self.result == other.result and self.exc == other.exc - - def __repr__(self): - return "YdbRetryOperationFinalResult(%s, exc=%s)" % (self.result, self.exc) - - def set_exception(self, exc): - self.exc = exc - - -def retry_operation_impl(callee, retry_settings=None, *args, **kwargs): - retry_settings = RetrySettings() if retry_settings is None else retry_settings - status = None - - for attempt in range(retry_settings.max_retries + 1): - try: - result = YdbRetryOperationFinalResult(callee(*args, **kwargs)) - yield result - - if result.exc is not None: - raise result.exc - - except issues.Error as e: - status = e - retry_settings.on_ydb_error_callback(e) - - retriable_info = check_retriable_error(e, retry_settings, attempt) - if not retriable_info.is_retriable: - raise - - skip_yield_error_types = [ - issues.Aborted, - issues.BadSession, - issues.NotFound, - issues.InternalError, - ] - - yield_sleep = True - for t in skip_yield_error_types: - if isinstance(e, t): - yield_sleep = False - - if yield_sleep: - yield YdbRetryOperationSleepOpt(retriable_info.sleep_timeout_seconds) - - except Exception as e: - # you should provide your own handler you want - retry_settings.unknown_error_handler(e) - raise - - raise status - - -def retry_operation_sync(callee, retry_settings=None, *args, **kwargs): - opt_generator = retry_operation_impl(callee, retry_settings, *args, **kwargs) - for next_opt in opt_generator: - if isinstance(next_opt, YdbRetryOperationSleepOpt): - time.sleep(next_opt.timeout) - else: - return next_opt.result - - class TableClientSettings(object): def __init__(self): self._client_query_cache_enabled = False diff --git a/ydb/table_test.py b/ydb/table_test.py index 640662c4..d5d86e05 100644 --- a/ydb/table_test.py +++ b/ydb/table_test.py @@ -1,8 +1,9 @@ from unittest import mock -from . import ( +from . import issues + +from .retries import ( retry_operation_impl, YdbRetryOperationFinalResult, - issues, YdbRetryOperationSleepOpt, RetrySettings, )