Skip to content

Commit 0a07253

Browse files
committed
tests fixes
1 parent 2e6ad26 commit 0a07253

File tree

3 files changed

+75
-65
lines changed

3 files changed

+75
-65
lines changed

tests/topics/test_topic_writer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import ydb.aio
99

10+
from ydb._topic_common.test_helpers import wait_condition
11+
1012

1113
@pytest.mark.asyncio
1214
class TestTopicWriterAsyncIO:
@@ -43,6 +45,10 @@ async def test_random_producer_id(self, driver: ydb.aio.Driver, topic_path, topi
4345

4446
batch = await topic_reader.receive_batch()
4547

48+
if len(batch.messages) == 1:
49+
batch2 = await topic_reader.receive_batch()
50+
batch.messages.extend(batch2.messages)
51+
4652
assert batch.messages[0].producer_id != batch.messages[1].producer_id
4753

4854
async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path):
@@ -201,14 +207,14 @@ def test_write_multi_message_with_ack(
201207
)
202208

203209
batch = topic_reader_sync.receive_batch()
210+
if len(batch.messages) == 1:
211+
batch2 = topic_reader_sync.receive_batch()
212+
batch.messages.extend(batch2.messages)
204213

205214
assert batch.messages[0].offset == 0
206215
assert batch.messages[0].seqno == 1
207216
assert batch.messages[0].data == "123".encode()
208217

209-
# remove second recieve batch when implement batching
210-
# https://github.com/ydb-platform/ydb-python-sdk/issues/142
211-
# batch = topic_reader_sync.receive_batch()
212218
assert batch.messages[1].offset == 1
213219
assert batch.messages[1].seqno == 2
214220
assert batch.messages[1].data == "456".encode()

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import random
77
import typing
88
from asyncio import Task
9+
from collections import OrderedDict
910
from typing import Optional, Set, Dict, Union, Callable
1011

1112
import ydb
@@ -296,7 +297,7 @@ def __init__(
296297
self._closed = False
297298
self._first_error = asyncio.get_running_loop().create_future()
298299
self._batches_to_decode = asyncio.Queue()
299-
self._message_batches = dict()
300+
self._message_batches = OrderedDict()
300301

301302
self._update_token_interval = settings.update_token_interval
302303
self._get_token_function = get_token_function
@@ -359,9 +360,9 @@ async def wait_messages(self):
359360
await self._state_changed.wait()
360361
self._state_changed.clear()
361362

362-
def _get_random_batch(self):
363-
rnd_id = random.choice(list(self._message_batches.keys()))
364-
return rnd_id, self._message_batches[rnd_id]
363+
def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
364+
first_id, batch = self._message_batches.popitem(last = False)
365+
return first_id, batch
365366

366367
def receive_batch_nowait(self):
367368
if self._get_first_error():
@@ -370,9 +371,8 @@ def receive_batch_nowait(self):
370371
if not self._message_batches:
371372
return None
372373

373-
part_sess_id, batch = self._get_random_batch()
374+
_, batch = self._get_first_batch()
374375
self._buffer_release_bytes(batch._bytes_size)
375-
del self._message_batches[part_sess_id]
376376

377377
return batch
378378

@@ -383,12 +383,15 @@ def receive_message_nowait(self):
383383
if not self._message_batches:
384384
return None
385385

386-
part_sess_id, batch = self._get_random_batch()
386+
part_sess_id, batch = self._get_first_batch()
387387

388388
message = batch.messages.pop(0)
389+
389390
if len(batch.messages) == 0:
390391
self._buffer_release_bytes(batch._bytes_size)
391-
del self._message_batches[part_sess_id]
392+
else:
393+
# TODO: we should somehow release bytes from single message as well
394+
self._message_batches[part_sess_id] = batch
392395

393396
return message
394397

ydb/_topic_reader/topic_reader_asyncio_test.py

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import datetime
55
import gzip
66
import typing
7-
from collections import deque
7+
from collections import OrderedDict
88
from dataclasses import dataclass
99
from unittest import mock
1010

@@ -52,9 +52,9 @@ def default_executor():
5252
executor.shutdown()
5353

5454

55-
def stub_partition_session():
55+
def stub_partition_session(id: int = 0):
5656
return datatypes.PartitionSession(
57-
id=0,
57+
id=id,
5858
state=datatypes.PartitionSession.State.Active,
5959
topic_path="asd",
6060
partition_id=1,
@@ -212,21 +212,27 @@ def create_message(
212212
_commit_end_offset=partition_session._next_message_start_commit_offset + offset_delta,
213213
)
214214

215-
async def send_message(self, stream_reader, message: PublicMessage):
216-
await self.send_batch(stream_reader, [message])
215+
async def send_message(self, stream_reader, message: PublicMessage, new_batch=True):
216+
await self.send_batch(stream_reader, [message], new_batch=new_batch)
217217

218-
async def send_batch(self, stream_reader, batch: typing.List[PublicMessage]):
218+
async def send_batch(self, stream_reader, batch: typing.List[PublicMessage], new_batch=True):
219219
if len(batch) == 0:
220220
return
221221

222222
first_message = batch[0]
223223
for message in batch:
224224
assert message._partition_session is first_message._partition_session
225225

226+
partition_session_id = first_message._partition_session.id
227+
226228
def batch_count():
227229
return len(stream_reader._message_batches)
228230

231+
def batch_size():
232+
return len(stream_reader._message_batches[partition_session_id].messages)
233+
229234
initial_batches = batch_count()
235+
initial_batch_size = batch_size() if not new_batch else 0
230236

231237
stream = stream_reader._stream # type: StreamMock
232238
stream.from_server.put_nowait(
@@ -261,7 +267,10 @@ def batch_count():
261267
),
262268
)
263269
)
264-
await wait_condition(lambda: batch_count() > initial_batches)
270+
if new_batch:
271+
await wait_condition(lambda: batch_count() > initial_batches)
272+
else:
273+
await wait_condition(lambda: batch_size() > initial_batch_size)
265274

266275
async def test_unknown_error(self, stream, stream_reader_finish_with_error):
267276
class TestError(Exception):
@@ -412,15 +421,11 @@ async def test_commit_ranges_for_received_messages(
412421
m2._commit_start_offset = m1.offset + 1
413422

414423
await self.send_message(stream_reader_started, m1)
415-
await self.send_message(stream_reader_started, m2)
416-
417-
await stream_reader_started.wait_messages()
418-
received = stream_reader_started.receive_batch_nowait().messages
419-
assert received == [m1]
424+
await self.send_message(stream_reader_started, m2, new_batch=False)
420425

421426
await stream_reader_started.wait_messages()
422427
received = stream_reader_started.receive_batch_nowait().messages
423-
assert received == [m2]
428+
assert received == [m1, m2]
424429

425430
await stream_reader_started.close(False)
426431

@@ -860,7 +865,7 @@ def reader_batch_count():
860865

861866
assert stream_reader._buffer_size_bytes == initial_buffer_size - bytes_size
862867

863-
last_batch = stream_reader._message_batches[-1]
868+
_, last_batch = stream_reader._message_batches.popitem()
864869
assert last_batch == PublicBatch(
865870
messages=[
866871
PublicMessage(
@@ -1059,74 +1064,74 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti
10591064
@pytest.mark.parametrize(
10601065
"batches_before,expected_message,batches_after",
10611066
[
1062-
([], None, []),
1067+
({}, None, {}),
10631068
(
1064-
[
1065-
PublicBatch(
1069+
{
1070+
0: PublicBatch(
10661071
messages=[stub_message(1)],
10671072
_partition_session=stub_partition_session(),
10681073
_bytes_size=0,
10691074
_codec=Codec.CODEC_RAW,
10701075
)
1071-
],
1076+
},
10721077
stub_message(1),
1073-
[],
1078+
{},
10741079
),
10751080
(
1076-
[
1077-
PublicBatch(
1081+
{
1082+
0: PublicBatch(
10781083
messages=[stub_message(1), stub_message(2)],
10791084
_partition_session=stub_partition_session(),
10801085
_bytes_size=0,
10811086
_codec=Codec.CODEC_RAW,
10821087
),
1083-
PublicBatch(
1088+
1: PublicBatch(
10841089
messages=[stub_message(3), stub_message(4)],
1085-
_partition_session=stub_partition_session(),
1090+
_partition_session=stub_partition_session(1),
10861091
_bytes_size=0,
10871092
_codec=Codec.CODEC_RAW,
10881093
),
1089-
],
1094+
},
10901095
stub_message(1),
1091-
[
1092-
PublicBatch(
1096+
{
1097+
0: PublicBatch(
10931098
messages=[stub_message(2)],
10941099
_partition_session=stub_partition_session(),
10951100
_bytes_size=0,
10961101
_codec=Codec.CODEC_RAW,
10971102
),
1098-
PublicBatch(
1103+
1: PublicBatch(
10991104
messages=[stub_message(3), stub_message(4)],
1100-
_partition_session=stub_partition_session(),
1105+
_partition_session=stub_partition_session(1),
11011106
_bytes_size=0,
11021107
_codec=Codec.CODEC_RAW,
11031108
),
1104-
],
1109+
},
11051110
),
11061111
(
1107-
[
1108-
PublicBatch(
1112+
{
1113+
0: PublicBatch(
11091114
messages=[stub_message(1)],
11101115
_partition_session=stub_partition_session(),
11111116
_bytes_size=0,
11121117
_codec=Codec.CODEC_RAW,
11131118
),
1114-
PublicBatch(
1119+
1: PublicBatch(
11151120
messages=[stub_message(2), stub_message(3)],
1116-
_partition_session=stub_partition_session(),
1121+
_partition_session=stub_partition_session(1),
11171122
_bytes_size=0,
11181123
_codec=Codec.CODEC_RAW,
11191124
),
1120-
],
1125+
},
11211126
stub_message(1),
1122-
[
1123-
PublicBatch(
1127+
{
1128+
1: PublicBatch(
11241129
messages=[stub_message(2), stub_message(3)],
1125-
_partition_session=stub_partition_session(),
1130+
_partition_session=stub_partition_session(1),
11261131
_bytes_size=0,
11271132
_codec=Codec.CODEC_RAW,
11281133
)
1129-
],
1134+
},
11301135
),
11311136
],
11321137
)
@@ -1137,11 +1142,11 @@ async def test_read_message(
11371142
expected_message: PublicMessage,
11381143
batches_after: typing.List[datatypes.PublicBatch],
11391144
):
1140-
stream_reader._message_batches = deque(batches_before)
1145+
stream_reader._message_batches = OrderedDict(batches_before)
11411146
mess = stream_reader.receive_message_nowait()
11421147

11431148
assert mess == expected_message
1144-
assert list(stream_reader._message_batches) == batches_after
1149+
assert dict(stream_reader._message_batches) == batches_after
11451150

11461151
async def test_receive_batch_nowait(self, stream, stream_reader, partition_session):
11471152
assert stream_reader.receive_batch_nowait() is None
@@ -1152,30 +1157,21 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi
11521157
await self.send_message(stream_reader, mess1)
11531158

11541159
mess2 = self.create_message(partition_session, 2, 1)
1155-
await self.send_message(stream_reader, mess2)
1160+
await self.send_message(stream_reader, mess2, new_batch=False)
11561161

11571162
assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size
11581163

11591164
received = stream_reader.receive_batch_nowait()
11601165
assert received == PublicBatch(
1161-
messages=[mess1],
1166+
messages=[mess1, mess2],
11621167
_partition_session=mess1._partition_session,
1163-
_bytes_size=self.default_batch_size,
1164-
_codec=Codec.CODEC_RAW,
1165-
)
1166-
1167-
received = stream_reader.receive_batch_nowait()
1168-
assert received == PublicBatch(
1169-
messages=[mess2],
1170-
_partition_session=mess2._partition_session,
1171-
_bytes_size=self.default_batch_size,
1168+
_bytes_size=self.default_batch_size * 2,
11721169
_codec=Codec.CODEC_RAW,
11731170
)
11741171

11751172
assert stream_reader._buffer_size_bytes == initial_buffer_size
11761173

1177-
assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message
1178-
assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message
1174+
assert StreamReadMessage.ReadRequest(self.default_batch_size * 2) == stream.from_client.get_nowait().client_message
11791175

11801176
with pytest.raises(asyncio.QueueEmpty):
11811177
stream.from_client.get_nowait()
@@ -1186,13 +1182,18 @@ async def test_receive_message_nowait(self, stream, stream_reader, partition_ses
11861182
initial_buffer_size = stream_reader._buffer_size_bytes
11871183

11881184
await self.send_batch(
1189-
stream_reader, [self.create_message(partition_session, 1, 1), self.create_message(partition_session, 2, 1)]
1185+
stream_reader,
1186+
[
1187+
self.create_message(partition_session, 1, 1),
1188+
self.create_message(partition_session, 2, 1),
1189+
],
11901190
)
11911191
await self.send_batch(
11921192
stream_reader,
11931193
[
11941194
self.create_message(partition_session, 10, 1),
11951195
],
1196+
new_batch=False,
11961197
)
11971198

11981199
assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size

0 commit comments

Comments
 (0)