Skip to content

Commit 157ab45

Browse files
authored
Merge pull request #494 from ydb-platform/batch_max_messages
Implement max_messages on recieve_batch
2 parents 3eedf72 + 10de36b commit 157ab45

File tree

4 files changed

+133
-8
lines changed

4 files changed

+133
-8
lines changed

ydb/_topic_reader/datatypes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,23 @@ def _extend(self, batch: PublicBatch) -> None:
179179
def _pop(self) -> Tuple[List[PublicMessage], bool]:
180180
msgs_left = True if len(self.messages) > 1 else False
181181
return self.messages.pop(0), msgs_left
182+
183+
def _pop_batch(self, message_count: int) -> PublicBatch:
184+
initial_length = len(self.messages)
185+
186+
if message_count >= initial_length:
187+
raise ValueError("Pop batch with size >= actual size is not supported.")
188+
189+
one_message_size = self._bytes_size // initial_length
190+
191+
new_batch = PublicBatch(
192+
messages=self.messages[:message_count],
193+
_partition_session=self._partition_session,
194+
_bytes_size=one_message_size * message_count,
195+
_codec=self._codec,
196+
)
197+
198+
self.messages = self.messages[message_count:]
199+
self._bytes_size = self._bytes_size - new_batch._bytes_size
200+
201+
return new_batch

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ async def wait_message(self):
9999

100100
async def receive_batch(
101101
self,
102+
max_messages: typing.Union[int, None] = None,
102103
) -> typing.Union[datatypes.PublicBatch, None]:
103104
"""
104105
Get one messages batch from reader.
@@ -107,7 +108,9 @@ async def receive_batch(
107108
use asyncio.wait_for for wait with timeout.
108109
"""
109110
await self._reconnector.wait_message()
110-
return self._reconnector.receive_batch_nowait()
111+
return self._reconnector.receive_batch_nowait(
112+
max_messages=max_messages,
113+
)
111114

112115
async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]:
113116
"""
@@ -214,8 +217,10 @@ async def wait_message(self):
214217
await self._state_changed.wait()
215218
self._state_changed.clear()
216219

217-
def receive_batch_nowait(self):
218-
return self._stream_reader.receive_batch_nowait()
220+
def receive_batch_nowait(self, max_messages: Optional[int] = None):
221+
return self._stream_reader.receive_batch_nowait(
222+
max_messages=max_messages,
223+
)
219224

220225
def receive_message_nowait(self):
221226
return self._stream_reader.receive_message_nowait()
@@ -383,17 +388,25 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
383388
partition_session_id, batch = self._message_batches.popitem(last=False)
384389
return partition_session_id, batch
385390

386-
def receive_batch_nowait(self):
391+
def receive_batch_nowait(self, max_messages: Optional[int] = None):
387392
if self._get_first_error():
388393
raise self._get_first_error()
389394

390395
if not self._message_batches:
391396
return None
392397

393-
_, batch = self._get_first_batch()
394-
self._buffer_release_bytes(batch._bytes_size)
398+
part_sess_id, batch = self._get_first_batch()
399+
400+
if max_messages is None or len(batch.messages) <= max_messages:
401+
self._buffer_release_bytes(batch._bytes_size)
402+
return batch
403+
404+
cutted_batch = batch._pop_batch(message_count=max_messages)
405+
406+
self._message_batches[part_sess_id] = batch
407+
self._buffer_release_bytes(cutted_batch._bytes_size)
395408

396-
return batch
409+
return cutted_batch
397410

398411
def receive_message_nowait(self):
399412
if self._get_first_error():

ydb/_topic_reader/topic_reader_asyncio_test.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,96 @@ async def test_read_message(
11481148
assert mess == expected_message
11491149
assert dict(stream_reader._message_batches) == batches_after
11501150

1151+
@pytest.mark.parametrize(
1152+
"batches_before,max_messages,actual_messages,batches_after",
1153+
[
1154+
(
1155+
{
1156+
0: PublicBatch(
1157+
messages=[stub_message(1)],
1158+
_partition_session=stub_partition_session(),
1159+
_bytes_size=4,
1160+
_codec=Codec.CODEC_RAW,
1161+
)
1162+
},
1163+
None,
1164+
1,
1165+
{},
1166+
),
1167+
(
1168+
{
1169+
0: PublicBatch(
1170+
messages=[stub_message(1), stub_message(2)],
1171+
_partition_session=stub_partition_session(),
1172+
_bytes_size=4,
1173+
_codec=Codec.CODEC_RAW,
1174+
),
1175+
1: PublicBatch(
1176+
messages=[stub_message(3), stub_message(4)],
1177+
_partition_session=stub_partition_session(1),
1178+
_bytes_size=4,
1179+
_codec=Codec.CODEC_RAW,
1180+
),
1181+
},
1182+
1,
1183+
1,
1184+
{
1185+
1: PublicBatch(
1186+
messages=[stub_message(3), stub_message(4)],
1187+
_partition_session=stub_partition_session(1),
1188+
_bytes_size=4,
1189+
_codec=Codec.CODEC_RAW,
1190+
),
1191+
0: PublicBatch(
1192+
messages=[stub_message(2)],
1193+
_partition_session=stub_partition_session(),
1194+
_bytes_size=2,
1195+
_codec=Codec.CODEC_RAW,
1196+
),
1197+
},
1198+
),
1199+
(
1200+
{
1201+
0: PublicBatch(
1202+
messages=[stub_message(1)],
1203+
_partition_session=stub_partition_session(),
1204+
_bytes_size=4,
1205+
_codec=Codec.CODEC_RAW,
1206+
),
1207+
1: PublicBatch(
1208+
messages=[stub_message(2), stub_message(3)],
1209+
_partition_session=stub_partition_session(1),
1210+
_bytes_size=4,
1211+
_codec=Codec.CODEC_RAW,
1212+
),
1213+
},
1214+
100,
1215+
1,
1216+
{
1217+
1: PublicBatch(
1218+
messages=[stub_message(2), stub_message(3)],
1219+
_partition_session=stub_partition_session(1),
1220+
_bytes_size=4,
1221+
_codec=Codec.CODEC_RAW,
1222+
)
1223+
},
1224+
),
1225+
],
1226+
)
1227+
async def test_read_batch_max_messages(
1228+
self,
1229+
stream_reader,
1230+
batches_before: typing.List[datatypes.PublicBatch],
1231+
max_messages: typing.Optional[int],
1232+
actual_messages: int,
1233+
batches_after: typing.List[datatypes.PublicBatch],
1234+
):
1235+
stream_reader._message_batches = OrderedDict(batches_before)
1236+
batch = stream_reader.receive_batch_nowait(max_messages=max_messages)
1237+
1238+
assert len(batch.messages) == actual_messages
1239+
assert stream_reader._message_batches == OrderedDict(batches_after)
1240+
11511241
async def test_receive_batch_nowait(self, stream, stream_reader, partition_session):
11521242
assert stream_reader.receive_batch_nowait() is None
11531243

ydb/_topic_reader/topic_reader_sync.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def receive_batch(
103103
self._check_closed()
104104

105105
return self._caller.safe_call_with_result(
106-
self._async_reader.receive_batch(),
106+
self._async_reader.receive_batch(
107+
max_messages=max_messages,
108+
),
107109
timeout,
108110
)
109111

0 commit comments

Comments
 (0)