Skip to content

Commit d4d3094

Browse files
authored
Implement Async Scheduling (#19970)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 85bd659 commit d4d3094

File tree

11 files changed

+508
-148
lines changed

11 files changed

+508
-148
lines changed

tests/v1/core/__init__.py

Whitespace-only changes.

tests/v1/core/test_async_scheduler.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from collections import deque
4+
5+
import pytest
6+
7+
from vllm.v1.core.sched.output import SchedulerOutput
8+
from vllm.v1.outputs import ModelRunnerOutput
9+
from vllm.v1.request import RequestStatus
10+
11+
from .utils import create_requests, create_scheduler
12+
13+
14+
def _make_model_runner_output(
15+
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput:
16+
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
17+
return ModelRunnerOutput(
18+
req_ids=req_ids,
19+
req_id_to_index={
20+
req_id: i
21+
for i, req_id in enumerate(req_ids)
22+
},
23+
sampled_token_ids=[[i] for i in range(len(req_ids))],
24+
spec_token_ids=None,
25+
logprobs=None,
26+
prompt_logprobs_dict={},
27+
pooler_output=[],
28+
)
29+
30+
31+
@pytest.mark.parametrize("max_tokens", [1, 2, 3, 5])
32+
def test_stop_by_max_tokens(max_tokens: int):
33+
scheduler = create_scheduler(async_scheduling=True)
34+
requests = create_requests(num_requests=2, max_tokens=max_tokens)
35+
req0, req1 = requests
36+
37+
sched_outputs: deque[SchedulerOutput] = deque()
38+
scheduler.add_request(req0)
39+
sched_outputs.append(scheduler.schedule())
40+
41+
scheduler.add_request(req1)
42+
sched_outputs.append(scheduler.schedule())
43+
44+
while sched_outputs:
45+
sched_output = sched_outputs.popleft()
46+
model_runner_output = _make_model_runner_output(sched_output)
47+
scheduler.update_from_output(sched_output, model_runner_output)
48+
49+
sched_output = scheduler.schedule()
50+
if sched_output.num_scheduled_tokens:
51+
sched_outputs.append(sched_output)
52+
53+
assert scheduler.get_num_unfinished_requests() == 0
54+
assert req0.num_output_tokens == max_tokens
55+
assert req1.num_output_tokens == max_tokens
56+
57+
58+
def test_abort():
59+
scheduler = create_scheduler(async_scheduling=True)
60+
requests = create_requests(num_requests=10, max_tokens=20)
61+
62+
for req in requests:
63+
scheduler.add_request(req)
64+
65+
sched_outputs: deque[SchedulerOutput] = deque()
66+
sched_outputs.append(scheduler.schedule())
67+
sched_outputs.append(scheduler.schedule())
68+
69+
abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9]
70+
abort_order_copy = abort_order.copy()
71+
72+
def abort_request():
73+
if not abort_order:
74+
return
75+
req = requests[abort_order.pop(0)]
76+
scheduler.finish_requests(req.request_id,
77+
RequestStatus.FINISHED_ABORTED)
78+
79+
while sched_outputs:
80+
# Abort a scheduled request.
81+
abort_request()
82+
sched_output = sched_outputs.popleft()
83+
model_runner_output = _make_model_runner_output(sched_output)
84+
scheduler.update_from_output(sched_output, model_runner_output)
85+
86+
sched_output = scheduler.schedule()
87+
if sched_output.num_scheduled_tokens:
88+
sched_outputs.append(sched_output)
89+
90+
for i, req in enumerate(requests):
91+
assert req.status == RequestStatus.FINISHED_ABORTED
92+
assert req.num_output_tokens == abort_order_copy.index(i)
93+
94+
95+
def test_preempt():
96+
scheduler = create_scheduler(async_scheduling=True)
97+
requests = create_requests(num_requests=10, max_tokens=20)
98+
99+
for req in requests:
100+
scheduler.add_request(req)
101+
102+
sched_outputs: deque[SchedulerOutput] = deque()
103+
sched_outputs.append(scheduler.schedule())
104+
sched_outputs.append(scheduler.schedule())
105+
106+
abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9]
107+
abort_order_copy = abort_order.copy()
108+
109+
def abort_request():
110+
if not abort_order:
111+
return
112+
req = requests[abort_order.pop(0)]
113+
scheduler.finish_requests(req.request_id,
114+
RequestStatus.FINISHED_ABORTED)
115+
116+
while sched_outputs:
117+
# Abort a scheduled request.
118+
abort_request()
119+
sched_output = sched_outputs.popleft()
120+
model_runner_output = _make_model_runner_output(sched_output)
121+
scheduler.update_from_output(sched_output, model_runner_output)
122+
123+
sched_output = scheduler.schedule()
124+
if sched_output.num_scheduled_tokens:
125+
sched_outputs.append(sched_output)
126+
127+
for i, req in enumerate(requests):
128+
assert req.status == RequestStatus.FINISHED_ABORTED
129+
assert req.num_output_tokens == abort_order_copy.index(i)
130+
131+
132+
def test_prefix_caching_for_prefill_dedup():
133+
CHUNK_SIZE = 1000
134+
BLOCK_SIZE = 16
135+
num_prompt_tokens = 100
136+
scheduler = create_scheduler(async_scheduling=True,
137+
max_num_batched_tokens=CHUNK_SIZE,
138+
enable_prefix_caching=True,
139+
block_size=BLOCK_SIZE)
140+
requests = create_requests(num_requests=5,
141+
num_tokens=num_prompt_tokens,
142+
max_tokens=3,
143+
same_prompt=True)
144+
requests_copy = requests.copy()
145+
146+
# Two requests with the same prompt.
147+
req0 = requests.pop(0)
148+
req1 = requests.pop(0)
149+
scheduler.add_request(req0)
150+
scheduler.add_request(req1)
151+
152+
sched_outputs: deque[SchedulerOutput] = deque()
153+
sched_output = scheduler.schedule()
154+
sched_outputs.append(sched_output)
155+
# Make sure prefix caching de-duplicates the prompts in the same step,
156+
# so all the blocks except the last are shared between the two requests.
157+
assert len(sched_output.num_scheduled_tokens) == 2
158+
num_blocks = num_prompt_tokens // BLOCK_SIZE
159+
assert req0.num_cached_tokens == 0
160+
assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE
161+
162+
sched_outputs.append(scheduler.schedule())
163+
while sched_outputs:
164+
if requests:
165+
scheduler.add_request(requests.pop(0))
166+
sched_output = sched_outputs.popleft()
167+
model_runner_output = _make_model_runner_output(sched_output)
168+
scheduler.update_from_output(sched_output, model_runner_output)
169+
sched_output = scheduler.schedule()
170+
if sched_output.num_scheduled_tokens:
171+
sched_outputs.append(sched_output)
172+
173+
# Other requests scheduled after the two requests should also get
174+
# prefix cache hit.
175+
assert scheduler.get_num_unfinished_requests() == 0
176+
for req in requests_copy[1:]:
177+
assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE
178+
179+
180+
def test_prefix_caching_for_multi_turn():
181+
CHUNK_SIZE = 1000
182+
BLOCK_SIZE = 16
183+
num_prompt_tokens = 100
184+
num_output_tokens = 200
185+
scheduler = create_scheduler(async_scheduling=True,
186+
max_num_batched_tokens=CHUNK_SIZE,
187+
enable_prefix_caching=True,
188+
block_size=BLOCK_SIZE)
189+
requests = create_requests(num_requests=5,
190+
num_tokens=num_prompt_tokens,
191+
max_tokens=num_output_tokens)
192+
193+
for req in requests:
194+
scheduler.add_request(req)
195+
sched_outputs: deque[SchedulerOutput] = deque()
196+
sched_outputs.append(scheduler.schedule())
197+
sched_outputs.append(scheduler.schedule())
198+
199+
# Process the requests.
200+
while sched_outputs:
201+
sched_output = sched_outputs.popleft()
202+
model_runner_output = _make_model_runner_output(sched_output)
203+
scheduler.update_from_output(sched_output, model_runner_output)
204+
sched_output = scheduler.schedule()
205+
if sched_output.num_scheduled_tokens:
206+
sched_outputs.append(sched_output)
207+
assert scheduler.get_num_unfinished_requests() == 0
208+
209+
# Create next-turn requests whose prompts are the full output of the
210+
# previous turn.
211+
next_turn_requests = create_requests(
212+
num_requests=5,
213+
num_tokens=num_prompt_tokens + num_output_tokens,
214+
max_tokens=num_output_tokens,
215+
)
216+
for i, req in enumerate(next_turn_requests):
217+
req.prompt_token_ids = (requests[i].prompt_token_ids +
218+
list(requests[i].output_token_ids))
219+
# Schedule the next-turn requests.
220+
for req in next_turn_requests:
221+
scheduler.add_request(req)
222+
sched_outputs.append(scheduler.schedule())
223+
224+
# Make sure the next-turn requests get prefix cache hit by the previous
225+
# requests.
226+
for req in next_turn_requests:
227+
assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE *
228+
BLOCK_SIZE)

tests/v1/core/test_scheduler.py

Lines changed: 1 addition & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -19,133 +19,7 @@
1919
from vllm.v1.structured_output import StructuredOutputManager
2020
from vllm.v1.structured_output.request import StructuredOutputRequest
2121

22-
EOS_TOKEN_ID = 50256
23-
24-
25-
def create_scheduler(
26-
model: str = "facebook/opt-125m",
27-
max_num_seqs: int = 16,
28-
max_num_batched_tokens: int = 8192,
29-
enable_prefix_caching: Optional[bool] = None,
30-
long_prefill_token_threshold: int = 0,
31-
disable_chunked_mm_input: bool = False,
32-
use_kv_connector: bool = False,
33-
num_blocks: int = 10000,
34-
block_size: int = 16,
35-
max_model_len: Optional[int] = None,
36-
num_speculative_tokens: Optional[int] = None,
37-
skip_tokenizer_init: bool = False,
38-
) -> Scheduler:
39-
'''Create scheduler under test.
40-
41-
Args:
42-
model: model under test
43-
max_num_seqs: max sequences to schedule
44-
max_num_batch_tokens: max num tokens to batch
45-
enable_prefix_caching: optionally force APC config
46-
(True/False) or use default
47-
(None)
48-
49-
Returns:
50-
{class}`Scheduler` instance
51-
'''
52-
if max_model_len is None:
53-
max_model_len = max_num_batched_tokens
54-
scheduler_config = SchedulerConfig(
55-
max_num_seqs=max_num_seqs,
56-
max_num_batched_tokens=max_num_batched_tokens,
57-
max_model_len=max_model_len,
58-
long_prefill_token_threshold=long_prefill_token_threshold,
59-
disable_chunked_mm_input=disable_chunked_mm_input,
60-
enable_chunked_prefill=True,
61-
)
62-
model_config = ModelConfig(
63-
model=model,
64-
task="auto",
65-
tokenizer=model,
66-
tokenizer_mode="auto",
67-
trust_remote_code=True,
68-
dtype="float16",
69-
seed=42,
70-
skip_tokenizer_init=skip_tokenizer_init,
71-
)
72-
# Cache config, optionally force APC
73-
kwargs_cache = ({} if enable_prefix_caching is None else {
74-
'enable_prefix_caching': enable_prefix_caching
75-
})
76-
cache_config = CacheConfig(
77-
block_size=block_size,
78-
gpu_memory_utilization=0.9,
79-
swap_space=0,
80-
cache_dtype="auto",
81-
**kwargs_cache,
82-
)
83-
kv_transfer_config = KVTransferConfig(
84-
kv_connector="SharedStorageConnector",
85-
kv_role="kv_both",
86-
kv_connector_extra_config={"shared_storage_path": "local_storage"},
87-
) if use_kv_connector else None
88-
89-
speculative_config: Optional[SpeculativeConfig] = None
90-
if num_speculative_tokens is not None:
91-
speculative_config = SpeculativeConfig(
92-
model="ngram", num_speculative_tokens=num_speculative_tokens)
93-
94-
vllm_config = VllmConfig(
95-
scheduler_config=scheduler_config,
96-
model_config=model_config,
97-
cache_config=cache_config,
98-
kv_transfer_config=kv_transfer_config,
99-
speculative_config=speculative_config,
100-
)
101-
kv_cache_config = KVCacheConfig(
102-
num_blocks=num_blocks, # A large number of blocks to hold all requests
103-
kv_cache_tensors=[],
104-
kv_cache_groups=[
105-
KVCacheGroupSpec(['layer'],
106-
FullAttentionSpec(block_size, 1, 1, torch.float32,
107-
False))
108-
],
109-
)
110-
cache_config.num_gpu_blocks = num_blocks
111-
return Scheduler(
112-
vllm_config=vllm_config,
113-
kv_cache_config=kv_cache_config,
114-
log_stats=True,
115-
structured_output_manager=StructuredOutputManager(vllm_config),
116-
)
117-
118-
119-
def create_requests(num_requests: int,
120-
num_tokens: int = 10,
121-
mm_positions: Optional[list[PlaceholderRange]] = None,
122-
max_tokens: int = 16,
123-
stop_token_ids: Optional[list[int]] = None,
124-
prompt_logprobs: Optional[int] = None):
125-
sampling_params = SamplingParams(ignore_eos=False,
126-
max_tokens=max_tokens,
127-
stop_token_ids=stop_token_ids,
128-
prompt_logprobs=prompt_logprobs)
129-
requests = []
130-
for i in range(num_requests):
131-
if mm_positions is not None:
132-
mm_position = mm_positions[i]
133-
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
134-
else:
135-
mm_position = None
136-
mm_inputs = None
137-
request = Request(
138-
request_id=f"{i}",
139-
prompt_token_ids=[i] * num_tokens,
140-
sampling_params=sampling_params,
141-
pooling_params=None,
142-
multi_modal_inputs=mm_inputs,
143-
multi_modal_placeholders=mm_position,
144-
multi_modal_hashes=None,
145-
eos_token_id=EOS_TOKEN_ID,
146-
)
147-
requests.append(request)
148-
return requests
22+
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
14923

15024

15125
def test_add_requests():

0 commit comments

Comments
 (0)