Skip to content

Commit d849be3

Browse files
authored
Merge branch 'main' into main
2 parents 777cfa7 + 6061f33 commit d849be3

31 files changed

+1033
-231
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,23 @@ jobs:
4646
max-parallel: 2
4747
matrix:
4848
os: [linux-arm64-npu-1, linux-arm64-npu-4]
49-
vllm_verison: [main, v0.8.3]
49+
vllm_verison: [main, v0.8.4]
50+
concurrency:
51+
group: >
52+
${{
53+
matrix.os == 'linux-arm64-npu-4'
54+
&& github.event.pull_request.number
55+
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
56+
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
57+
}}
58+
cancel-in-progress: false
5059
name: vLLM Ascend test
5160
runs-on: ${{ matrix.os }}
5261
container:
5362
image: quay.io/ascend/cann:8.0.0-910b-ubuntu22.04-py3.10
63+
env:
64+
HF_ENDPOINT: https://hf-mirror.com
65+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
5466
steps:
5567
- name: Check npu and CANN info
5668
run: |
@@ -108,7 +120,6 @@ jobs:
108120
- name: Run vllm-project/vllm-ascend test on V0 engine
109121
env:
110122
VLLM_USE_V1: 0
111-
HF_ENDPOINT: https://hf-mirror.com
112123
run: |
113124
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
114125
pytest -sv tests/singlecard
@@ -122,7 +133,6 @@ jobs:
122133
env:
123134
VLLM_USE_V1: 1
124135
VLLM_WORKER_MULTIPROC_METHOD: spawn
125-
HF_ENDPOINT: https://hf-mirror.com
126136
run: |
127137
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
128138
pytest -sv tests/singlecard
@@ -136,6 +146,5 @@ jobs:
136146
env:
137147
VLLM_USE_V1: 0
138148
PYTORCH_NPU_ALLOC_CONF: max_split_size_mb:256
139-
HF_ENDPOINT: https://hf-mirror.com
140149
run: |
141150
pytest -sv

docs/source/faqs.md

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ After configuration, you can download our container from `m.daocloud.io/quay.io/
5555

5656
### 3. What models does vllm-ascend supports?
5757

58-
Currently, we have already fully tested and supported `Qwen` / `Deepseek` (V0 only) / `Llama` models, other models we have tested are shown [<u>here</u>](https://vllm-ascend.readthedocs.io/en/latest/user_guide/supported_models.html). Plus, accoding to users' feedback, `gemma3` and `glm4` are not supported yet. Besides, more models need test.
58+
Currently, we have already fully tested and supported `Qwen` / `Deepseek` (V0 only) / `Llama` models, other models we have tested are shown [<u>here</u>](https://vllm-ascend.readthedocs.io/en/latest/user_guide/supported_models.html). Plus, according to users' feedback, `gemma3` and `glm4` are not supported yet. Besides, more models need test.
5959

6060
### 4. How to get in touch with our community?
6161

@@ -69,3 +69,51 @@ There are many channels that you can communicate with our community developers /
6969
### 5. What features does vllm-ascend V1 supports?
7070

7171
Find more details [<u>here</u>](https://github.com/vllm-project/vllm-ascend/issues/414).
72+
73+
### 6. How to solve the problem of "Failed to infer device type" or "libatb.so: cannot open shared object file"?
74+
75+
Basically, the reason is that the NPU environment is not configured correctly. You can:
76+
1. try `source /usr/local/Ascend/nnal/atb/set_env.sh` to enable NNAL package.
77+
2. try `source /usr/local/Ascend/ascend-toolkit/set_env.sh` to enable CANN package.
78+
3. try `npu-smi info` to check whether the NPU is working.
79+
80+
If all above steps are not working, you can try the following code with python to check whether there is any error:
81+
82+
```
83+
import torch
84+
import torch_npu
85+
import vllm
86+
```
87+
88+
If all above steps are not working, feel free to submit a GitHub issue.
89+
90+
### 7. Does vllm-ascend support Atlas 300I Duo?
91+
92+
No, vllm-ascend now only supports Atlas A2 series. We are working on it.
93+
94+
### 8. How does vllm-ascend perform?
95+
96+
Currently, only some models are improved. Such as `Qwen2 VL`, `Deepseek V3`. Others are not good enough. In the future, we will support graph mode and custom ops to improve the performance of vllm-ascend. And when the official release of vllm-ascend is released, you can install `mindie-turbo` with `vllm-ascend` to speed up the inference as well.
97+
98+
### 9. How vllm-ascend work with vllm?
99+
vllm-ascend is a plugin for vllm. Basically, the version of vllm-ascend is the same as the version of vllm. For example, if you use vllm 0.7.3, you should use vllm-ascend 0.7.3 as well. For main branch, we will make sure `vllm-ascend` and `vllm` are compatible by each commit.
100+
101+
### 10. Does vllm-ascend support Prefill Disaggregation feature?
102+
103+
Currently, only 1P1D is supported by vllm. For vllm-ascend, it'll be done by [this PR](https://github.com/vllm-project/vllm-ascend/pull/432). For NPND, vllm is not stable and fully supported yet. We will make it stable and supported by vllm-ascend in the future.
104+
105+
### 11. Does vllm-ascend support quantization method?
106+
107+
Currently, there is no quantization method supported in vllm-ascend originally. And the quantization supported is working in progress, w8a8 will firstly be supported.
108+
109+
### 12. How to run w8a8 DeepSeek model?
110+
111+
Currently, running on v0.7.3, we should run w8a8 with vllm + vllm-ascend + mindie-turbo. And we only need vllm + vllm-ascend when v0.8.X is released. After installing the above packages, you can follow the steps below to run w8a8 DeepSeek:
112+
113+
1. Quantize bf16 DeepSeek, e.g. [unsloth/DeepSeek-R1-BF16](https://modelscope.cn/models/unsloth/DeepSeek-R1-BF16), with msModelSlim to get w8a8 DeepSeek. Find more details in [msModelSlim doc](https://gitee.com/ascend/msit/tree/master/msmodelslim/msmodelslim/pytorch/llm_ptq)
114+
2. Copy the content of `quant_model_description_w8a8_dynamic.json` into the `quantization_config` of `config.json` of the quantized model files.
115+
3. Reference with the quantized DeepSeek model.
116+
117+
### 13. There is not output in log when loading models using vllm-ascend, How to solve it?
118+
119+
If you're using vllm 0.7.3 version, this is a known progress bar display issue in VLLM, which has been resolved in [this PR](https://github.com/vllm-project/vllm/pull/12428), please cherry-pick it locally by yourself. Otherwise, please fill up an issue.

docs/source/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ faqs
4545
:maxdepth: 1
4646
user_guide/suppoted_features
4747
user_guide/supported_models
48+
user_guide/env_vars
4849
user_guide/release_notes
4950
:::
5051

docs/source/user_guide/env_vars.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Environment Variables
2+
3+
vllm-ascend uses the following environment variables to configure the system:
4+
5+
:::{literalinclude} ../../../vllm_ascend/envs.py
6+
:language: python
7+
:start-after: begin-env-vars-definition
8+
:end-before: end-env-vars-definition
9+
:::
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
import multiprocessing as mp
20+
import os
21+
import time
22+
from multiprocessing import Event, Process
23+
24+
25+
def clean_up():
26+
import gc
27+
28+
import torch
29+
from vllm.distributed.parallel_state import (
30+
destroy_distributed_environment, destroy_model_parallel)
31+
destroy_model_parallel()
32+
destroy_distributed_environment()
33+
gc.collect()
34+
torch.npu.empty_cache()
35+
36+
37+
def run_prefill(prefill_done, process_close):
38+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
39+
40+
from vllm import LLM, SamplingParams
41+
from vllm.config import KVTransferConfig
42+
43+
prompts = [
44+
"Hello, how are you today?", "Hi, what is your name?",
45+
"Tell me a very long story.", "what is your favourite book?"
46+
]
47+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
48+
49+
ktc = KVTransferConfig.from_cli(
50+
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
51+
)
52+
53+
# Set NPU memory utilization to 0.8
54+
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
55+
kv_transfer_config=ktc,
56+
max_model_len=2000,
57+
gpu_memory_utilization=0.8,
58+
tensor_parallel_size=2)
59+
60+
llm.generate(prompts, sampling_params)
61+
print("Prefill node is finished.")
62+
prefill_done.set()
63+
64+
# To keep the prefill node running in case the decode node is not done
65+
# otherwise, the script might exit prematurely, causing incomplete decoding.
66+
try:
67+
while not process_close.is_set():
68+
time.sleep(1)
69+
except KeyboardInterrupt:
70+
print("Script stopped by user.")
71+
finally:
72+
print("Cleanup prefill resources")
73+
del llm
74+
clean_up()
75+
76+
77+
def run_decode(prefill_done):
78+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3"
79+
80+
from vllm import LLM, SamplingParams
81+
from vllm.config import KVTransferConfig
82+
83+
prompts = [
84+
"Hello, how are you today?", "Hi, what is your name?",
85+
"Tell me a very long story.", "what is your favourite book?"
86+
]
87+
sampling_params = SamplingParams(temperature=0, top_p=0.95)
88+
89+
ktc = KVTransferConfig.from_cli(
90+
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
91+
)
92+
93+
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
94+
kv_transfer_config=ktc,
95+
max_model_len=2000,
96+
gpu_memory_utilization=0.8,
97+
tensor_parallel_size=2)
98+
99+
# Wait for the producer to start the consumer
100+
print("Waiting for prefill node to finish...")
101+
prefill_done.wait()
102+
103+
# At this point when the prefill_done is set, the kv-cache should have been
104+
# transferred to this decode node, so we can start decoding.
105+
outputs = llm.generate(prompts, sampling_params)
106+
for output in outputs:
107+
prompt = output.prompt
108+
generated_text = output.outputs[0].text
109+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
110+
111+
del llm
112+
clean_up()
113+
114+
115+
if __name__ == "__main__":
116+
mp.get_context('spawn')
117+
118+
prefill_done = Event()
119+
process_close = Event()
120+
prefill_process = Process(target=run_prefill,
121+
args=(
122+
prefill_done,
123+
process_close,
124+
))
125+
decode_process = Process(target=run_decode, args=(prefill_done, ))
126+
127+
# Start prefill node
128+
prefill_process.start()
129+
130+
# Start decode node
131+
decode_process.start()
132+
133+
# Terminate the prefill node when decode is finished
134+
decode_process.join()
135+
136+
# Terminate prefill process
137+
process_close.set()
138+
prefill_process.join()
139+
prefill_process.terminate()
140+
print("All process done!")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ requires = [
44
"cmake>=3.26",
55
"decorator",
66
"numpy<2.0.0",
7+
"packaging",
78
"pip",
89
"pybind11",
910
"pyyaml",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
cmake>=3.26
33
decorator
44
numpy<2.0.0
5+
packaging
56
pybind11
67
pyyaml
78
scipy

tests/conftest.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,13 @@
2929
from vllm.distributed.parallel_state import (destroy_distributed_environment,
3030
destroy_model_parallel)
3131
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
32-
from vllm.logger import init_logger
3332
from vllm.outputs import RequestOutput
3433
from vllm.sampling_params import BeamSearchParams
3534
from vllm.utils import is_list_of
3635

3736
from tests.model_utils import (TokensTextLogprobs,
3837
TokensTextLogprobsPromptLogprobs)
3938

40-
logger = init_logger(__name__)
41-
4239
_M = TypeVar("_M")
4340

4441
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]

vllm_ascend/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818

1919
def register():
2020
"""Register the NPU platform."""
21+
# Adapt the global patch here.
22+
from vllm_ascend.utils import adapt_patch
23+
adapt_patch(is_global_patch=True)
24+
2125
return "vllm_ascend.platform.NPUPlatform"
2226

2327

vllm_ascend/distributed/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from vllm.distributed.kv_transfer.kv_connector.factory import \
2+
KVConnectorFactory
3+
4+
KVConnectorFactory.register_connector(
5+
"AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector",
6+
"LLMDataDistConnector")
File renamed without changes.

0 commit comments

Comments
 (0)