Skip to content

Commit 7fb466d

Browse files
committed
feat: impl the connector based on the llmdatadist for v1
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 79538b5 commit 7fb466d

File tree

8 files changed

+1153
-2
lines changed

8 files changed

+1153
-2
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import json
4+
import os
5+
6+
import aiohttp
7+
from quart import Quart, make_response, request
8+
9+
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
10+
11+
app = Quart(__name__)
12+
13+
PREFILL_ENDPOINT = "localhost:8100"
14+
DECODE_ENDPOINT = "localhost:8200"
15+
16+
17+
async def forward_request(url, data, headers: dict):
18+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
19+
headers.update({
20+
"Authorization":
21+
f"Bearer {os.environ.get('OPENAI_API_KEY')}",
22+
})
23+
24+
async with session.post(url=url, json=data,
25+
headers=headers) as response:
26+
if response.status == 200:
27+
async for chunk_bytes in response.content.iter_chunked(1024):
28+
yield chunk_bytes
29+
30+
31+
@app.route("/v1/completions", methods=["POST"])
32+
async def handle_request():
33+
try:
34+
original_request_data = await request.get_json()
35+
print(f"{request.headers.get('X-Request-ID')=}")
36+
37+
prefill_request = original_request_data.copy()
38+
# Change max_tokens = 1 to let it only do prefill
39+
prefill_request["max_tokens"] = 1
40+
41+
# Finish prefill
42+
async for prefill_result in forward_request(
43+
f"http://{PREFILL_ENDPOINT}/v1/completions",
44+
prefill_request,
45+
headers={
46+
"X-Request-ID": request.headers.get("X-Request-ID"),
47+
},
48+
):
49+
# Print the prefill result
50+
print(f"===== Prefill result =====")
51+
print(prefill_result.decode("utf-8"))
52+
print("==========================")
53+
response = json.loads(prefill_result.decode("utf-8"))
54+
continue
55+
56+
# Get the prefill result token, and add it to the decoding request
57+
decode_request = original_request_data.copy()
58+
for idx, choices in enumerate(response.get("choices")):
59+
decode_request["prompt"][idx] += choices.get("text")
60+
61+
# Return the decoding result
62+
generator = forward_request(
63+
f"http://{DECODE_ENDPOINT}/v1/completions",
64+
decode_request,
65+
headers={
66+
"X-Request-ID": request.headers.get("X-Request-ID"),
67+
},
68+
)
69+
response = await make_response(generator)
70+
response.timeout = None
71+
72+
return response
73+
74+
except Exception as e:
75+
import sys
76+
import traceback
77+
78+
exc_info = sys.exc_info()
79+
print("Error occurred in disagg prefill proxy server")
80+
print(e)
81+
print("".join(traceback.format_exception(*exc_info)))
82+
83+
84+
if __name__ == "__main__":
85+
app.run(port=8000)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#!/bin/bash
2+
# This file demonstrates the example usage of disaggregated prefilling We will
3+
# launch 2 vllm instances (1 for prefill and 1 for decode), and then transfer
4+
# the KV cache between them.
5+
6+
set -xe
7+
8+
current_dir=$(dirname "$0")
9+
10+
# vLLM Environment configuration
11+
export VLLM_USE_V1=1
12+
13+
# vLLM-Ascend Environment configuration
14+
export GLOBAL_RANKTABLE="${current_dir}/global_ranktable.json"
15+
# The following environment variables are required for LLMDataDist.
16+
export PROMPT_DEVICE_ID=0,1,2,3
17+
export DECODE_DEVICE_ID=4,5,6,7
18+
export TENSOR_PARALLEL_SIZE=$(($(echo $PROMPT_DEVICE_ID | grep -o ',' | wc -l) + 1))
19+
20+
# Model Configuration
21+
export MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
22+
23+
# Generate the global rank table
24+
if [ ! -f "${GLOBAL_RANKTABLE}" ]; then
25+
echo "Generating global rank table..."
26+
# TODO(jianzs): Impl a tool to generate the global rank table automatically
27+
else
28+
echo "Global rank table already exists."
29+
fi
30+
31+
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
32+
sleep 1
33+
34+
# Trap the SIGINT signal (triggered by Ctrl+C)
35+
trap 'cleanup' INT
36+
37+
# Cleanup function
38+
cleanup() {
39+
echo "Caught Ctrl+C, cleaning up..."
40+
# Cleanup commands
41+
pgrep python | xargs kill -9
42+
pkill -f python
43+
echo "Cleanup complete. Exiting."
44+
exit 0
45+
}
46+
47+
# install quart first -- required for disagg prefill proxy serve
48+
if python3 -c "import quart" &>/dev/null; then
49+
echo "Quart is already installed."
50+
else
51+
echo "Quart is not installed. Installing..."
52+
python3 -m pip install quart
53+
fi
54+
55+
# a function that waits vLLM server to start
56+
wait_for_server() {
57+
local port=$1
58+
timeout 1200 bash -c "
59+
until curl -s localhost:${port}/v1/completions > /dev/null; do
60+
sleep 1
61+
done" && return 0 || return 1
62+
}
63+
64+
ASCEND_RT_VISIBLE_DEVICES=${PROMPT_DEVICE_ID} vllm serve ${MODEL_NAME} \
65+
--port 8100 \
66+
--max-model-len 100 \
67+
--gpu-memory-utilization 0.9 \
68+
--trust-remote-code \
69+
--enforce-eager \
70+
--no-enable-prefix-caching \
71+
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
72+
--kv-transfer-config \
73+
'{
74+
"kv_connector": "AscendHcclConnectorV1",
75+
"kv_buffer_device": "npu",
76+
"kv_role": "kv_producer",
77+
"kv_rank": 0,
78+
"kv_parallel_size": 2,
79+
"kv_connector_extra_config": {
80+
"local_server_id": "server-0"
81+
}
82+
}' &
83+
84+
ASCEND_RT_VISIBLE_DEVICES=${DECODE_DEVICE_ID} vllm serve ${MODEL_NAME} \
85+
--port 8200 \
86+
--max-model-len 100 \
87+
--gpu-memory-utilization 0.9 \
88+
--trust-remote-code \
89+
--enforce-eager \
90+
--no-enable-prefix-caching \
91+
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
92+
--kv-transfer-config \
93+
'{
94+
"kv_connector": "AscendHcclConnectorV1",
95+
"kv_buffer_device": "npu",
96+
"kv_role": "kv_consumer",
97+
"kv_rank": 1,
98+
"kv_parallel_size": 2,
99+
"kv_connector_extra_config": {
100+
"local_server_id": "server-1"
101+
}
102+
}' &
103+
104+
# wait until prefill and decode instances are ready
105+
wait_for_server 8100
106+
wait_for_server 8200
107+
108+
echo "🚧🚧 Warning: server started 🚧🚧"
109+
110+
python3 disagg_prefill_proxy_server.py
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/bin/bash
2+
3+
# Make sure the model is same as the one used in the server
4+
MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
5+
REQUEST_ID=request$RANDOM
6+
7+
curl http://localhost:8000/v1/completions \
8+
-H "Content-Type: application/json" \
9+
-H "X-Request-ID: ${REQUEST_ID}" \
10+
-d '{
11+
"ignore_eos": false,
12+
"stream": false,
13+
"stop": "None",
14+
"temperature": 0.5,
15+
"top_k": -1,
16+
"top_p": 1,
17+
"model": "'${MODEL_NAME}'",
18+
"prompt": [
19+
"In 2020, who won the world series?",
20+
"In 2019, Who won the world series?"
21+
],
22+
"max_tokens": 40
23+
}'

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ def build(self,
242242
max_seq_lens = seq_lens[:self._num_decodes].max().item()
243243
decode_metadata = AscendMLADecodeMetadata(
244244
input_positions=input_positions[:self._num_decode_tokens],
245-
block_table=block_table[:self._num_decode_tokens, ...],
246-
seq_lens=seq_lens[:self._num_decode_tokens],
245+
block_table=block_table[:self._num_decodes, ...],
246+
seq_lens=seq_lens[:self._num_decodes],
247247
max_seq_lens=max_seq_lens)
248248

249249
return self.metadata_cls( # type: ignore

vllm_ascend/distributed/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,8 @@
2525
KVConnectorFactory.register_connector(
2626
"AscendSimpleConnector",
2727
"vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector")
28+
29+
KVConnectorFactory.register_connector(
30+
"AscendHcclConnectorV1",
31+
"vllm_ascend.distributed.llmdatadist_connector_v1",
32+
"LLMDataDistConnectorV1")

0 commit comments

Comments
 (0)