@@ -1148,6 +1148,96 @@ async def test_read_message(
1148
1148
assert mess == expected_message
1149
1149
assert dict (stream_reader ._message_batches ) == batches_after
1150
1150
1151
+ @pytest .mark .parametrize (
1152
+ "batches_before,max_messages,actual_messages,batches_after" ,
1153
+ [
1154
+ (
1155
+ {
1156
+ 0 : PublicBatch (
1157
+ messages = [stub_message (1 )],
1158
+ _partition_session = stub_partition_session (),
1159
+ _bytes_size = 0 ,
1160
+ _codec = Codec .CODEC_RAW ,
1161
+ )
1162
+ },
1163
+ None ,
1164
+ 1 ,
1165
+ {},
1166
+ ),
1167
+ (
1168
+ {
1169
+ 0 : PublicBatch (
1170
+ messages = [stub_message (1 ), stub_message (2 )],
1171
+ _partition_session = stub_partition_session (),
1172
+ _bytes_size = 0 ,
1173
+ _codec = Codec .CODEC_RAW ,
1174
+ ),
1175
+ 1 : PublicBatch (
1176
+ messages = [stub_message (3 ), stub_message (4 )],
1177
+ _partition_session = stub_partition_session (1 ),
1178
+ _bytes_size = 0 ,
1179
+ _codec = Codec .CODEC_RAW ,
1180
+ ),
1181
+ },
1182
+ 1 ,
1183
+ 1 ,
1184
+ {
1185
+ 1 : PublicBatch (
1186
+ messages = [stub_message (3 ), stub_message (4 )],
1187
+ _partition_session = stub_partition_session (1 ),
1188
+ _bytes_size = 0 ,
1189
+ _codec = Codec .CODEC_RAW ,
1190
+ ),
1191
+ 0 : PublicBatch (
1192
+ messages = [stub_message (2 )],
1193
+ _partition_session = stub_partition_session (),
1194
+ _bytes_size = 0 ,
1195
+ _codec = Codec .CODEC_RAW ,
1196
+ ),
1197
+ },
1198
+ ),
1199
+ (
1200
+ {
1201
+ 0 : PublicBatch (
1202
+ messages = [stub_message (1 )],
1203
+ _partition_session = stub_partition_session (),
1204
+ _bytes_size = 0 ,
1205
+ _codec = Codec .CODEC_RAW ,
1206
+ ),
1207
+ 1 : PublicBatch (
1208
+ messages = [stub_message (2 ), stub_message (3 )],
1209
+ _partition_session = stub_partition_session (1 ),
1210
+ _bytes_size = 0 ,
1211
+ _codec = Codec .CODEC_RAW ,
1212
+ ),
1213
+ },
1214
+ 100 ,
1215
+ 1 ,
1216
+ {
1217
+ 1 : PublicBatch (
1218
+ messages = [stub_message (2 ), stub_message (3 )],
1219
+ _partition_session = stub_partition_session (1 ),
1220
+ _bytes_size = 0 ,
1221
+ _codec = Codec .CODEC_RAW ,
1222
+ )
1223
+ },
1224
+ ),
1225
+ ],
1226
+ )
1227
+ async def test_read_batch_max_messages (
1228
+ self ,
1229
+ stream_reader ,
1230
+ batches_before : typing .List [datatypes .PublicBatch ],
1231
+ max_messages : typing .Optional [int ],
1232
+ actual_messages : int ,
1233
+ batches_after : typing .List [datatypes .PublicBatch ],
1234
+ ):
1235
+ stream_reader ._message_batches = OrderedDict (batches_before )
1236
+ batch = stream_reader .receive_batch_nowait (max_messages = max_messages )
1237
+
1238
+ assert len (batch .messages ) == actual_messages
1239
+ assert stream_reader ._message_batches == OrderedDict (batches_after )
1240
+
1151
1241
async def test_receive_batch_nowait (self , stream , stream_reader , partition_session ):
1152
1242
assert stream_reader .receive_batch_nowait () is None
1153
1243
0 commit comments