@@ -451,6 +451,7 @@ def test_stop_via_update_from_output():
451
451
req .num_computed_tokens = req .num_tokens
452
452
scheduler .requests [req .request_id ] = req
453
453
scheduler .running .append (req )
454
+ req .status = RequestStatus .RUNNING
454
455
455
456
scheduler_output = SchedulerOutput (
456
457
scheduled_new_reqs = [],
@@ -504,6 +505,7 @@ def test_stop_via_update_from_output():
504
505
req .num_computed_tokens = req .num_tokens
505
506
scheduler .requests [req .request_id ] = req
506
507
scheduler .running .append (req )
508
+ req .status = RequestStatus .RUNNING
507
509
508
510
scheduler_output = SchedulerOutput (
509
511
scheduled_new_reqs = [],
@@ -556,6 +558,7 @@ def test_stop_via_update_from_output():
556
558
req .num_computed_tokens = req .num_tokens
557
559
scheduler .requests [req .request_id ] = req
558
560
scheduler .running .append (req )
561
+ req .status = RequestStatus .RUNNING
559
562
560
563
scheduler_output = SchedulerOutput (
561
564
scheduled_new_reqs = [],
@@ -703,6 +706,65 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
703
706
scheduler .update_from_output (scheduler_output1 , model_runner_output )
704
707
705
708
709
+ def test_preempt_during_execution ():
710
+ # NOTE(woosuk): The actual number of available blocks is 10 instead of 11
711
+ # because block 0 is reserved as the null block.
712
+ scheduler = create_scheduler (max_num_batched_tokens = 100 ,
713
+ block_size = 16 ,
714
+ num_blocks = 11 ,
715
+ enable_prefix_caching = False )
716
+ requests = create_requests (num_requests = 2 , num_tokens = 80 )
717
+
718
+ # Schedule the first request.
719
+ scheduler .add_request (requests [0 ])
720
+ scheduler_output0 = scheduler .schedule ()
721
+ assert len (scheduler_output0 .num_scheduled_tokens ) == 1
722
+ assert len (scheduler_output0 .scheduled_new_reqs [0 ].block_ids [0 ]) == 5
723
+
724
+ # Schedule the second request while the first request is still running.
725
+ # This scenario can occur in certain cases, when max_concurrent_batches > 1
726
+ # (e.g., when pipeline parallelism is used).
727
+ scheduler .add_request (requests [1 ])
728
+ scheduler_output1 = scheduler .schedule ()
729
+ assert len (scheduler_output1 .num_scheduled_tokens ) == 1
730
+ assert len (scheduler_output1 .scheduled_new_reqs [0 ].block_ids [0 ]) == 5
731
+
732
+ # Get the output of the first request.
733
+ model_runner_output0 = ModelRunnerOutput (
734
+ req_ids = [requests [0 ].request_id ],
735
+ req_id_to_index = {requests [0 ].request_id : 0 },
736
+ sampled_token_ids = [[0 ]],
737
+ spec_token_ids = None ,
738
+ logprobs = None ,
739
+ prompt_logprobs_dict = {},
740
+ pooler_output = [],
741
+ )
742
+ scheduler .update_from_output (scheduler_output0 , model_runner_output0 )
743
+
744
+ # Schedule the first request again. This will cause the preemption
745
+ # of the second request because the KV cache is full.
746
+ _ = scheduler .schedule ()
747
+ assert len (scheduler .running ) == 1
748
+ assert scheduler .running [0 ] == requests [0 ]
749
+ assert requests [1 ].status == RequestStatus .PREEMPTED
750
+
751
+ model_runner_output1 = ModelRunnerOutput (
752
+ req_ids = [requests [1 ].request_id ],
753
+ req_id_to_index = {requests [1 ].request_id : 0 },
754
+ sampled_token_ids = [[42 ]],
755
+ spec_token_ids = None ,
756
+ logprobs = None ,
757
+ prompt_logprobs_dict = {},
758
+ pooler_output = [],
759
+ )
760
+ scheduler .update_from_output (scheduler_output1 , model_runner_output1 )
761
+
762
+ # The second request (that is preempted) should be updated with the
763
+ # sampled token id.
764
+ assert len (requests [1 ].output_token_ids ) == 1
765
+ assert requests [1 ].output_token_ids [0 ] == 42
766
+
767
+
706
768
# Note - these test cases mirror some of those in test_rejection_sampler.py
707
769
@pytest .mark .parametrize (
708
770
"spec_tokens,output_tokens,expected" ,
0 commit comments