Skip to content

Commit fd7b81e

Browse files
committed
Merge branch 'main' of https://github.com/raindaywhu/vllm-ascend into main
* 'main' of https://github.com/raindaywhu/vllm-ascend: [aclgraph] implentment NPUPiecewiseBackend to enable aclgraph (vllm-project#836) [Bugfix][V1] Fix deepseek with v1 (vllm-project#958) [Perf] Refactor tensor disposal logic to reduce memory usage (vllm-project#966)
2 parents b8b6175 + 55c8bb5 commit fd7b81e

File tree

14 files changed

+413
-64
lines changed

14 files changed

+413
-64
lines changed

tests/compile/test_aclgraph.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""
18+
Compare the outputs of vLLM with and without aclgraph.
19+
20+
Run `pytest tests/compile/test_aclgraph.py`.
21+
"""
22+
23+
import os
24+
25+
import pytest
26+
import torch
27+
from vllm import LLM, SamplingParams
28+
29+
from tests.conftest import VllmRunner
30+
from tests.model_utils import check_outputs_equal
31+
from vllm_ascend.utils import vllm_version_is
32+
33+
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
34+
35+
36+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
37+
reason="aclgraph only support on v1")
38+
@pytest.mark.skipif(
39+
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
40+
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
41+
@pytest.mark.parametrize("model", MODELS)
42+
@pytest.mark.parametrize("max_tokens", [32])
43+
def test_models(
44+
model: str,
45+
max_tokens: int,
46+
monkeypatch: pytest.MonkeyPatch,
47+
) -> None:
48+
with monkeypatch.context() as m:
49+
prompts = [
50+
"Hello, my name is", "The president of the United States is",
51+
"The capital of France is", "The future of AI is"
52+
]
53+
54+
# aclgraph only support on v1
55+
m.setenv("VLLM_USE_V1", "1")
56+
57+
sampling_params = SamplingParams(max_tokens=max_tokens,
58+
temperature=0.0)
59+
# TODO: change to use vllmrunner when the registry of custom op is solved
60+
# while running pytest
61+
vllm_model = LLM(model)
62+
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
63+
del vllm_model
64+
torch.npu.empty_cache()
65+
66+
vllm_model = LLM(model, enforce_eager=True)
67+
vllm_eager_outputs = vllm_model.generate(prompts, sampling_params)
68+
del vllm_model
69+
torch.npu.empty_cache()
70+
71+
vllm_aclgraph_outputs_list = []
72+
for output in vllm_aclgraph_outputs:
73+
vllm_aclgraph_outputs_list.append(
74+
(output.outputs[0].index, output.outputs[0].text))
75+
76+
vllm_eager_outputs_list = []
77+
for output in vllm_eager_outputs:
78+
vllm_eager_outputs_list.append(
79+
(output.outputs[0].index, output.outputs[0].text))
80+
81+
check_outputs_equal(
82+
outputs_0_lst=vllm_eager_outputs_list,
83+
outputs_1_lst=vllm_aclgraph_outputs_list,
84+
name_0="vllm_eager_outputs",
85+
name_1="vllm_aclgraph_outputs",
86+
)
87+
88+
89+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
90+
reason="aclgraph only support on v1")
91+
@pytest.mark.skipif(
92+
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
93+
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
94+
def test_deepseek_raises_error(monkeypatch: pytest.MonkeyPatch) -> None:
95+
with monkeypatch.context() as m:
96+
m.setenv("VLLM_USE_MODELSCOPE", "True")
97+
m.setenv("VLLM_USE_V1", "1")
98+
with pytest.raises(NotImplementedError) as excinfo:
99+
VllmRunner("deepseek-ai/DeepSeek-V2-Lite-Chat",
100+
max_model_len=1024,
101+
enforce_eager=False)
102+
assert "ACL Graph does not support deepseek" in str(excinfo.value)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
block_size: int = 16,
7878
enable_chunked_prefill: bool = False,
7979
swap_space: int = 4,
80-
enforce_eager: Optional[bool] = False,
80+
enforce_eager: Optional[bool] = True,
8181
**kwargs,
8282
) -> None:
8383
self.model = LLM(

tests/long_term/spec_decode/e2e/test_v1_spec_decode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_ngram_correctness(
7272
with monkeypatch.context() as m:
7373
m.setenv("VLLM_USE_V1", "1")
7474

75-
ref_llm = LLM(model=model_name, max_model_len=1024)
75+
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
7676
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
7777
del ref_llm
7878

@@ -85,6 +85,7 @@ def test_ngram_correctness(
8585
"num_speculative_tokens": 3,
8686
},
8787
max_model_len=1024,
88+
enforce_eager=True,
8889
)
8990
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
9091
matches = 0
@@ -135,6 +136,7 @@ def test_eagle_correctness(
135136
"max_model_len": 2048,
136137
},
137138
max_model_len=2048,
139+
enforce_eager=True,
138140
)
139141
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
140142
matches = 0

tests/multicard/test_dynamic_npugraph_batchsize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
import torch
1919
from vllm import LLM, SamplingParams
2020

21-
# TODO: revert me when cuda hard code is fixed in 'VllmBackend'
22-
torch.cuda.CUDAGraph = torch.npu.NPUGraph
21+
from vllm_ascend.utils import vllm_version_is
2322

2423
MODELS = [
2524
"Qwen/Qwen2.5-0.5B-Instruct",
@@ -33,6 +32,9 @@
3332
]
3433

3534

35+
@pytest.mark.skipif(
36+
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
37+
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
3638
@pytest.mark.parametrize("model", MODELS)
3739
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
3840
@pytest.mark.parametrize("max_tokens", [64])

tests/multicard/test_offline_inference_distributed.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"""
2323
import os
2424

25-
import pytest
2625
import vllm # noqa: F401
2726

2827
from tests.conftest import VllmRunner
@@ -47,8 +46,6 @@ def test_models_distributed_QwQ():
4746
vllm_model.generate_greedy(example_prompts, max_tokens)
4847

4948

50-
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
51-
reason="deepseek v2 lite is not supported on v1")
5249
def test_models_distributed_DeepSeek():
5350
example_prompts = [
5451
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",

tests/singlecard/test_offline_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
5252
with VllmRunner(model,
5353
max_model_len=8192,
5454
dtype=dtype,
55-
enforce_eager=False,
55+
enforce_eager=True,
5656
gpu_memory_utilization=0.7) as vllm_model:
5757
vllm_model.generate_greedy(example_prompts, max_tokens)
5858

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,8 @@ def build(self,
239239
# it blocks on all previous kernels.
240240
device = self.runner.device
241241

242-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
243-
)
244-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
245-
block_table[:num_reqs])
242+
block_table = (self.runner.input_batch.block_table[0].
243+
get_device_tensor()[:num_reqs])
246244
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
247245
device, non_blocking=True)
248246
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(

0 commit comments

Comments
 (0)