Skip to content

Commit 44a8301

Browse files
eeethenQZihuiQian
andauthored
[Feature] Add PD separation feature (#432)
### What this PR does / why we need it? Adapt Disaggregated Prefill feature onto Ascend device ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? The test usage has been provided alongwith the PR, in examples/offline_disaggregated_prefill_npu.py To run it, do this ``` export PROMPT_DEVICE_ID=0,1 export DECODE_DEVICE_ID=2,3 python examples/offline_disaggregated_prefill_npu.py ``` --------- Signed-off-by: ZihuiQian <qianzihui@huawei.com> Co-authored-by: ZihuiQian <qianzihui@huawei.com>
1 parent c7f6584 commit 44a8301

File tree

8 files changed

+634
-8
lines changed

8 files changed

+634
-8
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
import multiprocessing as mp
20+
import os
21+
import time
22+
from multiprocessing import Event, Process
23+
24+
25+
def clean_up():
26+
import gc
27+
28+
import torch
29+
from vllm.distributed.parallel_state import (
30+
destroy_distributed_environment, destroy_model_parallel)
31+
destroy_model_parallel()
32+
destroy_distributed_environment()
33+
gc.collect()
34+
torch.npu.empty_cache()
35+
36+
37+
def run_prefill(prefill_done, process_close):
38+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
39+
40+
from vllm import LLM, SamplingParams
41+
from vllm.config import KVTransferConfig
42+
43+
prompts = [
44+
"Hello, how are you today?", "Hi, what is your name?",
45+
"Tell me a very long story.", "what is your favourite book?"
46+
]
47+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
48+
49+
ktc = KVTransferConfig.from_cli(
50+
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
51+
)
52+
53+
# Set NPU memory utilization to 0.8
54+
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
55+
kv_transfer_config=ktc,
56+
max_model_len=2000,
57+
gpu_memory_utilization=0.8,
58+
tensor_parallel_size=2)
59+
60+
llm.generate(prompts, sampling_params)
61+
print("Prefill node is finished.")
62+
prefill_done.set()
63+
64+
# To keep the prefill node running in case the decode node is not done
65+
# otherwise, the script might exit prematurely, causing incomplete decoding.
66+
try:
67+
while not process_close.is_set():
68+
time.sleep(1)
69+
except KeyboardInterrupt:
70+
print("Script stopped by user.")
71+
finally:
72+
print("Cleanup prefill resources")
73+
del llm
74+
clean_up()
75+
76+
77+
def run_decode(prefill_done):
78+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3"
79+
80+
from vllm import LLM, SamplingParams
81+
from vllm.config import KVTransferConfig
82+
83+
prompts = [
84+
"Hello, how are you today?", "Hi, what is your name?",
85+
"Tell me a very long story.", "what is your favourite book?"
86+
]
87+
sampling_params = SamplingParams(temperature=0, top_p=0.95)
88+
89+
ktc = KVTransferConfig.from_cli(
90+
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
91+
)
92+
93+
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
94+
kv_transfer_config=ktc,
95+
max_model_len=2000,
96+
gpu_memory_utilization=0.8,
97+
tensor_parallel_size=2)
98+
99+
# Wait for the producer to start the consumer
100+
print("Waiting for prefill node to finish...")
101+
prefill_done.wait()
102+
103+
# At this point when the prefill_done is set, the kv-cache should have been
104+
# transferred to this decode node, so we can start decoding.
105+
outputs = llm.generate(prompts, sampling_params)
106+
for output in outputs:
107+
prompt = output.prompt
108+
generated_text = output.outputs[0].text
109+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
110+
111+
del llm
112+
clean_up()
113+
114+
115+
if __name__ == "__main__":
116+
mp.get_context('spawn')
117+
118+
prefill_done = Event()
119+
process_close = Event()
120+
prefill_process = Process(target=run_prefill,
121+
args=(
122+
prefill_done,
123+
process_close,
124+
))
125+
decode_process = Process(target=run_decode, args=(prefill_done, ))
126+
127+
# Start prefill node
128+
prefill_process.start()
129+
130+
# Start decode node
131+
decode_process.start()
132+
133+
# Terminate the prefill node when decode is finished
134+
decode_process.join()
135+
136+
# Terminate prefill process
137+
process_close.set()
138+
prefill_process.join()
139+
prefill_process.terminate()
140+
print("All process done!")

vllm_ascend/distributed/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from vllm.distributed.kv_transfer.kv_connector.factory import \
2+
KVConnectorFactory
3+
4+
KVConnectorFactory.register_connector(
5+
"AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector",
6+
"LLMDataDistConnector")
File renamed without changes.

0 commit comments

Comments
 (0)