-
Notifications
You must be signed in to change notification settings - Fork 159
[Feature][1/2] Impl the connector based on the llmdatadist for v1 #684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jianzs
wants to merge
17
commits into
vllm-project:main
Choose a base branch
from
jianzs:zhengsj/datadist-conn-v1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
884ed70
feat: impl the connector based on the llmdatadist for v1
jianzs 08e158e
feat: resolve npu_reshape_and_cache error
jianzs e0bb3fb
chore: lint code
jianzs 82e62e9
feat: add offline inference example
jianzs b04a6c7
feat: simplify 1p1d startup
jianzs 73489fd
chore: rename file
jianzs 3ae97db
chore: lint code
jianzs f26c0e0
fix: resolve import issue when running with vllm 0.8.4
jianzs d967a87
chore: remove v0.8.4 patch
jianzs 0a1ddd2
chore: refine the init optins of datadist
jianzs 5ff5a06
feat: refine linking logic
jianzs 0d90c24
fix: manage KV cache buffer lifecycle to prevent premature deallocation
jianzs 791993a
fix: correct finding the kv cache shape for mha
jianzs 5c752ed
fix: reverse iteration over link_rets to safely remove clusters
jianzs cd57ee5
fix: correct block_ids assignment
jianzs 85eb470
fix: typo
jianzs 540a57e
avoid memory lack
jianzs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
85 changes: 85 additions & 0 deletions
85
examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import os | ||
|
||
import aiohttp | ||
from quart import Quart, make_response, request # type: ignore | ||
|
||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) | ||
|
||
app = Quart(__name__) | ||
|
||
PREFILL_ENDPOINT = "localhost:8100" | ||
DECODE_ENDPOINT = "localhost:8200" | ||
|
||
|
||
async def forward_request(url, data, headers: dict): | ||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: | ||
headers.update({ | ||
"Authorization": | ||
f"Bearer {os.environ.get('OPENAI_API_KEY')}", | ||
}) | ||
|
||
async with session.post(url=url, json=data, | ||
headers=headers) as response: | ||
if response.status == 200: | ||
async for chunk_bytes in response.content.iter_chunked(1024): | ||
yield chunk_bytes | ||
|
||
|
||
@app.route("/v1/completions", methods=["POST"]) | ||
async def handle_request(): | ||
try: | ||
original_request_data = await request.get_json() | ||
print(f"{request.headers.get('X-Request-ID')=}") | ||
|
||
prefill_request = original_request_data.copy() | ||
# Change max_tokens = 1 to let it only do prefill | ||
prefill_request["max_tokens"] = 1 | ||
|
||
# Finish prefill | ||
async for prefill_result in forward_request( | ||
f"http://{PREFILL_ENDPOINT}/v1/completions", | ||
prefill_request, | ||
headers={ | ||
"X-Request-ID": request.headers.get("X-Request-ID"), | ||
}, | ||
): | ||
# Print the prefill result | ||
print("===== Prefill result =====") | ||
print(prefill_result.decode("utf-8")) | ||
print("==========================") | ||
response = json.loads(prefill_result.decode("utf-8")) | ||
continue | ||
|
||
# Get the prefill result token, and add it to the decoding request | ||
decode_request = original_request_data.copy() | ||
for idx, choices in enumerate(response.get("choices")): | ||
decode_request["prompt"][idx] += choices.get("text") | ||
|
||
# Return the decoding result | ||
generator = forward_request( | ||
f"http://{DECODE_ENDPOINT}/v1/completions", | ||
decode_request, | ||
headers={ | ||
"X-Request-ID": request.headers.get("X-Request-ID"), | ||
}, | ||
) | ||
response = await make_response(generator) | ||
response.timeout = None | ||
|
||
return response | ||
|
||
except Exception as e: | ||
import sys | ||
import traceback | ||
|
||
exc_info = sys.exc_info() | ||
print("Error occurred in disagg prefill proxy server") | ||
print(e) | ||
print("".join(traceback.format_exception(*exc_info))) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(port=8000) |
110 changes: 110 additions & 0 deletions
110
examples/disaggregated-prefill-v1/disaggregated_prefill_multi_prefill.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#!/bin/bash | ||
# This file demonstrates the example usage of disaggregated prefilling We will | ||
# launch 2 vllm instances (1 for prefill and 1 for decode), and then transfer | ||
# the KV cache between them. | ||
|
||
set -xe | ||
|
||
current_dir=$(dirname "$0") | ||
|
||
# vLLM Environment configuration | ||
export VLLM_USE_V1=1 | ||
|
||
# vLLM-Ascend Environment configuration | ||
export GLOBAL_RANKTABLE="${current_dir}/global_ranktable.json" | ||
# The following environment variables are required for LLMDataDist. | ||
export PROMPT_DEVICE_ID=0,1,2,3 | ||
export DECODE_DEVICE_ID=4,5,6,7 | ||
export TENSOR_PARALLEL_SIZE=$(($(echo $PROMPT_DEVICE_ID | grep -o ',' | wc -l) + 1)) | ||
|
||
# Model Configuration | ||
export MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | ||
|
||
# Generate the global rank table | ||
if [ ! -f "${GLOBAL_RANKTABLE}" ]; then | ||
echo "Generating global rank table..." | ||
# TODO(jianzs): Impl a tool to generate the global rank table automatically | ||
else | ||
echo "Global rank table already exists." | ||
fi | ||
|
||
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧" | ||
sleep 1 | ||
|
||
# Trap the SIGINT signal (triggered by Ctrl+C) | ||
trap 'cleanup' INT | ||
|
||
# Cleanup function | ||
cleanup() { | ||
echo "Caught Ctrl+C, cleaning up..." | ||
# Cleanup commands | ||
pgrep python | xargs kill -9 | ||
pkill -f python | ||
echo "Cleanup complete. Exiting." | ||
exit 0 | ||
} | ||
|
||
# install quart first -- required for disagg prefill proxy serve | ||
if python3 -c "import quart" &>/dev/null; then | ||
echo "Quart is already installed." | ||
else | ||
echo "Quart is not installed. Installing..." | ||
python3 -m pip install quart | ||
fi | ||
|
||
# a function that waits vLLM server to start | ||
wait_for_server() { | ||
local port=$1 | ||
timeout 1200 bash -c " | ||
until curl -s localhost:${port}/v1/completions > /dev/null; do | ||
sleep 1 | ||
done" && return 0 || return 1 | ||
} | ||
|
||
ASCEND_RT_VISIBLE_DEVICES=${PROMPT_DEVICE_ID} vllm serve ${MODEL_NAME} \ | ||
--port 8100 \ | ||
--max-model-len 100 \ | ||
--gpu-memory-utilization 0.9 \ | ||
--trust-remote-code \ | ||
--enforce-eager \ | ||
--no-enable-prefix-caching \ | ||
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \ | ||
--kv-transfer-config \ | ||
'{ | ||
"kv_connector": "AscendHcclConnectorV1", | ||
"kv_buffer_device": "npu", | ||
"kv_role": "kv_producer", | ||
"kv_rank": 0, | ||
"kv_parallel_size": 2, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What dose this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The v0 implementation needed this, but I'm unsure if it's still necessary. |
||
"kv_connector_extra_config": { | ||
"local_server_id": "server-0" | ||
} | ||
}' & | ||
|
||
ASCEND_RT_VISIBLE_DEVICES=${DECODE_DEVICE_ID} vllm serve ${MODEL_NAME} \ | ||
--port 8200 \ | ||
--max-model-len 100 \ | ||
--gpu-memory-utilization 0.9 \ | ||
--trust-remote-code \ | ||
--enforce-eager \ | ||
--no-enable-prefix-caching \ | ||
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \ | ||
--kv-transfer-config \ | ||
'{ | ||
"kv_connector": "AscendHcclConnectorV1", | ||
"kv_buffer_device": "npu", | ||
"kv_role": "kv_consumer", | ||
"kv_rank": 1, | ||
"kv_parallel_size": 2, | ||
"kv_connector_extra_config": { | ||
"local_server_id": "server-1" | ||
} | ||
}' & | ||
|
||
# wait until prefill and decode instances are ready | ||
wait_for_server 8100 | ||
wait_for_server 8200 | ||
|
||
echo "🚧🚧 Warning: server started 🚧🚧" | ||
|
||
python3 disagg_prefill_proxy_server.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
""" | ||
This file demonstrates the example usage of disaggregated prefilling We will | ||
launch 2 vllm instances (NPU 0,1,3,4 for prefill and NPU 5,6,7,8 for decode), | ||
and then transfer the KV cache between them. | ||
""" | ||
|
||
import multiprocessing as mp | ||
import os | ||
from multiprocessing import Event, Process, Queue | ||
from typing import List, Literal | ||
|
||
|
||
def get_kv_transfer_config(role: Literal["kv_producer", "kv_consumer"], | ||
local_server_id: str): | ||
kv_rank = 0 if role == "kv_producer" else 1 | ||
return f"""{{ | ||
"kv_connector": "AscendHcclConnectorV1", | ||
"kv_buffer_device": "npu", | ||
"kv_role": "{role}", | ||
"kv_rank": {kv_rank}, | ||
"kv_parallel_size": 2 | ||
}}""" | ||
|
||
|
||
def clean_up(): | ||
import gc | ||
|
||
import torch | ||
from vllm.distributed.parallel_state import ( | ||
destroy_distributed_environment, destroy_model_parallel) | ||
|
||
destroy_model_parallel() | ||
destroy_distributed_environment() | ||
gc.collect() | ||
torch.npu.empty_cache() | ||
|
||
|
||
def run_prefill( | ||
prefill_done, | ||
process_close, | ||
prompt_q: Queue, | ||
prompts: List[str], | ||
model: str, | ||
local_server_id: str, | ||
visible_devices: str, | ||
): | ||
os.environ["VLLM_USE_V1"] = "1" | ||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices | ||
tensor_parallel_size = len(visible_devices.split(",")) | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.config import KVTransferConfig | ||
|
||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) | ||
|
||
ktc = KVTransferConfig.from_cli( | ||
get_kv_transfer_config( | ||
role="kv_producer", | ||
local_server_id=local_server_id, | ||
)) | ||
|
||
llm = LLM( | ||
model=model, | ||
trust_remote_code=True, | ||
enforce_eager=True, | ||
enable_prefix_caching=False, | ||
kv_transfer_config=ktc, | ||
tensor_parallel_size=tensor_parallel_size, | ||
gpu_memory_utilization=0.9, | ||
max_model_len=40, | ||
) | ||
|
||
result = llm.generate(prompts, sampling_params) | ||
for output in result: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print( | ||
f"[Prefill] Prompt: {prompt!r}, Generated text: {generated_text!r}" | ||
) | ||
prompt_q.put(prompt + generated_text) | ||
prompt_q.close() | ||
|
||
print("[Prefill] DONE.") | ||
prefill_done.set() | ||
|
||
# To keep the prefill node running in case the decode node is not done; | ||
# otherwise, the script might exit prematurely, causing incomplete decoding. | ||
process_close.wait() | ||
|
||
del llm | ||
clean_up() | ||
|
||
|
||
def run_decode( | ||
prefill_done, | ||
prompt_q: Queue, | ||
num_prompts: int, | ||
model: str, | ||
local_server_id: str, | ||
visible_devices: str, | ||
): | ||
os.environ["VLLM_USE_V1"] = "1" | ||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices | ||
tensor_parallel_size = len(visible_devices.split(",")) | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.config import KVTransferConfig | ||
|
||
sampling_params = SamplingParams(temperature=0, top_p=0.95) | ||
|
||
ktc = KVTransferConfig.from_cli( | ||
get_kv_transfer_config( | ||
role="kv_consumer", | ||
local_server_id=local_server_id, | ||
)) | ||
|
||
llm = LLM( | ||
model=model, | ||
trust_remote_code=True, | ||
enforce_eager=True, | ||
enable_prefix_caching=False, | ||
kv_transfer_config=ktc, | ||
tensor_parallel_size=tensor_parallel_size, | ||
gpu_memory_utilization=0.9, | ||
max_model_len=40, | ||
) | ||
|
||
# Wait for the producer to start the consumer | ||
print("[Decode] Waiting for prefill node to finish...") | ||
prefill_done.wait() | ||
|
||
# Get the prompts from the queue | ||
prompts = [] | ||
for _ in range(num_prompts): | ||
prompts.append(prompt_q.get()) | ||
|
||
# At this point when the prefill_done is set, the kv-cache should have been | ||
# transferred to this decode node, so we can start decoding. | ||
outputs = llm.generate(prompts, sampling_params) | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print( | ||
f"[Decode] Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
print("[Decode] DONE.") | ||
|
||
# Must delete the llm instance, otherwise the process will not exit | ||
del llm | ||
clean_up() | ||
|
||
|
||
if __name__ == "__main__": | ||
mp.get_context("spawn") | ||
|
||
model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | ||
|
||
# Set the server id and device ids for prefill and decode nodes | ||
prompt_server_id = "server-0" | ||
prompt_deivce_ids = "0,1,2,3" | ||
decode_server_id = "server-1" | ||
decode_device_ids = "4,5,6,7" | ||
|
||
prompts = [ | ||
"Hello, how are you today?", | ||
"Hi, what is your name?", | ||
"Tell me a very long story.", | ||
"what is your favourite book?", | ||
] | ||
num_prompts = len(prompts) | ||
|
||
prompt_q: Queue = Queue(num_prompts) | ||
prefill_done = Event() | ||
process_close = Event() | ||
|
||
prefill_process = Process( | ||
target=run_prefill, | ||
args=(prefill_done, process_close, prompt_q, prompts, model, | ||
prompt_server_id, prompt_deivce_ids), | ||
) | ||
decode_process = Process( | ||
target=run_decode, | ||
args=(prefill_done, prompt_q, num_prompts, model, decode_server_id, | ||
decode_device_ids), | ||
) | ||
|
||
# Start prefill node | ||
prefill_process.start() | ||
# Start decode node | ||
decode_process.start() | ||
|
||
# Wait for decode process to finish | ||
decode_process.join() | ||
print("[Main] Decode process done.") | ||
|
||
# Terminate the prefill node, and wait for it to finish | ||
process_close.set() | ||
prefill_process.join() | ||
print("[Main] Prefill process done.") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a quick ReadMe in
disaggregated-prefill-v1
folder to describe how to use the example, i.e. runbash disaggregated_prefill_multi_prefill.sh
and then xxxx