Skip to content

Commit e0f4e4b

Browse files
committed
Implement max_messages on recieve_batch
1 parent 60a4504 commit e0f4e4b

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ async def wait_message(self):
9797

9898
async def receive_batch(
9999
self,
100+
max_messages: typing.Union[int, None] = None,
100101
) -> typing.Union[datatypes.PublicBatch, None]:
101102
"""
102103
Get one messages batch from reader.
@@ -105,7 +106,9 @@ async def receive_batch(
105106
use asyncio.wait_for for wait with timeout.
106107
"""
107108
await self._reconnector.wait_message()
108-
return self._reconnector.receive_batch_nowait()
109+
return self._reconnector.receive_batch_nowait(
110+
max_messages=max_messages,
111+
)
109112

110113
async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]:
111114
"""
@@ -212,8 +215,10 @@ async def wait_message(self):
212215
await self._state_changed.wait()
213216
self._state_changed.clear()
214217

215-
def receive_batch_nowait(self):
216-
return self._stream_reader.receive_batch_nowait()
218+
def receive_batch_nowait(self, max_messages: Optional[int] = None):
219+
return self._stream_reader.receive_batch_nowait(
220+
max_messages=max_messages,
221+
)
217222

218223
def receive_message_nowait(self):
219224
return self._stream_reader.receive_message_nowait()
@@ -363,17 +368,44 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
363368
first_id, batch = self._message_batches.popitem(last=False)
364369
return first_id, batch
365370

366-
def receive_batch_nowait(self):
371+
def _cut_batch_by_max_messages(
372+
batch: datatypes.PublicBatch,
373+
max_messages: int,
374+
) -> typing.Tuple[datatypes.PublicBatch, datatypes.PublicBatch]:
375+
initial_length = len(batch.messages)
376+
one_message_size = batch._bytes_size // initial_length
377+
378+
new_batch = datatypes.PublicBatch(
379+
messages=batch.messages[:max_messages],
380+
_partition_session=batch._partition_session,
381+
_bytes_size=one_message_size*max_messages,
382+
_codec=batch._codec,
383+
)
384+
385+
batch.messages = batch.messages[max_messages:]
386+
batch._bytes_size = one_message_size * (initial_length - max_messages)
387+
388+
return new_batch, batch
389+
390+
def receive_batch_nowait(self, max_messages: Optional[int] = None):
367391
if self._get_first_error():
368392
raise self._get_first_error()
369393

370394
if not self._message_batches:
371395
return None
372396

373-
_, batch = self._get_first_batch()
374-
self._buffer_release_bytes(batch._bytes_size)
397+
part_sess_id, batch = self._get_first_batch()
398+
399+
if max_messages is None or len(batch.messages) <= max_messages:
400+
self._buffer_release_bytes(batch._bytes_size)
401+
return batch
402+
403+
cutted_batch, remaining_batch = self._cut_batch_by_max_messages(batch, max_messages)
404+
405+
self._message_batches[part_sess_id] = remaining_batch
406+
self._buffer_release_bytes(cutted_batch._bytes_size)
375407

376-
return batch
408+
return cutted_batch
377409

378410
def receive_message_nowait(self):
379411
if self._get_first_error():

ydb/_topic_reader/topic_reader_sync.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def receive_batch(
103103
self._check_closed()
104104

105105
return self._caller.safe_call_with_result(
106-
self._async_reader.receive_batch(),
106+
self._async_reader.receive_batch(
107+
max_messages=max_messages,
108+
),
107109
timeout,
108110
)
109111

0 commit comments

Comments
 (0)