Skip to content

Commit 082ab86

Browse files
authored
[V1] Support long_prefill_token_threshold in v1 scheduler (#15419)
Signed-off-by: Lu Fang <lufang@fb.com>
1 parent 6aa196c commit 082ab86

File tree

4 files changed

+113
-4
lines changed

4 files changed

+113
-4
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ def create_scheduler(
2020
max_num_seqs: int = 16,
2121
max_num_batched_tokens: int = 8192,
2222
enable_prefix_caching: Optional[bool] = None,
23+
long_prefill_token_threshold: int = 0,
2324
) -> Scheduler:
2425
'''Create scheduler under test.
25-
26+
2627
Args:
2728
model: model under test
2829
max_num_seqs: max sequences to schedule
@@ -38,6 +39,7 @@ def create_scheduler(
3839
max_num_seqs=max_num_seqs,
3940
max_num_batched_tokens=max_num_batched_tokens,
4041
max_model_len=max_num_batched_tokens,
42+
long_prefill_token_threshold=long_prefill_token_threshold,
4143
)
4244
model_config = ModelConfig(
4345
model=model,
@@ -263,6 +265,78 @@ def test_schedule_partial_requests():
263265
assert requests[2].request_id not in output.num_scheduled_tokens
264266

265267

268+
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
269+
def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
270+
"""Test scheduling behavior with concurrent partial requests.
271+
272+
This test verifies that: there are multiple long prefill requests in the
273+
RUNNING state, and we can schedule them together.
274+
275+
"""
276+
scheduler = create_scheduler(
277+
model="facebook/opt-125m",
278+
max_num_batched_tokens=1024,
279+
long_prefill_token_threshold=400,
280+
enable_prefix_caching=enable_prefix_caching,
281+
)
282+
requests = create_requests(
283+
num_requests=3,
284+
num_tokens=800,
285+
)
286+
for request in requests:
287+
scheduler.add_request(request)
288+
289+
output = scheduler.schedule()
290+
assert len(output.scheduled_new_reqs) == 3
291+
assert len(output.scheduled_cached_reqs) == 0
292+
assert len(output.finished_req_ids) == 0
293+
294+
# The first request is scheduled partially - 400.
295+
assert output.num_scheduled_tokens[requests[0].request_id] == 400
296+
# The second request is scheduled partially - 400.
297+
assert output.num_scheduled_tokens[requests[1].request_id] == 400
298+
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
299+
assert output.num_scheduled_tokens[requests[2].request_id] == 224
300+
req_to_index = {
301+
request.request_id: i
302+
for i, request in enumerate(requests)
303+
}
304+
model_runner_output = ModelRunnerOutput(
305+
req_ids=[request.request_id for request in requests],
306+
req_id_to_index=req_to_index,
307+
sampled_token_ids=[[0] for _ in range(len(requests))],
308+
spec_token_ids=None,
309+
logprobs=None,
310+
prompt_logprobs_dict={},
311+
)
312+
scheduler.update_from_output(output, model_runner_output)
313+
314+
# Schedule the next step. All three requests are running.
315+
# Processed the remaining prefills of the first and second requests.
316+
output1 = scheduler.schedule()
317+
assert len(scheduler.running) == 3
318+
assert len(output1.scheduled_new_reqs) == 0
319+
assert len(output1.scheduled_cached_reqs) == 3
320+
assert len(output1.finished_req_ids) == 0
321+
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
322+
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
323+
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
324+
325+
# Schedule the third step. All three requests are running.
326+
# First and second requests are in the decode stage.
327+
# All the remaining tokens in the third request are processed.
328+
scheduler.update_from_output(output1, model_runner_output)
329+
output2 = scheduler.schedule()
330+
assert len(scheduler.running) == 3
331+
assert len(output2.scheduled_new_reqs) == 0
332+
assert len(output2.scheduled_cached_reqs) == 3
333+
assert len(output2.finished_req_ids) == 0
334+
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
335+
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
336+
assert output2.num_scheduled_tokens[
337+
requests[2].request_id] == 800 - 224 - 224
338+
339+
266340
def test_stop_via_update_from_output():
267341
"""Test stopping behavior through update_from_output"""
268342
scheduler = create_scheduler()

tests/v1/core/test_scheduler_e2e.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
4+
import pytest
5+
6+
from vllm import LLM
7+
8+
if os.getenv("VLLM_USE_V1", "0") != "1":
9+
pytest.skip("Test package requires V1", allow_module_level=True)
10+
11+
MODEL = "meta-llama/Llama-3.2-1B"
12+
PROMPT = "Hello my name is Robert and I"
13+
14+
15+
@pytest.fixture(scope="module")
16+
def model() -> LLM:
17+
return LLM(MODEL,
18+
enforce_eager=True,
19+
enable_prefix_caching=True,
20+
long_prefill_token_threshold=2,
21+
max_num_batched_tokens=6,
22+
max_num_seqs=3)
23+
24+
25+
def test_concurrent_partial_prefill(model):
26+
outputs = model.generate([PROMPT] * 3)
27+
assert len(outputs) == 3
28+
for output in outputs:
29+
assert len(output.outputs) == 1

vllm/engine/arg_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,9 +1625,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
16251625
if (self.max_num_partial_prefills
16261626
!= EngineArgs.max_num_partial_prefills
16271627
or self.max_long_partial_prefills
1628-
!= EngineArgs.max_long_partial_prefills
1629-
or self.long_prefill_token_threshold
1630-
!= EngineArgs.long_prefill_token_threshold):
1628+
!= EngineArgs.max_long_partial_prefills):
16311629
_raise_or_fallback(feature_name="Concurrent Partial Prefill",
16321630
recommend_to_remove=False)
16331631
return False

vllm/v1/core/sched/scheduler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def schedule(self) -> SchedulerOutput:
152152

153153
num_new_tokens = (request.num_tokens_with_spec -
154154
request.num_computed_tokens)
155+
if self.scheduler_config.long_prefill_token_threshold > 0:
156+
num_new_tokens = min(
157+
num_new_tokens,
158+
self.scheduler_config.long_prefill_token_threshold)
155159
num_new_tokens = min(num_new_tokens, token_budget)
156160
assert num_new_tokens > 0
157161

@@ -299,6 +303,10 @@ def schedule(self) -> SchedulerOutput:
299303
num_computed_tokens -= self.block_size
300304
num_new_tokens = self.block_size
301305
computed_blocks.pop()
306+
if self.scheduler_config.long_prefill_token_threshold > 0:
307+
num_new_tokens = min(
308+
num_new_tokens,
309+
self.scheduler_config.long_prefill_token_threshold)
302310
num_new_tokens = min(num_new_tokens, token_budget)
303311
assert num_new_tokens > 0
304312

0 commit comments

Comments
 (0)