|
| 1 | +""" |
| 2 | +This file demonstrates the example usage of disaggregated prefilling We will |
| 3 | +launch 2 vllm instances (NPU 0,1,3,4 for prefill and NPU 5,6,7,8 for decode), |
| 4 | +and then transfer the KV cache between them. |
| 5 | +""" |
| 6 | + |
| 7 | +import multiprocessing as mp |
| 8 | +import os |
| 9 | +from multiprocessing import Event, Process, Queue |
| 10 | +from typing import List, Literal |
| 11 | + |
| 12 | + |
| 13 | +def get_kv_transfer_config(role: Literal["kv_producer", "kv_consumer"], |
| 14 | + local_server_id: str): |
| 15 | + kv_rank = 0 if role == "kv_producer" else 1 |
| 16 | + return f"""{{ |
| 17 | + "kv_connector": "AscendHcclConnectorV1", |
| 18 | + "kv_buffer_device": "npu", |
| 19 | + "kv_role": "{role}", |
| 20 | + "kv_rank": {kv_rank}, |
| 21 | + "kv_parallel_size": 2, |
| 22 | + "kv_connector_extra_config": {{ |
| 23 | + "local_server_id": "{local_server_id}" |
| 24 | + }} |
| 25 | + }}""" |
| 26 | + |
| 27 | + |
| 28 | +def clean_up(): |
| 29 | + import gc |
| 30 | + |
| 31 | + import torch |
| 32 | + from vllm.distributed.parallel_state import ( |
| 33 | + destroy_distributed_environment, destroy_model_parallel) |
| 34 | + |
| 35 | + destroy_model_parallel() |
| 36 | + destroy_distributed_environment() |
| 37 | + gc.collect() |
| 38 | + torch.npu.empty_cache() |
| 39 | + |
| 40 | + |
| 41 | +def run_prefill( |
| 42 | + prefill_done, |
| 43 | + process_close, |
| 44 | + prompt_q: Queue, |
| 45 | + prompts: List[str], |
| 46 | + model: str, |
| 47 | + local_server_id: str, |
| 48 | + visible_devices: str, |
| 49 | +): |
| 50 | + os.environ["VLLM_USE_V1"] = "1" |
| 51 | + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices |
| 52 | + tensor_parallel_size = len(visible_devices.split(",")) |
| 53 | + |
| 54 | + from vllm import LLM, SamplingParams |
| 55 | + from vllm.config import KVTransferConfig |
| 56 | + |
| 57 | + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) |
| 58 | + |
| 59 | + ktc = KVTransferConfig.from_cli( |
| 60 | + get_kv_transfer_config( |
| 61 | + role="kv_producer", |
| 62 | + local_server_id=local_server_id, |
| 63 | + )) |
| 64 | + |
| 65 | + llm = LLM( |
| 66 | + model=model, |
| 67 | + trust_remote_code=True, |
| 68 | + enforce_eager=True, |
| 69 | + enable_prefix_caching=False, |
| 70 | + kv_transfer_config=ktc, |
| 71 | + tensor_parallel_size=tensor_parallel_size, |
| 72 | + gpu_memory_utilization=0.9, |
| 73 | + max_model_len=40, |
| 74 | + ) |
| 75 | + |
| 76 | + result = llm.generate(prompts, sampling_params) |
| 77 | + for output in result: |
| 78 | + prompt = output.prompt |
| 79 | + generated_text = output.outputs[0].text |
| 80 | + print( |
| 81 | + f"[Prefill] Prompt: {prompt!r}, Generated text: {generated_text!r}" |
| 82 | + ) |
| 83 | + prompt_q.put(prompt + generated_text) |
| 84 | + prompt_q.close() |
| 85 | + |
| 86 | + print("[Prefill] DONE.") |
| 87 | + prefill_done.set() |
| 88 | + |
| 89 | + # To keep the prefill node running in case the decode node is not done; |
| 90 | + # otherwise, the script might exit prematurely, causing incomplete decoding. |
| 91 | + process_close.wait() |
| 92 | + |
| 93 | + del llm |
| 94 | + clean_up() |
| 95 | + |
| 96 | + |
| 97 | +def run_decode( |
| 98 | + prefill_done, |
| 99 | + prompt_q: Queue, |
| 100 | + num_prompts: int, |
| 101 | + model: str, |
| 102 | + local_server_id: str, |
| 103 | + visible_devices: str, |
| 104 | +): |
| 105 | + os.environ["VLLM_USE_V1"] = "1" |
| 106 | + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices |
| 107 | + tensor_parallel_size = len(visible_devices.split(",")) |
| 108 | + |
| 109 | + from vllm import LLM, SamplingParams |
| 110 | + from vllm.config import KVTransferConfig |
| 111 | + |
| 112 | + sampling_params = SamplingParams(temperature=0, top_p=0.95) |
| 113 | + |
| 114 | + ktc = KVTransferConfig.from_cli( |
| 115 | + get_kv_transfer_config( |
| 116 | + role="kv_consumer", |
| 117 | + local_server_id=local_server_id, |
| 118 | + )) |
| 119 | + |
| 120 | + llm = LLM( |
| 121 | + model=model, |
| 122 | + trust_remote_code=True, |
| 123 | + enforce_eager=True, |
| 124 | + enable_prefix_caching=False, |
| 125 | + kv_transfer_config=ktc, |
| 126 | + tensor_parallel_size=tensor_parallel_size, |
| 127 | + gpu_memory_utilization=0.9, |
| 128 | + max_model_len=40, |
| 129 | + ) |
| 130 | + |
| 131 | + # Wait for the producer to start the comsumer |
| 132 | + print("[Decode] Waiting for prefill node to finish...") |
| 133 | + prefill_done.wait() |
| 134 | + |
| 135 | + # Get the prompts from the queue |
| 136 | + prompts = [] |
| 137 | + for _ in range(num_prompts): |
| 138 | + prompts.append(prompt_q.get()) |
| 139 | + |
| 140 | + # At this point when the prefill_done is set, the kv-cache should have been |
| 141 | + # transferred to this decode node, so we can start decoding. |
| 142 | + outputs = llm.generate(prompts, sampling_params) |
| 143 | + for output in outputs: |
| 144 | + prompt = output.prompt |
| 145 | + generated_text = output.outputs[0].text |
| 146 | + print( |
| 147 | + f"[Decode] Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 148 | + print("[Decode] DONE.") |
| 149 | + |
| 150 | + # Must delete the llm instance, otherwise the process will not exit |
| 151 | + del llm |
| 152 | + clean_up() |
| 153 | + |
| 154 | + |
| 155 | +if __name__ == "__main__": |
| 156 | + mp.get_context("spawn") |
| 157 | + |
| 158 | + model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" |
| 159 | + |
| 160 | + # Set the server id and device ids for prefill and decode nodes |
| 161 | + prompt_server_id = "server-0" |
| 162 | + prompt_deivce_ids = "0,1,2,3" |
| 163 | + decode_server_id = "server-1" |
| 164 | + decode_device_ids = "4,5,6,7" |
| 165 | + |
| 166 | + prompts = [ |
| 167 | + "Hello, how are you today?", |
| 168 | + "Hi, what is your name?", |
| 169 | + "Tell me a very long story.", |
| 170 | + "what is your favourite book?", |
| 171 | + ] |
| 172 | + num_prompts = len(prompts) |
| 173 | + |
| 174 | + prompt_q: Queue = Queue(num_prompts) |
| 175 | + prefill_done = Event() |
| 176 | + process_close = Event() |
| 177 | + |
| 178 | + prefill_process = Process( |
| 179 | + target=run_prefill, |
| 180 | + args=(prefill_done, process_close, prompt_q, prompts, model, |
| 181 | + prompt_server_id, prompt_deivce_ids), |
| 182 | + ) |
| 183 | + decode_process = Process( |
| 184 | + target=run_decode, |
| 185 | + args=(prefill_done, prompt_q, num_prompts, model, decode_server_id, |
| 186 | + decode_device_ids), |
| 187 | + ) |
| 188 | + |
| 189 | + # Start prefill node |
| 190 | + prefill_process.start() |
| 191 | + # Start decode node |
| 192 | + decode_process.start() |
| 193 | + |
| 194 | + # Wait for decode process to finish |
| 195 | + decode_process.join() |
| 196 | + print("[Main] Decode process done.") |
| 197 | + |
| 198 | + # Terminate the prefill node, and wait for it to finish |
| 199 | + process_close.set() |
| 200 | + prefill_process.join() |
| 201 | + print("[Main] Prefill process done.") |
0 commit comments