diff --git a/examples/query-service/basic_example_asyncio.py b/examples/query-service/basic_example_asyncio.py new file mode 100644 index 00000000..cd7a4919 --- /dev/null +++ b/examples/query-service/basic_example_asyncio.py @@ -0,0 +1,143 @@ +import asyncio +import ydb + + +async def main(): + driver_config = ydb.DriverConfig( + endpoint="grpc://localhost:2136", + database="/local", + # credentials=ydb.credentials_from_env_variables(), + # root_certificates=ydb.load_ydb_root_certificate(), + ) + try: + driver = ydb.aio.Driver(driver_config) + await driver.wait(timeout=5) + except TimeoutError: + raise RuntimeError("Connect failed to YDB") + + pool = ydb.aio.QuerySessionPoolAsync(driver) + + print("=" * 50) + print("DELETE TABLE IF EXISTS") + await pool.execute_with_retries("DROP TABLE IF EXISTS example") + + print("=" * 50) + print("CREATE TABLE") + await pool.execute_with_retries("CREATE TABLE example(key UInt64, value String, PRIMARY KEY (key))") + + await pool.execute_with_retries("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal')") + + async def callee(session): + print("=" * 50) + async with await session.execute("DELETE FROM example"): + pass + + print("BEFORE ACTION") + async with await session.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + async for result_set in results: + print(f"rows: {str(result_set.rows)}") + + print("=" * 50) + print("INSERT WITH COMMIT TX") + + async with session.transaction() as tx: + await tx.begin() + + async with await tx.execute("INSERT INTO example (key, value) VALUES (1, 'onepieceisreal')"): + pass + + async with await tx.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + async for result_set in results: + print(f"rows: {str(result_set.rows)}") + + await tx.commit() + + print("=" * 50) + print("AFTER COMMIT TX") + + async with await session.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + async for result_set in results: + print(f"rows: {str(result_set.rows)}") + + print("=" * 50) + print("INSERT WITH ROLLBACK TX") + + async with session.transaction() as tx: + await tx.begin() + + async with await tx.execute("INSERT INTO example (key, value) VALUES (2, 'onepieceisreal')"): + pass + + async with await tx.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + async for result_set in results: + print(f"rows: {str(result_set.rows)}") + + await tx.rollback() + + print("=" * 50) + print("AFTER ROLLBACK TX") + + async with await session.execute("SELECT COUNT(*) AS rows_count FROM example") as results: + async for result_set in results: + print(f"rows: {str(result_set.rows)}") + + await pool.retry_operation_async(callee) + + async def callee(session: ydb.aio.QuerySessionAsync): + query_print = """select $a""" + + print("=" * 50) + print("Check implicit typed parameters") + + values = [ + 1, + 1.0, + True, + "text", + {"4": 8, "15": 16, "23": 42}, + [{"name": "Michael"}, {"surname": "Scott"}], + ] + + for value in values: + print(f"value: {value}") + async with await session.transaction().execute( + query=query_print, + parameters={"$a": value}, + commit_tx=True, + ) as results: + async for result_set in results: + print(f"rows: {str(result_set.rows)}") + + print("=" * 50) + print("Check typed parameters as tuple pair") + + typed_value = ([1, 2, 3], ydb.ListType(ydb.PrimitiveType.Int64)) + print(f"value: {typed_value}") + + async with await session.transaction().execute( + query=query_print, + parameters={"$a": typed_value}, + commit_tx=True, + ) as results: + async for result_set in results: + print(f"rows: {str(result_set.rows)}") + + print("=" * 50) + print("Check typed parameters as ydb.TypedValue") + + typed_value = ydb.TypedValue(111, ydb.PrimitiveType.Int64) + print(f"value: {typed_value}") + + async with await session.transaction().execute( + query=query_print, + parameters={"$a": typed_value}, + commit_tx=True, + ) as results: + async for result_set in results: + print(f"rows: {str(result_set.rows)}") + + await pool.retry_operation_async(callee) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/aio/query/__init__.py b/tests/aio/query/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/aio/query/conftest.py b/tests/aio/query/conftest.py new file mode 100644 index 00000000..0fbdbd38 --- /dev/null +++ b/tests/aio/query/conftest.py @@ -0,0 +1,34 @@ +import pytest +from ydb.aio.query.session import QuerySessionAsync +from ydb.aio.query.pool import QuerySessionPoolAsync + + +@pytest.fixture +async def session(driver): + session = QuerySessionAsync(driver) + + yield session + + try: + await session.delete() + except BaseException: + pass + + +@pytest.fixture +async def tx(session): + await session.create() + transaction = session.transaction() + + yield transaction + + try: + await transaction.rollback() + except BaseException: + pass + + +@pytest.fixture +def pool(driver): + pool = QuerySessionPoolAsync(driver) + yield pool diff --git a/tests/aio/query/test_query_session.py b/tests/aio/query/test_query_session.py new file mode 100644 index 00000000..117e39af --- /dev/null +++ b/tests/aio/query/test_query_session.py @@ -0,0 +1,112 @@ +import pytest +from ydb.aio.query.session import QuerySessionAsync + + +def _check_session_state_empty(session: QuerySessionAsync): + assert session._state.session_id is None + assert session._state.node_id is None + assert not session._state.attached + + +def _check_session_state_full(session: QuerySessionAsync): + assert session._state.session_id is not None + assert session._state.node_id is not None + assert session._state.attached + + +class TestAsyncQuerySession: + @pytest.mark.asyncio + async def test_session_normal_lifecycle(self, session: QuerySessionAsync): + _check_session_state_empty(session) + + await session.create() + _check_session_state_full(session) + + await session.delete() + _check_session_state_empty(session) + + @pytest.mark.asyncio + async def test_second_create_do_nothing(self, session: QuerySessionAsync): + await session.create() + _check_session_state_full(session) + + session_id_before = session._state.session_id + node_id_before = session._state.node_id + + await session.create() + _check_session_state_full(session) + + assert session._state.session_id == session_id_before + assert session._state.node_id == node_id_before + + @pytest.mark.asyncio + async def test_second_delete_do_nothing(self, session: QuerySessionAsync): + await session.create() + + await session.delete() + await session.delete() + + @pytest.mark.asyncio + async def test_delete_before_create_not_possible(self, session: QuerySessionAsync): + with pytest.raises(RuntimeError): + await session.delete() + + @pytest.mark.asyncio + async def test_create_after_delete_not_possible(self, session: QuerySessionAsync): + await session.create() + await session.delete() + with pytest.raises(RuntimeError): + await session.create() + + def test_transaction_before_create_raises(self, session: QuerySessionAsync): + with pytest.raises(RuntimeError): + session.transaction() + + @pytest.mark.asyncio + async def test_transaction_after_delete_raises(self, session: QuerySessionAsync): + await session.create() + + await session.delete() + + with pytest.raises(RuntimeError): + session.transaction() + + @pytest.mark.asyncio + async def test_transaction_after_create_not_raises(self, session: QuerySessionAsync): + await session.create() + session.transaction() + + @pytest.mark.asyncio + async def test_execute_before_create_raises(self, session: QuerySessionAsync): + with pytest.raises(RuntimeError): + await session.execute("select 1;") + + @pytest.mark.asyncio + async def test_execute_after_delete_raises(self, session: QuerySessionAsync): + await session.create() + await session.delete() + with pytest.raises(RuntimeError): + await session.execute("select 1;") + + @pytest.mark.asyncio + async def test_basic_execute(self, session: QuerySessionAsync): + await session.create() + it = await session.execute("select 1;") + result_sets = [result_set async for result_set in it] + + assert len(result_sets) == 1 + assert len(result_sets[0].rows) == 1 + assert len(result_sets[0].columns) == 1 + assert list(result_sets[0].rows[0].values()) == [1] + + @pytest.mark.asyncio + async def test_two_results(self, session: QuerySessionAsync): + await session.create() + res = [] + + async with await session.execute("select 1; select 2") as results: + async for result_set in results: + if len(result_set.rows) > 0: + res.append(list(result_set.rows[0].values())) + + assert res == [[1], [2]] diff --git a/tests/aio/query/test_query_session_pool.py b/tests/aio/query/test_query_session_pool.py new file mode 100644 index 00000000..e544f7b6 --- /dev/null +++ b/tests/aio/query/test_query_session_pool.py @@ -0,0 +1,56 @@ +import pytest +import ydb +from ydb.aio.query.pool import QuerySessionPoolAsync +from ydb.aio.query.session import QuerySessionAsync, QuerySessionStateEnum + + +class TestQuerySessionPoolAsync: + @pytest.mark.asyncio + async def test_checkout_provides_created_session(self, pool: QuerySessionPoolAsync): + async with pool.checkout() as session: + assert session._state._state == QuerySessionStateEnum.CREATED + + assert session._state._state == QuerySessionStateEnum.CLOSED + + @pytest.mark.asyncio + async def test_oneshot_query_normal(self, pool: QuerySessionPoolAsync): + res = await pool.execute_with_retries("select 1;") + assert len(res) == 1 + + @pytest.mark.asyncio + async def test_oneshot_ddl_query(self, pool: QuerySessionPoolAsync): + await pool.execute_with_retries("create table Queen(key UInt64, PRIMARY KEY (key));") + await pool.execute_with_retries("drop table Queen;") + + @pytest.mark.asyncio + async def test_oneshot_query_raises(self, pool: QuerySessionPoolAsync): + with pytest.raises(ydb.GenericError): + await pool.execute_with_retries("Is this the real life? Is this just fantasy?") + + @pytest.mark.asyncio + async def test_retry_op_uses_created_session(self, pool: QuerySessionPoolAsync): + async def callee(session: QuerySessionAsync): + assert session._state._state == QuerySessionStateEnum.CREATED + + await pool.retry_operation_async(callee) + + @pytest.mark.asyncio + async def test_retry_op_normal(self, pool: QuerySessionPoolAsync): + async def callee(session: QuerySessionAsync): + async with session.transaction() as tx: + iterator = await tx.execute("select 1;", commit_tx=True) + return [result_set async for result_set in iterator] + + res = await pool.retry_operation_async(callee) + assert len(res) == 1 + + @pytest.mark.asyncio + async def test_retry_op_raises(self, pool: QuerySessionPoolAsync): + class CustomException(Exception): + pass + + async def callee(session: QuerySessionAsync): + raise CustomException() + + with pytest.raises(CustomException): + await pool.retry_operation_async(callee) diff --git a/tests/aio/query/test_query_transaction.py b/tests/aio/query/test_query_transaction.py new file mode 100644 index 00000000..e332b086 --- /dev/null +++ b/tests/aio/query/test_query_transaction.py @@ -0,0 +1,94 @@ +import pytest + +from ydb.aio.query.transaction import QueryTxContextAsync +from ydb.query.transaction import QueryTxStateEnum + + +class TestAsyncQueryTransaction: + @pytest.mark.asyncio + async def test_tx_begin(self, tx: QueryTxContextAsync): + assert tx.tx_id is None + + await tx.begin() + assert tx.tx_id is not None + + @pytest.mark.asyncio + async def test_tx_allow_double_commit(self, tx: QueryTxContextAsync): + await tx.begin() + await tx.commit() + await tx.commit() + + @pytest.mark.asyncio + async def test_tx_allow_double_rollback(self, tx: QueryTxContextAsync): + await tx.begin() + await tx.rollback() + await tx.rollback() + + @pytest.mark.asyncio + async def test_tx_commit_before_begin(self, tx: QueryTxContextAsync): + await tx.commit() + assert tx._tx_state._state == QueryTxStateEnum.COMMITTED + + @pytest.mark.asyncio + async def test_tx_rollback_before_begin(self, tx: QueryTxContextAsync): + await tx.rollback() + assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED + + @pytest.mark.asyncio + async def test_tx_first_execute_begins_tx(self, tx: QueryTxContextAsync): + await tx.execute("select 1;") + await tx.commit() + + @pytest.mark.asyncio + async def test_interactive_tx_commit(self, tx: QueryTxContextAsync): + await tx.execute("select 1;", commit_tx=True) + with pytest.raises(RuntimeError): + await tx.execute("select 1;") + + @pytest.mark.asyncio + async def test_tx_execute_raises_after_commit(self, tx: QueryTxContextAsync): + await tx.begin() + await tx.commit() + with pytest.raises(RuntimeError): + await tx.execute("select 1;") + + @pytest.mark.asyncio + async def test_tx_execute_raises_after_rollback(self, tx: QueryTxContextAsync): + await tx.begin() + await tx.rollback() + with pytest.raises(RuntimeError): + await tx.execute("select 1;") + + @pytest.mark.asyncio + async def test_context_manager_rollbacks_tx(self, tx: QueryTxContextAsync): + async with tx: + await tx.begin() + + assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED + + @pytest.mark.asyncio + async def test_context_manager_normal_flow(self, tx: QueryTxContextAsync): + async with tx: + await tx.begin() + await tx.execute("select 1;") + await tx.commit() + + assert tx._tx_state._state == QueryTxStateEnum.COMMITTED + + @pytest.mark.asyncio + async def test_context_manager_does_not_hide_exceptions(self, tx: QueryTxContextAsync): + class CustomException(Exception): + pass + + with pytest.raises(CustomException): + async with tx: + raise CustomException() + + @pytest.mark.asyncio + async def test_execute_as_context_manager(self, tx: QueryTxContextAsync): + await tx.begin() + + async with await tx.execute("select 1;") as results: + res = [result_set async for result_set in results] + + assert len(res) == 1 diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 1c3fdda2..07a43fa6 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -1,62 +1,62 @@ import pytest -from ydb.query.transaction import BaseQueryTxContext +from ydb.query.transaction import QueryTxContextSync from ydb.query.transaction import QueryTxStateEnum class TestQueryTransaction: - def test_tx_begin(self, tx: BaseQueryTxContext): + def test_tx_begin(self, tx: QueryTxContextSync): assert tx.tx_id is None tx.begin() assert tx.tx_id is not None - def test_tx_allow_double_commit(self, tx: BaseQueryTxContext): + def test_tx_allow_double_commit(self, tx: QueryTxContextSync): tx.begin() tx.commit() tx.commit() - def test_tx_allow_double_rollback(self, tx: BaseQueryTxContext): + def test_tx_allow_double_rollback(self, tx: QueryTxContextSync): tx.begin() tx.rollback() tx.rollback() - def test_tx_commit_before_begin(self, tx: BaseQueryTxContext): + def test_tx_commit_before_begin(self, tx: QueryTxContextSync): tx.commit() assert tx._tx_state._state == QueryTxStateEnum.COMMITTED - def test_tx_rollback_before_begin(self, tx: BaseQueryTxContext): + def test_tx_rollback_before_begin(self, tx: QueryTxContextSync): tx.rollback() assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED - def test_tx_first_execute_begins_tx(self, tx: BaseQueryTxContext): + def test_tx_first_execute_begins_tx(self, tx: QueryTxContextSync): tx.execute("select 1;") tx.commit() - def test_interactive_tx_commit(self, tx: BaseQueryTxContext): + def test_interactive_tx_commit(self, tx: QueryTxContextSync): tx.execute("select 1;", commit_tx=True) with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_tx_execute_raises_after_commit(self, tx: BaseQueryTxContext): + def test_tx_execute_raises_after_commit(self, tx: QueryTxContextSync): tx.begin() tx.commit() with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_tx_execute_raises_after_rollback(self, tx: BaseQueryTxContext): + def test_tx_execute_raises_after_rollback(self, tx: QueryTxContextSync): tx.begin() tx.rollback() with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_context_manager_rollbacks_tx(self, tx: BaseQueryTxContext): + def test_context_manager_rollbacks_tx(self, tx: QueryTxContextSync): with tx: tx.begin() assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED - def test_context_manager_normal_flow(self, tx: BaseQueryTxContext): + def test_context_manager_normal_flow(self, tx: QueryTxContextSync): with tx: tx.begin() tx.execute("select 1;") @@ -64,7 +64,7 @@ def test_context_manager_normal_flow(self, tx: BaseQueryTxContext): assert tx._tx_state._state == QueryTxStateEnum.COMMITTED - def test_context_manager_does_not_hide_exceptions(self, tx: BaseQueryTxContext): + def test_context_manager_does_not_hide_exceptions(self, tx: QueryTxContextSync): class CustomException(Exception): pass @@ -72,7 +72,7 @@ class CustomException(Exception): with tx: raise CustomException() - def test_execute_as_context_manager(self, tx: BaseQueryTxContext): + def test_execute_as_context_manager(self, tx: QueryTxContextSync): tx.begin() with tx.execute("select 1;") as results: diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index a7febd5b..7fb5b684 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -24,7 +24,8 @@ from google.protobuf.duration_pb2 import Duration as ProtoDuration from google.protobuf.timestamp_pb2 import Timestamp as ProtoTimeStamp -import ydb.aio +from ...driver import Driver +from ...aio.driver import Driver as DriverIO # Workaround for good IDE and universal for runtime if typing.TYPE_CHECKING: @@ -142,7 +143,7 @@ def close(self): ... -SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver] +SupportedDriverType = Union[Driver, DriverIO] class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): @@ -181,7 +182,7 @@ def _clean_executor(self, wait: bool): if self._wait_executor: self._wait_executor.shutdown(wait) - async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): + async def _start_asyncio_driver(self, driver: DriverIO, stub, method): requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) stream_call = await driver( requests_iterator, @@ -191,7 +192,7 @@ async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): self._stream_call = stream_call self.from_server_grpc = stream_call.__aiter__() - async def _start_sync_driver(self, driver: ydb.Driver, stub, method): + async def _start_sync_driver(self, driver: Driver, stub, method): requests_iterator = AsyncQueueToSyncIteratorAsyncIO(self.from_client_grpc) self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) diff --git a/ydb/aio/__init__.py b/ydb/aio/__init__.py index acc44db5..0e7d4e74 100644 --- a/ydb/aio/__init__.py +++ b/ydb/aio/__init__.py @@ -1,2 +1,3 @@ from .driver import Driver # noqa from .table import SessionPool, retry_operation # noqa +from .query import QuerySessionPoolAsync, QuerySessionAsync # noqa diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 10cbead6..454378b0 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -7,6 +7,9 @@ def cancel(self): self.it.cancel() return self + def __iter__(self): + return self + def __aiter__(self): return self diff --git a/ydb/aio/query/__init__.py b/ydb/aio/query/__init__.py new file mode 100644 index 00000000..829d7b54 --- /dev/null +++ b/ydb/aio/query/__init__.py @@ -0,0 +1,7 @@ +__all__ = [ + "QuerySessionPoolAsync", + "QuerySessionAsync", +] + +from .pool import QuerySessionPoolAsync +from .session import QuerySessionAsync diff --git a/ydb/aio/query/base.py b/ydb/aio/query/base.py new file mode 100644 index 00000000..3800ce3d --- /dev/null +++ b/ydb/aio/query/base.py @@ -0,0 +1,11 @@ +from .. import _utilities + + +class AsyncResponseContextIterator(_utilities.AsyncResponseIterator): + async def __aenter__(self) -> "AsyncResponseContextIterator": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # To close stream on YDB it is necessary to scroll through it to the end + async for _ in self: + pass diff --git a/ydb/aio/query/pool.py b/ydb/aio/query/pool.py new file mode 100644 index 00000000..53f11a03 --- /dev/null +++ b/ydb/aio/query/pool.py @@ -0,0 +1,93 @@ +import logging +from typing import ( + Callable, + Optional, + List, +) + +from .session import ( + QuerySessionAsync, +) +from ...retries import ( + RetrySettings, + retry_operation_async, +) +from ... import convert +from ..._grpc.grpcwrapper import common_utils + +logger = logging.getLogger(__name__) + + +class QuerySessionPoolAsync: + """QuerySessionPoolAsync is an object to simplify operations with sessions of Query Service.""" + + def __init__(self, driver: common_utils.SupportedDriverType): + """ + :param driver: A driver instance + """ + + logger.warning("QuerySessionPoolAsync is an experimental API, which could be changed.") + self._driver = driver + + def checkout(self) -> "SimpleQuerySessionCheckoutAsync": + """WARNING: This API is experimental and could be changed. + Return a Session context manager, that opens session on enter and closes session on exit. + """ + + return SimpleQuerySessionCheckoutAsync(self) + + async def retry_operation_async( + self, callee: Callable, retry_settings: Optional[RetrySettings] = None, *args, **kwargs + ): + """WARNING: This API is experimental and could be changed. + Special interface to execute a bunch of commands with session in a safe, retriable way. + + :param callee: A function, that works with session. + :param retry_settings: RetrySettings object. + + :return: Result sets or exception in case of execution errors. + """ + + retry_settings = RetrySettings() if retry_settings is None else retry_settings + + async def wrapped_callee(): + async with self.checkout() as session: + return await callee(session, *args, **kwargs) + + return await retry_operation_async(wrapped_callee, retry_settings) + + async def execute_with_retries( + self, query: str, retry_settings: Optional[RetrySettings] = None, *args, **kwargs + ) -> List[convert.ResultSet]: + """WARNING: This API is experimental and could be changed. + Special interface to execute a one-shot queries in a safe, retriable way. + Note: this method loads all data from stream before return, do not use this + method with huge read queries. + + :param query: A query, yql or sql text. + :param retry_settings: RetrySettings object. + + :return: Result sets or exception in case of execution errors. + """ + + retry_settings = RetrySettings() if retry_settings is None else retry_settings + + async def wrapped_callee(): + async with self.checkout() as session: + it = await session.execute(query, *args, **kwargs) + return [result_set async for result_set in it] + + return await retry_operation_async(wrapped_callee, retry_settings) + + +class SimpleQuerySessionCheckoutAsync: + def __init__(self, pool: QuerySessionPoolAsync): + self._pool = pool + self._session = QuerySessionAsync(pool._driver) + + async def __aenter__(self) -> QuerySessionAsync: + await self._session.create() + return self._session + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._session.delete() diff --git a/ydb/aio/query/session.py b/ydb/aio/query/session.py new file mode 100644 index 00000000..3b918e61 --- /dev/null +++ b/ydb/aio/query/session.py @@ -0,0 +1,140 @@ +import asyncio + +from typing import ( + Optional, +) + +from .base import AsyncResponseContextIterator +from .transaction import QueryTxContextAsync +from .. import _utilities +from ... import issues +from ..._grpc.grpcwrapper import common_utils +from ..._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public + +from ...query import base +from ...query.session import ( + BaseQuerySession, + QuerySessionStateEnum, +) + + +class QuerySessionAsync(BaseQuerySession): + """Session object for Query Service. It is not recommended to control + session's lifecycle manually - use a QuerySessionPool is always a better choise. + """ + + _loop: asyncio.AbstractEventLoop + _status_stream: _utilities.AsyncResponseIterator = None + + def __init__( + self, + driver: common_utils.SupportedDriverType, + settings: Optional[base.QueryClientSettings] = None, + loop: asyncio.AbstractEventLoop = None, + ): + super(QuerySessionAsync, self).__init__(driver, settings) + self._loop = loop if loop is not None else asyncio.get_running_loop() + + async def _attach(self) -> None: + self._stream = await self._attach_call() + self._status_stream = _utilities.AsyncResponseIterator( + self._stream, + lambda response: common_utils.ServerStatus.from_proto(response), + ) + + first_response = await self._status_stream.next() + if first_response.status != issues.StatusCode.SUCCESS: + pass + + self._state.set_attached(True) + self._state._change_state(QuerySessionStateEnum.CREATED) + + self._loop.create_task(self._check_session_status_loop(), name="check session status task") + + async def _check_session_status_loop(self) -> None: + try: + async for status in self._status_stream: + if status.status != issues.StatusCode.SUCCESS: + self._state.reset() + self._state._change_state(QuerySessionStateEnum.CLOSED) + except Exception: + if not self._state._already_in(QuerySessionStateEnum.CLOSED): + self._state.reset() + self._state._change_state(QuerySessionStateEnum.CLOSED) + + async def delete(self) -> None: + """WARNING: This API is experimental and could be changed. + + Deletes a Session of Query Service on server side and releases resources. + + :return: None + """ + if self._state._already_in(QuerySessionStateEnum.CLOSED): + return + + self._state._check_invalid_transition(QuerySessionStateEnum.CLOSED) + await self._delete_call() + self._stream.cancel() + + async def create(self) -> "QuerySessionAsync": + """WARNING: This API is experimental and could be changed. + + Creates a Session of Query Service on server side and attaches it. + + :return: QuerySessionSync object. + """ + if self._state._already_in(QuerySessionStateEnum.CREATED): + return + + self._state._check_invalid_transition(QuerySessionStateEnum.CREATED) + await self._create_call() + await self._attach() + + return self + + def transaction(self, tx_mode=None) -> QueryTxContextAsync: + self._state._check_session_ready_to_use() + tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite() + + return QueryTxContextAsync( + self._driver, + self._state, + self, + tx_mode, + ) + + async def execute( + self, + query: str, + parameters: dict = None, + syntax: base.QuerySyntax = None, + exec_mode: base.QueryExecMode = None, + concurrent_result_sets: bool = False, + ) -> AsyncResponseContextIterator: + """WARNING: This API is experimental and could be changed. + + Sends a query to Query Service + :param query: (YQL or SQL text) to be executed. + :param syntax: Syntax of the query, which is a one from the following choises: + 1) QuerySyntax.YQL_V1, which is default; + 2) QuerySyntax.PG. + :param parameters: dict with parameters and YDB types; + :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; + + :return: Iterator with result sets + """ + self._state._check_session_ready_to_use() + + stream_it = await self._execute_call( + query=query, + commit_tx=True, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + ) + + return AsyncResponseContextIterator( + stream_it, + lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp), + ) diff --git a/ydb/aio/query/transaction.py b/ydb/aio/query/transaction.py new file mode 100644 index 00000000..f8e332fa --- /dev/null +++ b/ydb/aio/query/transaction.py @@ -0,0 +1,151 @@ +import logging +from typing import ( + Optional, +) + +from .base import AsyncResponseContextIterator +from ... import issues +from ...query import base +from ...query.transaction import ( + BaseQueryTxContext, + QueryTxStateEnum, +) + +logger = logging.getLogger(__name__) + + +class QueryTxContextAsync(BaseQueryTxContext): + async def __aenter__(self) -> "QueryTxContextAsync": + """ + Enters a context manager and returns a transaction + + :return: A transaction instance + """ + return self + + async def __aexit__(self, *args, **kwargs): + """ + Closes a transaction context manager and rollbacks transaction if + it is not finished explicitly + """ + await self._ensure_prev_stream_finished() + if self._tx_state._state == QueryTxStateEnum.BEGINED: + # It's strictly recommended to close transactions directly + # by using commit_tx=True flag while executing statement or by + # .commit() or .rollback() methods, but here we trying to do best + # effort to avoid useless open transactions + logger.warning("Potentially leaked tx: %s", self._tx_state.tx_id) + try: + await self.rollback() + except issues.Error: + logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id) + + async def _ensure_prev_stream_finished(self) -> None: + if self._prev_stream is not None: + async with self._prev_stream: + pass + self._prev_stream = None + + async def begin(self, settings: Optional[base.QueryClientSettings] = None) -> "QueryTxContextAsync": + """WARNING: This API is experimental and could be changed. + + Explicitly begins a transaction + + :param settings: A request settings + + :return: None or exception if begin is failed + """ + await self._begin_call(settings) + + async def commit(self, settings: Optional[base.QueryClientSettings] = None) -> None: + """WARNING: This API is experimental and could be changed. + + Calls commit on a transaction if it is open otherwise is no-op. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: A committed transaction or exception if commit is failed + """ + if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + return + + if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: + self._tx_state._change_state(QueryTxStateEnum.COMMITTED) + return + + await self._ensure_prev_stream_finished() + + await self._commit_call(settings) + + async def rollback(self, settings: Optional[base.QueryClientSettings] = None) -> None: + """WARNING: This API is experimental and could be changed. + + Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: A committed transaction or exception if commit is failed + """ + if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + return + + if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: + self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED) + return + + await self._ensure_prev_stream_finished() + + await self._rollback_call(settings) + + async def execute( + self, + query: str, + parameters: Optional[dict] = None, + commit_tx: Optional[bool] = False, + syntax: Optional[base.QuerySyntax] = None, + exec_mode: Optional[base.QueryExecMode] = None, + concurrent_result_sets: Optional[bool] = False, + settings: Optional[base.QueryClientSettings] = None, + ) -> AsyncResponseContextIterator: + """WARNING: This API is experimental and could be changed. + + Sends a query to Query Service + :param query: (YQL or SQL text) to be executed. + :param parameters: dict with parameters and YDB types; + :param commit_tx: A special flag that allows transaction commit. + :param syntax: Syntax of the query, which is a one from the following choises: + 1) QuerySyntax.YQL_V1, which is default; + 2) QuerySyntax.PG. + :param exec_mode: Exec mode of the query, which is a one from the following choises: + 1) QueryExecMode.EXECUTE, which is default; + 2) QueryExecMode.EXPLAIN; + 3) QueryExecMode.VALIDATE; + 4) QueryExecMode.PARSE. + :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; + + :return: Iterator with result sets + """ + await self._ensure_prev_stream_finished() + + stream_it = await self._execute_call( + query=query, + commit_tx=commit_tx, + syntax=syntax, + exec_mode=exec_mode, + parameters=parameters, + concurrent_result_sets=concurrent_result_sets, + ) + + settings = settings if settings is not None else self.session._settings + self._prev_stream = AsyncResponseContextIterator( + stream_it, + lambda resp: base.wrap_execute_query_response( + rpc_state=None, + response_pb=resp, + tx=self, + commit_tx=commit_tx, + ), + ) + return self._prev_stream diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py index eb967abc..40e512cd 100644 --- a/ydb/query/__init__.py +++ b/ydb/query/__init__.py @@ -11,13 +11,12 @@ import logging from .base import ( - IQueryClient, - SupportedDriverType, QueryClientSettings, ) from .session import QuerySessionSync +from .._grpc.grpcwrapper import common_utils from .._grpc.grpcwrapper.ydb_query_public_types import ( QueryOnlineReadOnly, QuerySerializableReadWrite, @@ -30,8 +29,8 @@ logger = logging.getLogger(__name__) -class QueryClientSync(IQueryClient): - def __init__(self, driver: SupportedDriverType, query_client_settings: QueryClientSettings = None): +class QueryClientSync: + def __init__(self, driver: common_utils.SupportedDriverType, query_client_settings: QueryClientSettings = None): logger.warning("QueryClientSync is an experimental API, which could be changed.") self._driver = driver self._settings = query_client_settings diff --git a/ydb/query/base.py b/ydb/query/base.py index eef51ee6..55087d0c 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -2,14 +2,11 @@ import enum import functools +import typing from typing import ( - Iterator, Optional, ) -from .._grpc.grpcwrapper.common_utils import ( - SupportedDriverType, -) from .._grpc.grpcwrapper import ydb_query from .._grpc.grpcwrapper.ydb_query_public_types import ( BaseQueryTxMode, @@ -20,6 +17,9 @@ from .. import _utilities from .. import _apis +if typing.TYPE_CHECKING: + from .transaction import BaseQueryTxContext + class QuerySyntax(enum.IntEnum): UNSPECIFIED = 0 @@ -48,6 +48,7 @@ def __enter__(self) -> "SyncResponseContextIterator": return self def __exit__(self, exc_type, exc_val, exc_tb): + # To close stream on YDB it is necessary to scroll through it to the end for _ in self: pass @@ -117,231 +118,6 @@ def set_attached(self, attached: bool) -> "IQuerySessionState": pass -class IQuerySession(abc.ABC): - """Session object for Query Service. It is not recommended to control - session's lifecycle manually - use a QuerySessionPool is always a better choise. - """ - - @abc.abstractmethod - def __init__(self, driver: SupportedDriverType, settings: Optional[QueryClientSettings] = None): - pass - - @abc.abstractmethod - def create(self) -> "IQuerySession": - """WARNING: This API is experimental and could be changed. - - Creates a Session of Query Service on server side and attaches it. - - :return: Session object. - """ - pass - - @abc.abstractmethod - def delete(self) -> None: - """WARNING: This API is experimental and could be changed. - - Deletes a Session of Query Service on server side and releases resources. - - :return: None - """ - pass - - @abc.abstractmethod - def transaction(self, tx_mode: Optional[BaseQueryTxMode] = None) -> "IQueryTxContext": - """WARNING: This API is experimental and could be changed. - - Creates a transaction context manager with specified transaction mode. - - :param tx_mode: Transaction mode, which is a one from the following choises: - 1) QuerySerializableReadWrite() which is default mode; - 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); - 3) QuerySnapshotReadOnly(); - 4) QueryStaleReadOnly(). - - :return: transaction context manager. - """ - pass - - @abc.abstractmethod - def execute( - self, - query: str, - parameters: Optional[dict] = None, - syntax: Optional[QuerySyntax] = None, - exec_mode: Optional[QueryExecMode] = None, - concurrent_result_sets: Optional[bool] = False, - ) -> Iterator: - """WARNING: This API is experimental and could be changed. - - Sends a query to Query Service - :param query: (YQL or SQL text) to be executed. - :param syntax: Syntax of the query, which is a one from the following choises: - 1) QuerySyntax.YQL_V1, which is default; - 2) QuerySyntax.PG. - :param parameters: dict with parameters and YDB types; - :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; - - :return: Iterator with result sets - """ - - -class IQueryTxContext(abc.ABC): - """ - An object that provides a simple transaction context manager that allows statements execution - in a transaction. You don't have to open transaction explicitly, because context manager encapsulates - transaction control logic, and opens new transaction if: - 1) By explicit .begin(); - 2) On execution of a first statement, which is strictly recommended method, because that avoids - useless round trip - - This context manager is not thread-safe, so you should not manipulate on it concurrently. - """ - - @abc.abstractmethod - def __init__( - self, - driver: SupportedDriverType, - session_state: IQuerySessionState, - session: IQuerySession, - tx_mode: BaseQueryTxMode, - ): - """ - An object that provides a simple transaction context manager that allows statements execution - in a transaction. You don't have to open transaction explicitly, because context manager encapsulates - transaction control logic, and opens new transaction if: - - 1) By explicit .begin() method; - 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip - - This context manager is not thread-safe, so you should not manipulate on it concurrently. - - :param driver: A driver instance - :param session_state: A state of session - :param tx_mode: Transaction mode, which is a one from the following choises: - 1) QuerySerializableReadWrite() which is default mode; - 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); - 3) QuerySnapshotReadOnly(); - 4) QueryStaleReadOnly(). - """ - pass - - @abc.abstractmethod - def __enter__(self) -> "IQueryTxContext": - """ - Enters a context manager and returns a transaction - - :return: A transaction instance - """ - pass - - @abc.abstractmethod - def __exit__(self, *args, **kwargs): - """ - Closes a transaction context manager and rollbacks transaction if - it is not finished explicitly - """ - pass - - @property - @abc.abstractmethod - def session_id(self) -> str: - """ - A transaction's session id - - :return: A transaction's session id - """ - pass - - @property - @abc.abstractmethod - def tx_id(self) -> Optional[str]: - """ - Returns an id of open transaction or None otherwise - - :return: An id of open transaction or None otherwise - """ - pass - - @abc.abstractmethod - def begin(self, settings: Optional[QueryClientSettings] = None) -> "IQueryTxContext": - """WARNING: This API is experimental and could be changed. - - Explicitly begins a transaction - - :param settings: A request settings - - :return: Transaction object or exception if begin is failed - """ - pass - - @abc.abstractmethod - def commit(self, settings: Optional[QueryClientSettings] = None) -> None: - """WARNING: This API is experimental and could be changed. - - Calls commit on a transaction if it is open. If transaction execution - failed then this method raises PreconditionFailed. - - :param settings: A request settings - - :return: None or exception if commit is failed - """ - pass - - @abc.abstractmethod - def rollback(self, settings: Optional[QueryClientSettings] = None) -> None: - """WARNING: This API is experimental and could be changed. - - Calls rollback on a transaction if it is open. If transaction execution - failed then this method raises PreconditionFailed. - - :param settings: A request settings - - :return: None or exception if rollback is failed - """ - pass - - @abc.abstractmethod - def execute( - self, - query: str, - commit_tx: Optional[bool] = False, - syntax: Optional[QuerySyntax] = None, - exec_mode: Optional[QueryExecMode] = None, - parameters: Optional[dict] = None, - concurrent_result_sets: Optional[bool] = False, - settings: Optional[QueryClientSettings] = None, - ) -> Iterator: - """WARNING: This API is experimental and could be changed. - - Sends a query to Query Service - :param query: (YQL or SQL text) to be executed. - :param commit_tx: A special flag that allows transaction commit. - :param syntax: Syntax of the query, which is a one from the following choises: - 1) QuerySyntax.YQL_V1, which is default; - 2) QuerySyntax.PG. - :param exec_mode: Exec mode of the query, which is a one from the following choises: - 1) QueryExecMode.EXECUTE, which is default; - 2) QueryExecMode.EXPLAIN; - 3) QueryExecMode.VALIDATE; - 4) QueryExecMode.PARSE. - :param parameters: dict with parameters and YDB types; - :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; - :param settings: An additional request settings QueryClientSettings; - - :return: Iterator with result sets - """ - pass - - -class IQueryClient(abc.ABC): - def __init__(self, driver: SupportedDriverType, query_client_settings: Optional[QueryClientSettings] = None): - pass - - @abc.abstractmethod - def session(self) -> IQuerySession: - pass - - def create_execute_query_request( query: str, session_id: str, @@ -392,7 +168,7 @@ def create_execute_query_request( def wrap_execute_query_response( rpc_state: RpcState, response_pb: _apis.ydb_query.ExecuteQueryResponsePart, - tx: Optional[IQueryTxContext] = None, + tx: Optional["BaseQueryTxContext"] = None, commit_tx: Optional[bool] = False, settings: Optional[QueryClientSettings] = None, ) -> convert.ResultSet: diff --git a/ydb/query/pool.py b/ydb/query/pool.py index bc214ecb..bf868352 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -5,7 +5,6 @@ List, ) -from . import base from .session import ( QuerySessionSync, ) @@ -14,6 +13,8 @@ retry_operation_sync, ) from .. import convert +from .._grpc.grpcwrapper import common_utils + logger = logging.getLogger(__name__) @@ -21,7 +22,7 @@ class QuerySessionPool: """QuerySessionPool is an object to simplify operations with sessions of Query Service.""" - def __init__(self, driver: base.SupportedDriverType): + def __init__(self, driver: common_utils.SupportedDriverType): """ :param driver: A driver instance """ @@ -97,7 +98,7 @@ def __init__(self, pool: QuerySessionPool): self._pool = pool self._session = QuerySessionSync(pool._driver) - def __enter__(self) -> base.IQuerySession: + def __enter__(self) -> QuerySessionSync: self._session.create() return self._session diff --git a/ydb/query/session.py b/ydb/query/session.py index 1fa3025d..4b051dc1 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -15,7 +15,7 @@ from .._grpc.grpcwrapper import ydb_query as _ydb_query from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public -from .transaction import BaseQueryTxContext +from .transaction import QueryTxContextSync logger = logging.getLogger(__name__) @@ -126,12 +126,12 @@ def wrapper_delete_session( return session -class BaseQuerySession(base.IQuerySession): - _driver: base.SupportedDriverType +class BaseQuerySession: + _driver: common_utils.SupportedDriverType _settings: base.QueryClientSettings _state: QuerySessionState - def __init__(self, driver: base.SupportedDriverType, settings: Optional[base.QueryClientSettings] = None): + def __init__(self, driver: common_utils.SupportedDriverType, settings: Optional[base.QueryClientSettings] = None): self._driver = driver self._settings = settings if settings is not None else base.QueryClientSettings() self._state = QuerySessionState(settings) @@ -224,7 +224,9 @@ def _check_session_status_loop(self, status_stream: _utilities.SyncResponseItera self._state.reset() self._state._change_state(QuerySessionStateEnum.CLOSED) except Exception: - pass + if not self._state._already_in(QuerySessionStateEnum.CLOSED): + self._state.reset() + self._state._change_state(QuerySessionStateEnum.CLOSED) def delete(self) -> None: """WARNING: This API is experimental and could be changed. @@ -256,7 +258,7 @@ def create(self) -> "QuerySessionSync": return self - def transaction(self, tx_mode: Optional[base.BaseQueryTxMode] = None) -> base.IQueryTxContext: + def transaction(self, tx_mode: Optional[base.BaseQueryTxMode] = None) -> QueryTxContextSync: """WARNING: This API is experimental and could be changed. Creates a transaction context manager with specified transaction mode. @@ -273,7 +275,7 @@ def transaction(self, tx_mode: Optional[base.BaseQueryTxMode] = None) -> base.IQ tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite() - return BaseQueryTxContext( + return QueryTxContextSync( self._driver, self._state, self, diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index f42571c2..750a94b0 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -169,7 +169,7 @@ def wrap_tx_rollback_response( return tx -class BaseQueryTxContext(base.IQueryTxContext): +class BaseQueryTxContext: def __init__(self, driver, session_state, session, tx_mode): """ An object that provides a simple transaction context manager that allows statements execution @@ -196,31 +196,6 @@ def __init__(self, driver, session_state, session, tx_mode): self.session = session self._prev_stream = None - def __enter__(self) -> "BaseQueryTxContext": - """ - Enters a context manager and returns a transaction - - :return: A transaction instance - """ - return self - - def __exit__(self, *args, **kwargs): - """ - Closes a transaction context manager and rollbacks transaction if - it is not finished explicitly - """ - self._ensure_prev_stream_finished() - if self._tx_state._state == QueryTxStateEnum.BEGINED: - # It's strictly recommended to close transactions directly - # by using commit_tx=True flag while executing statement or by - # .commit() or .rollback() methods, but here we trying to do best - # effort to avoid useless open transactions - logger.warning("Potentially leaked tx: %s", self._tx_state.tx_id) - try: - self.rollback() - except issues.Error: - logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id) - @property def session_id(self) -> str: """ @@ -240,6 +215,8 @@ def tx_id(self) -> Optional[str]: return self._tx_state.tx_id def _begin_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext": + self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) + return self._driver( _create_begin_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -250,6 +227,8 @@ def _begin_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQuer ) def _commit_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext": + self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) + return self._driver( _create_commit_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -260,6 +239,8 @@ def _commit_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQue ) def _rollback_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext": + self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) + return self._driver( _create_rollback_transaction_request(self._session_state, self._tx_state), _apis.QueryService.Stub, @@ -278,6 +259,8 @@ def _execute_call( parameters: dict = None, concurrent_result_sets: bool = False, ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: + self._tx_state._check_tx_ready_to_use() + request = base.create_execute_query_request( query=query, session_id=self._session_state.session_id, @@ -296,12 +279,6 @@ def _execute_call( _apis.QueryService.ExecuteQuery, ) - def _ensure_prev_stream_finished(self) -> None: - if self._prev_stream is not None: - for _ in self._prev_stream: - pass - self._prev_stream = None - def _move_to_beginned(self, tx_id: str) -> None: if self._tx_state._already_in(QueryTxStateEnum.BEGINED): return @@ -313,7 +290,40 @@ def _move_to_commited(self) -> None: return self._tx_state._change_state(QueryTxStateEnum.COMMITTED) - def begin(self, settings: Optional[base.QueryClientSettings] = None) -> "BaseQueryTxContext": + +class QueryTxContextSync(BaseQueryTxContext): + def __enter__(self) -> "BaseQueryTxContext": + """ + Enters a context manager and returns a transaction + + :return: A transaction instance + """ + return self + + def __exit__(self, *args, **kwargs): + """ + Closes a transaction context manager and rollbacks transaction if + it is not finished explicitly + """ + self._ensure_prev_stream_finished() + if self._tx_state._state == QueryTxStateEnum.BEGINED: + # It's strictly recommended to close transactions directly + # by using commit_tx=True flag while executing statement or by + # .commit() or .rollback() methods, but here we trying to do best + # effort to avoid useless open transactions + logger.warning("Potentially leaked tx: %s", self._tx_state.tx_id) + try: + self.rollback() + except issues.Error: + logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id) + + def _ensure_prev_stream_finished(self) -> None: + if self._prev_stream is not None: + with self._prev_stream: + pass + self._prev_stream = None + + def begin(self, settings: Optional[base.QueryClientSettings] = None) -> "QueryTxContextSync": """WARNING: This API is experimental and could be changed. Explicitly begins a transaction @@ -322,8 +332,6 @@ def begin(self, settings: Optional[base.QueryClientSettings] = None) -> "BaseQue :return: Transaction object or exception if begin is failed """ - self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) - self._begin_call(settings) return self @@ -340,27 +348,33 @@ def commit(self, settings: Optional[base.QueryClientSettings] = None) -> None: """ if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): return - self._ensure_prev_stream_finished() if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: self._tx_state._change_state(QueryTxStateEnum.COMMITTED) return - self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) + self._ensure_prev_stream_finished() self._commit_call(settings) def rollback(self, settings: Optional[base.QueryClientSettings] = None) -> None: + """WARNING: This API is experimental and could be changed. + + Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution + failed then this method raises PreconditionFailed. + + :param settings: A request settings + + :return: A committed transaction or exception if commit is failed + """ if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): return - self._ensure_prev_stream_finished() - if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED) return - self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) + self._ensure_prev_stream_finished() self._rollback_call(settings) @@ -378,6 +392,7 @@ def execute( Sends a query to Query Service :param query: (YQL or SQL text) to be executed. + :param parameters: dict with parameters and YDB types; :param commit_tx: A special flag that allows transaction commit. :param syntax: Syntax of the query, which is a one from the following choises: 1) QuerySyntax.YQL_V1, which is default; @@ -387,14 +402,12 @@ def execute( 2) QueryExecMode.EXPLAIN; 3) QueryExecMode.VALIDATE; 4) QueryExecMode.PARSE. - :param parameters: dict with parameters and YDB types; :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; :param settings: An additional request settings QueryClientSettings; :return: Iterator with result sets """ self._ensure_prev_stream_finished() - self._tx_state._check_tx_ready_to_use() stream_it = self._execute_call( query=query, diff --git a/ydb/retries.py b/ydb/retries.py index 5d4f6e6a..c9c23b1a 100644 --- a/ydb/retries.py +++ b/ydb/retries.py @@ -1,3 +1,4 @@ +import asyncio import random import time @@ -134,3 +135,27 @@ def retry_operation_sync(callee, retry_settings=None, *args, **kwargs): time.sleep(next_opt.timeout) else: return next_opt.result + + +async def retry_operation_async(callee, retry_settings=None, *args, **kwargs): # pylint: disable=W1113 + """ + The retry operation helper can be used to retry a coroutine that raises YDB specific + exceptions. + + :param callee: A coroutine to retry. + :param retry_settings: An instance of ydb.RetrySettings that describes how the coroutine + should be retried. If None, default instance of retry settings will be used. + :param args: A tuple with positional arguments to be passed into the coroutine. + :param kwargs: A dictionary with keyword arguments to be passed into the coroutine. + + Returns awaitable result of coroutine. If retries are not succussful exception is raised. + """ + opt_generator = retry_operation_impl(callee, retry_settings, *args, **kwargs) + for next_opt in opt_generator: + if isinstance(next_opt, YdbRetryOperationSleepOpt): + await asyncio.sleep(next_opt.timeout) + else: + try: + return await next_opt.result + except BaseException as e: # pylint: disable=W0703 + next_opt.set_exception(e)