@@ -99,6 +99,7 @@ async def wait_message(self):
99
99
100
100
async def receive_batch (
101
101
self ,
102
+ max_messages : typing .Union [int , None ] = None ,
102
103
) -> typing .Union [datatypes .PublicBatch , None ]:
103
104
"""
104
105
Get one messages batch from reader.
@@ -107,7 +108,9 @@ async def receive_batch(
107
108
use asyncio.wait_for for wait with timeout.
108
109
"""
109
110
await self ._reconnector .wait_message ()
110
- return self ._reconnector .receive_batch_nowait ()
111
+ return self ._reconnector .receive_batch_nowait (
112
+ max_messages = max_messages ,
113
+ )
111
114
112
115
async def receive_message (self ) -> typing .Optional [datatypes .PublicMessage ]:
113
116
"""
@@ -214,8 +217,10 @@ async def wait_message(self):
214
217
await self ._state_changed .wait ()
215
218
self ._state_changed .clear ()
216
219
217
- def receive_batch_nowait (self ):
218
- return self ._stream_reader .receive_batch_nowait ()
220
+ def receive_batch_nowait (self , max_messages : Optional [int ] = None ):
221
+ return self ._stream_reader .receive_batch_nowait (
222
+ max_messages = max_messages ,
223
+ )
219
224
220
225
def receive_message_nowait (self ):
221
226
return self ._stream_reader .receive_message_nowait ()
@@ -383,17 +388,44 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
383
388
partition_session_id , batch = self ._message_batches .popitem (last = False )
384
389
return partition_session_id , batch
385
390
386
- def receive_batch_nowait (self ):
391
+ def _cut_batch_by_max_messages (
392
+ batch : datatypes .PublicBatch ,
393
+ max_messages : int ,
394
+ ) -> typing .Tuple [datatypes .PublicBatch , datatypes .PublicBatch ]:
395
+ initial_length = len (batch .messages )
396
+ one_message_size = batch ._bytes_size // initial_length
397
+
398
+ new_batch = datatypes .PublicBatch (
399
+ messages = batch .messages [:max_messages ],
400
+ _partition_session = batch ._partition_session ,
401
+ _bytes_size = one_message_size * max_messages ,
402
+ _codec = batch ._codec ,
403
+ )
404
+
405
+ batch .messages = batch .messages [max_messages :]
406
+ batch ._bytes_size = one_message_size * (initial_length - max_messages )
407
+
408
+ return new_batch , batch
409
+
410
+ def receive_batch_nowait (self , max_messages : Optional [int ] = None ):
387
411
if self ._get_first_error ():
388
412
raise self ._get_first_error ()
389
413
390
414
if not self ._message_batches :
391
415
return None
392
416
393
- _ , batch = self ._get_first_batch ()
394
- self ._buffer_release_bytes (batch ._bytes_size )
417
+ part_sess_id , batch = self ._get_first_batch ()
418
+
419
+ if max_messages is None or len (batch .messages ) <= max_messages :
420
+ self ._buffer_release_bytes (batch ._bytes_size )
421
+ return batch
422
+
423
+ cutted_batch , remaining_batch = self ._cut_batch_by_max_messages (batch , max_messages )
424
+
425
+ self ._message_batches [part_sess_id ] = remaining_batch
426
+ self ._buffer_release_bytes (cutted_batch ._bytes_size )
395
427
396
- return batch
428
+ return cutted_batch
397
429
398
430
def receive_message_nowait (self ):
399
431
if self ._get_first_error ():
0 commit comments