Skip to content

Commit 1a1f9a6

Browse files
ganyi1996ppoSidaoYlinfeng-yuanYizhou Liumengwei805
authored
port deepseekv2 and mtp to main branch (#429)
### What this PR does / why we need it? This PR ports all the deepseek graph mode code and mtp code from v0.7.3 to the main branch --------- Signed-off-by: SidaoY <1024863041@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: mengwei805 <mengwei25@huawei.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: q00832892 <qiaoyang19@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: SidaoY <1024863041@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: libaokui <libaokui@huawei.com>
1 parent 086423d commit 1a1f9a6

33 files changed

+3368
-322
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
This file demonstrates the example usage of disaggregated prefilling
3+
We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode),
4+
and then transfer the KV cache between them.
5+
"""
6+
import multiprocessing as mp
7+
import os
8+
import time
9+
from multiprocessing import Event, Process
10+
11+
12+
def clean_up():
13+
import gc
14+
15+
import torch
16+
from vllm.distributed.parallel_state import (
17+
destroy_distributed_environment, destroy_model_parallel)
18+
destroy_model_parallel()
19+
destroy_distributed_environment()
20+
gc.collect()
21+
torch.npu.empty_cache()
22+
23+
24+
def run_prefill(prefill_done, process_close):
25+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
26+
27+
from vllm import LLM, SamplingParams
28+
from vllm.config import KVTransferConfig
29+
30+
prompts = [
31+
"Hello, how are you today?", "Hi, what is your name?",
32+
"Tell me a very long story.", "what is your favourite book?"
33+
]
34+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
35+
36+
ktc = KVTransferConfig.from_cli(
37+
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
38+
)
39+
40+
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
41+
# memory. You may need to adjust the value to fit your GPU.
42+
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
43+
kv_transfer_config=ktc,
44+
max_model_len=2000,
45+
gpu_memory_utilization=0.8,
46+
tensor_parallel_size=2)
47+
48+
llm.generate(prompts, sampling_params)
49+
print("Prefill node is finished.")
50+
prefill_done.set()
51+
52+
# To keep the prefill node running in case the decode node is not done;
53+
# otherwise, the script might exit prematurely, causing incomplete decoding.
54+
try:
55+
while not process_close.is_set():
56+
time.sleep(1)
57+
except KeyboardInterrupt:
58+
print("Script stopped by user.")
59+
finally:
60+
print("Cleanup prefill resources")
61+
del llm
62+
clean_up()
63+
64+
65+
def run_decode(prefill_done):
66+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3"
67+
68+
from vllm import LLM, SamplingParams
69+
from vllm.config import KVTransferConfig
70+
71+
prompts = [
72+
"Hello, how are you today?", "Hi, what is your name?",
73+
"Tell me a very long story.", "what is your favourite book?"
74+
]
75+
sampling_params = SamplingParams(temperature=0, top_p=0.95)
76+
77+
ktc = KVTransferConfig.from_cli(
78+
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
79+
)
80+
81+
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
82+
kv_transfer_config=ktc,
83+
max_model_len=2000,
84+
gpu_memory_utilization=0.8,
85+
tensor_parallel_size=2)
86+
87+
# Wait for the producer to start the comsumer
88+
print("Waiting for prefill node to finish...")
89+
prefill_done.wait()
90+
91+
# At this point when the prefill_done is set, the kv-cache should have been
92+
# transferred to this decode node, so we can start decoding.
93+
outputs = llm.generate(prompts, sampling_params)
94+
for output in outputs:
95+
prompt = output.prompt
96+
generated_text = output.outputs[0].text
97+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
98+
99+
del llm
100+
clean_up()
101+
102+
103+
if __name__ == "__main__":
104+
mp.get_context('spawn')
105+
106+
prefill_done = Event()
107+
process_close = Event()
108+
prefill_process = Process(target=run_prefill,
109+
args=(
110+
prefill_done,
111+
process_close,
112+
))
113+
decode_process = Process(target=run_decode, args=(prefill_done, ))
114+
115+
# Start prefill node
116+
prefill_process.start()
117+
118+
# Start decode node
119+
decode_process.start()
120+
121+
# Terminate the prefill node when decode is finished
122+
decode_process.join()
123+
124+
# Terminate prefill process
125+
process_close.set()
126+
prefill_process.join()
127+
prefill_process.terminate()
128+
print("All process done!")

examples/dp_offline/data_parallel.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py
5+
# SPDX-License-Identifier: Apache-2.0
6+
# usage:
7+
# python examples/offline_inference_data_parallel.py
8+
# we need to have a launcher to create multiple data parallel
9+
# ranks. And each rank will create a vLLM instance to process its own prompts.
10+
11+
import gc
12+
import os
13+
14+
VLLM_ENABLE_GRAPGH_MODE = os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1"
15+
16+
17+
def main():
18+
dp_rank = int(os.environ['RANK'])
19+
local_rank = int(os.environ['LOCAL_RANK'])
20+
dp_size = int(os.environ['WORLD_SIZE'])
21+
master_addr = os.environ['MASTER_ADDR']
22+
master_port = os.environ['MASTER_PORT']
23+
tp_size = 4
24+
etp_size = 2
25+
26+
os.environ["VLLM_DP_RANK"] = str(dp_rank)
27+
os.environ["VLLM_DP_SIZE"] = str(dp_size)
28+
os.environ["VLLM_DP_MASTER_IP"] = master_addr
29+
os.environ["VLLM_DP_MASTER_PORT"] = master_port
30+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = ",".join(
31+
str(i)
32+
for i in range(local_rank * tp_size, (local_rank + 1) * tp_size))
33+
34+
import torch
35+
import torch_npu # noqa
36+
from vllm import LLM, SamplingParams
37+
from vllm.distributed.parallel_state import (
38+
destroy_distributed_environment, destroy_model_parallel)
39+
40+
prompts = [
41+
"Hello, my name is",
42+
"The president of the United States is",
43+
"The capital of France is",
44+
"The future of AI is",
45+
] * 4
46+
47+
promts_per_rank = len(prompts) // dp_size
48+
start = dp_rank * promts_per_rank
49+
end = start + promts_per_rank
50+
prompts = prompts[start:end]
51+
if len(prompts) == 0:
52+
prompts = ["Placeholder"]
53+
print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts")
54+
num_seqs = len(prompts)
55+
56+
sampling_params = SamplingParams(temperature=0.8,
57+
top_p=0.95,
58+
max_tokens=4,
59+
min_tokens=4)
60+
# Create an LLM.
61+
llm = LLM(
62+
model="deepseek-ai/DeepSeek-V2-Lite-Chat",
63+
tensor_parallel_size=tp_size,
64+
trust_remote_code=True,
65+
expert_tensor_parallel_size=etp_size,
66+
max_model_len=4096,
67+
max_num_seqs=num_seqs,
68+
compilation_config=1 if VLLM_ENABLE_GRAPGH_MODE else 0,
69+
)
70+
71+
outputs = llm.generate(prompts, sampling_params)
72+
for output in outputs:
73+
prompt = output.prompt
74+
generated_text = output.outputs[0].text
75+
print(f"DP rank {dp_rank}, Prompt: {prompt!r}, "
76+
f"Generated text: {generated_text!r}")
77+
78+
del llm
79+
destroy_model_parallel()
80+
destroy_distributed_environment()
81+
gc.collect()
82+
torch.npu.empty_cache()
83+
84+
85+
if __name__ == "__main__":
86+
main()

examples/dp_offline/run_dp.sh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
export HCCL_IF_IP=${local_ip}
2+
export GLOO_SOCKET_IFNAME=${ifname}
3+
export TP_SOCKET_IFNAME=${ifname}
4+
export HCCL_SOCKET_IFNAME=${ifname}
5+
6+
# dp_size = node_size * dp_per_node
7+
node_size=1
8+
node_rank=0
9+
dp_per_node=2
10+
master_addr=127.0.0.1
11+
master_port=12345
12+
13+
rm -rf ./.torchair_cache/
14+
rm -rf ./dynamo_*
15+
rm -rf /root/ascend/log/debug/plog/*
16+
export VLLM_ENABLE_GRAPH_MODE=0
17+
export VLLM_ENABLE_MC2=0
18+
19+
torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \
20+
--node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \
21+
data_parallel.py

examples/offline_inference_npu_v1.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
20+
import os
21+
22+
from vllm import LLM, SamplingParams
23+
24+
os.environ["VLLM_USE_V1"] = "1"
25+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
26+
27+
if __name__ == "__main__":
28+
prompts = [
29+
"Hello, my name is",
30+
"The president of the United States is",
31+
"The capital of France is",
32+
"The future of AI is",
33+
]
34+
35+
# Create a sampling params object.
36+
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
37+
# Create an LLM.
38+
llm = LLM(model="/data/weights/deepseek-ai/deepseekv3-lite-base-latest",
39+
tensor_parallel_size=2,
40+
enforce_eager=True,
41+
trust_remote_code=True,
42+
max_model_len=1024)
43+
44+
# Generate texts from the prompts.
45+
outputs = llm.generate(prompts, sampling_params)
46+
for output in outputs:
47+
prompt = output.prompt
48+
generated_text = output.outputs[0].text
49+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

tests/conftest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,22 @@
2626
from PIL import Image
2727
from vllm import LLM, SamplingParams
2828
from vllm.config import TaskOption
29-
from vllm.distributed.parallel_state import (destroy_distributed_environment,
30-
destroy_model_parallel)
3129
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
3230
from vllm.outputs import RequestOutput
3331
from vllm.sampling_params import BeamSearchParams
3432
from vllm.utils import is_list_of
3533

3634
from tests.model_utils import (TokensTextLogprobs,
3735
TokensTextLogprobsPromptLogprobs)
36+
# TODO: remove this part after the patch merged into vllm, if
37+
# we not explicitly patch here, some of them might be effectiveless
38+
# in pytest scenario
39+
from vllm_ascend.utils import adapt_patch # noqa E402
40+
41+
adapt_patch(True)
42+
43+
from vllm.distributed.parallel_state import ( # noqa E402
44+
destroy_distributed_environment, destroy_model_parallel)
3845

3946
_M = TypeVar("_M")
4047

0 commit comments

Comments
 (0)