Skip to content

Commit 2e6ad26

Browse files
committed
Ability to batch messages in topic reader
1 parent 8928b00 commit 2e6ad26

File tree

3 files changed

+51
-38
lines changed

3 files changed

+51
-38
lines changed

tests/topics/test_topic_reader.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,14 @@ async def test_read_and_commit_with_close_reader(self, driver, topic_with_messag
4040
assert message != message2
4141

4242
async def test_read_and_commit_with_ack(self, driver, topic_with_messages, topic_consumer):
43-
reader = driver.topic_client.reader(topic_with_messages, topic_consumer)
44-
batch = await reader.receive_batch()
45-
await reader.commit_with_ack(batch)
43+
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
44+
message = await reader.receive_message()
45+
await reader.commit_with_ack(message)
4646

47-
reader = driver.topic_client.reader(topic_with_messages, topic_consumer)
48-
batch2 = await reader.receive_batch()
49-
assert batch.messages[0] != batch2.messages[0]
47+
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
48+
batch = await reader.receive_batch()
5049

51-
await reader.close()
50+
assert message != batch.messages[0]
5251

5352
async def test_read_compressed_messages(self, driver, topic_path, topic_consumer):
5453
async with driver.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer:
@@ -147,12 +146,12 @@ def test_read_and_commit_with_close_reader(self, driver_sync, topic_with_message
147146

148147
def test_read_and_commit_with_ack(self, driver_sync, topic_with_messages, topic_consumer):
149148
reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer)
150-
batch = reader.receive_batch()
151-
reader.commit_with_ack(batch)
149+
message = reader.receive_message()
150+
reader.commit_with_ack(message)
152151

153152
reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer)
154-
batch2 = reader.receive_batch()
155-
assert batch.messages[0] != batch2.messages[0]
153+
batch = reader.receive_batch()
154+
assert message != batch.messages[0]
156155

157156
def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer):
158157
with driver_sync.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer:

tests/topics/test_topic_writer.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,9 @@ async def test_random_producer_id(self, driver: ydb.aio.Driver, topic_path, topi
4141
async with driver.topic_client.writer(topic_path) as writer:
4242
await writer.write(ydb.TopicWriterMessage(data="123".encode()))
4343

44-
batch1 = await topic_reader.receive_batch()
45-
batch2 = await topic_reader.receive_batch()
44+
batch = await topic_reader.receive_batch()
4645

47-
assert batch1.messages[0].producer_id != batch2.messages[0].producer_id
46+
assert batch.messages[0].producer_id != batch.messages[1].producer_id
4847

4948
async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path):
5049
async with driver.topic_client.writer(
@@ -83,12 +82,12 @@ async def test_write_multi_message_with_ack(
8382
assert batch.messages[0].seqno == 1
8483
assert batch.messages[0].data == "123".encode()
8584

86-
# remove second recieve batch when implement batching
87-
# https://github.com/ydb-platform/ydb-python-sdk/issues/142
88-
batch = await topic_reader.receive_batch()
89-
assert batch.messages[0].offset == 1
90-
assert batch.messages[0].seqno == 2
91-
assert batch.messages[0].data == "456".encode()
85+
# # remove second recieve batch when implement batching
86+
# # https://github.com/ydb-platform/ydb-python-sdk/issues/142
87+
# batch = await topic_reader.receive_batch()
88+
assert batch.messages[1].offset == 1
89+
assert batch.messages[1].seqno == 2
90+
assert batch.messages[1].data == "456".encode()
9291

9392
@pytest.mark.parametrize(
9493
"codec",
@@ -186,10 +185,9 @@ def test_random_producer_id(
186185
with driver_sync.topic_client.writer(topic_path) as writer:
187186
writer.write(ydb.TopicWriterMessage(data="123".encode()))
188187

189-
batch1 = topic_reader_sync.receive_batch()
190-
batch2 = topic_reader_sync.receive_batch()
188+
batch = topic_reader_sync.receive_batch()
191189

192-
assert batch1.messages[0].producer_id != batch2.messages[0].producer_id
190+
assert batch.messages[0].producer_id != batch.messages[1].producer_id
193191

194192
def test_write_multi_message_with_ack(
195193
self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader
@@ -210,10 +208,10 @@ def test_write_multi_message_with_ack(
210208

211209
# remove second recieve batch when implement batching
212210
# https://github.com/ydb-platform/ydb-python-sdk/issues/142
213-
batch = topic_reader_sync.receive_batch()
214-
assert batch.messages[0].offset == 1
215-
assert batch.messages[0].seqno == 2
216-
assert batch.messages[0].data == "456".encode()
211+
# batch = topic_reader_sync.receive_batch()
212+
assert batch.messages[1].offset == 1
213+
assert batch.messages[1].seqno == 2
214+
assert batch.messages[1].data == "456".encode()
217215

218216
@pytest.mark.parametrize(
219217
"codec",

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import asyncio
44
import concurrent.futures
55
import gzip
6+
import random
67
import typing
78
from asyncio import Task
8-
from collections import deque
99
from typing import Optional, Set, Dict, Union, Callable
1010

1111
import ydb
@@ -264,7 +264,7 @@ class ReaderStream:
264264

265265
_state_changed: asyncio.Event
266266
_closed: bool
267-
_message_batches: typing.Deque[datatypes.PublicBatch]
267+
_message_batches: typing.Dict[int, datatypes.PublicBatch]
268268
_first_error: asyncio.Future[YdbError]
269269

270270
_update_token_interval: Union[int, float]
@@ -296,7 +296,7 @@ def __init__(
296296
self._closed = False
297297
self._first_error = asyncio.get_running_loop().create_future()
298298
self._batches_to_decode = asyncio.Queue()
299-
self._message_batches = deque()
299+
self._message_batches = dict()
300300

301301
self._update_token_interval = settings.update_token_interval
302302
self._get_token_function = get_token_function
@@ -359,29 +359,36 @@ async def wait_messages(self):
359359
await self._state_changed.wait()
360360
self._state_changed.clear()
361361

362+
def _get_random_batch(self):
363+
rnd_id = random.choice(list(self._message_batches.keys()))
364+
return rnd_id, self._message_batches[rnd_id]
365+
362366
def receive_batch_nowait(self):
363367
if self._get_first_error():
364368
raise self._get_first_error()
365369

366370
if not self._message_batches:
367371
return None
368372

369-
batch = self._message_batches.popleft()
373+
part_sess_id, batch = self._get_random_batch()
370374
self._buffer_release_bytes(batch._bytes_size)
375+
del self._message_batches[part_sess_id]
376+
371377
return batch
372378

373379
def receive_message_nowait(self):
374380
if self._get_first_error():
375381
raise self._get_first_error()
376382

377-
try:
378-
batch = self._message_batches[0]
379-
message = batch.pop_message()
380-
except IndexError:
383+
if not self._message_batches:
381384
return None
382385

383-
if batch.empty():
384-
self.receive_batch_nowait()
386+
part_sess_id, batch = self._get_random_batch()
387+
388+
message = batch.messages.pop(0)
389+
if len(batch.messages) == 0:
390+
self._buffer_release_bytes(batch._bytes_size)
391+
del self._message_batches[part_sess_id]
385392

386393
return message
387394

@@ -605,9 +612,18 @@ async def _decode_batches_loop(self):
605612
while True:
606613
batch = await self._batches_to_decode.get()
607614
await self._decode_batch_inplace(batch)
608-
self._message_batches.append(batch)
615+
self._add_batch_to_queue(batch)
609616
self._state_changed.set()
610617

618+
def _add_batch_to_queue(self, batch: datatypes.PublicBatch):
619+
part_sess_id = batch._partition_session.id
620+
if part_sess_id in self._message_batches:
621+
self._message_batches[part_sess_id].messages.extend(batch.messages)
622+
self._message_batches[part_sess_id]._bytes_size += batch._bytes_size
623+
return
624+
625+
self._message_batches[part_sess_id] = batch
626+
611627
async def _decode_batch_inplace(self, batch):
612628
if batch._codec == Codec.CODEC_RAW:
613629
return

0 commit comments

Comments
 (0)