Skip to content

Commit 6a47a8c

Browse files
whx-sjtuhw_whx
andauthored
[V1] Add v0 style schedule into v1 engine. (#512)
This PR adds ascend scheduler into v1 engine to support prefill-first schedule with v1 engine. --------- Signed-off-by: hw_whx <wanghexiang7@huawei.com> Signed-off-by: hw_whx <2952154980@qq.com> Co-authored-by: hw_whx <wanghexiang7@huawei.com>
1 parent b4cc2ca commit 6a47a8c

File tree

10 files changed

+1081
-69
lines changed

10 files changed

+1081
-69
lines changed

tests/test_scheduler.py

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
from typing import List, Optional
20+
21+
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
22+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
23+
from vllm.sampling_params import SamplingParams
24+
from vllm.v1.core.scheduler import SchedulerOutput
25+
from vllm.v1.outputs import ModelRunnerOutput
26+
from vllm.v1.request import Request, RequestStatus
27+
28+
from vllm_ascend.core.scheduler import AscendScheduler
29+
30+
EOS_TOKEN_ID = 50256
31+
32+
33+
def create_scheduler(
34+
model: str = "/data/weights/Qwen2.5-72B-Instruct",
35+
max_num_seqs: int = 16,
36+
max_num_batched_tokens: int = 8192,
37+
) -> AscendScheduler:
38+
scheduler_config = SchedulerConfig(
39+
max_num_seqs=max_num_seqs,
40+
max_num_batched_tokens=max_num_batched_tokens,
41+
max_model_len=max_num_batched_tokens,
42+
)
43+
model_config = ModelConfig(
44+
model=model,
45+
task="auto",
46+
tokenizer=model,
47+
tokenizer_mode="auto",
48+
trust_remote_code=True,
49+
dtype="float16",
50+
seed=42,
51+
)
52+
cache_config = CacheConfig(
53+
block_size=16,
54+
gpu_memory_utilization=0.9,
55+
swap_space=0,
56+
cache_dtype="auto",
57+
)
58+
cache_config.num_gpu_blocks = 10000
59+
return AscendScheduler(scheduler_config,
60+
model_config,
61+
cache_config,
62+
speculative_config=None,
63+
lora_config=None,
64+
log_stats=True)
65+
66+
67+
def create_requests(
68+
num_requests: int,
69+
num_tokens: int = 10,
70+
mm_positions: Optional[List[PlaceholderRange]] = None,
71+
max_tokens: int = 16,
72+
stop_token_ids: Optional[List[int]] = None,
73+
):
74+
sampling_params = SamplingParams(ignore_eos=False,
75+
max_tokens=max_tokens,
76+
stop_token_ids=stop_token_ids)
77+
requests = []
78+
for i in range(num_requests):
79+
if mm_positions is not None:
80+
mm_position = mm_positions[i]
81+
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
82+
else:
83+
mm_position = None
84+
mm_inputs = None
85+
request = Request(
86+
request_id=f"{i}",
87+
prompt=None,
88+
prompt_token_ids=[i] * num_tokens,
89+
sampling_params=sampling_params,
90+
multi_modal_inputs=mm_inputs,
91+
multi_modal_placeholders=mm_position,
92+
multi_modal_hashes=None,
93+
eos_token_id=EOS_TOKEN_ID,
94+
arrival_time=0,
95+
)
96+
requests.append(request)
97+
return requests
98+
99+
100+
def test_add_requests():
101+
scheduler = create_scheduler()
102+
requests = create_requests(num_requests=10)
103+
104+
for i, request in enumerate(requests):
105+
scheduler.add_request(request)
106+
assert request.request_id in scheduler.requests
107+
assert len(scheduler.waiting) == i + 1
108+
109+
110+
def test_finish_request():
111+
scheduler = create_scheduler()
112+
requests = create_requests(num_requests=10)
113+
for request in requests:
114+
scheduler.add_request(request)
115+
116+
for i, request in enumerate(requests):
117+
scheduler.finish_requests(request.request_id,
118+
RequestStatus.FINISHED_ABORTED)
119+
assert request.request_id not in scheduler.requests
120+
assert len(scheduler.waiting) == 9 - i
121+
122+
123+
def test_get_num_unfinished_requests():
124+
scheduler = create_scheduler()
125+
requests = create_requests(num_requests=10)
126+
for request in requests:
127+
scheduler.add_request(request)
128+
129+
for i, request in enumerate(requests):
130+
scheduler.finish_requests(request.request_id,
131+
RequestStatus.FINISHED_STOPPED)
132+
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
133+
134+
135+
def test_schedule():
136+
scheduler = create_scheduler()
137+
requests = create_requests(num_requests=10)
138+
for request in requests:
139+
scheduler.add_request(request)
140+
141+
# Test initial scheduling
142+
output = scheduler.schedule()
143+
assert len(output.scheduled_new_reqs) == len(requests)
144+
assert len(output.scheduled_cached_reqs) == 0
145+
assert len(output.finished_req_ids) == 0
146+
# Verify all requests are scheduled.
147+
for req_id, num_tokens in output.num_scheduled_tokens.items():
148+
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
149+
150+
# Verify requests moved from waiting to running
151+
assert len(scheduler.waiting) == 0
152+
assert len(scheduler.running) == len(requests)
153+
for i, request in enumerate(requests):
154+
assert scheduler.running[i] == request
155+
156+
157+
def test_stop_via_update_from_output():
158+
"""Test stopping behavior through update_from_output"""
159+
scheduler = create_scheduler()
160+
161+
# Test case 1: Stop on EOS token
162+
requests = create_requests(num_requests=2, max_tokens=10)
163+
for req in requests:
164+
req.num_computed_tokens = req.num_tokens
165+
scheduler.requests[req.request_id] = req
166+
scheduler.running.append(req)
167+
scheduler.scheduled_req_ids.add(req.request_id)
168+
169+
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
170+
scheduled_cached_reqs=[],
171+
num_scheduled_tokens={
172+
requests[0].request_id: 1,
173+
requests[1].request_id: 2
174+
},
175+
total_num_scheduled_tokens=3,
176+
scheduled_encoder_inputs={},
177+
scheduled_spec_decode_tokens={
178+
requests[0].request_id: [],
179+
requests[1].request_id: [10]
180+
},
181+
num_common_prefix_blocks=0,
182+
finished_req_ids=set(),
183+
free_encoder_input_ids=[])
184+
185+
model_output = ModelRunnerOutput(
186+
req_ids=[req.request_id for req in requests],
187+
req_id_to_index={req.request_id: i
188+
for i, req in enumerate(requests)},
189+
sampled_token_ids=[[EOS_TOKEN_ID],
190+
[10,
191+
11]], # First request hits EOS, second continues
192+
spec_token_ids=None,
193+
logprobs=None,
194+
prompt_logprobs_dict={})
195+
196+
scheduler.update_from_output(scheduler_output, model_output)
197+
198+
# Verify first request stopped, second continues
199+
assert len(scheduler.running) == 1
200+
assert scheduler.running[0].request_id == requests[1].request_id
201+
assert requests[0].status == RequestStatus.FINISHED_STOPPED
202+
assert requests[0].request_id in scheduler.finished_req_ids
203+
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
204+
assert list(requests[1].output_token_ids) == [10, 11]
205+
206+
# Test case 2: Stop on custom stop token
207+
scheduler = create_scheduler()
208+
requests = create_requests(num_requests=2,
209+
max_tokens=10,
210+
stop_token_ids=[42, 43])
211+
for req in requests:
212+
req.num_computed_tokens = req.num_tokens
213+
scheduler.requests[req.request_id] = req
214+
scheduler.running.append(req)
215+
scheduler.scheduled_req_ids.add(req.request_id)
216+
217+
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
218+
scheduled_cached_reqs=[],
219+
num_scheduled_tokens={
220+
requests[0].request_id: 3,
221+
requests[1].request_id: 2
222+
},
223+
total_num_scheduled_tokens=5,
224+
scheduled_encoder_inputs={},
225+
scheduled_spec_decode_tokens={
226+
requests[0].request_id: [10, 42],
227+
requests[1].request_id: [13]
228+
},
229+
num_common_prefix_blocks=0,
230+
finished_req_ids=set(),
231+
free_encoder_input_ids=[])
232+
233+
model_output = ModelRunnerOutput(
234+
req_ids=[req.request_id for req in requests],
235+
req_id_to_index={req.request_id: i
236+
for i, req in enumerate(requests)},
237+
sampled_token_ids=[[10, 42, 12],
238+
[13, 14]], # First request hits stop token
239+
spec_token_ids=None,
240+
logprobs=None,
241+
prompt_logprobs_dict={})
242+
243+
scheduler.update_from_output(scheduler_output, model_output)
244+
245+
# Verify first request stopped on custom token
246+
assert len(scheduler.running) == 1
247+
assert scheduler.running[0].request_id == requests[1].request_id
248+
assert requests[0].status == RequestStatus.FINISHED_STOPPED
249+
assert requests[0].stop_reason == 42
250+
assert requests[0].request_id in scheduler.finished_req_ids
251+
assert list(requests[0].output_token_ids) == [10, 42]
252+
assert list(requests[1].output_token_ids) == [13, 14]
253+
254+
# Test case 3: Stop on max tokens
255+
scheduler = create_scheduler()
256+
requests = create_requests(num_requests=2, max_tokens=2)
257+
for req in requests:
258+
req.num_computed_tokens = req.num_tokens
259+
scheduler.requests[req.request_id] = req
260+
scheduler.running.append(req)
261+
scheduler.scheduled_req_ids.add(req.request_id)
262+
263+
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
264+
scheduled_cached_reqs=[],
265+
num_scheduled_tokens={
266+
requests[0].request_id: 3,
267+
requests[1].request_id: 1
268+
},
269+
total_num_scheduled_tokens=4,
270+
scheduled_encoder_inputs={},
271+
scheduled_spec_decode_tokens={
272+
requests[0].request_id: [10, 11],
273+
requests[1].request_id: []
274+
},
275+
num_common_prefix_blocks=0,
276+
finished_req_ids=set(),
277+
free_encoder_input_ids=[])
278+
279+
model_output = ModelRunnerOutput(
280+
req_ids=[req.request_id for req in requests],
281+
req_id_to_index={req.request_id: i
282+
for i, req in enumerate(requests)},
283+
sampled_token_ids=[[10, 11, 12],
284+
[13]], # First request exceeds max_tokens
285+
spec_token_ids=None,
286+
logprobs=None,
287+
prompt_logprobs_dict={})
288+
289+
scheduler.update_from_output(scheduler_output, model_output)
290+
291+
# Verify first request stopped due to length
292+
assert len(scheduler.running) == 1
293+
assert scheduler.running[0].request_id == requests[1].request_id
294+
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
295+
assert requests[0].request_id in scheduler.finished_req_ids
296+
assert list(requests[0].output_token_ids) == [10, 11
297+
] # Truncated to max_tokens
298+
assert list(requests[1].output_token_ids) == [13]
299+
300+
# Test case 4: Ignore EOS flag
301+
scheduler = create_scheduler()
302+
requests = create_requests(num_requests=1, max_tokens=10)
303+
requests[0].sampling_params.ignore_eos = True
304+
requests[0].num_computed_tokens = requests[0].num_tokens
305+
scheduler.requests[requests[0].request_id] = requests[0]
306+
scheduler.running.append(requests[0])
307+
scheduler.scheduled_req_ids.add(requests[0].request_id)
308+
309+
scheduler_output = SchedulerOutput(
310+
scheduled_new_reqs=[],
311+
scheduled_cached_reqs=[],
312+
num_scheduled_tokens={requests[0].request_id: 3},
313+
total_num_scheduled_tokens=3,
314+
scheduled_encoder_inputs={},
315+
scheduled_spec_decode_tokens={
316+
requests[0].request_id: [EOS_TOKEN_ID, 10]
317+
},
318+
num_common_prefix_blocks=0,
319+
finished_req_ids=set(),
320+
free_encoder_input_ids=[])
321+
322+
model_output = ModelRunnerOutput(
323+
req_ids=[requests[0].request_id],
324+
req_id_to_index={requests[0].request_id: 0},
325+
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
326+
spec_token_ids=None,
327+
logprobs=None,
328+
prompt_logprobs_dict={})
329+
330+
scheduler.update_from_output(scheduler_output, model_output)
331+
332+
# Verify request continues past EOS
333+
assert len(scheduler.running) == 1
334+
assert not requests[0].is_finished()
335+
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]

0 commit comments

Comments
 (0)