@@ -20,9 +20,10 @@ def create_scheduler(
20
20
max_num_seqs : int = 16 ,
21
21
max_num_batched_tokens : int = 8192 ,
22
22
enable_prefix_caching : Optional [bool ] = None ,
23
+ long_prefill_token_threshold : int = 0 ,
23
24
) -> Scheduler :
24
25
'''Create scheduler under test.
25
-
26
+
26
27
Args:
27
28
model: model under test
28
29
max_num_seqs: max sequences to schedule
@@ -38,6 +39,7 @@ def create_scheduler(
38
39
max_num_seqs = max_num_seqs ,
39
40
max_num_batched_tokens = max_num_batched_tokens ,
40
41
max_model_len = max_num_batched_tokens ,
42
+ long_prefill_token_threshold = long_prefill_token_threshold ,
41
43
)
42
44
model_config = ModelConfig (
43
45
model = model ,
@@ -263,6 +265,78 @@ def test_schedule_partial_requests():
263
265
assert requests [2 ].request_id not in output .num_scheduled_tokens
264
266
265
267
268
+ @pytest .mark .parametrize ("enable_prefix_caching" , [True , False ])
269
+ def test_schedule_concurrent_partial_requestse (enable_prefix_caching : bool ):
270
+ """Test scheduling behavior with concurrent partial requests.
271
+
272
+ This test verifies that: there are multiple long prefill requests in the
273
+ RUNNING state, and we can schedule them together.
274
+
275
+ """
276
+ scheduler = create_scheduler (
277
+ model = "facebook/opt-125m" ,
278
+ max_num_batched_tokens = 1024 ,
279
+ long_prefill_token_threshold = 400 ,
280
+ enable_prefix_caching = enable_prefix_caching ,
281
+ )
282
+ requests = create_requests (
283
+ num_requests = 3 ,
284
+ num_tokens = 800 ,
285
+ )
286
+ for request in requests :
287
+ scheduler .add_request (request )
288
+
289
+ output = scheduler .schedule ()
290
+ assert len (output .scheduled_new_reqs ) == 3
291
+ assert len (output .scheduled_cached_reqs ) == 0
292
+ assert len (output .finished_req_ids ) == 0
293
+
294
+ # The first request is scheduled partially - 400.
295
+ assert output .num_scheduled_tokens [requests [0 ].request_id ] == 400
296
+ # The second request is scheduled partially - 400.
297
+ assert output .num_scheduled_tokens [requests [1 ].request_id ] == 400
298
+ # The third request is also scheduled partially - 1024 - 400 - 400 = 224.
299
+ assert output .num_scheduled_tokens [requests [2 ].request_id ] == 224
300
+ req_to_index = {
301
+ request .request_id : i
302
+ for i , request in enumerate (requests )
303
+ }
304
+ model_runner_output = ModelRunnerOutput (
305
+ req_ids = [request .request_id for request in requests ],
306
+ req_id_to_index = req_to_index ,
307
+ sampled_token_ids = [[0 ] for _ in range (len (requests ))],
308
+ spec_token_ids = None ,
309
+ logprobs = None ,
310
+ prompt_logprobs_dict = {},
311
+ )
312
+ scheduler .update_from_output (output , model_runner_output )
313
+
314
+ # Schedule the next step. All three requests are running.
315
+ # Processed the remaining prefills of the first and second requests.
316
+ output1 = scheduler .schedule ()
317
+ assert len (scheduler .running ) == 3
318
+ assert len (output1 .scheduled_new_reqs ) == 0
319
+ assert len (output1 .scheduled_cached_reqs ) == 3
320
+ assert len (output1 .finished_req_ids ) == 0
321
+ assert output1 .num_scheduled_tokens [requests [0 ].request_id ] == 400
322
+ assert output1 .num_scheduled_tokens [requests [1 ].request_id ] == 400
323
+ assert output1 .num_scheduled_tokens [requests [2 ].request_id ] == 224
324
+
325
+ # Schedule the third step. All three requests are running.
326
+ # First and second requests are in the decode stage.
327
+ # All the remaining tokens in the third request are processed.
328
+ scheduler .update_from_output (output1 , model_runner_output )
329
+ output2 = scheduler .schedule ()
330
+ assert len (scheduler .running ) == 3
331
+ assert len (output2 .scheduled_new_reqs ) == 0
332
+ assert len (output2 .scheduled_cached_reqs ) == 3
333
+ assert len (output2 .finished_req_ids ) == 0
334
+ assert output2 .num_scheduled_tokens [requests [0 ].request_id ] == 1
335
+ assert output2 .num_scheduled_tokens [requests [1 ].request_id ] == 1
336
+ assert output2 .num_scheduled_tokens [
337
+ requests [2 ].request_id ] == 800 - 224 - 224
338
+
339
+
266
340
def test_stop_via_update_from_output ():
267
341
"""Test stopping behavior through update_from_output"""
268
342
scheduler = create_scheduler ()
0 commit comments