From fd4a4bb66910ad7f3482f158e6d1a93d5fae26c8 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 10 Mar 2025 11:05:28 +0300 Subject: [PATCH 01/11] topic transactions example --- examples/topic/topic_transactions_example.py | 31 ++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 examples/topic/topic_transactions_example.py diff --git a/examples/topic/topic_transactions_example.py b/examples/topic/topic_transactions_example.py new file mode 100644 index 00000000..422c24ff --- /dev/null +++ b/examples/topic/topic_transactions_example.py @@ -0,0 +1,31 @@ +import ydb + + +def writer_example(driver: ydb.Driver, topic: str): + session_pool = ydb.QuerySessionPool(driver) + + def callee(tx: ydb.QueryTxContext): + tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) + # нужно ли внутри ретраить ? нужно разрешать ретраи с дедупликацией + # дефолт - без дедупликации, без ретраев и без producer_id. договорились. + + with tx.execute(query="select 1") as result_sets: + messages = [result_set.rows[0] for result_set in result_sets] + + tx_writer.write(messages) + + session_pool.retry_tx_sync(callee) + + +def reader_example(driver: ydb.Driver, reader: ydb.TopicReader): + session_pool = ydb.QuerySessionPool(driver) + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=5) + + with tx.execute(query="INSERT INTO max_values(val) VALUES ($val)", parameters={"$val": max(batch)}) as _: + pass + + # коммитим при выходе из лямбды + + session_pool.retry_tx_sync(callee) From e5a9db86fcd145c4c28effc138cb7f047470d78f Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 10 Mar 2025 11:05:28 +0300 Subject: [PATCH 02/11] topic tx basic test cases --- tests/topics/test_topic_transactions.py | 67 +++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/topics/test_topic_transactions.py diff --git a/tests/topics/test_topic_transactions.py b/tests/topics/test_topic_transactions.py new file mode 100644 index 00000000..4f9c6053 --- /dev/null +++ b/tests/topics/test_topic_transactions.py @@ -0,0 +1,67 @@ +import asyncio +from asyncio import wait_for +import pytest +import ydb + + +@pytest.mark.skip("Not implemented yet.") +@pytest.mark.asyncio +class TestTopicTransactionalReader: + async def test_commit(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), 1) + assert len(batch) == 1 + assert batch[0].data.decode() == "123" + + await pool.retry_tx_async(callee) + + msg = await wait_for(reader.receive_message(), 1) + assert msg.data.decode() == "456" + + async def test_rollback(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), 1) + assert len(batch) == 1 + assert batch[0].data.decode() == "123" + + await tx.rollback() + + await pool.retry_tx_async(callee) + + msg = await wait_for(reader.receive_message(), 1) + assert msg.data.decode() == "123" + + +@pytest.mark.skip("Not implemented yet.") +class TestTopicTransactionalWriter: + async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + await pool.retry_tx_async(callee) + + msg = await wait_for(topic_reader.receive_message(), 0.1) + assert msg.data.decode() == "123" + + async def test_rollback(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + await tx.rollback() + + await pool.retry_tx_async(callee) + + with pytest.raises(asyncio.TimeoutError): + await wait_for(topic_reader.receive_message(), 0.1) From 8512408ce953f9e5d1ef38ccc80eef2e5e5558d0 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 10 Mar 2025 11:05:28 +0300 Subject: [PATCH 03/11] temp --- examples/topic/topic_transactions_example.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/topic/topic_transactions_example.py b/examples/topic/topic_transactions_example.py index 422c24ff..720025a8 100644 --- a/examples/topic/topic_transactions_example.py +++ b/examples/topic/topic_transactions_example.py @@ -5,14 +5,13 @@ def writer_example(driver: ydb.Driver, topic: str): session_pool = ydb.QuerySessionPool(driver) def callee(tx: ydb.QueryTxContext): - tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) - # нужно ли внутри ретраить ? нужно разрешать ретраи с дедупликацией - # дефолт - без дедупликации, без ретраев и без producer_id. договорились. + tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) # <======= + # дефолт - без дедупликации, без ретраев и без producer_id. with tx.execute(query="select 1") as result_sets: messages = [result_set.rows[0] for result_set in result_sets] - tx_writer.write(messages) + tx_writer.write(messages) # вне зависимости от состояния вышестоящего стрима поведение должно быть одинаковое session_pool.retry_tx_sync(callee) @@ -21,11 +20,9 @@ def reader_example(driver: ydb.Driver, reader: ydb.TopicReader): session_pool = ydb.QuerySessionPool(driver) def callee(tx: ydb.QueryTxContext): - batch = reader.receive_batch_with_tx(tx, max_messages=5) + batch = reader.receive_batch_with_tx(tx, max_messages=5) # <======= with tx.execute(query="INSERT INTO max_values(val) VALUES ($val)", parameters={"$val": max(batch)}) as _: pass - # коммитим при выходе из лямбды - session_pool.retry_tx_sync(callee) From e64d447cfdc349920114fc6bd44d74b877d4b8d5 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 10 Mar 2025 11:05:28 +0300 Subject: [PATCH 04/11] temp --- ydb/_topic_writer/topic_writer_asyncio.py | 4 ++++ ydb/topic.py | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 32d8fefe..f7af42c3 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -164,6 +164,10 @@ async def wait_init(self) -> PublicWriterInitInfo: return await self._reconnector.wait_init() +class TxWriterAsyncIO(WriterAsyncIO): + ... + + class WriterAsyncIOReconnector: _closed: bool _loop: asyncio.AbstractEventLoop diff --git a/ydb/topic.py b/ydb/topic.py index 55f4ea04..a230d0af 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -276,6 +276,26 @@ def writer( return TopicWriterAsyncIO(self._driver, settings, _client=self) + def tx_writer( + self, + tx, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + # encoders: map[codec_code] func(encoded_bytes)->decoded_bytes + # the func will be called from multiply threads in parallel. + encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None, + # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. + # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. + encoder_executor: Optional[concurrent.futures.Executor] = None, + ) -> TopicTxWriterAsyncIO: + + def close(self): if self._closed: return From e1629aa37f2880e895560479886927925d10e5a6 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 10 Mar 2025 17:08:04 +0300 Subject: [PATCH 05/11] transaction identity --- examples/topic/topic_transactions_example.py | 6 +-- tests/query/test_query_transaction.py | 12 +++++ ydb/_grpc/grpcwrapper/ydb_topic.py | 16 +++++++ ydb/_topic_writer/topic_writer.py | 4 ++ ydb/_topic_writer/topic_writer_asyncio.py | 36 ++++++++++++-- .../topic_writer_asyncio_test.py | 47 ++++++++++++++++++- ydb/query/transaction.py | 6 +++ ydb/topic.py | 11 ++++- 8 files changed, 130 insertions(+), 8 deletions(-) diff --git a/examples/topic/topic_transactions_example.py b/examples/topic/topic_transactions_example.py index 720025a8..ac273025 100644 --- a/examples/topic/topic_transactions_example.py +++ b/examples/topic/topic_transactions_example.py @@ -5,13 +5,13 @@ def writer_example(driver: ydb.Driver, topic: str): session_pool = ydb.QuerySessionPool(driver) def callee(tx: ydb.QueryTxContext): - tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) # <======= + tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) # <======= # дефолт - без дедупликации, без ретраев и без producer_id. with tx.execute(query="select 1") as result_sets: messages = [result_set.rows[0] for result_set in result_sets] - tx_writer.write(messages) # вне зависимости от состояния вышестоящего стрима поведение должно быть одинаковое + tx_writer.write(messages) # вне зависимости от состояния вышестоящего стрима поведение должно быть одинаковое session_pool.retry_tx_sync(callee) @@ -20,7 +20,7 @@ def reader_example(driver: ydb.Driver, reader: ydb.TopicReader): session_pool = ydb.QuerySessionPool(driver) def callee(tx: ydb.QueryTxContext): - batch = reader.receive_batch_with_tx(tx, max_messages=5) # <======= + batch = reader.receive_batch_with_tx(tx, max_messages=5) # <======= with tx.execute(query="INSERT INTO max_values(val) VALUES ($val)", parameters={"$val": max(batch)}) as _: pass diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index dfc88897..4533e528 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -92,3 +92,15 @@ def test_execute_two_results(self, tx: QueryTxContext): assert res == [[1], [2]] assert counter == 2 + + def test_tx_identity_before_begin_raises(self, tx: QueryTxContext): + with pytest.raises(RuntimeError): + tx._tx_identity() + + def test_tx_identity_after_begin_works(self, tx: QueryTxContext): + tx.begin() + + identity = tx._tx_identity() + + assert identity.tx_id == tx.tx_id + assert identity.session_id == tx.session_id diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index 600dfb69..64fd3a10 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -142,6 +142,18 @@ def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: ######################################################################################################################## +@dataclass +class TransactionIdentity(IToProto): + tx_id: str + session_id: str + + def to_proto(self) -> ydb_topic_pb2.TransactionIdentity: + return ydb_topic_pb2.TransactionIdentity( + id=self.tx_id, + session=self.session_id, + ) + + class StreamWriteMessage: @dataclass() class InitRequest(IToProto): @@ -200,6 +212,7 @@ def from_proto( class WriteRequest(IToProto): messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"] codec: int + tx_identity: Optional[TransactionIdentity] @dataclass class MessageData(IToProto): @@ -238,6 +251,9 @@ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest: proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest() proto.codec = self.codec + if self.tx_identity is not None: + proto.tx = self.tx_identity.to_proto() + for message in self.messages: proto_mess = proto.messages.add() proto_mess.CopyFrom(message.to_proto()) diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index aa5fe974..2515acab 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -11,6 +11,7 @@ import ydb.aio from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage +from .._grpc.grpcwrapper.ydb_topic import TransactionIdentity from .._grpc.grpcwrapper.common_utils import IToProto from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .. import connection @@ -205,6 +206,7 @@ def default_serializer_message_content(data: Any) -> bytes: def messages_to_proto_requests( messages: List[InternalMessage], + tx_identity: Optional[TransactionIdentity], ) -> List[StreamWriteMessage.FromClient]: gropus = _slit_messages_for_send(messages) @@ -215,6 +217,7 @@ def messages_to_proto_requests( StreamWriteMessage.WriteRequest( messages=list(map(InternalMessage.to_message_data, group)), codec=group[0].codec, + tx_identity=tx_identity, ) ) res.append(req) @@ -239,6 +242,7 @@ def messages_to_proto_requests( ), ], codec=20000, + tx_identity=None, ) ) .to_proto() diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index f7af42c3..1c1662b2 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -35,6 +35,7 @@ UpdateTokenRequest, UpdateTokenResponse, StreamWriteMessage, + TransactionIdentity, WriterMessagesFromServerToClient, ) from .._grpc.grpcwrapper.common_utils import ( @@ -43,6 +44,9 @@ GrpcWrapperAsyncIO, ) +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + logger = logging.getLogger(__name__) @@ -165,7 +169,20 @@ async def wait_init(self) -> PublicWriterInitInfo: class TxWriterAsyncIO(WriterAsyncIO): - ... + _tx: object + + def __init__( + self, + tx, + driver: SupportedDriverType, + settings: PublicWriterSettings, + _client=None, + ): + self._tx = tx + self._loop = asyncio.get_running_loop() + self._closed = False + self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings), tx=self._tx) + self._parent = _client class WriterAsyncIOReconnector: @@ -182,6 +199,7 @@ class WriterAsyncIOReconnector: _codec_selector_batch_num: int _codec_selector_last_codec: Optional[PublicCodec] _codec_selector_check_batches_interval: int + _tx: Optional["BaseQueryTxContext"] if typing.TYPE_CHECKING: _messages_for_encode: asyncio.Queue[List[InternalMessage]] @@ -199,7 +217,9 @@ class WriterAsyncIOReconnector: _stop_reason: asyncio.Future _init_info: Optional[PublicWriterInitInfo] - def __init__(self, driver: SupportedDriverType, settings: WriterSettings): + def __init__( + self, driver: SupportedDriverType, settings: WriterSettings, tx: Optional["BaseQueryTxContext"] = None + ): self._closed = False self._loop = asyncio.get_running_loop() self._driver = driver @@ -209,6 +229,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._init_info = None self._stream_connected = asyncio.Event() self._settings = settings + self._tx = tx self._codec_functions = { PublicCodec.RAW: lambda data: data, @@ -358,10 +379,12 @@ async def _connection_loop(self): # noinspection PyBroadException stream_writer = None try: + tx_identity = None if self._tx is None else self._tx._tx_identity() stream_writer = await WriterAsyncIOStream.create( self._driver, self._init_message, self._settings.update_token_interval, + tx_identity=tx_identity, ) try: if self._init_info is None: @@ -601,10 +624,13 @@ class WriterAsyncIOStream: _update_token_event: asyncio.Event _get_token_function: Optional[Callable[[], str]] + _tx_identity: Optional[TransactionIdentity] + def __init__( self, update_token_interval: Optional[Union[int, float]] = None, get_token_function: Optional[Callable[[], str]] = None, + tx_identity: Optional[TransactionIdentity] = None, ): self._closed = False @@ -613,6 +639,8 @@ def __init__( self._update_token_event = asyncio.Event() self._update_token_task = None + self._tx_identity = tx_identity + async def close(self): if self._closed: return @@ -629,6 +657,7 @@ async def create( driver: SupportedDriverType, init_request: StreamWriteMessage.InitRequest, update_token_interval: Optional[Union[int, float]] = None, + tx_identity: Optional[TransactionIdentity] = None, ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) @@ -638,6 +667,7 @@ async def create( writer = WriterAsyncIOStream( update_token_interval=update_token_interval, get_token_function=creds.get_auth_token if creds else lambda: "", + tx_identity=tx_identity, ) await writer._start(stream, init_request) return writer @@ -684,7 +714,7 @@ def write(self, messages: List[InternalMessage]): if self._closed: raise RuntimeError("Can not write on closed stream.") - for request in messages_to_proto_requests(messages): + for request in messages_to_proto_requests(messages, self._tx_identity): self._stream.write(request) async def _update_token_loop(self): diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index b288d0f0..cf88f797 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -18,6 +18,7 @@ from .._grpc.grpcwrapper.ydb_topic import ( Codec, StreamWriteMessage, + TransactionIdentity, UpdateTokenRequest, UpdateTokenResponse, ) @@ -43,6 +44,12 @@ from ..credentials import AnonymousCredentials +FAKE_TRANSACTION_IDENTITY = TransactionIdentity( + tx_id="transaction_id", + session_id="session_id", +) + + @pytest.fixture def default_driver() -> aio.Driver: driver = mock.Mock(spec=aio.Driver) @@ -148,6 +155,44 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): expected_message = StreamWriteMessage.FromClient( StreamWriteMessage.WriteRequest( codec=Codec.CODEC_RAW, + tx_identity=None, + messages=[ + StreamWriteMessage.WriteRequest.MessageData( + seq_no=1, + created_at=now, + data=data, + metadata_items={}, + uncompressed_size=len(data), + partitioning=None, + ) + ], + ) + ) + + sent_message = await writer_and_stream.stream.from_client.get() + assert expected_message == sent_message + + async def test_write_a_message_with_tx(self, writer_and_stream: WriterWithMockedStream): + writer_and_stream.writer._tx_identity = FAKE_TRANSACTION_IDENTITY + + data = "123".encode() + now = datetime.datetime.now(datetime.timezone.utc) + writer_and_stream.writer.write( + [ + InternalMessage( + PublicMessage( + seqno=1, + created_at=now, + data=data, + ) + ) + ] + ) + + expected_message = StreamWriteMessage.FromClient( + StreamWriteMessage.WriteRequest( + codec=Codec.CODEC_RAW, + tx_identity=FAKE_TRANSACTION_IDENTITY, messages=[ StreamWriteMessage.WriteRequest.MessageData( seq_no=1, @@ -264,7 +309,7 @@ def _create(self): res = DoubleQueueWriters() - async def async_create(driver, init_message, token_getter): + async def async_create(driver, init_message, token_getter, tx_identity): return res.get_first() monkeypatch.setattr(WriterAsyncIOStream, "create", async_create) diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 414401da..1b6a8051 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -11,6 +11,7 @@ _apis, issues, ) +from .._grpc.grpcwrapper import ydb_topic as _ydb_topic from .._grpc.grpcwrapper import ydb_query as _ydb_query from ..connection import _RpcState as RpcState @@ -215,6 +216,11 @@ def tx_id(self) -> Optional[str]: """ return self._tx_state.tx_id + def _tx_identity(self) -> _ydb_topic.TransactionIdentity: + if not self.tx_id: + raise RuntimeError("Unable to get tx identity without started tx.") + return _ydb_topic.TransactionIdentity(self.tx_id, self.session_id) + def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) diff --git a/ydb/topic.py b/ydb/topic.py index a230d0af..6d92575b 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -65,6 +65,7 @@ PublicWriteResult as TopicWriteResult, ) +from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter @@ -294,7 +295,15 @@ def tx_writer( # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. encoder_executor: Optional[concurrent.futures.Executor] = None, ) -> TopicTxWriterAsyncIO: - + args = locals().copy() + del args["self"] + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TopicTxWriterAsyncIO(tx=tx, driver=self._driver, settings=settings, _client=self) def close(self): if self._closed: From 3579d4a7dee8513bb04fe88bfd813a23135cfa01 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 11 Mar 2025 11:16:20 +0300 Subject: [PATCH 06/11] async topic tx with listener pattern --- tests/topics/test_topic_transactions.py | 6 +- ydb/_grpc/grpcwrapper/ydb_topic.py | 7 +- ydb/_topic_writer/topic_writer.py | 6 +- ydb/_topic_writer/topic_writer_asyncio.py | 19 +++- ydb/aio/query/pool.py | 2 + ydb/aio/query/transaction.py | 6 +- ydb/query/base.py | 114 ++++++++++++++++++++++ ydb/query/pool.py | 2 + ydb/query/transaction.py | 29 ++++-- ydb/topic.py | 1 + 10 files changed, 175 insertions(+), 17 deletions(-) diff --git a/tests/topics/test_topic_transactions.py b/tests/topics/test_topic_transactions.py index 4f9c6053..f715d4a4 100644 --- a/tests/topics/test_topic_transactions.py +++ b/tests/topics/test_topic_transactions.py @@ -38,14 +38,14 @@ async def callee(tx: ydb.aio.QueryTxContext): assert msg.data.decode() == "123" -@pytest.mark.skip("Not implemented yet.") +# @pytest.mark.skip("Not implemented yet.") class TestTopicTransactionalWriter: async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO): async with ydb.aio.QuerySessionPool(driver) as pool: async def callee(tx: ydb.aio.QueryTxContext): tx_writer = driver.topic_client.tx_writer(tx, topic_path) - tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) await pool.retry_tx_async(callee) @@ -57,7 +57,7 @@ async def test_rollback(self, driver: ydb.aio.Driver, topic_path, topic_reader: async def callee(tx: ydb.aio.QueryTxContext): tx_writer = driver.topic_client.tx_writer(tx, topic_path) - tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) await tx.rollback() diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index 64fd3a10..0da75bca 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -252,7 +252,7 @@ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest: proto.codec = self.codec if self.tx_identity is not None: - proto.tx = self.tx_identity.to_proto() + proto.tx.CopyFrom(self.tx_identity.to_proto()) for message in self.messages: proto_mess = proto.messages.add() @@ -314,6 +314,8 @@ def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.Wr ) except ValueError: message_write_status = reason + elif proto_ack.HasField("written_in_tx"): + message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusWrittenInTx() else: raise NotImplementedError("unexpected ack status") @@ -326,6 +328,9 @@ def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.Wr class StatusWritten: offset: int + class StatusWrittenInTx: + pass + @dataclass class StatusSkipped: reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason" diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 2515acab..a3e407ed 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -54,8 +54,12 @@ class Written: class Skipped: pass + @dataclass(eq=True) + class WrittenInTx: + pass + -PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped] +PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped, PublicWriteResult.WrittenInTx] class WriterSettings(PublicWriterSettings): diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 1c1662b2..5468ff66 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -44,6 +44,8 @@ GrpcWrapperAsyncIO, ) +from ..query.base import TxListenerAsyncIO + if typing.TYPE_CHECKING: from ..query.transaction import BaseQueryTxContext @@ -168,12 +170,12 @@ async def wait_init(self) -> PublicWriterInitInfo: return await self._reconnector.wait_init() -class TxWriterAsyncIO(WriterAsyncIO): - _tx: object +class TxWriterAsyncIO(WriterAsyncIO, TxListenerAsyncIO): + _tx: "BaseQueryTxContext" def __init__( self, - tx, + tx: "BaseQueryTxContext", driver: SupportedDriverType, settings: PublicWriterSettings, _client=None, @@ -183,6 +185,13 @@ def __init__( self._closed = False self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings), tx=self._tx) self._parent = _client + self._tx._add_listener(self) + + async def _on_before_commit(self): + await self.close() + + async def _on_before_rollback(self): + await self.close() class WriterAsyncIOReconnector: @@ -560,6 +569,8 @@ def _handle_receive_ack(self, ack): result = PublicWriteResult.Skipped() elif isinstance(status, write_ack_msg.StatusWritten): result = PublicWriteResult.Written(offset=status.offset) + elif isinstance(status, write_ack_msg.StatusWrittenInTx): + result = PublicWriteResult.WrittenInTx() else: raise TopicWriterError("internal error - receive unexpected ack message.") message_future.set_result(result) @@ -575,6 +586,7 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"): while True: m = await self._new_messages.get() # type: InternalMessage + print("NEW MESSAGE") if m.seq_no > last_seq_no: writer.write([m]) except asyncio.CancelledError: @@ -606,6 +618,7 @@ async def flush(self): # wait last message await asyncio.wait(self._messages_future) + print("ALL MESSAGES WERE SENT TO SERVER") class WriterAsyncIOStream: diff --git a/ydb/aio/query/pool.py b/ydb/aio/query/pool.py index 947db658..fda22388 100644 --- a/ydb/aio/query/pool.py +++ b/ydb/aio/query/pool.py @@ -158,6 +158,8 @@ async def retry_tx_async( async def wrapped_callee(): async with self.checkout() as session: async with session.transaction(tx_mode=tx_mode) as tx: + if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]: + await tx.begin() result = await callee(tx, *args, **kwargs) await tx.commit() return result diff --git a/ydb/aio/query/transaction.py b/ydb/aio/query/transaction.py index 5b63a32b..cb757324 100644 --- a/ydb/aio/query/transaction.py +++ b/ydb/aio/query/transaction.py @@ -57,6 +57,7 @@ async def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryT await self._begin_call(settings) return self + @base.with_async_transaction_events async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls commit on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. @@ -65,7 +66,7 @@ async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -76,6 +77,7 @@ async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: await self._commit_call(settings) + @base.with_async_transaction_events async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. @@ -84,7 +86,7 @@ async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: diff --git a/ydb/query/base.py b/ydb/query/base.py index 57a769bb..e20529e2 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -1,4 +1,5 @@ import abc +import asyncio import enum import functools @@ -196,3 +197,116 @@ def wrap_execute_query_response( return convert.ResultSet.from_message(response_pb.result_set, settings) return None + + +class TxListener: + def _on_before_commit(self): + pass + + def _on_after_commit(self, exc: typing.Optional[BaseException]): + pass + + def _on_before_rollback(self): + pass + + def _on_after_rollback(self, exc: typing.Optional[BaseException]): + pass + + +class TxListenerAsyncIO: + async def _on_before_commit(self): + pass + + async def _on_after_commit(self, exc: typing.Optional[BaseException]): + pass + + async def _on_before_rollback(self): + pass + + async def _on_after_rollback(self, exc: typing.Optional[BaseException]): + pass + + +def with_transaction_events(method): + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + method_name = method.__name__ + before_event = f"_on_before_{method_name}" + after_event = f"_on_after_{method_name}" + + self._notify_listeners_sync(before_event) + + try: + result = method(self, *args, **kwargs) + + self._notify_listeners_sync(after_event, exc=None) + + return result + except BaseException as e: + self._notify_listeners_sync(after_event, exc=e) + raise + + return wrapper + + +def with_async_transaction_events(method): + @functools.wraps(method) + async def wrapper(self, *args, **kwargs): + method_name = method.__name__ + before_event = f"_on_before_{method_name}" + after_event = f"_on_after_{method_name}" + + await self._notify_listeners_async(before_event) + + try: + result = await method(self, *args, **kwargs) + + await self._notify_listeners_async(after_event, exc=None) + + return result + except BaseException as e: + await self._notify_listeners_async(after_event, exc=e) + raise + + return wrapper + + +class ListenerHandlerMixin: + def _init_listener_handler(self): + self.listeners = [] + + def _add_listener(self, listener): + if listener not in self.listeners: + self.listeners.append(listener) + return self + + def _remove_listener(self, listener): + if listener in self.listeners: + self.listeners.remove(listener) + return self + + def _clear_listeners(self): + self.listeners.clear() + return self + + def _notify_sync_listeners(self, event_name: str, **kwargs) -> None: + for listener in self.listeners: + if isinstance(listener, TxListener) and hasattr(listener, event_name): + getattr(listener, event_name)(**kwargs) + + async def _notify_async_listeners(self, event_name: str, **kwargs) -> None: + coros = [] + for listener in self.listeners: + if isinstance(listener, TxListenerAsyncIO) and hasattr(listener, event_name): + coros.append(getattr(listener, event_name)(**kwargs)) + + if coros: + await asyncio.gather(*coros) + + def _notify_listeners_sync(self, event_name: str, **kwargs) -> None: + self._notify_sync_listeners(event_name, **kwargs) + + async def _notify_listeners_async(self, event_name: str, **kwargs) -> None: + # self._notify_sync_listeners(event_name, **kwargs) + + await self._notify_async_listeners(event_name, **kwargs) diff --git a/ydb/query/pool.py b/ydb/query/pool.py index e3775c4d..43cc2e8d 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -167,6 +167,8 @@ def retry_tx_sync( def wrapped_callee(): with self.checkout(timeout=retry_settings.max_session_acquire_timeout) as session: with session.transaction(tx_mode=tx_mode) as tx: + if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]: + tx.begin() result = callee(tx, *args, **kwargs) tx.commit() return result diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 1b6a8051..cbe91764 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -43,10 +43,22 @@ class QueryTxStateHelper(abc.ABC): QueryTxStateEnum.DEAD: [], } + _SKIP_TRANSITIONS = { + QueryTxStateEnum.NOT_INITIALIZED: [], + QueryTxStateEnum.BEGINED: [], + QueryTxStateEnum.COMMITTED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED], + QueryTxStateEnum.ROLLBACKED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED], + QueryTxStateEnum.DEAD: [], + } + @classmethod def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: return after in cls._VALID_TRANSITIONS[before] + @classmethod + def should_skip(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: + return after in cls._SKIP_TRANSITIONS[before] + @classmethod def terminal(cls, state: QueryTxStateEnum) -> bool: return len(cls._VALID_TRANSITIONS[state]) == 0 @@ -89,8 +101,8 @@ def _check_tx_ready_to_use(self) -> None: if QueryTxStateHelper.terminal(self._state): raise RuntimeError(f"Transaction is in terminal state: {self._state.value}") - def _already_in(self, target: QueryTxStateEnum) -> bool: - return self._state == target + def _should_skip(self, target: QueryTxStateEnum) -> bool: + return QueryTxStateHelper.should_skip(self._state, target) def _construct_tx_settings(tx_state: QueryTxState) -> _ydb_query.TransactionSettings: @@ -171,7 +183,7 @@ def wrap_tx_rollback_response( return tx -class BaseQueryTxContext: +class BaseQueryTxContext(base.ListenerHandlerMixin): def __init__(self, driver, session_state, session, tx_mode): """ An object that provides a simple transaction context manager that allows statements execution @@ -197,6 +209,7 @@ def __init__(self, driver, session_state, session, tx_mode): self._session_state = session_state self.session = session self._prev_stream = None + self._init_listener_handler() @property def session_id(self) -> str: @@ -289,13 +302,13 @@ def _execute_call( ) def _move_to_beginned(self, tx_id: str) -> None: - if self._tx_state._already_in(QueryTxStateEnum.BEGINED) or not tx_id: + if self._tx_state._should_skip(QueryTxStateEnum.BEGINED) or not tx_id: return self._tx_state._change_state(QueryTxStateEnum.BEGINED) self._tx_state.tx_id = tx_id def _move_to_commited(self) -> None: - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return self._tx_state._change_state(QueryTxStateEnum.COMMITTED) @@ -343,6 +356,7 @@ def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryTxConte return self + @base.with_transaction_events def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls commit on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. @@ -351,7 +365,7 @@ def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -362,6 +376,7 @@ def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: self._commit_call(settings) + @base.with_transaction_events def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. @@ -370,7 +385,7 @@ def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: diff --git a/ydb/topic.py b/ydb/topic.py index 6d92575b..6954a5ff 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -297,6 +297,7 @@ def tx_writer( ) -> TopicTxWriterAsyncIO: args = locals().copy() del args["self"] + del args["tx"] settings = TopicWriterSettings(**args) From 5b17cc2c988d883fb52f90fba8a047c8d101a6c1 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 24 Mar 2025 12:17:17 +0300 Subject: [PATCH 07/11] tx reader and writer based on callbacks --- .github/workflows/tests.yaml | 12 +- tests/topics/test_topic_transactions.py | 142 ++++++++++++++++++--- tox.ini | 6 +- ydb/_apis.py | 1 + ydb/_errors.py | 1 + ydb/_grpc/grpcwrapper/ydb_topic.py | 46 +++++++ ydb/_topic_reader/datatypes.py | 9 ++ ydb/_topic_reader/topic_reader_asyncio.py | 110 +++++++++++++++- ydb/_topic_writer/topic_writer_asyncio.py | 24 ++-- ydb/_topic_writer/topic_writer_sync.py | 46 ++++++- ydb/aio/query/transaction.py | 46 ++++++- ydb/issues.py | 4 + ydb/query/base.py | 145 ++++++++-------------- ydb/query/transaction.py | 60 +++++++-- ydb/topic.py | 31 +++++ 15 files changed, 537 insertions(+), 146 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index cbc0bc67..adbf779f 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -18,21 +18,15 @@ jobs: fail-fast: false matrix: python-version: [3.8, 3.9] - environment: [py-proto5, py-tls-proto5, py-proto4, py-tls-proto4, py-proto3, py-tls-proto3] - folder: [ydb, tests --ignore=tests/topics, tests/topics] + environment: [py, py-tls, py-proto4, py-tls-proto4, py-proto3, py-tls-proto3] + folder: [ydb, tests] exclude: - - environment: py-tls-proto5 + - environment: py-tls folder: ydb - environment: py-tls-proto4 folder: ydb - environment: py-tls-proto3 folder: ydb - - environment: py-tls-proto5 - folder: tests/topics - - environment: py-tls-proto4 - folder: tests/topics - - environment: py-tls-proto3 - folder: tests/topics steps: - uses: actions/checkout@v1 diff --git a/tests/topics/test_topic_transactions.py b/tests/topics/test_topic_transactions.py index f715d4a4..b5fc8bab 100644 --- a/tests/topics/test_topic_transactions.py +++ b/tests/topics/test_topic_transactions.py @@ -1,24 +1,28 @@ import asyncio from asyncio import wait_for import pytest +from unittest import mock import ydb +DEFAULT_TIMEOUT = 0.1 +DEFAULT_RETRY_SETTINGS = ydb.RetrySettings(max_retries=1) + -@pytest.mark.skip("Not implemented yet.") @pytest.mark.asyncio class TestTopicTransactionalReader: async def test_commit(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): - async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: - async with ydb.aio.QuerySessionPool(driver) as pool: + async with ydb.aio.QuerySessionPool(driver) as pool: + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: async def callee(tx: ydb.aio.QueryTxContext): - batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), 1) - assert len(batch) == 1 - assert batch[0].data.decode() == "123" + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" - await pool.retry_tx_async(callee) + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) - msg = await wait_for(reader.receive_message(), 1) + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) assert msg.data.decode() == "456" async def test_rollback(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): @@ -26,19 +30,40 @@ async def test_rollback(self, driver: ydb.aio.Driver, topic_with_messages, topic async with ydb.aio.QuerySessionPool(driver) as pool: async def callee(tx: ydb.aio.QueryTxContext): - batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), 1) - assert len(batch) == 1 - assert batch[0].data.decode() == "123" + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" await tx.rollback() - await pool.retry_tx_async(callee) + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) - msg = await wait_for(reader.receive_message(), 1) + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) assert msg.data.decode() == "123" + async def test_tx_failed_if_update_offsets_call_failed( + self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer + ): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + with mock.patch.object( + reader._reconnector, + "_do_commit_batches_with_tx_call", + side_effect=ydb.Error("Update offsets in tx failed"), + ): + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + with pytest.raises(ydb.Error): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + -# @pytest.mark.skip("Not implemented yet.") class TestTopicTransactionalWriter: async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO): async with ydb.aio.QuerySessionPool(driver) as pool: @@ -47,7 +72,7 @@ async def callee(tx: ydb.aio.QueryTxContext): tx_writer = driver.topic_client.tx_writer(tx, topic_path) await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) - await pool.retry_tx_async(callee) + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) msg = await wait_for(topic_reader.receive_message(), 0.1) assert msg.data.decode() == "123" @@ -61,7 +86,92 @@ async def callee(tx: ydb.aio.QueryTxContext): await tx.rollback() - await pool.retry_tx_async(callee) + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(asyncio.TimeoutError): + await wait_for(topic_reader.receive_message(), 0.1) + + async def test_no_msg_written_in_error_case( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + raise BaseException("error") + + with pytest.raises(BaseException): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) with pytest.raises(asyncio.TimeoutError): await wait_for(topic_reader.receive_message(), 0.1) + + async def test_msg_written_exactly_once_with_retries( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + error_raised = False + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + nonlocal error_raised + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + if not error_raised: + error_raised = True + raise ydb.issues.Unavailable("some retriable error") + + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + msg = await wait_for(topic_reader.receive_message(), 0.1) + assert msg.data.decode() == "123" + + with pytest.raises(asyncio.TimeoutError): + await wait_for(topic_reader.receive_message(), 0.1) + + +class TestTopicTransactionalWriterSync: + def test_commit(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + pool.retry_tx_sync(callee, retry_settings=ydb.RetrySettings(max_retries=1)) + + msg = topic_reader_sync.receive_message(timeout=0.1) + assert msg.data.decode() == "123" + + def test_rollback(self, driver_sync: ydb.aio.Driver, topic_path, topic_reader_sync: ydb.TopicReader): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + tx.rollback() + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(TimeoutError): + topic_reader_sync.receive_message(timeout=0.1) + + def test_no_msg_written_in_error_case( + self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReaderAsyncIO + ): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + raise BaseException("error") + + with pytest.raises(BaseException): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(TimeoutError): + topic_reader_sync.receive_message(timeout=0.1) diff --git a/tox.ini b/tox.ini index df029d2a..f91e7d8a 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py-proto5,py-proto4,py-proto3,py-tls-proto5,py-tls-proto4,py-tls-proto3,style,pylint,black,protoc,py-cov-proto4 +envlist = py,py-proto4,py-proto3,py-tls,py-tls-proto4,py-tls-proto3,style,pylint,black,protoc,py-cov-proto4 minversion = 4.2.6 skipsdist = True ignore_basepython_conflict = true @@ -30,7 +30,7 @@ deps = -r{toxinidir}/test-requirements.txt protobuf<4.0.0 -[testenv:py-proto5] +[testenv:py] commands = pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} deps = @@ -60,7 +60,7 @@ deps = -r{toxinidir}/test-requirements.txt protobuf<4.0.0 -[testenv:py-tls-proto5] +[testenv:py-tls] commands = pytest -v -m tls --docker-compose-remove-volumes --docker-compose=docker-compose-tls.yml {posargs} deps = diff --git a/ydb/_apis.py b/ydb/_apis.py index 2a9a14e8..e54f25d2 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -116,6 +116,7 @@ class TopicService(object): DropTopic = "DropTopic" StreamRead = "StreamRead" StreamWrite = "StreamWrite" + UpdateOffsetsInTransaction = "UpdateOffsetsInTransaction" class QueryService(object): diff --git a/ydb/_errors.py b/ydb/_errors.py index 17002d25..1e2308ef 100644 --- a/ydb/_errors.py +++ b/ydb/_errors.py @@ -5,6 +5,7 @@ _errors_retriable_fast_backoff_types = [ issues.Unavailable, + issues.ClientInternalError, ] _errors_retriable_slow_backoff_types = [ issues.Aborted, diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index 0da75bca..6db50a11 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -1209,6 +1209,52 @@ def to_public(self) -> ydb_topic_public_types.PublicMeteringMode: return ydb_topic_public_types.PublicMeteringMode.UNSPECIFIED +@dataclass +class UpdateOffsetsInTransactionRequest(IToProto): + tx: TransactionIdentity + topics: List[UpdateOffsetsInTransactionRequest.TopicOffsets] + consumer: str + + def to_proto(self): + return ydb_topic_pb2.UpdateOffsetsInTransactionRequest( + tx=self.tx.to_proto(), + consumer=self.consumer, + topics=list( + map( + UpdateOffsetsInTransactionRequest.TopicOffsets.to_proto, + self.topics, + ) + ), + ) + + @dataclass + class TopicOffsets(IToProto): + path: str + partitions: List[UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets] + + def to_proto(self): + return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets( + path=self.path, + partitions=list( + map( + UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets.to_proto, + self.partitions, + ) + ), + ) + + @dataclass + class PartitionOffsets(IToProto): + partition_id: int + partition_offsets: List[OffsetsRange] + + def to_proto(self) -> ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets: + return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets( + partition_id=self.partition_id, + partition_offsets=list(map(OffsetsRange.to_proto, self.partition_offsets)), + ) + + @dataclass class CreateTopicRequest(IToProto, IFromPublic): path: str diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index b48501af..74f06a08 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -108,6 +108,9 @@ def ack_notify(self, offset: int): waiter = self._ack_waiters.popleft() waiter._finish_ok() + def _update_last_commited_offset_if_needed(self, offset: int): + self.committed_offset = max(self.committed_offset, offset) + def close(self): if self.closed: return @@ -211,3 +214,9 @@ def _pop_batch(self, message_count: int) -> PublicBatch: self._bytes_size = self._bytes_size - new_batch._bytes_size return new_batch + + def _update_partition_offsets(self, tx, exc=None): + if exc is not None: + return + offsets = self._commit_get_offsets_range() + self._partition_session._update_last_commited_offset_if_needed(offsets.end) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 7061b4e4..edf40510 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -5,7 +5,7 @@ import gzip import typing from asyncio import Task -from collections import OrderedDict +from collections import defaultdict, OrderedDict from typing import Optional, Set, Dict, Union, Callable import ydb @@ -19,17 +19,24 @@ from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, + to_thread, GrpcWrapperAsyncIO, ) from .._grpc.grpcwrapper.ydb_topic import ( StreamReadMessage, UpdateTokenRequest, UpdateTokenResponse, + UpdateOffsetsInTransactionRequest, Codec, ) from .._errors import check_retriable_error import logging +from ..query.base import TxEvent + +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + logger = logging.getLogger(__name__) @@ -112,6 +119,23 @@ async def receive_batch( max_messages=max_messages, ) + async def receive_batch_with_tx( + self, + tx: "BaseQueryTxContext", + max_messages: typing.Union[int, None] = None, + ) -> typing.Union[datatypes.PublicBatch, None]: + """ + Get one messages batch from reader. + All messages in a batch from same partition. + + use asyncio.wait_for for wait with timeout. + """ + await self._reconnector.wait_message() + return self._reconnector.receive_batch_with_tx_nowait( + tx=tx, + max_messages=max_messages, + ) + async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]: """ Block until receive new message @@ -165,6 +189,7 @@ class ReaderReconnector: _state_changed: asyncio.Event _stream_reader: Optional["ReaderStream"] _first_error: asyncio.Future[YdbError] + _tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]] def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): self._id = self._static_reader_reconnector_counter.inc_and_get() @@ -177,19 +202,26 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): self._background_tasks.add(asyncio.create_task(self._connection_loop())) self._first_error = asyncio.get_running_loop().create_future() + self._tx_to_batches_map = defaultdict(list) + async def _connection_loop(self): attempt = 0 while True: try: + print("reconnector connection loop") self._stream_reader = await ReaderStream.create(self._id, self._driver, self._settings) attempt = 0 self._state_changed.set() await self._stream_reader.wait_error() except BaseException as err: + print(f"FOUND EXCEPTION: {err}") retry_info = check_retriable_error(err, self._settings._retry_settings(), attempt) if not retry_info.is_retriable: + print("ERROR IS NOT RETRIABLE") self._set_first_error(err) return + print(f"ERROR IS RETRIABLE, sleep for {retry_info.sleep_timeout_seconds}") + await asyncio.sleep(retry_info.sleep_timeout_seconds) attempt += 1 @@ -222,9 +254,85 @@ def receive_batch_nowait(self, max_messages: Optional[int] = None): max_messages=max_messages, ) + def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: Optional[int] = None): + batch = self._stream_reader.receive_batch_nowait( + max_messages=max_messages, + ) + + self._init_tx_if_needed(tx) + + self._tx_to_batches_map[tx.tx_id].append(batch) + + tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, None) # probably should be current loop + + return batch + def receive_message_nowait(self): return self._stream_reader.receive_message_nowait() + def _init_tx_if_needed(self, tx: "BaseQueryTxContext"): + if tx.tx_id not in self._tx_to_batches_map: # Init tx callbacks + tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, None) + tx._add_callback(TxEvent.AFTER_COMMIT, self._reconnect_if_tx_commit_failed, None) + tx._add_callback(TxEvent.AFTER_ROLLBACK, self._reconnect_after_tx_rollback, None) + + async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"): + grouped_batches = defaultdict(lambda: defaultdict(list)) + for batch in self._tx_to_batches_map[tx.tx_id]: + grouped_batches[batch._partition_session.topic_path][batch._partition_session.partition_id].append(batch) + + request = UpdateOffsetsInTransactionRequest(tx=tx._tx_identity(), consumer=self._settings.consumer, topics=[]) + + for topic_path in grouped_batches: + topic_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets(path=topic_path, partitions=[]) + for partition_id in grouped_batches[topic_path]: + partition_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets( + partition_id=partition_id, + partition_offsets=[ + batch._commit_get_offsets_range() for batch in grouped_batches[topic_path][partition_id] + ], + ) + topic_offsets.partitions.append(partition_offsets) + request.topics.append(topic_offsets) + + try: + return await self._do_commit_batches_with_tx_call(request) + except BaseException: + exc = issues.ClientInternalError("Failed to update offsets in tx.") + tx._set_external_error(exc) + self._stream_reader._set_first_error(exc) + await asyncio.sleep(0) + finally: + del self._tx_to_batches_map[tx.tx_id] + + async def _do_commit_batches_with_tx_call(self, request: UpdateOffsetsInTransactionRequest): + args = [ + request.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.UpdateOffsetsInTransaction, + topic_common.wrap_operation, + ] + + if asyncio.iscoroutinefunction(self._driver.__call__): + res = await self._driver(*args) + else: + res = await to_thread(self._driver, *args, executor=None) + + return res + + async def _reconnect_after_tx_rollback(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: + exc = exc if exc is not None else issues.ClientInternalError("Reconnect due to transaction rollback") + print("FIRST ERROR SET") + self._stream_reader._set_first_error(exc) + await asyncio.sleep(0) + + async def _reconnect_if_tx_commit_failed(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: + if exc is not None: + self._stream_reader._set_first_error( + issues.ClientInternalError("Reconnect due to transaction commit failed") + ) + await asyncio.sleep(0) + def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.CommitAckWaiter: return self._stream_reader.commit(batch) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 5468ff66..e685df01 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -44,7 +44,7 @@ GrpcWrapperAsyncIO, ) -from ..query.base import TxListenerAsyncIO +from ..query.base import TxEvent if typing.TYPE_CHECKING: from ..query.transaction import BaseQueryTxContext @@ -170,7 +170,7 @@ async def wait_init(self) -> PublicWriterInitInfo: return await self._reconnector.wait_init() -class TxWriterAsyncIO(WriterAsyncIO, TxListenerAsyncIO): +class TxWriterAsyncIO(WriterAsyncIO): _tx: "BaseQueryTxContext" def __init__( @@ -179,20 +179,28 @@ def __init__( driver: SupportedDriverType, settings: PublicWriterSettings, _client=None, + _is_implicit=False, ): self._tx = tx self._loop = asyncio.get_running_loop() self._closed = False self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings), tx=self._tx) self._parent = _client - self._tx._add_listener(self) + self._is_implicit = _is_implicit - async def _on_before_commit(self): - await self.close() + tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, self._loop) + tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, self._loop) - async def _on_before_rollback(self): + async def _on_before_commit(self, tx: "BaseQueryTxContext"): + if self._is_implicit: + return await self.close() + async def _on_before_rollback(self, tx: "BaseQueryTxContext"): + if self._is_implicit: + return + await self.close(flush=False) + class WriterAsyncIOReconnector: _closed: bool @@ -423,7 +431,7 @@ async def _connection_loop(self): done.pop().result() # need for raise exception - reason of stop task except issues.Error as err: err_info = check_retriable_error(err, retry_settings, attempt) - if not err_info.is_retriable: + if not err_info.is_retriable or self._tx is not None: # no retries in tx writer self._stop(err) return @@ -586,7 +594,6 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"): while True: m = await self._new_messages.get() # type: InternalMessage - print("NEW MESSAGE") if m.seq_no > last_seq_no: writer.write([m]) except asyncio.CancelledError: @@ -618,7 +625,6 @@ async def flush(self): # wait last message await asyncio.wait(self._messages_future) - print("ALL MESSAGES WERE SENT TO SERVER") class WriterAsyncIOStream: diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index a5193caf..440b7e01 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -14,13 +14,21 @@ TopicWriterClosedError, ) -from .topic_writer_asyncio import WriterAsyncIO +from ..query.base import TxEvent + +from .topic_writer_asyncio import ( + TxWriterAsyncIO, + WriterAsyncIO, +) from .._topic_common.common import ( _get_shared_event_loop, TimeoutType, CallFromSyncToAsync, ) +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + class WriterSync: _caller: CallFromSyncToAsync @@ -122,3 +130,39 @@ def write_with_ack( self._check_closed() return self._caller.unsafe_call_with_result(self._async_writer.write_with_ack(messages), timeout=timeout) + + +class TxWriterSync(WriterSync): + def __init__( + self, + tx: "BaseQueryTxContext", + driver: SupportedDriverType, + settings: PublicWriterSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, + _parent=None, + ): + + self._closed = False + + if eventloop: + loop = eventloop + else: + loop = _get_shared_event_loop() + + self._caller = CallFromSyncToAsync(loop) + + async def create_async_writer(): + return TxWriterAsyncIO(tx, driver, settings, _is_implicit=True) + + self._async_writer = self._caller.safe_call_with_result(create_async_writer(), None) + self._parent = _parent + + tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, None) + tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, None) + + def _on_before_commit(self, tx: "BaseQueryTxContext"): + self.close() + + def _on_before_rollback(self, tx: "BaseQueryTxContext"): + self.close(flush=False) diff --git a/ydb/aio/query/transaction.py b/ydb/aio/query/transaction.py index cb757324..18ad1405 100644 --- a/ydb/aio/query/transaction.py +++ b/ydb/aio/query/transaction.py @@ -16,6 +16,28 @@ class QueryTxContext(BaseQueryTxContext): + def __init__(self, driver, session_state, session, tx_mode): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() method; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + """ + super().__init__(driver, session_state, session, tx_mode) + self._init_callback_handler(base.CallbackHandlerMode.ASYNC) + async def __aenter__(self) -> "QueryTxContext": """ Enters a context manager and returns a transaction @@ -30,7 +52,7 @@ async def __aexit__(self, *args, **kwargs): it is not finished explicitly """ await self._ensure_prev_stream_finished() - if self._tx_state._state == QueryTxStateEnum.BEGINED: + if self._tx_state._state == QueryTxStateEnum.BEGINED and self._external_error is None: # It's strictly recommended to close transactions directly # by using commit_tx=True flag while executing statement or by # .commit() or .rollback() methods, but here we trying to do best @@ -57,7 +79,6 @@ async def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryT await self._begin_call(settings) return self - @base.with_async_transaction_events async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls commit on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. @@ -66,6 +87,8 @@ async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: :return: A committed transaction or exception if commit is failed """ + self._check_external_error_set() + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return @@ -75,9 +98,14 @@ async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: await self._ensure_prev_stream_finished() - await self._commit_call(settings) + try: + await self._execute_callbacks_async(base.TxEvent.BEFORE_COMMIT) + await self._commit_call(settings) + await self._execute_callbacks_async(base.TxEvent.AFTER_COMMIT, exc=None) + except BaseException as e: + await self._execute_callbacks_async(base.TxEvent.AFTER_COMMIT, exc=e) + raise e - @base.with_async_transaction_events async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. @@ -86,6 +114,8 @@ async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None :return: A committed transaction or exception if commit is failed """ + self._check_external_error_set() + if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED): return @@ -95,7 +125,13 @@ async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None await self._ensure_prev_stream_finished() - await self._rollback_call(settings) + try: + await self._execute_callbacks_async(base.TxEvent.BEFORE_ROLLBACK) + await self._rollback_call(settings) + await self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=None) + except BaseException as e: + self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=e) + raise e async def execute( self, diff --git a/ydb/issues.py b/ydb/issues.py index 065dcbc8..8b098667 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -179,6 +179,10 @@ class SessionPoolEmpty(Error, queue.Empty): status = StatusCode.SESSION_POOL_EMPTY +class ClientInternalError(Error): + status = StatusCode.CLIENT_INTERNAL_ERROR + + class UnexpectedGrpcMessage(Error): def __init__(self, message: str): super().__init__(message) diff --git a/ydb/query/base.py b/ydb/query/base.py index e20529e2..3d30eacc 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -2,6 +2,7 @@ import asyncio import enum import functools +from collections import defaultdict import typing from typing import ( @@ -18,6 +19,10 @@ from .. import _utilities from .. import _apis +from ydb._topic_common.common import CallFromSyncToAsync, _get_shared_event_loop +from ydb._grpc.grpcwrapper.common_utils import to_thread + + if typing.TYPE_CHECKING: from .transaction import BaseQueryTxContext @@ -199,114 +204,64 @@ def wrap_execute_query_response( return None -class TxListener: - def _on_before_commit(self): - pass - - def _on_after_commit(self, exc: typing.Optional[BaseException]): - pass - - def _on_before_rollback(self): - pass - - def _on_after_rollback(self, exc: typing.Optional[BaseException]): - pass - - -class TxListenerAsyncIO: - async def _on_before_commit(self): - pass - - async def _on_after_commit(self, exc: typing.Optional[BaseException]): - pass - - async def _on_before_rollback(self): - pass - - async def _on_after_rollback(self, exc: typing.Optional[BaseException]): - pass - +class TxEvent(enum.Enum): + BEFORE_COMMIT = "BEFORE_COMMIT" + AFTER_COMMIT = "AFTER_COMMIT" + BEFORE_ROLLBACK = "BEFORE_ROLLBACK" + AFTER_ROLLBACK = "AFTER_ROLLBACK" -def with_transaction_events(method): - @functools.wraps(method) - def wrapper(self, *args, **kwargs): - method_name = method.__name__ - before_event = f"_on_before_{method_name}" - after_event = f"_on_after_{method_name}" - self._notify_listeners_sync(before_event) +class CallbackHandlerMode(enum.Enum): + SYNC = "SYNC" + ASYNC = "ASYNC" - try: - result = method(self, *args, **kwargs) - self._notify_listeners_sync(after_event, exc=None) - - return result - except BaseException as e: - self._notify_listeners_sync(after_event, exc=e) - raise +def _get_sync_callback(method: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]): + if asyncio.iscoroutinefunction(method): + if loop is None: + loop = _get_shared_event_loop() - return wrapper + def async_to_sync_callback(*args, **kwargs): + caller = CallFromSyncToAsync(loop) + return caller.safe_call_with_result(method(*args, **kwargs), 10) + return async_to_sync_callback + return method -def with_async_transaction_events(method): - @functools.wraps(method) - async def wrapper(self, *args, **kwargs): - method_name = method.__name__ - before_event = f"_on_before_{method_name}" - after_event = f"_on_after_{method_name}" - await self._notify_listeners_async(before_event) +def _get_async_callback(method: typing.Callable): + if asyncio.iscoroutinefunction(method): + return method - try: - result = await method(self, *args, **kwargs) - - await self._notify_listeners_async(after_event, exc=None) - - return result - except BaseException as e: - await self._notify_listeners_async(after_event, exc=e) - raise + async def sync_to_async_callback(*args, **kwargs): + return await to_thread(method, *args, **kwargs, executor=None) - return wrapper - - -class ListenerHandlerMixin: - def _init_listener_handler(self): - self.listeners = [] - - def _add_listener(self, listener): - if listener not in self.listeners: - self.listeners.append(listener) - return self - - def _remove_listener(self, listener): - if listener in self.listeners: - self.listeners.remove(listener) - return self - - def _clear_listeners(self): - self.listeners.clear() - return self + return sync_to_async_callback - def _notify_sync_listeners(self, event_name: str, **kwargs) -> None: - for listener in self.listeners: - if isinstance(listener, TxListener) and hasattr(listener, event_name): - getattr(listener, event_name)(**kwargs) - async def _notify_async_listeners(self, event_name: str, **kwargs) -> None: - coros = [] - for listener in self.listeners: - if isinstance(listener, TxListenerAsyncIO) and hasattr(listener, event_name): - coros.append(getattr(listener, event_name)(**kwargs)) +class CallbackHandler: + def _init_callback_handler(self, mode: CallbackHandlerMode) -> None: + self._callbacks = defaultdict(list) + self._callback_mode = mode - if coros: - await asyncio.gather(*coros) + def _execute_callbacks_sync(self, event_name: str, *args, **kwargs) -> None: + print(f"EXECUTE SYNC CALLBACKS FOR EVENT: {event_name}") + for callback in self._callbacks[event_name]: + callback(self, *args, **kwargs) - def _notify_listeners_sync(self, event_name: str, **kwargs) -> None: - self._notify_sync_listeners(event_name, **kwargs) + async def _execute_callbacks_async(self, event_name: str, *args, **kwargs) -> None: + print(f"EXECUTE ASYNC CALLBACKS FOR EVENT: {event_name}") + tasks = [asyncio.create_task(callback(self, *args, **kwargs)) for callback in self._callbacks[event_name]] + if not tasks: + return + await asyncio.gather(*tasks) - async def _notify_listeners_async(self, event_name: str, **kwargs) -> None: - # self._notify_sync_listeners(event_name, **kwargs) + def _prepare_callback( + self, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop] + ) -> typing.Callable: + if self._callback_mode == CallbackHandlerMode.SYNC: + return _get_sync_callback(callback, loop) + return _get_async_callback(callback) - await self._notify_async_listeners(event_name, **kwargs) + def _add_callback(self, event_name: str, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]): + self._callbacks[event_name].append(self._prepare_callback(callback, loop)) diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index cbe91764..ae7642db 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -183,7 +183,7 @@ def wrap_tx_rollback_response( return tx -class BaseQueryTxContext(base.ListenerHandlerMixin): +class BaseQueryTxContext(base.CallbackHandler): def __init__(self, driver, session_state, session, tx_mode): """ An object that provides a simple transaction context manager that allows statements execution @@ -209,7 +209,7 @@ def __init__(self, driver, session_state, session, tx_mode): self._session_state = session_state self.session = session self._prev_stream = None - self._init_listener_handler() + self._external_error = None @property def session_id(self) -> str: @@ -234,6 +234,14 @@ def _tx_identity(self) -> _ydb_topic.TransactionIdentity: raise RuntimeError("Unable to get tx identity without started tx.") return _ydb_topic.TransactionIdentity(self.tx_id, self.session_id) + def _set_external_error(self, exc: BaseException) -> None: + self._external_error = exc + + def _check_external_error_set(self): + if self._external_error is None: + return + raise issues.ClientInternalError("Transaction was failed by external error.") from self._external_error + def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) @@ -247,6 +255,7 @@ def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxCo ) def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": + self._check_external_error_set() self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) return self._driver( @@ -259,6 +268,7 @@ def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxC ) def _rollback_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": + self._check_external_error_set() self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) return self._driver( @@ -281,6 +291,7 @@ def _execute_call( settings: Optional[BaseRequestSettings], ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: self._tx_state._check_tx_ready_to_use() + self._check_external_error_set() request = base.create_execute_query_request( query=query, @@ -314,6 +325,29 @@ def _move_to_commited(self) -> None: class QueryTxContext(BaseQueryTxContext): + def __init__(self, driver, session_state, session, tx_mode): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() method; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + """ + + super().__init__(driver, session_state, session, tx_mode) + self._init_callback_handler(base.CallbackHandlerMode.SYNC) + def __enter__(self) -> "BaseQueryTxContext": """ Enters a context manager and returns a transaction @@ -328,7 +362,7 @@ def __exit__(self, *args, **kwargs): it is not finished explicitly """ self._ensure_prev_stream_finished() - if self._tx_state._state == QueryTxStateEnum.BEGINED: + if self._tx_state._state == QueryTxStateEnum.BEGINED and self._external_error is None: # It's strictly recommended to close transactions directly # by using commit_tx=True flag while executing statement or by # .commit() or .rollback() methods, but here we trying to do best @@ -356,7 +390,6 @@ def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryTxConte return self - @base.with_transaction_events def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls commit on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. @@ -365,6 +398,7 @@ def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: :return: A committed transaction or exception if commit is failed """ + self._check_external_error_set() if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return @@ -374,9 +408,14 @@ def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: self._ensure_prev_stream_finished() - self._commit_call(settings) + try: + self._execute_callbacks_sync(base.TxEvent.BEFORE_COMMIT) + self._commit_call(settings) + self._execute_callbacks_sync(base.TxEvent.AFTER_COMMIT, exc=None) + except BaseException as e: # TODO: probably should be less wide + self._execute_callbacks_sync(base.TxEvent.AFTER_COMMIT, exc=e) + raise e - @base.with_transaction_events def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. @@ -385,6 +424,7 @@ def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: :return: A committed transaction or exception if commit is failed """ + self._check_external_error_set() if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED): return @@ -394,7 +434,13 @@ def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: self._ensure_prev_stream_finished() - self._rollback_call(settings) + try: + self._execute_callbacks_sync(base.TxEvent.BEFORE_ROLLBACK) + self._rollback_call(settings) + self._execute_callbacks_sync(base.TxEvent.AFTER_ROLLBACK, exc=None) + except BaseException as e: # TODO: probably should be less wide + self._execute_callbacks_sync(base.TxEvent.AFTER_ROLLBACK, exc=e) + raise e def execute( self, diff --git a/ydb/topic.py b/ydb/topic.py index 6954a5ff..26dc57b3 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -68,6 +68,7 @@ from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter +from ._topic_writer.topic_writer_sync import TxWriterSync as TxTopicWriter from ._topic_common.common import ( wrap_operation as _wrap_operation, @@ -517,6 +518,36 @@ def writer( return TopicWriter(self._driver, settings, _parent=self) + def tx_writer( + self, + tx, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + # encoders: map[codec_code] func(encoded_bytes)->decoded_bytes + # the func will be called from multiply threads in parallel. + encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None, + # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. + # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. + encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool + ) -> TopicWriter: + args = locals().copy() + del args["self"] + del args["tx"] + self._check_closed() + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TxTopicWriter(tx, self._driver, settings, _parent=self) + def close(self): if self._closed: return From 78f8149ad99920b859ca4423754fe9f3d8be7848 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 24 Mar 2025 16:02:02 +0300 Subject: [PATCH 08/11] sync version of docs --- tests/topics/test_topic_transactions.py | 288 +++++++++++++++++++++- ydb/_topic_reader/topic_reader_asyncio.py | 25 +- ydb/_topic_reader/topic_reader_sync.py | 28 +++ ydb/aio/query/transaction.py | 2 +- ydb/query/base.py | 2 - 5 files changed, 319 insertions(+), 26 deletions(-) diff --git a/tests/topics/test_topic_transactions.py b/tests/topics/test_topic_transactions.py index b5fc8bab..a45a565e 100644 --- a/tests/topics/test_topic_transactions.py +++ b/tests/topics/test_topic_transactions.py @@ -4,7 +4,7 @@ from unittest import mock import ydb -DEFAULT_TIMEOUT = 0.1 +DEFAULT_TIMEOUT = 0.5 DEFAULT_RETRY_SETTINGS = ydb.RetrySettings(max_retries=1) @@ -19,11 +19,16 @@ async def callee(tx: ydb.aio.QueryTxContext): assert len(batch.messages) == 1 assert batch.messages[0].data.decode() == "123" + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "456" + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + assert len(reader._reconnector._tx_to_batches_map) == 0 async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) - assert msg.data.decode() == "456" + assert msg.data.decode() == "789" async def test_rollback(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: @@ -37,6 +42,7 @@ async def callee(tx: ydb.aio.QueryTxContext): await tx.rollback() await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + assert len(reader._reconnector._tx_to_batches_map) == 0 msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) assert msg.data.decode() == "123" @@ -57,12 +63,211 @@ async def callee(tx: ydb.aio.QueryTxContext): assert len(batch.messages) == 1 assert batch.messages[0].data.decode() == "123" - with pytest.raises(ydb.Error): + with pytest.raises(ydb.Error, match="Transaction was failed"): await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + assert len(reader._reconnector._tx_to_batches_map) == 0 + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) assert msg.data.decode() == "123" + async def test_error_in_lambda(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + raise RuntimeError("Something went wrong") + + with pytest.raises(RuntimeError): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._reconnector._tx_to_batches_map) == 0 + + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + async def test_error_during_commit(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + with mock.patch.object( + tx, + "_commit_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + await tx.commit() + + with pytest.raises(ydb.Unavailable): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._reconnector._tx_to_batches_map) == 0 + + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + async def test_error_during_rollback(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + with mock.patch.object( + tx, + "_rollback_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + await tx.rollback() + + with pytest.raises(ydb.Unavailable): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._reconnector._tx_to_batches_map) == 0 + + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + +class TestTopicTransactionalReaderSync: + def test_commit(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with ydb.QuerySessionPool(driver_sync) as pool: + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "456" + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "789" + + def test_rollback(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + tx.rollback() + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + def test_tx_failed_if_update_offsets_call_failed( + self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer + ): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + with mock.patch.object( + reader._async_reader._reconnector, + "_do_commit_batches_with_tx_call", + side_effect=ydb.Error("Update offsets in tx failed"), + ): + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + with pytest.raises(ydb.Error, match="Transaction was failed"): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + def test_error_in_lambda(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + raise RuntimeError("Something went wrong") + + with pytest.raises(RuntimeError): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + def test_error_during_commit(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + with mock.patch.object( + tx, + "_commit_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + tx.commit() + + with pytest.raises(ydb.Unavailable): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + def test_error_during_rollback(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + with mock.patch.object( + tx, + "_rollback_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + tx.rollback() + + with pytest.raises(ydb.Unavailable): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + class TestTopicTransactionalWriter: async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO): @@ -108,6 +313,27 @@ async def callee(tx: ydb.aio.QueryTxContext): with pytest.raises(asyncio.TimeoutError): await wait_for(topic_reader.receive_message(), 0.1) + async def test_no_msg_written_in_tx_commit_error( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + with mock.patch.object( + tx, + "_commit_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + await tx.commit() + + with pytest.raises(ydb.Unavailable): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(asyncio.TimeoutError): + await wait_for(topic_reader.receive_message(), 0.1) + async def test_msg_written_exactly_once_with_retries( self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO ): @@ -140,12 +366,12 @@ def callee(tx: ydb.QueryTxContext): tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) - pool.retry_tx_sync(callee, retry_settings=ydb.RetrySettings(max_retries=1)) + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) - msg = topic_reader_sync.receive_message(timeout=0.1) + msg = topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) assert msg.data.decode() == "123" - def test_rollback(self, driver_sync: ydb.aio.Driver, topic_path, topic_reader_sync: ydb.TopicReader): + def test_rollback(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader): with ydb.QuerySessionPool(driver_sync) as pool: def callee(tx: ydb.QueryTxContext): @@ -157,10 +383,10 @@ def callee(tx: ydb.QueryTxContext): pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) with pytest.raises(TimeoutError): - topic_reader_sync.receive_message(timeout=0.1) + topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) def test_no_msg_written_in_error_case( - self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReaderAsyncIO + self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader ): with ydb.QuerySessionPool(driver_sync) as pool: @@ -174,4 +400,48 @@ def callee(tx: ydb.QueryTxContext): pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) with pytest.raises(TimeoutError): - topic_reader_sync.receive_message(timeout=0.1) + topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + + def test_no_msg_written_in_tx_commit_error( + self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader + ): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + with mock.patch.object( + tx, + "_commit_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + tx.commit() + + with pytest.raises(ydb.Unavailable): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(TimeoutError): + topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + + def test_msg_written_exactly_once_with_retries( + self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader + ): + error_raised = False + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + nonlocal error_raised + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + if not error_raised: + error_raised = True + raise ydb.issues.Unavailable("some retriable error") + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + msg = topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + with pytest.raises(TimeoutError): + topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index edf40510..4c239882 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -125,7 +125,7 @@ async def receive_batch_with_tx( max_messages: typing.Union[int, None] = None, ) -> typing.Union[datatypes.PublicBatch, None]: """ - Get one messages batch from reader. + Get one messages batch with tx from reader. All messages in a batch from same partition. use asyncio.wait_for for wait with timeout. @@ -208,19 +208,15 @@ async def _connection_loop(self): attempt = 0 while True: try: - print("reconnector connection loop") self._stream_reader = await ReaderStream.create(self._id, self._driver, self._settings) attempt = 0 self._state_changed.set() await self._stream_reader.wait_error() except BaseException as err: - print(f"FOUND EXCEPTION: {err}") retry_info = check_retriable_error(err, self._settings._retry_settings(), attempt) if not retry_info.is_retriable: - print("ERROR IS NOT RETRIABLE") self._set_first_error(err) return - print(f"ERROR IS RETRIABLE, sleep for {retry_info.sleep_timeout_seconds}") await asyncio.sleep(retry_info.sleep_timeout_seconds) @@ -273,8 +269,8 @@ def receive_message_nowait(self): def _init_tx_if_needed(self, tx: "BaseQueryTxContext"): if tx.tx_id not in self._tx_to_batches_map: # Init tx callbacks tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, None) - tx._add_callback(TxEvent.AFTER_COMMIT, self._reconnect_if_tx_commit_failed, None) - tx._add_callback(TxEvent.AFTER_ROLLBACK, self._reconnect_after_tx_rollback, None) + tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, None) + tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, None) async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"): grouped_batches = defaultdict(lambda: defaultdict(list)) @@ -301,7 +297,6 @@ async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"): exc = issues.ClientInternalError("Failed to update offsets in tx.") tx._set_external_error(exc) self._stream_reader._set_first_error(exc) - await asyncio.sleep(0) finally: del self._tx_to_batches_map[tx.tx_id] @@ -320,18 +315,20 @@ async def _do_commit_batches_with_tx_call(self, request: UpdateOffsetsInTransact return res - async def _reconnect_after_tx_rollback(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: - exc = exc if exc is not None else issues.ClientInternalError("Reconnect due to transaction rollback") - print("FIRST ERROR SET") + async def _handle_after_tx_rollback(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: + if tx.tx_id in self._tx_to_batches_map: + del self._tx_to_batches_map[tx.tx_id] + exc = issues.ClientInternalError("Reconnect due to transaction rollback") self._stream_reader._set_first_error(exc) - await asyncio.sleep(0) - async def _reconnect_if_tx_commit_failed(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: + async def _handle_after_tx_commit(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: + if tx.tx_id in self._tx_to_batches_map: + del self._tx_to_batches_map[tx.tx_id] + if exc is not None: self._stream_reader._set_first_error( issues.ClientInternalError("Reconnect due to transaction commit failed") ) - await asyncio.sleep(0) def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.CommitAckWaiter: return self._stream_reader.commit(batch) diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index eda1d374..1dfcee86 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -20,6 +20,9 @@ TopicReaderClosedError, ) +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + class TopicReaderSync: _caller: CallFromSyncToAsync @@ -109,6 +112,31 @@ def receive_batch( timeout, ) + def receive_batch_with_tx( + self, + tx: "BaseQueryTxContext", + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> Union[PublicBatch, None]: + """ + Get one messages batch with tx from reader + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. + """ + self._check_closed() + + return self._caller.safe_call_with_result( + self._async_reader.receive_batch_with_tx( + tx=tx, + max_messages=max_messages, + ), + timeout, + ) + def commit(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]): """ Put commit message to internal buffer. diff --git a/ydb/aio/query/transaction.py b/ydb/aio/query/transaction.py index 18ad1405..f0547e5f 100644 --- a/ydb/aio/query/transaction.py +++ b/ydb/aio/query/transaction.py @@ -130,7 +130,7 @@ async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None await self._rollback_call(settings) await self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=None) except BaseException as e: - self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=e) + await self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=e) raise e async def execute( diff --git a/ydb/query/base.py b/ydb/query/base.py index 3d30eacc..a5ebedd9 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -245,12 +245,10 @@ def _init_callback_handler(self, mode: CallbackHandlerMode) -> None: self._callback_mode = mode def _execute_callbacks_sync(self, event_name: str, *args, **kwargs) -> None: - print(f"EXECUTE SYNC CALLBACKS FOR EVENT: {event_name}") for callback in self._callbacks[event_name]: callback(self, *args, **kwargs) async def _execute_callbacks_async(self, event_name: str, *args, **kwargs) -> None: - print(f"EXECUTE ASYNC CALLBACKS FOR EVENT: {event_name}") tasks = [asyncio.create_task(callback(self, *args, **kwargs)) for callback in self._callbacks[event_name]] if not tasks: return From 82ec438fb79c2361b875acf174cf3bd702edd38c Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 24 Mar 2025 17:39:31 +0300 Subject: [PATCH 09/11] add executable example --- examples/topic/topic_transactions_example.py | 94 ++++++++++++++++---- ydb/_topic_writer/topic_writer_asyncio.py | 13 +++ ydb/topic.py | 6 +- 3 files changed, 95 insertions(+), 18 deletions(-) diff --git a/examples/topic/topic_transactions_example.py b/examples/topic/topic_transactions_example.py index ac273025..2b9c6160 100644 --- a/examples/topic/topic_transactions_example.py +++ b/examples/topic/topic_transactions_example.py @@ -1,28 +1,90 @@ +import asyncio +import argparse +import logging import ydb -def writer_example(driver: ydb.Driver, topic: str): - session_pool = ydb.QuerySessionPool(driver) +async def connect(endpoint: str, database: str) -> ydb.aio.Driver: + config = ydb.DriverConfig(endpoint=endpoint, database=database) + config.credentials = ydb.credentials_from_env_variables() + driver = ydb.aio.Driver(config) + await driver.wait(5, fail_fast=True) + return driver - def callee(tx: ydb.QueryTxContext): - tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) # <======= - # дефолт - без дедупликации, без ретраев и без producer_id. - with tx.execute(query="select 1") as result_sets: - messages = [result_set.rows[0] for result_set in result_sets] +async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str): + try: + await driver.topic_client.drop_topic(topic) + except ydb.SchemeError: + pass - tx_writer.write(messages) # вне зависимости от состояния вышестоящего стрима поведение должно быть одинаковое + await driver.topic_client.create_topic(topic, consumers=[consumer]) - session_pool.retry_tx_sync(callee) +async def write_with_tx_example(driver: ydb.aio.Driver, topic: str, message_count: int = 10): + async with ydb.aio.QuerySessionPool(driver) as session_pool: -def reader_example(driver: ydb.Driver, reader: ydb.TopicReader): - session_pool = ydb.QuerySessionPool(driver) + async def callee(tx: ydb.aio.QueryTxContext): + print(f"TX ID: {tx.tx_id}") + print(f"TX STATE: {tx._tx_state._state.value}") + tx_writer: ydb.TopicTxWriterAsyncIO = driver.topic_client.tx_writer(tx, topic) + print(f"TX ID: {tx.tx_id}") + print(f"TX STATE: {tx._tx_state._state.value}") + for i in range(message_count): + result_stream = await tx.execute(query=f"select {i} as res") + messages = [result_set.rows[0]["res"] async for result_set in result_stream] - def callee(tx: ydb.QueryTxContext): - batch = reader.receive_batch_with_tx(tx, max_messages=5) # <======= + await tx_writer.write([ydb.TopicWriterMessage(data=str(message)) for message in messages]) - with tx.execute(query="INSERT INTO max_values(val) VALUES ($val)", parameters={"$val": max(batch)}) as _: - pass + print(f"Messages {messages} were written with tx.") - session_pool.retry_tx_sync(callee) + await session_pool.retry_tx_async(callee) + + +async def read_with_tx_example(driver: ydb.aio.Driver, topic: str, consumer: str, message_count: int = 10): + async with driver.topic_client.reader(topic, consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as session_pool: + for _ in range(message_count): + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await reader.receive_batch_with_tx(tx, max_messages=1) + print(f"Messages {batch.messages[0].data} were read with tx.") + + await session_pool.retry_tx_async(callee) + + +async def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""YDB topic basic example.\n""", + ) + parser.add_argument("-d", "--database", default="/local", help="Name of the database to use") + parser.add_argument("-e", "--endpoint", default="grpc://localhost:2136", help="Endpoint url to use") + parser.add_argument("-p", "--path", default="test-topic", help="Topic name") + parser.add_argument("-c", "--consumer", default="consumer", help="Consumer name") + parser.add_argument("-v", "--verbose", default=False, action="store_true") + parser.add_argument( + "-s", + "--skip-drop-and-create-topic", + default=False, + action="store_true", + help="Use existed topic, skip remove it and re-create", + ) + + args = parser.parse_args() + + if args.verbose: + logger = logging.getLogger("topicexample") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + + driver = await connect(args.endpoint, args.database) + if not args.skip_drop_and_create_topic: + await create_topic(driver, args.path, args.consumer) + + await write_with_tx_example(driver, args.path) + await read_with_tx_example(driver, args.path, args.consumer) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index e685df01..f3ce7804 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -191,9 +191,22 @@ def __init__( tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, self._loop) tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, self._loop) + async def write( + self, + messages: Union[Message, List[Message]], + ): + """ + send one or number of messages to server. + it put message to internal buffer + + For wait with timeout use asyncio.wait_for. + """ + await self.write_with_ack(messages) + async def _on_before_commit(self, tx: "BaseQueryTxContext"): if self._is_implicit: return + await self.flush() await self.close() async def _on_before_rollback(self, tx: "BaseQueryTxContext"): diff --git a/ydb/topic.py b/ydb/topic.py index 26dc57b3..1762632d 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -25,6 +25,8 @@ "TopicWriteResult", "TopicWriter", "TopicWriterAsyncIO", + "TopicTxWriter", + "TopicTxWriterAsyncIO", "TopicWriterInitInfo", "TopicWriterMessage", "TopicWriterSettings", @@ -68,7 +70,7 @@ from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter -from ._topic_writer.topic_writer_sync import TxWriterSync as TxTopicWriter +from ._topic_writer.topic_writer_sync import TxWriterSync as TopicTxWriter from ._topic_common.common import ( wrap_operation as _wrap_operation, @@ -546,7 +548,7 @@ def tx_writer( if not settings.encoder_executor: settings.encoder_executor = self._executor - return TxTopicWriter(tx, self._driver, settings, _parent=self) + return TopicTxWriter(tx, self._driver, settings, _parent=self) def close(self): if self._closed: From c7c1f9f5e890ad4560ef204e87d35d4be309d7e4 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 25 Mar 2025 13:34:48 +0300 Subject: [PATCH 10/11] Change topics __del__ to warn if not closed --- tests/topics/test_topic_reader.py | 14 +++++++++----- ydb/_grpc/grpcwrapper/common_utils.py | 3 --- ydb/_topic_reader/topic_reader_asyncio.py | 3 +-- ydb/_topic_reader/topic_reader_sync.py | 6 +++++- ydb/_topic_writer/topic_writer_asyncio.py | 6 ++---- ydb/_topic_writer/topic_writer_sync.py | 6 +++++- ydb/aio/driver.py | 1 + ydb/driver.py | 1 + ydb/topic.py | 13 +++++++++---- 9 files changed, 33 insertions(+), 20 deletions(-) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 23b5b4be..623dc8c0 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -174,12 +174,13 @@ def test_read_and_commit_with_close_reader(self, driver_sync, topic_with_message assert message != message2 def test_read_and_commit_with_ack(self, driver_sync, topic_with_messages, topic_consumer): - reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer) - message = reader.receive_message() - reader.commit_with_ack(message) + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + message = reader.receive_message() + reader.commit_with_ack(message) + + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + batch = reader.receive_batch() - reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer) - batch = reader.receive_batch() assert message != batch.messages[0] def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer): @@ -247,3 +248,6 @@ async def wait(fut): datas.sort() assert datas == ["10", "11"] + + await reader0.close() + await reader1.close() diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 7fb5b684..6a7275b4 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -161,9 +161,6 @@ def __init__(self, convert_server_grpc_to_wrapper): self._stream_call = None self._wait_executor = None - def __del__(self): - self._clean_executor(wait=False) - async def start(self, driver: SupportedDriverType, stub, method): if asyncio.iscoroutinefunction(driver.__call__): await self._start_asyncio_driver(driver, stub, method) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 4c239882..6408234b 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -95,8 +95,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def __del__(self): if not self._closed: - task = self._loop.create_task(self.close(flush=False)) - topic_common.wrap_set_name_for_asyncio_task(task, task_name="close reader") + logger.warning("Topic reader was not closed properly. Consider using method close().") async def wait_message(self): """ diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index 1dfcee86..3e6806d0 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -1,5 +1,6 @@ import asyncio import concurrent.futures +import logging import typing from typing import List, Union, Optional @@ -23,6 +24,8 @@ if typing.TYPE_CHECKING: from ..query.transaction import BaseQueryTxContext +logger = logging.getLogger(__name__) + class TopicReaderSync: _caller: CallFromSyncToAsync @@ -55,7 +58,8 @@ async def create_reader(): self._parent = _parent def __del__(self): - self.close(flush=False) + if not self._closed: + logger.warning("Topic reader was not closed properly. Consider using method close().") def __enter__(self): return self diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index f3ce7804..c86ada42 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -80,10 +80,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): raise def __del__(self): - if self._closed or self._loop.is_closed(): - return - - self._loop.call_soon(functools.partial(self.close, flush=False)) + if not self._closed: + logger.warning("Topic writer was not closed properly. Consider using method close().") async def close(self, *, flush: bool = True): if self._closed: diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 440b7e01..4796d7ac 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import typing from concurrent.futures import Future from typing import Union, List, Optional @@ -29,6 +30,8 @@ if typing.TYPE_CHECKING: from ..query.transaction import BaseQueryTxContext +logger = logging.getLogger(__name__) + class WriterSync: _caller: CallFromSyncToAsync @@ -71,7 +74,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): raise def __del__(self): - self.close(flush=False) + if not self._closed: + logger.warning("Topic writer was not closed properly. Consider using method close().") def close(self, *, flush: bool = True, timeout: TimeoutType = None): if self._closed: diff --git a/ydb/aio/driver.py b/ydb/aio/driver.py index 9cd6fd2b..267997fb 100644 --- a/ydb/aio/driver.py +++ b/ydb/aio/driver.py @@ -62,4 +62,5 @@ def __init__( async def stop(self, timeout=10): await self.table_client._stop_pool_if_needed(timeout=timeout) + self.topic_client.close() await super().stop(timeout=timeout) diff --git a/ydb/driver.py b/ydb/driver.py index 49bd223c..3998aeee 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -288,4 +288,5 @@ def __init__( def stop(self, timeout=10): self.table_client._stop_pool_if_needed(timeout=timeout) + self.topic_client.close() super().stop(timeout=timeout) diff --git a/ydb/topic.py b/ydb/topic.py index 1762632d..52f98e61 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -35,6 +35,7 @@ import concurrent.futures import datetime from dataclasses import dataclass +import logging from typing import List, Union, Mapping, Optional, Dict, Callable from . import aio, Credentials, _apis, issues @@ -92,6 +93,8 @@ PublicAlterAutoPartitioningSettings as TopicAlterAutoPartitioningSettings, ) +logger = logging.getLogger(__name__) + class TopicClientAsyncIO: _closed: bool @@ -112,7 +115,8 @@ def __init__(self, driver: aio.Driver, settings: Optional[TopicClientSettings] = ) def __del__(self): - self.close() + if not self._closed: + logger.warning("Topic client was not closed properly. Consider using method close().") async def create_topic( self, @@ -320,7 +324,7 @@ def _check_closed(self): if not self._closed: return - raise RuntimeError("Topic client closed") + raise issues.Error("Topic client closed") class TopicClient: @@ -343,7 +347,8 @@ def __init__(self, driver: driver.Driver, settings: Optional[TopicClientSettings ) def __del__(self): - self.close() + if not self._closed: + logger.warning("Topic client was not closed properly. Consider using method close().") def create_topic( self, @@ -561,7 +566,7 @@ def _check_closed(self): if not self._closed: return - raise RuntimeError("Topic client closed") + raise issues.Error("Topic client closed") @dataclass From 22c466cd39573d32ffb792ce7fd1c2a1f48c5e75 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 25 Mar 2025 13:50:39 +0300 Subject: [PATCH 11/11] Fix review comments --- .../topic/topic_transactions_async_example.py | 86 +++++++++++++++++++ examples/topic/topic_transactions_example.py | 67 +++++++-------- tests/topics/test_topic_transactions.py | 22 +++++ ydb/_topic_reader/topic_reader_asyncio.py | 25 ++++-- ydb/_topic_writer/topic_writer_asyncio.py | 11 ++- 5 files changed, 163 insertions(+), 48 deletions(-) create mode 100644 examples/topic/topic_transactions_async_example.py diff --git a/examples/topic/topic_transactions_async_example.py b/examples/topic/topic_transactions_async_example.py new file mode 100644 index 00000000..cae61063 --- /dev/null +++ b/examples/topic/topic_transactions_async_example.py @@ -0,0 +1,86 @@ +import asyncio +import argparse +import logging +import ydb + + +async def connect(endpoint: str, database: str) -> ydb.aio.Driver: + config = ydb.DriverConfig(endpoint=endpoint, database=database) + config.credentials = ydb.credentials_from_env_variables() + driver = ydb.aio.Driver(config) + await driver.wait(5, fail_fast=True) + return driver + + +async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str): + try: + await driver.topic_client.drop_topic(topic) + except ydb.SchemeError: + pass + + await driver.topic_client.create_topic(topic, consumers=[consumer]) + + +async def write_with_tx_example(driver: ydb.aio.Driver, topic: str, message_count: int = 10): + async with ydb.aio.QuerySessionPool(driver) as session_pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer: ydb.TopicTxWriterAsyncIO = driver.topic_client.tx_writer(tx, topic) + + for i in range(message_count): + async with await tx.execute(query=f"select {i} as res;") as result_stream: + async for result_set in result_stream: + message = str(result_set.rows[0]["res"]) + await tx_writer.write(ydb.TopicWriterMessage(message)) + print(f"Message {result_set.rows[0]['res']} was written with tx.") + + await session_pool.retry_tx_async(callee) + + +async def read_with_tx_example(driver: ydb.aio.Driver, topic: str, consumer: str, message_count: int = 10): + async with driver.topic_client.reader(topic, consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as session_pool: + for _ in range(message_count): + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await reader.receive_batch_with_tx(tx, max_messages=1) + print(f"Message {batch.messages[0].data.decode()} was read with tx.") + + await session_pool.retry_tx_async(callee) + + +async def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""YDB topic basic example.\n""", + ) + parser.add_argument("-d", "--database", default="/local", help="Name of the database to use") + parser.add_argument("-e", "--endpoint", default="grpc://localhost:2136", help="Endpoint url to use") + parser.add_argument("-p", "--path", default="test-topic", help="Topic name") + parser.add_argument("-c", "--consumer", default="consumer", help="Consumer name") + parser.add_argument("-v", "--verbose", default=False, action="store_true") + parser.add_argument( + "-s", + "--skip-drop-and-create-topic", + default=False, + action="store_true", + help="Use existed topic, skip remove it and re-create", + ) + + args = parser.parse_args() + + if args.verbose: + logger = logging.getLogger("topicexample") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + + async with await connect(args.endpoint, args.database) as driver: + if not args.skip_drop_and_create_topic: + await create_topic(driver, args.path, args.consumer) + + await write_with_tx_example(driver, args.path) + await read_with_tx_example(driver, args.path, args.consumer) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/topic/topic_transactions_example.py b/examples/topic/topic_transactions_example.py index 2b9c6160..0f7432e7 100644 --- a/examples/topic/topic_transactions_example.py +++ b/examples/topic/topic_transactions_example.py @@ -1,59 +1,54 @@ -import asyncio import argparse import logging import ydb -async def connect(endpoint: str, database: str) -> ydb.aio.Driver: +def connect(endpoint: str, database: str) -> ydb.Driver: config = ydb.DriverConfig(endpoint=endpoint, database=database) config.credentials = ydb.credentials_from_env_variables() - driver = ydb.aio.Driver(config) - await driver.wait(5, fail_fast=True) + driver = ydb.Driver(config) + driver.wait(5, fail_fast=True) return driver -async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str): +def create_topic(driver: ydb.Driver, topic: str, consumer: str): try: - await driver.topic_client.drop_topic(topic) + driver.topic_client.drop_topic(topic) except ydb.SchemeError: pass - await driver.topic_client.create_topic(topic, consumers=[consumer]) + driver.topic_client.create_topic(topic, consumers=[consumer]) -async def write_with_tx_example(driver: ydb.aio.Driver, topic: str, message_count: int = 10): - async with ydb.aio.QuerySessionPool(driver) as session_pool: +def write_with_tx_example(driver: ydb.Driver, topic: str, message_count: int = 10): + with ydb.QuerySessionPool(driver) as session_pool: - async def callee(tx: ydb.aio.QueryTxContext): - print(f"TX ID: {tx.tx_id}") - print(f"TX STATE: {tx._tx_state._state.value}") - tx_writer: ydb.TopicTxWriterAsyncIO = driver.topic_client.tx_writer(tx, topic) - print(f"TX ID: {tx.tx_id}") - print(f"TX STATE: {tx._tx_state._state.value}") - for i in range(message_count): - result_stream = await tx.execute(query=f"select {i} as res") - messages = [result_set.rows[0]["res"] async for result_set in result_stream] - - await tx_writer.write([ydb.TopicWriterMessage(data=str(message)) for message in messages]) + def callee(tx: ydb.QueryTxContext): + tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) - print(f"Messages {messages} were written with tx.") + for i in range(message_count): + result_stream = tx.execute(query=f"select {i} as res;") + for result_set in result_stream: + message = str(result_set.rows[0]["res"]) + tx_writer.write(ydb.TopicWriterMessage(message)) + print(f"Message {message} was written with tx.") - await session_pool.retry_tx_async(callee) + session_pool.retry_tx_sync(callee) -async def read_with_tx_example(driver: ydb.aio.Driver, topic: str, consumer: str, message_count: int = 10): - async with driver.topic_client.reader(topic, consumer) as reader: - async with ydb.aio.QuerySessionPool(driver) as session_pool: +def read_with_tx_example(driver: ydb.Driver, topic: str, consumer: str, message_count: int = 10): + with driver.topic_client.reader(topic, consumer) as reader: + with ydb.QuerySessionPool(driver) as session_pool: for _ in range(message_count): - async def callee(tx: ydb.aio.QueryTxContext): - batch = await reader.receive_batch_with_tx(tx, max_messages=1) - print(f"Messages {batch.messages[0].data} were read with tx.") + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1) + print(f"Message {batch.messages[0].data.decode()} was read with tx.") - await session_pool.retry_tx_async(callee) + session_pool.retry_tx_sync(callee) -async def main(): +def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description="""YDB topic basic example.\n""", @@ -78,13 +73,13 @@ async def main(): logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler()) - driver = await connect(args.endpoint, args.database) - if not args.skip_drop_and_create_topic: - await create_topic(driver, args.path, args.consumer) + with connect(args.endpoint, args.database) as driver: + if not args.skip_drop_and_create_topic: + create_topic(driver, args.path, args.consumer) - await write_with_tx_example(driver, args.path) - await read_with_tx_example(driver, args.path, args.consumer) + write_with_tx_example(driver, args.path) + read_with_tx_example(driver, args.path, args.consumer) if __name__ == "__main__": - asyncio.run(main()) + main() diff --git a/tests/topics/test_topic_transactions.py b/tests/topics/test_topic_transactions.py index a45a565e..b79df740 100644 --- a/tests/topics/test_topic_transactions.py +++ b/tests/topics/test_topic_transactions.py @@ -357,6 +357,17 @@ async def callee(tx: ydb.aio.QueryTxContext): with pytest.raises(asyncio.TimeoutError): await wait_for(topic_reader.receive_message(), 0.1) + async def test_writes_do_not_conflict_with_executes(self, driver: ydb.aio.Driver, topic_path): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + for _ in range(3): + async with await tx.execute("select 1"): + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + class TestTopicTransactionalWriterSync: def test_commit(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader): @@ -445,3 +456,14 @@ def callee(tx: ydb.QueryTxContext): with pytest.raises(TimeoutError): topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + + def test_writes_do_not_conflict_with_executes(self, driver_sync: ydb.Driver, topic_path): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + for _ in range(3): + with tx.execute("select 1"): + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 6408234b..c9704d55 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -84,7 +84,7 @@ def __init__( ): self._loop = asyncio.get_running_loop() self._closed = False - self._reconnector = ReaderReconnector(driver, settings) + self._reconnector = ReaderReconnector(driver, settings, self._loop) self._parent = _parent async def __aenter__(self): @@ -190,10 +190,16 @@ class ReaderReconnector: _first_error: asyncio.Future[YdbError] _tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]] - def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): + def __init__( + self, + driver: Driver, + settings: topic_reader.PublicReaderSettings, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): self._id = self._static_reader_reconnector_counter.inc_and_get() self._settings = settings self._driver = driver + self._loop = loop if loop is not None else asyncio.get_running_loop() self._background_tasks = set() self._state_changed = asyncio.Event() @@ -201,7 +207,7 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): self._background_tasks.add(asyncio.create_task(self._connection_loop())) self._first_error = asyncio.get_running_loop().create_future() - self._tx_to_batches_map = defaultdict(list) + self._tx_to_batches_map = dict() async def _connection_loop(self): attempt = 0 @@ -254,22 +260,23 @@ def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: O max_messages=max_messages, ) - self._init_tx_if_needed(tx) + self._init_tx(tx) self._tx_to_batches_map[tx.tx_id].append(batch) - tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, None) # probably should be current loop + tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, self._loop) return batch def receive_message_nowait(self): return self._stream_reader.receive_message_nowait() - def _init_tx_if_needed(self, tx: "BaseQueryTxContext"): + def _init_tx(self, tx: "BaseQueryTxContext"): if tx.tx_id not in self._tx_to_batches_map: # Init tx callbacks - tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, None) - tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, None) - tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, None) + self._tx_to_batches_map[tx.tx_id] = [] + tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, self._loop) + tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, self._loop) + tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, self._loop) async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"): grouped_batches = defaultdict(lambda: defaultdict(list)) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index c86ada42..1ea6c250 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,7 +1,6 @@ import asyncio import concurrent.futures import datetime -import functools import gzip import typing from collections import deque @@ -186,6 +185,10 @@ def __init__( self._parent = _client self._is_implicit = _is_implicit + # For some reason, creating partition could conflict with other session operations. + # Could be removed later. + self._first_write = True + tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, self._loop) tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, self._loop) @@ -199,12 +202,14 @@ async def write( For wait with timeout use asyncio.wait_for. """ - await self.write_with_ack(messages) + if self._first_write: + self._first_write = False + return await super().write_with_ack(messages) + return await super().write(messages) async def _on_before_commit(self, tx: "BaseQueryTxContext"): if self._is_implicit: return - await self.flush() await self.close() async def _on_before_rollback(self, tx: "BaseQueryTxContext"):