Skip to content

[Feature][1/2] Impl the connector based on the llmdatadist for v1 #684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a quick ReadMe in disaggregated-prefill-v1 folder to describe how to use the example, i.e. run bash disaggregated_prefill_multi_prefill.sh and then xxxx


import json
import os

import aiohttp
from quart import Quart, make_response, request # type: ignore

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)

app = Quart(__name__)

PREFILL_ENDPOINT = "localhost:8100"
DECODE_ENDPOINT = "localhost:8200"


async def forward_request(url, data, headers: dict):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers.update({
"Authorization":
f"Bearer {os.environ.get('OPENAI_API_KEY')}",
})

async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes


@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
print(f"{request.headers.get('X-Request-ID')=}")

prefill_request = original_request_data.copy()
# Change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1

# Finish prefill
async for prefill_result in forward_request(
f"http://{PREFILL_ENDPOINT}/v1/completions",
prefill_request,
headers={
"X-Request-ID": request.headers.get("X-Request-ID"),
},
):
# Print the prefill result
print("===== Prefill result =====")
print(prefill_result.decode("utf-8"))
print("==========================")
response = json.loads(prefill_result.decode("utf-8"))
continue

# Get the prefill result token, and add it to the decoding request
decode_request = original_request_data.copy()
for idx, choices in enumerate(response.get("choices")):
decode_request["prompt"][idx] += choices.get("text")

# Return the decoding result
generator = forward_request(
f"http://{DECODE_ENDPOINT}/v1/completions",
decode_request,
headers={
"X-Request-ID": request.headers.get("X-Request-ID"),
},
)
response = await make_response(generator)
response.timeout = None

return response

except Exception as e:
import sys
import traceback

exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))


if __name__ == "__main__":
app.run(port=8000)
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/bin/bash
# This file demonstrates the example usage of disaggregated prefilling We will
# launch 2 vllm instances (1 for prefill and 1 for decode), and then transfer
# the KV cache between them.

set -xe

current_dir=$(dirname "$0")

# vLLM Environment configuration
export VLLM_USE_V1=1

# vLLM-Ascend Environment configuration
export GLOBAL_RANKTABLE="${current_dir}/global_ranktable.json"
# The following environment variables are required for LLMDataDist.
export PROMPT_DEVICE_ID=0,1,2,3
export DECODE_DEVICE_ID=4,5,6,7
export TENSOR_PARALLEL_SIZE=$(($(echo $PROMPT_DEVICE_ID | grep -o ',' | wc -l) + 1))

# Model Configuration
export MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# Generate the global rank table
if [ ! -f "${GLOBAL_RANKTABLE}" ]; then
echo "Generating global rank table..."
# TODO(jianzs): Impl a tool to generate the global rank table automatically
else
echo "Global rank table already exists."
fi

echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
sleep 1

# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT

# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}

# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &>/dev/null; then
echo "Quart is already installed."
else
echo "Quart is not installed. Installing..."
python3 -m pip install quart
fi

# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}

ASCEND_RT_VISIBLE_DEVICES=${PROMPT_DEVICE_ID} vllm serve ${MODEL_NAME} \
--port 8100 \
--max-model-len 100 \
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--enforce-eager \
--no-enable-prefix-caching \
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
--kv-transfer-config \
'{
"kv_connector": "AscendHcclConnectorV1",
"kv_buffer_device": "npu",
"kv_role": "kv_producer",
"kv_rank": 0,
"kv_parallel_size": 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What dose this kv_parallel_size do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The v0 implementation needed this, but I'm unsure if it's still necessary.

"kv_connector_extra_config": {
"local_server_id": "server-0"
}
}' &

ASCEND_RT_VISIBLE_DEVICES=${DECODE_DEVICE_ID} vllm serve ${MODEL_NAME} \
--port 8200 \
--max-model-len 100 \
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--enforce-eager \
--no-enable-prefix-caching \
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
--kv-transfer-config \
'{
"kv_connector": "AscendHcclConnectorV1",
"kv_buffer_device": "npu",
"kv_role": "kv_consumer",
"kv_rank": 1,
"kv_parallel_size": 2,
"kv_connector_extra_config": {
"local_server_id": "server-1"
}
}' &

# wait until prefill and decode instances are ready
wait_for_server 8100
wait_for_server 8200

echo "🚧🚧 Warning: server started 🚧🚧"

python3 disagg_prefill_proxy_server.py
198 changes: 198 additions & 0 deletions examples/disaggregated-prefill-v1/offline_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""
This file demonstrates the example usage of disaggregated prefilling We will
launch 2 vllm instances (NPU 0,1,3,4 for prefill and NPU 5,6,7,8 for decode),
and then transfer the KV cache between them.
"""

import multiprocessing as mp
import os
from multiprocessing import Event, Process, Queue
from typing import List, Literal


def get_kv_transfer_config(role: Literal["kv_producer", "kv_consumer"],
local_server_id: str):
kv_rank = 0 if role == "kv_producer" else 1
return f"""{{
"kv_connector": "AscendHcclConnectorV1",
"kv_buffer_device": "npu",
"kv_role": "{role}",
"kv_rank": {kv_rank},
"kv_parallel_size": 2
}}"""


def clean_up():
import gc

import torch
from vllm.distributed.parallel_state import (
destroy_distributed_environment, destroy_model_parallel)

destroy_model_parallel()
destroy_distributed_environment()
gc.collect()
torch.npu.empty_cache()


def run_prefill(
prefill_done,
process_close,
prompt_q: Queue,
prompts: List[str],
model: str,
local_server_id: str,
visible_devices: str,
):
os.environ["VLLM_USE_V1"] = "1"
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices
tensor_parallel_size = len(visible_devices.split(","))

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig

sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)

ktc = KVTransferConfig.from_cli(
get_kv_transfer_config(
role="kv_producer",
local_server_id=local_server_id,
))

llm = LLM(
model=model,
trust_remote_code=True,
enforce_eager=True,
enable_prefix_caching=False,
kv_transfer_config=ktc,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=0.9,
max_model_len=40,
)

result = llm.generate(prompts, sampling_params)
for output in result:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"[Prefill] Prompt: {prompt!r}, Generated text: {generated_text!r}"
)
prompt_q.put(prompt + generated_text)
prompt_q.close()

print("[Prefill] DONE.")
prefill_done.set()

# To keep the prefill node running in case the decode node is not done;
# otherwise, the script might exit prematurely, causing incomplete decoding.
process_close.wait()

del llm
clean_up()


def run_decode(
prefill_done,
prompt_q: Queue,
num_prompts: int,
model: str,
local_server_id: str,
visible_devices: str,
):
os.environ["VLLM_USE_V1"] = "1"
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices
tensor_parallel_size = len(visible_devices.split(","))

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig

sampling_params = SamplingParams(temperature=0, top_p=0.95)

ktc = KVTransferConfig.from_cli(
get_kv_transfer_config(
role="kv_consumer",
local_server_id=local_server_id,
))

llm = LLM(
model=model,
trust_remote_code=True,
enforce_eager=True,
enable_prefix_caching=False,
kv_transfer_config=ktc,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=0.9,
max_model_len=40,
)

# Wait for the producer to start the comsumer
print("[Decode] Waiting for prefill node to finish...")
prefill_done.wait()

# Get the prompts from the queue
prompts = []
for _ in range(num_prompts):
prompts.append(prompt_q.get())

# At this point when the prefill_done is set, the kv-cache should have been
# transferred to this decode node, so we can start decoding.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"[Decode] Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("[Decode] DONE.")

# Must delete the llm instance, otherwise the process will not exit
del llm
clean_up()


if __name__ == "__main__":
mp.get_context("spawn")

model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# Set the server id and device ids for prefill and decode nodes
prompt_server_id = "server-0"
prompt_deivce_ids = "0,1,2,3"
decode_server_id = "server-1"
decode_device_ids = "4,5,6,7"

prompts = [
"Hello, how are you today?",
"Hi, what is your name?",
"Tell me a very long story.",
"what is your favourite book?",
]
num_prompts = len(prompts)

prompt_q: Queue = Queue(num_prompts)
prefill_done = Event()
process_close = Event()

prefill_process = Process(
target=run_prefill,
args=(prefill_done, process_close, prompt_q, prompts, model,
prompt_server_id, prompt_deivce_ids),
)
decode_process = Process(
target=run_decode,
args=(prefill_done, prompt_q, num_prompts, model, decode_server_id,
decode_device_ids),
)

# Start prefill node
prefill_process.start()
# Start decode node
decode_process.start()

# Wait for decode process to finish
decode_process.join()
print("[Main] Decode process done.")

# Terminate the prefill node, and wait for it to finish
process_close.set()
prefill_process.join()
print("[Main] Prefill process done.")
Loading