Skip to content

Commit b2e88b5

Browse files
committed
Implement max_messages on recieve_batch
1 parent 4877cc8 commit b2e88b5

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ async def wait_message(self):
9999

100100
async def receive_batch(
101101
self,
102+
max_messages: typing.Union[int, None] = None,
102103
) -> typing.Union[datatypes.PublicBatch, None]:
103104
"""
104105
Get one messages batch from reader.
@@ -107,7 +108,9 @@ async def receive_batch(
107108
use asyncio.wait_for for wait with timeout.
108109
"""
109110
await self._reconnector.wait_message()
110-
return self._reconnector.receive_batch_nowait()
111+
return self._reconnector.receive_batch_nowait(
112+
max_messages=max_messages,
113+
)
111114

112115
async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]:
113116
"""
@@ -214,8 +217,10 @@ async def wait_message(self):
214217
await self._state_changed.wait()
215218
self._state_changed.clear()
216219

217-
def receive_batch_nowait(self):
218-
return self._stream_reader.receive_batch_nowait()
220+
def receive_batch_nowait(self, max_messages: Optional[int] = None):
221+
return self._stream_reader.receive_batch_nowait(
222+
max_messages=max_messages,
223+
)
219224

220225
def receive_message_nowait(self):
221226
return self._stream_reader.receive_message_nowait()
@@ -383,17 +388,44 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
383388
partition_session_id, batch = self._message_batches.popitem(last=False)
384389
return partition_session_id, batch
385390

386-
def receive_batch_nowait(self):
391+
def _cut_batch_by_max_messages(
392+
batch: datatypes.PublicBatch,
393+
max_messages: int,
394+
) -> typing.Tuple[datatypes.PublicBatch, datatypes.PublicBatch]:
395+
initial_length = len(batch.messages)
396+
one_message_size = batch._bytes_size // initial_length
397+
398+
new_batch = datatypes.PublicBatch(
399+
messages=batch.messages[:max_messages],
400+
_partition_session=batch._partition_session,
401+
_bytes_size=one_message_size*max_messages,
402+
_codec=batch._codec,
403+
)
404+
405+
batch.messages = batch.messages[max_messages:]
406+
batch._bytes_size = one_message_size * (initial_length - max_messages)
407+
408+
return new_batch, batch
409+
410+
def receive_batch_nowait(self, max_messages: Optional[int] = None):
387411
if self._get_first_error():
388412
raise self._get_first_error()
389413

390414
if not self._message_batches:
391415
return None
392416

393-
_, batch = self._get_first_batch()
394-
self._buffer_release_bytes(batch._bytes_size)
417+
part_sess_id, batch = self._get_first_batch()
418+
419+
if max_messages is None or len(batch.messages) <= max_messages:
420+
self._buffer_release_bytes(batch._bytes_size)
421+
return batch
422+
423+
cutted_batch, remaining_batch = self._cut_batch_by_max_messages(batch, max_messages)
424+
425+
self._message_batches[part_sess_id] = remaining_batch
426+
self._buffer_release_bytes(cutted_batch._bytes_size)
395427

396-
return batch
428+
return cutted_batch
397429

398430
def receive_message_nowait(self):
399431
if self._get_first_error():

ydb/_topic_reader/topic_reader_sync.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def receive_batch(
103103
self._check_closed()
104104

105105
return self._caller.safe_call_with_result(
106-
self._async_reader.receive_batch(),
106+
self._async_reader.receive_batch(
107+
max_messages=max_messages,
108+
),
107109
timeout,
108110
)
109111

0 commit comments

Comments
 (0)