diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 0f15ff85..01501638 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -179,3 +179,23 @@ def _extend(self, batch: PublicBatch) -> None: def _pop(self) -> Tuple[List[PublicMessage], bool]: msgs_left = True if len(self.messages) > 1 else False return self.messages.pop(0), msgs_left + + def _pop_batch(self, message_count: int) -> PublicBatch: + initial_length = len(self.messages) + + if message_count >= initial_length: + raise ValueError("Pop batch with size >= actual size is not supported.") + + one_message_size = self._bytes_size // initial_length + + new_batch = PublicBatch( + messages=self.messages[:message_count], + _partition_session=self._partition_session, + _bytes_size=one_message_size * message_count, + _codec=self._codec, + ) + + self.messages = self.messages[message_count:] + self._bytes_size = self._bytes_size - new_batch._bytes_size + + return new_batch diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 68ac5451..6833492d 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -99,6 +99,7 @@ async def wait_message(self): async def receive_batch( self, + max_messages: typing.Union[int, None] = None, ) -> typing.Union[datatypes.PublicBatch, None]: """ Get one messages batch from reader. @@ -107,7 +108,9 @@ async def receive_batch( use asyncio.wait_for for wait with timeout. """ await self._reconnector.wait_message() - return self._reconnector.receive_batch_nowait() + return self._reconnector.receive_batch_nowait( + max_messages=max_messages, + ) async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]: """ @@ -214,8 +217,10 @@ async def wait_message(self): await self._state_changed.wait() self._state_changed.clear() - def receive_batch_nowait(self): - return self._stream_reader.receive_batch_nowait() + def receive_batch_nowait(self, max_messages: Optional[int] = None): + return self._stream_reader.receive_batch_nowait( + max_messages=max_messages, + ) def receive_message_nowait(self): return self._stream_reader.receive_message_nowait() @@ -383,17 +388,25 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]: partition_session_id, batch = self._message_batches.popitem(last=False) return partition_session_id, batch - def receive_batch_nowait(self): + def receive_batch_nowait(self, max_messages: Optional[int] = None): if self._get_first_error(): raise self._get_first_error() if not self._message_batches: return None - _, batch = self._get_first_batch() - self._buffer_release_bytes(batch._bytes_size) + part_sess_id, batch = self._get_first_batch() + + if max_messages is None or len(batch.messages) <= max_messages: + self._buffer_release_bytes(batch._bytes_size) + return batch + + cutted_batch = batch._pop_batch(message_count=max_messages) + + self._message_batches[part_sess_id] = batch + self._buffer_release_bytes(cutted_batch._bytes_size) - return batch + return cutted_batch def receive_message_nowait(self): if self._get_first_error(): diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 77bf57c3..4c76cd1d 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1148,6 +1148,96 @@ async def test_read_message( assert mess == expected_message assert dict(stream_reader._message_batches) == batches_after + @pytest.mark.parametrize( + "batches_before,max_messages,actual_messages,batches_after", + [ + ( + { + 0: PublicBatch( + messages=[stub_message(1)], + _partition_session=stub_partition_session(), + _bytes_size=4, + _codec=Codec.CODEC_RAW, + ) + }, + None, + 1, + {}, + ), + ( + { + 0: PublicBatch( + messages=[stub_message(1), stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=4, + _codec=Codec.CODEC_RAW, + ), + 1: PublicBatch( + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(1), + _bytes_size=4, + _codec=Codec.CODEC_RAW, + ), + }, + 1, + 1, + { + 1: PublicBatch( + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(1), + _bytes_size=4, + _codec=Codec.CODEC_RAW, + ), + 0: PublicBatch( + messages=[stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=2, + _codec=Codec.CODEC_RAW, + ), + }, + ), + ( + { + 0: PublicBatch( + messages=[stub_message(1)], + _partition_session=stub_partition_session(), + _bytes_size=4, + _codec=Codec.CODEC_RAW, + ), + 1: PublicBatch( + messages=[stub_message(2), stub_message(3)], + _partition_session=stub_partition_session(1), + _bytes_size=4, + _codec=Codec.CODEC_RAW, + ), + }, + 100, + 1, + { + 1: PublicBatch( + messages=[stub_message(2), stub_message(3)], + _partition_session=stub_partition_session(1), + _bytes_size=4, + _codec=Codec.CODEC_RAW, + ) + }, + ), + ], + ) + async def test_read_batch_max_messages( + self, + stream_reader, + batches_before: typing.List[datatypes.PublicBatch], + max_messages: typing.Optional[int], + actual_messages: int, + batches_after: typing.List[datatypes.PublicBatch], + ): + stream_reader._message_batches = OrderedDict(batches_before) + batch = stream_reader.receive_batch_nowait(max_messages=max_messages) + + assert len(batch.messages) == actual_messages + assert stream_reader._message_batches == OrderedDict(batches_after) + async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): assert stream_reader.receive_batch_nowait() is None diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index c266de82..3048d3c4 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -103,7 +103,9 @@ def receive_batch( self._check_closed() return self._caller.safe_call_with_result( - self._async_reader.receive_batch(), + self._async_reader.receive_batch( + max_messages=max_messages, + ), timeout, )