Skip to content

Commit f45a332

Browse files
authored
[Sched] Enhance the logic to remove stopped requests from queues (#20739)
1 parent 6e2c176 commit f45a332

File tree

3 files changed

+92
-17
lines changed

3 files changed

+92
-17
lines changed

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requests >= 2.26.0
77
tqdm
88
blake3
99
py-cpuinfo
10-
transformers >= 4.51.1
10+
transformers >= 4.53.2
1111
huggingface-hub[hf_xet] >= 0.33.0 # Required for Xet downloads.
1212
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
1313
protobuf # Required by LlamaTokenizer.

tests/v1/core/test_scheduler.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def test_stop_via_update_from_output():
451451
req.num_computed_tokens = req.num_tokens
452452
scheduler.requests[req.request_id] = req
453453
scheduler.running.append(req)
454+
req.status = RequestStatus.RUNNING
454455

455456
scheduler_output = SchedulerOutput(
456457
scheduled_new_reqs=[],
@@ -504,6 +505,7 @@ def test_stop_via_update_from_output():
504505
req.num_computed_tokens = req.num_tokens
505506
scheduler.requests[req.request_id] = req
506507
scheduler.running.append(req)
508+
req.status = RequestStatus.RUNNING
507509

508510
scheduler_output = SchedulerOutput(
509511
scheduled_new_reqs=[],
@@ -556,6 +558,7 @@ def test_stop_via_update_from_output():
556558
req.num_computed_tokens = req.num_tokens
557559
scheduler.requests[req.request_id] = req
558560
scheduler.running.append(req)
561+
req.status = RequestStatus.RUNNING
559562

560563
scheduler_output = SchedulerOutput(
561564
scheduled_new_reqs=[],
@@ -703,6 +706,65 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
703706
scheduler.update_from_output(scheduler_output1, model_runner_output)
704707

705708

709+
def test_preempt_during_execution():
710+
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
711+
# because block 0 is reserved as the null block.
712+
scheduler = create_scheduler(max_num_batched_tokens=100,
713+
block_size=16,
714+
num_blocks=11,
715+
enable_prefix_caching=False)
716+
requests = create_requests(num_requests=2, num_tokens=80)
717+
718+
# Schedule the first request.
719+
scheduler.add_request(requests[0])
720+
scheduler_output0 = scheduler.schedule()
721+
assert len(scheduler_output0.num_scheduled_tokens) == 1
722+
assert len(scheduler_output0.scheduled_new_reqs[0].block_ids[0]) == 5
723+
724+
# Schedule the second request while the first request is still running.
725+
# This scenario can occur in certain cases, when max_concurrent_batches > 1
726+
# (e.g., when pipeline parallelism is used).
727+
scheduler.add_request(requests[1])
728+
scheduler_output1 = scheduler.schedule()
729+
assert len(scheduler_output1.num_scheduled_tokens) == 1
730+
assert len(scheduler_output1.scheduled_new_reqs[0].block_ids[0]) == 5
731+
732+
# Get the output of the first request.
733+
model_runner_output0 = ModelRunnerOutput(
734+
req_ids=[requests[0].request_id],
735+
req_id_to_index={requests[0].request_id: 0},
736+
sampled_token_ids=[[0]],
737+
spec_token_ids=None,
738+
logprobs=None,
739+
prompt_logprobs_dict={},
740+
pooler_output=[],
741+
)
742+
scheduler.update_from_output(scheduler_output0, model_runner_output0)
743+
744+
# Schedule the first request again. This will cause the preemption
745+
# of the second request because the KV cache is full.
746+
_ = scheduler.schedule()
747+
assert len(scheduler.running) == 1
748+
assert scheduler.running[0] == requests[0]
749+
assert requests[1].status == RequestStatus.PREEMPTED
750+
751+
model_runner_output1 = ModelRunnerOutput(
752+
req_ids=[requests[1].request_id],
753+
req_id_to_index={requests[1].request_id: 0},
754+
sampled_token_ids=[[42]],
755+
spec_token_ids=None,
756+
logprobs=None,
757+
prompt_logprobs_dict={},
758+
pooler_output=[],
759+
)
760+
scheduler.update_from_output(scheduler_output1, model_runner_output1)
761+
762+
# The second request (that is preempted) should be updated with the
763+
# sampled token id.
764+
assert len(requests[1].output_token_ids) == 1
765+
assert requests[1].output_token_ids[0] == 42
766+
767+
706768
# Note - these test cases mirror some of those in test_rejection_sampler.py
707769
@pytest.mark.parametrize(
708770
"spec_tokens,output_tokens,expected",

vllm/v1/core/sched/scheduler.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -747,19 +747,21 @@ def update_from_output(
747747
pooler_outputs = model_runner_output.pooler_output
748748
num_nans_in_logits = model_runner_output.num_nans_in_logits
749749

750-
new_running: list[Request] = []
751750
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
752751
spec_decoding_stats: Optional[SpecDecodingStats] = None
753752

754-
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
755-
# loop can be a performance bottleneck. We should do our best to avoid
756-
# expensive operations inside the loop.
757-
for request in self.running:
758-
req_id = request.request_id
759-
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
760-
if num_tokens_scheduled == 0:
761-
# The request was not scheduled in this step.
762-
new_running.append(request)
753+
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
754+
# the below loop can be a performance bottleneck. We should do our best
755+
# to avoid expensive operations inside the loop.
756+
stopped_running_reqs: set[Request] = set()
757+
stopped_preempted_reqs: set[Request] = set()
758+
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
759+
assert num_tokens_scheduled > 0
760+
request = self.requests.get(req_id)
761+
if request is None:
762+
# The request is already finished. This can happen if the
763+
# request is aborted while the model is executing it (e.g.,
764+
# in pipeline parallelism).
763765
continue
764766

765767
req_index = model_runner_output.req_id_to_index[req_id]
@@ -792,6 +794,7 @@ def update_from_output(
792794
new_logprobs = None
793795
new_token_ids = generated_token_ids
794796
kv_transfer_params = None
797+
status_before_stop = request.status
795798

796799
# Append generated tokens and check for stop. Note that if
797800
# a request is still being prefilled, we expect the model runner
@@ -803,17 +806,22 @@ def update_from_output(
803806
# This must be called before we make the EngineCoreOutput.
804807
stopped = check_stop(request, self.max_model_len)
805808
if stopped:
806-
kv_transfer_params = self._free_request(request)
807809
del new_token_ids[num_new:] # Trim new tokens if needed.
808810
break
809811

812+
# Stop checking for pooler models.
810813
pooler_output = None
811814
if pooler_outputs:
812815
pooler_output = pooler_outputs[req_index]
813816
stopped = check_stop(request, self.max_model_len,
814817
pooler_output)
815-
if stopped:
816-
kv_transfer_params = self._free_request(request)
818+
819+
if stopped:
820+
kv_transfer_params = self._free_request(request)
821+
if status_before_stop == RequestStatus.RUNNING:
822+
stopped_running_reqs.add(request)
823+
else:
824+
stopped_preempted_reqs.add(request)
817825

818826
# Extract sample logprobs if needed.
819827
if request.sampling_params is not None \
@@ -868,9 +876,14 @@ def update_from_output(
868876
# Invariant: EngineCore returns no partial prefill outputs.
869877
assert not prompt_logprobs_tensors
870878

871-
if not stopped:
872-
new_running.append(request)
873-
self.running = new_running
879+
# Remove the stopped requests from the running and waiting queues.
880+
if stopped_running_reqs:
881+
self.running = [
882+
req for req in self.running if req not in stopped_running_reqs
883+
]
884+
if stopped_preempted_reqs:
885+
# This is a rare case and unlikely to impact performance.
886+
self.waiting.remove_requests(stopped_preempted_reqs)
874887

875888
# KV Connector: update state for finished KV Transfers.
876889
self._update_from_kv_xfer_finished(model_runner_output)

0 commit comments

Comments
 (0)