Skip to content

Commit 20a805d

Browse files
committed
fix: ascend_scheduler adapt v0.9.0
Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent 5903547 commit 20a805d

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ jobs:
112112
# pytest -sv tests/singlecard/tets_schedule.py
113113
# guided decoding doesn't work, fix it later
114114
# pytest -sv tests/singlecard/test_guided_decoding.py.py
115-
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
115+
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_guided_decoding.py
116116
else
117117
pytest -sv tests/multicard/test_ilama_lora_tp2.py
118118
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py

tests/singlecard/test_scheduler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from vllm.v1.structured_output import StructuredOutputManager
3232

3333
from vllm_ascend.core.scheduler import AscendScheduler
34+
from tests.conftest import VllmRunner
3435

3536
EOS_TOKEN_ID = 50256
3637

@@ -394,3 +395,27 @@ def test_stop_via_update_from_output():
394395
assert len(scheduler.running) == 1
395396
assert not requests[0].is_finished()
396397
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
398+
399+
MODELS = [
400+
"Qwen/Qwen3-0.6B-Base",
401+
]
402+
403+
@pytest.mark.parametrize("model", MODELS)
404+
@pytest.mark.parametrize("dtype", ["float16"])
405+
@pytest.mark.parametrize("max_tokens", [5])
406+
def test_models(model: str, dtype: str, max_tokens: int) -> None:
407+
# 5042 tokens for gemma2
408+
# gemma2 has alternating sliding window size of 4096
409+
# we need a prompt with more than 4096 tokens to test the sliding window
410+
prompt = "The following numbers of the sequence " + ", ".join(
411+
str(i) for i in range(1024)) + " are:"
412+
example_prompts = [prompt]
413+
414+
with VllmRunner(model,
415+
max_model_len=8192,
416+
dtype=dtype,
417+
enforce_eager=True,
418+
gpu_memory_utilization=0.7,
419+
enable_prefix_caching=False,
420+
additional_config={'ascend_schduler_config': {}}) as vllm_model:
421+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/core/scheduler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,15 @@ def skip_cur_request():
130130

131131
assert num_new_tokens > 0
132132
watermark = getattr(self.scheduler_config, "watermark", 0.01)
133-
if not self._check_watermark_for_prefill(
134-
request, num_new_tokens, computed_blocks, watermark):
133+
if not self._check_watermark_for_prefill(request, num_new_tokens,
134+
computed_blocks.blocks,
135+
watermark):
135136
# Scheduling would exceed watermark, skip.
136137
skip_cur_request()
137138
continue
138139

139140
new_blocks = self.kv_cache_manager.allocate_slots(
140-
request, num_new_tokens, computed_blocks)
141+
request, num_new_tokens, new_computed_blocks=computed_blocks)
141142
if new_blocks is None:
142143
# The request cannot be scheduled.
143144
break
@@ -155,9 +156,8 @@ def skip_cur_request():
155156

156157
if self.lora_config and request.lora_request:
157158
scheduled_loras.add(request.lora_request.lora_int_id)
158-
req_to_new_block_ids[request.request_id] = [
159-
b.block_id for b in computed_blocks + new_blocks
160-
]
159+
req_to_new_block_ids[request.request_id] = (
160+
self.kv_cache_manager.get_block_ids(request.request_id))
161161
# Update request info.
162162
num_scheduled_tokens[request.request_id] = num_new_tokens
163163
token_budget -= num_new_tokens
@@ -215,9 +215,8 @@ def skip_cur_request():
215215
# Schedule the request.
216216
scheduled_running_reqs.append(request)
217217
self.scheduled_req_ids.add(request.request_id)
218-
req_to_new_block_ids[request.request_id] = [
219-
b.block_id for b in new_blocks
220-
]
218+
req_to_new_block_ids[request.request_id] = (
219+
new_blocks.get_block_ids())
221220
num_scheduled_tokens[request.request_id] = num_new_tokens
222221
token_budget -= num_new_tokens
223222
req_index += 1
@@ -326,7 +325,8 @@ def _check_watermark_for_prefill(self,
326325
len(computed_blocks) * self.block_size)
327326
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
328327
self.block_size)
329-
req_blocks = self.kv_cache_manager.req_to_blocks[request.request_id]
328+
req_blocks = self.kv_cache_manager.single_type_manager.req_to_blocks[
329+
request.request_id]
330330
num_new_blocks = (num_required_blocks - len(req_blocks) -
331331
len(computed_blocks))
332332
num_evictable_computed_blocks = sum(1 for blk in computed_blocks

0 commit comments

Comments
 (0)