diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 362be059..74c8bccd 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -40,15 +40,14 @@ async def test_read_and_commit_with_close_reader(self, driver, topic_with_messag assert message != message2 async def test_read_and_commit_with_ack(self, driver, topic_with_messages, topic_consumer): - reader = driver.topic_client.reader(topic_with_messages, topic_consumer) - batch = await reader.receive_batch() - await reader.commit_with_ack(batch) + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + message = await reader.receive_message() + await reader.commit_with_ack(message) - reader = driver.topic_client.reader(topic_with_messages, topic_consumer) - batch2 = await reader.receive_batch() - assert batch.messages[0] != batch2.messages[0] + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + batch = await reader.receive_batch() - await reader.close() + assert message != batch.messages[0] async def test_read_compressed_messages(self, driver, topic_path, topic_consumer): async with driver.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer: @@ -147,12 +146,12 @@ def test_read_and_commit_with_close_reader(self, driver_sync, topic_with_message def test_read_and_commit_with_ack(self, driver_sync, topic_with_messages, topic_consumer): reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer) - batch = reader.receive_batch() - reader.commit_with_ack(batch) + message = reader.receive_message() + reader.commit_with_ack(message) reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer) - batch2 = reader.receive_batch() - assert batch.messages[0] != batch2.messages[0] + batch = reader.receive_batch() + assert message != batch.messages[0] def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer): with driver_sync.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer: diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 3817e34d..a3cdbe9d 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -41,10 +41,10 @@ async def test_random_producer_id(self, driver: ydb.aio.Driver, topic_path, topi async with driver.topic_client.writer(topic_path) as writer: await writer.write(ydb.TopicWriterMessage(data="123".encode())) - batch1 = await topic_reader.receive_batch() - batch2 = await topic_reader.receive_batch() + msg1 = await topic_reader.receive_message() + msg2 = await topic_reader.receive_message() - assert batch1.messages[0].producer_id != batch2.messages[0].producer_id + assert msg1.producer_id != msg2.producer_id async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): async with driver.topic_client.writer( @@ -77,18 +77,16 @@ async def test_write_multi_message_with_ack( assert res1.offset == 0 assert res2.offset == 1 - batch = await topic_reader.receive_batch() + msg1 = await topic_reader.receive_message() + msg2 = await topic_reader.receive_message() - assert batch.messages[0].offset == 0 - assert batch.messages[0].seqno == 1 - assert batch.messages[0].data == "123".encode() + assert msg1.offset == 0 + assert msg1.seqno == 1 + assert msg1.data == "123".encode() - # remove second recieve batch when implement batching - # https://github.com/ydb-platform/ydb-python-sdk/issues/142 - batch = await topic_reader.receive_batch() - assert batch.messages[0].offset == 1 - assert batch.messages[0].seqno == 2 - assert batch.messages[0].data == "456".encode() + assert msg2.offset == 1 + assert msg2.seqno == 2 + assert msg2.data == "456".encode() @pytest.mark.parametrize( "codec", @@ -186,10 +184,10 @@ def test_random_producer_id( with driver_sync.topic_client.writer(topic_path) as writer: writer.write(ydb.TopicWriterMessage(data="123".encode())) - batch1 = topic_reader_sync.receive_batch() - batch2 = topic_reader_sync.receive_batch() + msg1 = topic_reader_sync.receive_message() + msg2 = topic_reader_sync.receive_message() - assert batch1.messages[0].producer_id != batch2.messages[0].producer_id + assert msg1.producer_id != msg2.producer_id def test_write_multi_message_with_ack( self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader @@ -202,18 +200,16 @@ def test_write_multi_message_with_ack( ] ) - batch = topic_reader_sync.receive_batch() + msg1 = topic_reader_sync.receive_message() + msg2 = topic_reader_sync.receive_message() - assert batch.messages[0].offset == 0 - assert batch.messages[0].seqno == 1 - assert batch.messages[0].data == "123".encode() + assert msg1.offset == 0 + assert msg1.seqno == 1 + assert msg1.data == "123".encode() - # remove second recieve batch when implement batching - # https://github.com/ydb-platform/ydb-python-sdk/issues/142 - batch = topic_reader_sync.receive_batch() - assert batch.messages[0].offset == 1 - assert batch.messages[0].seqno == 2 - assert batch.messages[0].data == "456".encode() + assert msg2.offset == 1 + assert msg2.seqno == 2 + assert msg2.data == "456".encode() @pytest.mark.parametrize( "codec", diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 28155ea7..0f15ff85 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -7,7 +7,7 @@ from collections import deque from dataclasses import dataclass, field import datetime -from typing import Union, Any, List, Dict, Deque, Optional +from typing import Union, Any, List, Dict, Deque, Optional, Tuple from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange, Codec from ydb._topic_reader import topic_reader_asyncio @@ -171,3 +171,11 @@ def alive(self) -> bool: def pop_message(self) -> PublicMessage: return self.messages.pop(0) + + def _extend(self, batch: PublicBatch) -> None: + self.messages.extend(batch.messages) + self._bytes_size += batch._bytes_size + + def _pop(self) -> Tuple[List[PublicMessage], bool]: + msgs_left = True if len(self.messages) > 1 else False + return self.messages.pop(0), msgs_left diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 81c6d9f4..92cd78c2 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -5,7 +5,7 @@ import gzip import typing from asyncio import Task -from collections import deque +from collections import OrderedDict from typing import Optional, Set, Dict, Union, Callable import ydb @@ -264,7 +264,7 @@ class ReaderStream: _state_changed: asyncio.Event _closed: bool - _message_batches: typing.Deque[datatypes.PublicBatch] + _message_batches: typing.Dict[int, datatypes.PublicBatch] # keys are partition session ID _first_error: asyncio.Future[YdbError] _update_token_interval: Union[int, float] @@ -296,7 +296,7 @@ def __init__( self._closed = False self._first_error = asyncio.get_running_loop().create_future() self._batches_to_decode = asyncio.Queue() - self._message_batches = deque() + self._message_batches = OrderedDict() self._update_token_interval = settings.update_token_interval self._get_token_function = get_token_function @@ -359,6 +359,10 @@ async def wait_messages(self): await self._state_changed.wait() self._state_changed.clear() + def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]: + partition_session_id, batch = self._message_batches.popitem(last=False) + return partition_session_id, batch + def receive_batch_nowait(self): if self._get_first_error(): raise self._get_first_error() @@ -366,22 +370,27 @@ def receive_batch_nowait(self): if not self._message_batches: return None - batch = self._message_batches.popleft() + _, batch = self._get_first_batch() self._buffer_release_bytes(batch._bytes_size) + return batch def receive_message_nowait(self): if self._get_first_error(): raise self._get_first_error() - try: - batch = self._message_batches[0] - message = batch.pop_message() - except IndexError: + if not self._message_batches: return None - if batch.empty(): - self.receive_batch_nowait() + part_sess_id, batch = self._get_first_batch() + + message, msgs_left = batch._pop() + + if not msgs_left: + self._buffer_release_bytes(batch._bytes_size) + else: + # TODO: we should somehow release bytes from single message as well + self._message_batches[part_sess_id] = batch return message @@ -605,9 +614,17 @@ async def _decode_batches_loop(self): while True: batch = await self._batches_to_decode.get() await self._decode_batch_inplace(batch) - self._message_batches.append(batch) + self._add_batch_to_queue(batch) self._state_changed.set() + def _add_batch_to_queue(self, batch: datatypes.PublicBatch): + part_sess_id = batch._partition_session.id + if part_sess_id in self._message_batches: + self._message_batches[part_sess_id]._extend(batch) + return + + self._message_batches[part_sess_id] = batch + async def _decode_batch_inplace(self, batch): if batch._codec == Codec.CODEC_RAW: return diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 9af91b1b..77bf57c3 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -4,7 +4,7 @@ import datetime import gzip import typing -from collections import deque +from collections import OrderedDict from dataclasses import dataclass from unittest import mock @@ -52,9 +52,9 @@ def default_executor(): executor.shutdown() -def stub_partition_session(): +def stub_partition_session(id: int = 0): return datatypes.PartitionSession( - id=0, + id=id, state=datatypes.PartitionSession.State.Active, topic_path="asd", partition_id=1, @@ -212,10 +212,10 @@ def create_message( _commit_end_offset=partition_session._next_message_start_commit_offset + offset_delta, ) - async def send_message(self, stream_reader, message: PublicMessage): - await self.send_batch(stream_reader, [message]) + async def send_message(self, stream_reader, message: PublicMessage, new_batch=True): + await self.send_batch(stream_reader, [message], new_batch=new_batch) - async def send_batch(self, stream_reader, batch: typing.List[PublicMessage]): + async def send_batch(self, stream_reader, batch: typing.List[PublicMessage], new_batch=True): if len(batch) == 0: return @@ -223,10 +223,16 @@ async def send_batch(self, stream_reader, batch: typing.List[PublicMessage]): for message in batch: assert message._partition_session is first_message._partition_session + partition_session_id = first_message._partition_session.id + def batch_count(): return len(stream_reader._message_batches) + def batch_size(): + return len(stream_reader._message_batches[partition_session_id].messages) + initial_batches = batch_count() + initial_batch_size = 0 if new_batch else batch_size() stream = stream_reader._stream # type: StreamMock stream.from_server.put_nowait( @@ -261,7 +267,10 @@ def batch_count(): ), ) ) - await wait_condition(lambda: batch_count() > initial_batches) + if new_batch: + await wait_condition(lambda: batch_count() > initial_batches) + else: + await wait_condition(lambda: batch_size() > initial_batch_size) async def test_unknown_error(self, stream, stream_reader_finish_with_error): class TestError(Exception): @@ -412,15 +421,11 @@ async def test_commit_ranges_for_received_messages( m2._commit_start_offset = m1.offset + 1 await self.send_message(stream_reader_started, m1) - await self.send_message(stream_reader_started, m2) - - await stream_reader_started.wait_messages() - received = stream_reader_started.receive_batch_nowait().messages - assert received == [m1] + await self.send_message(stream_reader_started, m2, new_batch=False) await stream_reader_started.wait_messages() received = stream_reader_started.receive_batch_nowait().messages - assert received == [m2] + assert received == [m1, m2] await stream_reader_started.close(False) @@ -860,7 +865,7 @@ def reader_batch_count(): assert stream_reader._buffer_size_bytes == initial_buffer_size - bytes_size - last_batch = stream_reader._message_batches[-1] + _, last_batch = stream_reader._message_batches.popitem() assert last_batch == PublicBatch( messages=[ PublicMessage( @@ -1059,74 +1064,74 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti @pytest.mark.parametrize( "batches_before,expected_message,batches_after", [ - ([], None, []), + ({}, None, {}), ( - [ - PublicBatch( + { + 0: PublicBatch( messages=[stub_message(1)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ) - ], + }, stub_message(1), - [], + {}, ), ( - [ - PublicBatch( + { + 0: PublicBatch( messages=[stub_message(1), stub_message(2)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - PublicBatch( + 1: PublicBatch( messages=[stub_message(3), stub_message(4)], - _partition_session=stub_partition_session(), + _partition_session=stub_partition_session(1), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - ], + }, stub_message(1), - [ - PublicBatch( + { + 0: PublicBatch( messages=[stub_message(2)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - PublicBatch( + 1: PublicBatch( messages=[stub_message(3), stub_message(4)], - _partition_session=stub_partition_session(), + _partition_session=stub_partition_session(1), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - ], + }, ), ( - [ - PublicBatch( + { + 0: PublicBatch( messages=[stub_message(1)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - PublicBatch( + 1: PublicBatch( messages=[stub_message(2), stub_message(3)], - _partition_session=stub_partition_session(), + _partition_session=stub_partition_session(1), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - ], + }, stub_message(1), - [ - PublicBatch( + { + 1: PublicBatch( messages=[stub_message(2), stub_message(3)], - _partition_session=stub_partition_session(), + _partition_session=stub_partition_session(1), _bytes_size=0, _codec=Codec.CODEC_RAW, ) - ], + }, ), ], ) @@ -1137,11 +1142,11 @@ async def test_read_message( expected_message: PublicMessage, batches_after: typing.List[datatypes.PublicBatch], ): - stream_reader._message_batches = deque(batches_before) + stream_reader._message_batches = OrderedDict(batches_before) mess = stream_reader.receive_message_nowait() assert mess == expected_message - assert list(stream_reader._message_batches) == batches_after + assert dict(stream_reader._message_batches) == batches_after async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): assert stream_reader.receive_batch_nowait() is None @@ -1152,30 +1157,23 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi await self.send_message(stream_reader, mess1) mess2 = self.create_message(partition_session, 2, 1) - await self.send_message(stream_reader, mess2) + await self.send_message(stream_reader, mess2, new_batch=False) assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size received = stream_reader.receive_batch_nowait() assert received == PublicBatch( - messages=[mess1], + messages=[mess1, mess2], _partition_session=mess1._partition_session, - _bytes_size=self.default_batch_size, - _codec=Codec.CODEC_RAW, - ) - - received = stream_reader.receive_batch_nowait() - assert received == PublicBatch( - messages=[mess2], - _partition_session=mess2._partition_session, - _bytes_size=self.default_batch_size, + _bytes_size=self.default_batch_size * 2, _codec=Codec.CODEC_RAW, ) assert stream_reader._buffer_size_bytes == initial_buffer_size - assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message - assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message + assert ( + StreamReadMessage.ReadRequest(self.default_batch_size * 2) == stream.from_client.get_nowait().client_message + ) with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() @@ -1186,13 +1184,18 @@ async def test_receive_message_nowait(self, stream, stream_reader, partition_ses initial_buffer_size = stream_reader._buffer_size_bytes await self.send_batch( - stream_reader, [self.create_message(partition_session, 1, 1), self.create_message(partition_session, 2, 1)] + stream_reader, + [ + self.create_message(partition_session, 1, 1), + self.create_message(partition_session, 2, 1), + ], ) await self.send_batch( stream_reader, [ self.create_message(partition_session, 10, 1), ], + new_batch=False, ) assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size