Skip to content

Commit f984e3b

Browse files
zzhx1jianzs
authored andcommitted
chore: update the pd_example script and rank_table_utils.py
1 parent db1288c commit f984e3b

File tree

3 files changed

+85
-55
lines changed

3 files changed

+85
-55
lines changed

examples/disaggregated-prefill-v1/disaggregated_prefill_multi_prefill.sh

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,21 @@ export VLLM_USE_V1=1
1313
# vLLM-Ascend Environment configuration
1414
export GLOBAL_RANKTABLE="${current_dir}/global_ranktable.json"
1515
# The following environment variables are required for LLMDataDist.
16-
export PROMPT_DEVICE_ID=0,1,2,3
17-
export DECODE_DEVICE_ID=4,5,6,7
16+
export PROMPT_DEVICE_ID_0=0,1,2,3
17+
export DECODE_DEVICE_ID_0=4,5,6,7
1818
export NUM_PROMPT_INSTANCE=1
1919
export NUM_DECODE_INSTANCE=1
2020

21-
export TENSOR_PARALLEL_SIZE=$(($(echo $PROMPT_DEVICE_ID | grep -o ',' | wc -l) + 1))
21+
export TENSOR_PARALLEL_SIZE=$(($(echo $PROMPT_DEVICE_ID_0 | grep -o ',' | wc -l) + 1))
2222

2323
# Model Configuration
2424
export MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
25-
2625
# Generate the global rank table
2726
if [ ! -f "${GLOBAL_RANKTABLE}" ]; then
2827
echo "Generating global rank table..."
29-
30-
OUTPUT_DIR="${current_dir}"
31-
generate_hccl() {
32-
local role=$1
33-
local intance_index=$2
34-
if [ "$role" == "prefill" ]; then
35-
devices=(${PROMPT_DEVICE_ID//,/ })
36-
else
37-
devices=(${DECODE_DEVICE_ID//,/ })
38-
fi
39-
local start=${devices[0]}
40-
local end=$((${devices[-1]}+1))
41-
python rank_table_utils.py generate \
42-
--device_num="[$start,$end)" \
43-
--visible_devices=$(IFS=,; echo "${devices[*]}") \
44-
--instance_role $role \
45-
--instance_rank $intance_index \
46-
--num_instances $((role == "prefill"? $NUM_PROMPT_INSTANCE : $NUM_DECODE_INSTANCE)) \
47-
--output_dir=$OUTPUT_DIR
48-
}
49-
50-
generate_hccl "prefill" 1
51-
generate_hccl "decode" 1
52-
python rank_table_utils.py merge $OUTPUT_DIR/prefill_*_rank_table_*.json $OUTPUT_DIR/decode_*_rank_table_*.json \
53-
--output_dir=$OUTPUT_DIR
28+
# run the script to generate the global rank table
29+
bash "${current_dir}/gen_rank_table.sh"
30+
5431
echo "Global rank table generated."
5532
else
5633
echo "Global rank table already exists."
@@ -89,7 +66,7 @@ wait_for_server() {
8966
done" && return 0 || return 1
9067
}
9168

92-
ASCEND_RT_VISIBLE_DEVICES=${PROMPT_DEVICE_ID} vllm serve ${MODEL_NAME} \
69+
ASCEND_RT_VISIBLE_DEVICES=${PROMPT_DEVICE_ID_0} vllm serve ${MODEL_NAME} \
9370
--port 8100 \
9471
--max-model-len 100 \
9572
--gpu-memory-utilization 0.9 \
@@ -109,7 +86,7 @@ ASCEND_RT_VISIBLE_DEVICES=${PROMPT_DEVICE_ID} vllm serve ${MODEL_NAME} \
10986
}
11087
}' &
11188

112-
ASCEND_RT_VISIBLE_DEVICES=${DECODE_DEVICE_ID} vllm serve ${MODEL_NAME} \
89+
ASCEND_RT_VISIBLE_DEVICES=${DECODE_DEVICE_ID_0} vllm serve ${MODEL_NAME} \
11390
--port 8200 \
11491
--max-model-len 100 \
11592
--gpu-memory-utilization 0.9 \
@@ -135,4 +112,4 @@ wait_for_server 8200
135112

136113
echo "🚧🚧 Warning: server started 🚧🚧"
137114

138-
python3 disagg_prefill_proxy_server.py
115+
python3 disagg_prefill_proxy_server.py
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
!/bin/bash
2+
3+
# export PROMPT_DEVICE_ID_0=0,1,2,3
4+
# export PROMPT_DEVICE_ID_1=4,5,6,7
5+
# export DECODE_DEVICE_ID_0=8,9,10,11
6+
# export DECODE_DEVICE_ID_1=12,13,14,15
7+
# export NUM_PROMPT_INSTANCE=2
8+
# export NUM_DECODE_INSTANCE=2
9+
10+
current_dir=$(dirname "$0")
11+
OUTPUT_DIR="${current_dir}"
12+
13+
generate_hccl() {
14+
local role=$1
15+
local instance_index=$2
16+
local start=${devices[0]}
17+
local end=$((${devices[-1]}+1))
18+
python rank_table_utils.py generate \
19+
--device_num="[$start,$end)" \
20+
--visible_devices=$(IFS=,; echo "${devices[*]}") \
21+
--instance_role $role \
22+
--instance_rank $instance_index \
23+
--output_dir=$OUTPUT_DIR
24+
}
25+
26+
for ((i=0; i<NUM_PROMPT_INSTANCE; i++)); do
27+
device_var_name="PROMPT_DEVICE_ID_${i}"
28+
devices=(${!device_var_name//,/ })
29+
generate_hccl "prefill" $i
30+
done
31+
32+
# 生成decode实例(索引从0开始)
33+
for ((i=0; i<NUM_DECODE_INSTANCE; i++)); do
34+
device_var_name="DECODE_DEVICE_ID_${i}"
35+
devices=(${!device_var_name//,/ })
36+
generate_hccl "decode" $i
37+
done
38+
39+
40+
python rank_table_utils.py merge $OUTPUT_DIR/prefill_*_rank_table_*.json $OUTPUT_DIR/decode_*_rank_table_*.json \
41+
--output_dir=$OUTPUT_DIR
42+
43+
44+

examples/disaggregated-prefill-v1/rank_table_utils.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def parse_args():
1717

1818
# Common arguments
1919
subparsers = parser.add_subparsers(dest='command', required=True)
20-
subparsers.default = 'generate'
2120

2221
# Generate command
2322
gen_parser = subparsers.add_parser('generate', help='Generate single RANK TABLE config file')
@@ -31,8 +30,6 @@ def parse_args():
3130
help="Set the instance role, prefill or decode")
3231
gen_parser.add_argument("--instance_rank", type=int, default=0,
3332
help="Set the instance rank")
34-
gen_parser.add_argument("--num_instances", type=int, default=1,
35-
help="Set the number of instances")
3633
gen_parser.add_argument("--output_dir", type=str, default=os.getcwd(),
3734
help="Directory to save the generated rank_table config file")
3835

@@ -64,8 +61,8 @@ def generate_rank_table(args):
6461

6562
# server_id
6663
ip = get_host_ip()
67-
server_id = args.server_ip if args.server_ip else ip
68-
if not server_id:
64+
server_ip = ip if ip else args.server_ip
65+
if not server_ip:
6966
raise ValueError("Please input server ip!")
7067

7168
# device_num
@@ -80,19 +77,24 @@ def generate_rank_table(args):
8077
# construct rank_table
8178
device_ips: Dict[Any, Any] = {}
8279
try:
83-
for device_id in device_num_list:
84-
ret = os.popen(f"hccn_tool -i {device_id} -ip -g").readlines()
85-
device_ips[str(device_id)] = ret[0].split(":")[1].replace('\n', '')
86-
except IndexError:
80+
# make sure the /etc/hccn.conf configuration file is correct
81+
with open('/etc/hccn.conf', 'r') as fin:
82+
for hccn_item in fin.readlines():
83+
if hccn_item.strip().startswith('address_'):
84+
device_id, device_ip = hccn_item.split('=')
85+
device_id = device_id.split('_')[1]
86+
device_ips[device_id] = device_ip.strip()
87+
except OSError:
8788
try:
88-
with open('/etc/hccn.conf', 'r') as fin:
89-
for hccn_item in fin.readlines():
90-
if hccn_item.strip().startswith('address_'):
91-
device_id, device_ip = hccn_item.split('=')
92-
device_id = device_id.split('_')[1]
93-
device_ips[device_id] = device_ip.strip()
94-
except OSError:
95-
raise SystemError("Failed to find information for rank_table")
89+
for device_id in device_num_list:
90+
ret = os.popen(f"hccn_tool -i {device_id} -ip -g").readlines()
91+
device_ips[str(device_id)] = ret[0].split(":")[1].replace('\n', '')
92+
except:
93+
raise SystemError(
94+
"Failed to get device IPs. Need either:\n"
95+
"1. hccn_tool in PATH\n"
96+
"2. /etc/hccn.conf configuration file"
97+
)
9698

9799
rank_table = {
98100
'version': '1.0',
@@ -115,10 +117,9 @@ def generate_rank_table(args):
115117
rank_id += 1
116118
device_list.append(device)
117119

118-
global_instance_rank = args.num_instances + args.instance_rank
119120
rank_table['server_list'].append({
120-
'server_id': f"server-{global_instance_rank}",
121-
'server_ip': server_id,
121+
'server_id': f"server-{args.instance_rank}",
122+
'server_ip': server_ip,
122123
'device': device_list,
123124
})
124125

@@ -158,18 +159,26 @@ def merge_rank_table(args):
158159
}
159160

160161
if prefill_jsons:
162+
prefill_servers = []
163+
for j in prefill_jsons:
164+
prefill_servers.extend(j['server_list'])
165+
161166
rank_table['server_group_list'].append({
162167
"group_id": "1",
163168
"server_count": str(len(prefill_jsons)),
164-
"server_list": [s for j in prefill_jsons for s in j['server_list']]
169+
"server_list": prefill_servers
165170
})
166171

167172
if decode_jsons:
173+
decode_servers = []
174+
for j in decode_jsons:
175+
decode_servers.extend(j['server_list'])
176+
168177
rank_table['server_group_list'].append({
169178
"group_id": "2",
170179
"server_count": str(len(decode_jsons)),
171-
"server_list": [s for j in decode_jsons for s in j['server_list']]
172-
})
180+
"server_list": decode_servers
181+
})
173182

174183
rank_id = 0
175184
server_id_counter = 0

0 commit comments

Comments
 (0)