@@ -97,6 +97,7 @@ async def wait_message(self):
97
97
98
98
async def receive_batch (
99
99
self ,
100
+ max_messages : typing .Union [int , None ] = None ,
100
101
) -> typing .Union [datatypes .PublicBatch , None ]:
101
102
"""
102
103
Get one messages batch from reader.
@@ -105,7 +106,9 @@ async def receive_batch(
105
106
use asyncio.wait_for for wait with timeout.
106
107
"""
107
108
await self ._reconnector .wait_message ()
108
- return self ._reconnector .receive_batch_nowait ()
109
+ return self ._reconnector .receive_batch_nowait (
110
+ max_messages = max_messages ,
111
+ )
109
112
110
113
async def receive_message (self ) -> typing .Optional [datatypes .PublicMessage ]:
111
114
"""
@@ -212,8 +215,10 @@ async def wait_message(self):
212
215
await self ._state_changed .wait ()
213
216
self ._state_changed .clear ()
214
217
215
- def receive_batch_nowait (self ):
216
- return self ._stream_reader .receive_batch_nowait ()
218
+ def receive_batch_nowait (self , max_messages : Optional [int ] = None ):
219
+ return self ._stream_reader .receive_batch_nowait (
220
+ max_messages = max_messages ,
221
+ )
217
222
218
223
def receive_message_nowait (self ):
219
224
return self ._stream_reader .receive_message_nowait ()
@@ -363,17 +368,44 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
363
368
first_id , batch = self ._message_batches .popitem (last = False )
364
369
return first_id , batch
365
370
366
- def receive_batch_nowait (self ):
371
+ def _cut_batch_by_max_messages (
372
+ batch : datatypes .PublicBatch ,
373
+ max_messages : int ,
374
+ ) -> typing .Tuple [datatypes .PublicBatch , datatypes .PublicBatch ]:
375
+ initial_length = len (batch .messages )
376
+ one_message_size = batch ._bytes_size // initial_length
377
+
378
+ new_batch = datatypes .PublicBatch (
379
+ messages = batch .messages [:max_messages ],
380
+ _partition_session = batch ._partition_session ,
381
+ _bytes_size = one_message_size * max_messages ,
382
+ _codec = batch ._codec ,
383
+ )
384
+
385
+ batch .messages = batch .messages [max_messages :]
386
+ batch ._bytes_size = one_message_size * (initial_length - max_messages )
387
+
388
+ return new_batch , batch
389
+
390
+ def receive_batch_nowait (self , max_messages : Optional [int ] = None ):
367
391
if self ._get_first_error ():
368
392
raise self ._get_first_error ()
369
393
370
394
if not self ._message_batches :
371
395
return None
372
396
373
- _ , batch = self ._get_first_batch ()
374
- self ._buffer_release_bytes (batch ._bytes_size )
397
+ part_sess_id , batch = self ._get_first_batch ()
398
+
399
+ if max_messages is None or len (batch .messages ) <= max_messages :
400
+ self ._buffer_release_bytes (batch ._bytes_size )
401
+ return batch
402
+
403
+ cutted_batch , remaining_batch = self ._cut_batch_by_max_messages (batch , max_messages )
404
+
405
+ self ._message_batches [part_sess_id ] = remaining_batch
406
+ self ._buffer_release_bytes (cutted_batch ._bytes_size )
375
407
376
- return batch
408
+ return cutted_batch
377
409
378
410
def receive_message_nowait (self ):
379
411
if self ._get_first_error ():
0 commit comments