1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ # Adapted from vllm-project/vllm/examples/offline_inference/basic.py
5
+ # Copyright 2023 The vLLM team.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
19
+ import multiprocessing as mp
20
+ import os
21
+ import time
22
+ from multiprocessing import Event , Process
23
+
24
+
25
+ def clean_up ():
26
+ import gc
27
+
28
+ import torch
29
+ from vllm .distributed .parallel_state import (
30
+ destroy_distributed_environment , destroy_model_parallel )
31
+ destroy_model_parallel ()
32
+ destroy_distributed_environment ()
33
+ gc .collect ()
34
+ torch .npu .empty_cache ()
35
+
36
+
37
+ def run_prefill (prefill_done , process_close ):
38
+ os .environ ["ASCEND_RT_VISIBLE_DEVICES" ] = "0,1"
39
+
40
+ from vllm import LLM , SamplingParams
41
+ from vllm .config import KVTransferConfig
42
+
43
+ prompts = [
44
+ "Hello, how are you today?" , "Hi, what is your name?" ,
45
+ "Tell me a very long story." , "what is your favourite book?"
46
+ ]
47
+ sampling_params = SamplingParams (temperature = 0 , top_p = 0.95 , max_tokens = 1 )
48
+
49
+ ktc = KVTransferConfig .from_cli (
50
+ '{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
51
+ )
52
+
53
+ # Set NPU memory utilization to 0.8
54
+ llm = LLM (model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" ,
55
+ kv_transfer_config = ktc ,
56
+ max_model_len = 2000 ,
57
+ gpu_memory_utilization = 0.8 ,
58
+ tensor_parallel_size = 2 )
59
+
60
+ llm .generate (prompts , sampling_params )
61
+ print ("Prefill node is finished." )
62
+ prefill_done .set ()
63
+
64
+ # To keep the prefill node running in case the decode node is not done;
65
+ # otherwise, the script might exit prematurely, causing incomplete decoding.
66
+ try :
67
+ while not process_close .is_set ():
68
+ time .sleep (1 )
69
+ except KeyboardInterrupt :
70
+ print ("Script stopped by user." )
71
+ finally :
72
+ print ("Cleanup prefill resources" )
73
+ del llm
74
+ clean_up ()
75
+
76
+
77
+ def run_decode (prefill_done ):
78
+ os .environ ["ASCEND_RT_VISIBLE_DEVICES" ] = "2,3"
79
+
80
+ from vllm import LLM , SamplingParams
81
+ from vllm .config import KVTransferConfig
82
+
83
+ prompts = [
84
+ "Hello, how are you today?" , "Hi, what is your name?" ,
85
+ "Tell me a very long story." , "what is your favourite book?"
86
+ ]
87
+ sampling_params = SamplingParams (temperature = 0 , top_p = 0.95 )
88
+
89
+ ktc = KVTransferConfig .from_cli (
90
+ '{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
91
+ )
92
+
93
+ llm = LLM (model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" ,
94
+ kv_transfer_config = ktc ,
95
+ max_model_len = 2000 ,
96
+ gpu_memory_utilization = 0.8 ,
97
+ tensor_parallel_size = 2 )
98
+
99
+ # Wait for the producer to start the consumer
100
+ print ("Waiting for prefill node to finish..." )
101
+ prefill_done .wait ()
102
+
103
+ # At this point when the prefill_done is set, the kv-cache should have been
104
+ # transferred to this decode node, so we can start decoding.
105
+ outputs = llm .generate (prompts , sampling_params )
106
+ for output in outputs :
107
+ prompt = output .prompt
108
+ generated_text = output .outputs [0 ].text
109
+ print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
110
+
111
+ del llm
112
+ clean_up ()
113
+
114
+
115
+ if __name__ == "__main__" :
116
+ mp .get_context ('spawn' )
117
+
118
+ prefill_done = Event ()
119
+ process_close = Event ()
120
+ prefill_process = Process (target = run_prefill ,
121
+ args = (
122
+ prefill_done ,
123
+ process_close ,
124
+ ))
125
+ decode_process = Process (target = run_decode , args = (prefill_done , ))
126
+
127
+ # Start prefill node
128
+ prefill_process .start ()
129
+
130
+ # Start decode node
131
+ decode_process .start ()
132
+
133
+ # Terminate the prefill node when decode is finished
134
+ decode_process .join ()
135
+
136
+ # Terminate prefill process
137
+ process_close .set ()
138
+ prefill_process .join ()
139
+ prefill_process .terminate ()
140
+ print ("All process done!" )
0 commit comments