Skip to content

Commit a156db7

Browse files
committed
feat: add offline inference example
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 561679b commit a156db7

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""
2+
This file demonstrates the example usage of disaggregated prefilling We will
3+
launch 2 vllm instances (NPU 0,1,3,4 for prefill and NPU 5,6,7,8 for decode),
4+
and then transfer the KV cache between them.
5+
"""
6+
7+
import multiprocessing as mp
8+
import os
9+
from multiprocessing import Event, Process, Queue
10+
from typing import List, Literal
11+
12+
13+
def get_kv_transfer_config(role: Literal["kv_producer", "kv_consumer"],
14+
local_server_id: str):
15+
kv_rank = 0 if role == "kv_producer" else 1
16+
return f"""{{
17+
"kv_connector": "AscendHcclConnectorV1",
18+
"kv_buffer_device": "npu",
19+
"kv_role": "{role}",
20+
"kv_rank": {kv_rank},
21+
"kv_parallel_size": 2,
22+
"kv_connector_extra_config": {{
23+
"local_server_id": "{local_server_id}"
24+
}}
25+
}}"""
26+
27+
28+
def clean_up():
29+
import gc
30+
31+
import torch
32+
from vllm.distributed.parallel_state import (
33+
destroy_distributed_environment, destroy_model_parallel)
34+
35+
destroy_model_parallel()
36+
destroy_distributed_environment()
37+
gc.collect()
38+
torch.npu.empty_cache()
39+
40+
41+
def run_prefill(
42+
prefill_done,
43+
process_close,
44+
prompt_q: Queue,
45+
prompts: List[str],
46+
model: str,
47+
local_server_id: str,
48+
visible_devices: str,
49+
):
50+
os.environ["VLLM_USE_V1"] = "1"
51+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices
52+
tensor_parallel_size = len(visible_devices.split(","))
53+
54+
from vllm import LLM, SamplingParams
55+
from vllm.config import KVTransferConfig
56+
57+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
58+
59+
ktc = KVTransferConfig.from_cli(
60+
get_kv_transfer_config(
61+
role="kv_producer",
62+
local_server_id=local_server_id,
63+
))
64+
65+
llm = LLM(
66+
model=model,
67+
trust_remote_code=True,
68+
enforce_eager=True,
69+
enable_prefix_caching=False,
70+
kv_transfer_config=ktc,
71+
tensor_parallel_size=tensor_parallel_size,
72+
gpu_memory_utilization=0.9,
73+
max_model_len=40,
74+
)
75+
76+
result = llm.generate(prompts, sampling_params)
77+
for output in result:
78+
prompt = output.prompt
79+
generated_text = output.outputs[0].text
80+
print(
81+
f"[Prefill] Prompt: {prompt!r}, Generated text: {generated_text!r}"
82+
)
83+
prompt_q.put(prompt + generated_text)
84+
prompt_q.close()
85+
86+
print("[Prefill] DONE.")
87+
prefill_done.set()
88+
89+
# To keep the prefill node running in case the decode node is not done;
90+
# otherwise, the script might exit prematurely, causing incomplete decoding.
91+
process_close.wait()
92+
93+
del llm
94+
clean_up()
95+
96+
97+
def run_decode(
98+
prefill_done,
99+
prompt_q: Queue,
100+
num_prompts: int,
101+
model: str,
102+
local_server_id: str,
103+
visible_devices: str,
104+
):
105+
os.environ["VLLM_USE_V1"] = "1"
106+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices
107+
tensor_parallel_size = len(visible_devices.split(","))
108+
109+
from vllm import LLM, SamplingParams
110+
from vllm.config import KVTransferConfig
111+
112+
sampling_params = SamplingParams(temperature=0, top_p=0.95)
113+
114+
ktc = KVTransferConfig.from_cli(
115+
get_kv_transfer_config(
116+
role="kv_consumer",
117+
local_server_id=local_server_id,
118+
))
119+
120+
llm = LLM(
121+
model=model,
122+
trust_remote_code=True,
123+
enforce_eager=True,
124+
enable_prefix_caching=False,
125+
kv_transfer_config=ktc,
126+
tensor_parallel_size=tensor_parallel_size,
127+
gpu_memory_utilization=0.9,
128+
max_model_len=40,
129+
)
130+
131+
# Wait for the producer to start the comsumer
132+
print("[Decode] Waiting for prefill node to finish...")
133+
prefill_done.wait()
134+
135+
# Get the prompts from the queue
136+
prompts = []
137+
for _ in range(num_prompts):
138+
prompts.append(prompt_q.get())
139+
140+
# At this point when the prefill_done is set, the kv-cache should have been
141+
# transferred to this decode node, so we can start decoding.
142+
outputs = llm.generate(prompts, sampling_params)
143+
for output in outputs:
144+
prompt = output.prompt
145+
generated_text = output.outputs[0].text
146+
print(
147+
f"[Decode] Prompt: {prompt!r}, Generated text: {generated_text!r}")
148+
print("[Decode] DONE.")
149+
150+
# Must delete the llm instance, otherwise the process will not exit
151+
del llm
152+
clean_up()
153+
154+
155+
if __name__ == "__main__":
156+
mp.get_context("spawn")
157+
158+
model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
159+
160+
# Set the server id and device ids for prefill and decode nodes
161+
prompt_server_id = "server-0"
162+
prompt_deivce_ids = "0,1,2,3"
163+
decode_server_id = "server-1"
164+
decode_device_ids = "4,5,6,7"
165+
166+
prompts = [
167+
"Hello, how are you today?",
168+
"Hi, what is your name?",
169+
"Tell me a very long story.",
170+
"what is your favourite book?",
171+
]
172+
num_prompts = len(prompts)
173+
174+
prompt_q: Queue = Queue(num_prompts)
175+
prefill_done = Event()
176+
process_close = Event()
177+
178+
prefill_process = Process(
179+
target=run_prefill,
180+
args=(prefill_done, process_close, prompt_q, prompts, model,
181+
prompt_server_id, prompt_deivce_ids),
182+
)
183+
decode_process = Process(
184+
target=run_decode,
185+
args=(prefill_done, prompt_q, num_prompts, model, decode_server_id,
186+
decode_device_ids),
187+
)
188+
189+
# Start prefill node
190+
prefill_process.start()
191+
# Start decode node
192+
decode_process.start()
193+
194+
# Wait for decode process to finish
195+
decode_process.join()
196+
print("[Main] Decode process done.")
197+
198+
# Terminate the prefill node, and wait for it to finish
199+
process_close.set()
200+
prefill_process.join()
201+
print("[Main] Prefill process done.")

0 commit comments

Comments
 (0)