Skip to content

Commit db1288c

Browse files
zzhx1jianzs
authored andcommitted
PD example add automatic generation of ranktablefile
1 parent c2318fe commit db1288c

File tree

2 files changed

+234
-2
lines changed

2 files changed

+234
-2
lines changed

examples/disaggregated-prefill-v1/disaggregated_prefill_multi_prefill.sh

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ export GLOBAL_RANKTABLE="${current_dir}/global_ranktable.json"
1515
# The following environment variables are required for LLMDataDist.
1616
export PROMPT_DEVICE_ID=0,1,2,3
1717
export DECODE_DEVICE_ID=4,5,6,7
18+
export NUM_PROMPT_INSTANCE=1
19+
export NUM_DECODE_INSTANCE=1
20+
1821
export TENSOR_PARALLEL_SIZE=$(($(echo $PROMPT_DEVICE_ID | grep -o ',' | wc -l) + 1))
1922

2023
# Model Configuration
@@ -23,7 +26,32 @@ export MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
2326
# Generate the global rank table
2427
if [ ! -f "${GLOBAL_RANKTABLE}" ]; then
2528
echo "Generating global rank table..."
26-
# TODO(jianzs): Impl a tool to generate the global rank table automatically
29+
30+
OUTPUT_DIR="${current_dir}"
31+
generate_hccl() {
32+
local role=$1
33+
local intance_index=$2
34+
if [ "$role" == "prefill" ]; then
35+
devices=(${PROMPT_DEVICE_ID//,/ })
36+
else
37+
devices=(${DECODE_DEVICE_ID//,/ })
38+
fi
39+
local start=${devices[0]}
40+
local end=$((${devices[-1]}+1))
41+
python rank_table_utils.py generate \
42+
--device_num="[$start,$end)" \
43+
--visible_devices=$(IFS=,; echo "${devices[*]}") \
44+
--instance_role $role \
45+
--instance_rank $intance_index \
46+
--num_instances $((role == "prefill"? $NUM_PROMPT_INSTANCE : $NUM_DECODE_INSTANCE)) \
47+
--output_dir=$OUTPUT_DIR
48+
}
49+
50+
generate_hccl "prefill" 1
51+
generate_hccl "decode" 1
52+
python rank_table_utils.py merge $OUTPUT_DIR/prefill_*_rank_table_*.json $OUTPUT_DIR/decode_*_rank_table_*.json \
53+
--output_dir=$OUTPUT_DIR
54+
echo "Global rank table generated."
2755
else
2856
echo "Global rank table already exists."
2957
fi
@@ -107,4 +135,4 @@ wait_for_server 8200
107135

108136
echo "🚧🚧 Warning: server started 🚧🚧"
109137

110-
python3 disagg_prefill_proxy_server.py
138+
python3 disagg_prefill_proxy_server.py
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""RANK_TABLE Configuration Utility Script"""
2+
import json
3+
import os
4+
import socket
5+
import sys
6+
from argparse import ArgumentParser
7+
from typing import Dict, Any
8+
9+
def parse_args():
10+
"""
11+
Parse command line arguments for RANK_TABLE utility
12+
13+
Returns:
14+
args: Parsed command line arguments
15+
"""
16+
parser = ArgumentParser(description="RANK_TABLE Configuration Utility - Generate and merge RANK_TABLE config files")
17+
18+
# Common arguments
19+
subparsers = parser.add_subparsers(dest='command', required=True)
20+
subparsers.default = 'generate'
21+
22+
# Generate command
23+
gen_parser = subparsers.add_parser('generate', help='Generate single RANK TABLE config file')
24+
gen_parser.add_argument("--device_num", type=str, default="[0,16)",
25+
help="The number of the Ascend accelerators used. Must be continuous, e.g. [0,4) means using chips 0,1,2,3")
26+
gen_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15",
27+
help="The visible devices according to the software system")
28+
gen_parser.add_argument("--server_ip", type=str, default="127.0.0.1",
29+
help="Set the server_ip manually, to avoid errors in auto detection")
30+
gen_parser.add_argument("--instance_role", type=str, default="prefill",
31+
help="Set the instance role, prefill or decode")
32+
gen_parser.add_argument("--instance_rank", type=int, default=0,
33+
help="Set the instance rank")
34+
gen_parser.add_argument("--num_instances", type=int, default=1,
35+
help="Set the number of instances")
36+
gen_parser.add_argument("--output_dir", type=str, default=os.getcwd(),
37+
help="Directory to save the generated rank_table config file")
38+
39+
# Merge command
40+
merge_parser = subparsers.add_parser('merge', help='Merge multiple RANK_TABLE config files')
41+
merge_parser.add_argument("file_list", type=str, nargs="+", help="RANK_TABLE file lists to merge")
42+
merge_parser.add_argument("--output_dir", type=str, default=os.getcwd(),
43+
help="Directory to save the merged rank_table config file")
44+
45+
return parser.parse_args()
46+
47+
def get_host_ip():
48+
"""
49+
Get host IP address
50+
51+
Returns:
52+
str: Host IP address
53+
"""
54+
try:
55+
hostname = socket.gethostname()
56+
return socket.gethostbyname(hostname)
57+
except:
58+
return None
59+
60+
def generate_rank_table(args):
61+
"""Generate single RANK_TABLE config file"""
62+
# visible_devices
63+
visible_devices = args.visible_devices.split(',')
64+
65+
# server_id
66+
ip = get_host_ip()
67+
server_id = args.server_ip if args.server_ip else ip
68+
if not server_id:
69+
raise ValueError("Please input server ip!")
70+
71+
# device_num
72+
first_num = int(args.device_num.split('[')[1].split(',')[0])
73+
last_num = int(args.device_num.split(')')[0].split(',')[-1])
74+
if first_num > last_num:
75+
raise ValueError(f"First num {first_num} of device num {args.device_num} must less than last num {last_num}!")
76+
device_num_list = list(range(first_num, last_num))
77+
78+
assert len(visible_devices) >= len(device_num_list)
79+
80+
# construct rank_table
81+
device_ips: Dict[Any, Any] = {}
82+
try:
83+
for device_id in device_num_list:
84+
ret = os.popen(f"hccn_tool -i {device_id} -ip -g").readlines()
85+
device_ips[str(device_id)] = ret[0].split(":")[1].replace('\n', '')
86+
except IndexError:
87+
try:
88+
with open('/etc/hccn.conf', 'r') as fin:
89+
for hccn_item in fin.readlines():
90+
if hccn_item.strip().startswith('address_'):
91+
device_id, device_ip = hccn_item.split('=')
92+
device_id = device_id.split('_')[1]
93+
device_ips[device_id] = device_ip.strip()
94+
except OSError:
95+
raise SystemError("Failed to find information for rank_table")
96+
97+
rank_table = {
98+
'version': '1.0',
99+
'status': 'completed',
100+
'group_id': '0',
101+
'serve_count': "1",
102+
'server_list': []
103+
}
104+
105+
device_list = []
106+
rank_id = 0
107+
for instance_id in range(len(device_num_list)):
108+
device_id = visible_devices[instance_id]
109+
device_ip = device_ips[device_id]
110+
device = {
111+
'device_id': device_id,
112+
'device_ip': device_ip,
113+
'rank_id': str(rank_id)
114+
}
115+
rank_id += 1
116+
device_list.append(device)
117+
118+
global_instance_rank = args.num_instances + args.instance_rank
119+
rank_table['server_list'].append({
120+
'server_id': f"server-{global_instance_rank}",
121+
'server_ip': server_id,
122+
'device': device_list,
123+
})
124+
125+
# Save rank_table to file
126+
table_fn = os.path.join(args.output_dir,
127+
f'{args.instance_role}_{args.instance_rank}_rank_table_{len(device_num_list)}u.json')
128+
with open(table_fn, 'w') as table_fp:
129+
json.dump(rank_table, table_fp, indent=4)
130+
print(f"Completed: rank_table file was saved in: {table_fn}")
131+
132+
def merge_rank_table(args):
133+
"""Merge multiple RANK_TABLE config files"""
134+
prefill_jsons = []
135+
decode_jsons = []
136+
137+
for f_name in args.file_list:
138+
with open(f_name) as f:
139+
f_json = json.load(f)
140+
if "prefill" in f_name:
141+
prefill_jsons.append(f_json)
142+
elif "decode" in f_name:
143+
decode_jsons.append(f_json)
144+
145+
rank_table = {
146+
'version': '1.0',
147+
'status': "completed",
148+
'server_group_list': [
149+
{
150+
"group_id": "0",
151+
"server_count": "1",
152+
"server_list": [{
153+
"server_id": "router",
154+
"server_ip": "127.0.0.1"
155+
}]
156+
}
157+
]
158+
}
159+
160+
if prefill_jsons:
161+
rank_table['server_group_list'].append({
162+
"group_id": "1",
163+
"server_count": str(len(prefill_jsons)),
164+
"server_list": [s for j in prefill_jsons for s in j['server_list']]
165+
})
166+
167+
if decode_jsons:
168+
rank_table['server_group_list'].append({
169+
"group_id": "2",
170+
"server_count": str(len(decode_jsons)),
171+
"server_list": [s for j in decode_jsons for s in j['server_list']]
172+
})
173+
174+
rank_id = 0
175+
server_id_counter = 0
176+
177+
for group in rank_table['server_group_list']:
178+
if group['group_id'] == "0":
179+
continue
180+
181+
local_rank_id = 0
182+
for server in group['server_list']:
183+
server['server_id'] = f"server-{server_id_counter}"
184+
server_id_counter += 1
185+
local_rank_id = 0
186+
for device in server['device']:
187+
device['rank_id'] = str(local_rank_id)
188+
local_rank_id += 1
189+
rank_id += 1
190+
191+
table_name = os.path.join(args.output_dir, f'global_ranktable.json')
192+
with open(table_name, 'w') as table_fp:
193+
json.dump(rank_table, table_fp, indent=4)
194+
print(f"Completed: rank_table file was saved in: {table_name}")
195+
196+
def main():
197+
args = parse_args()
198+
if args.command == 'generate':
199+
generate_rank_table(args)
200+
elif args.command == 'merge':
201+
merge_rank_table(args)
202+
203+
if __name__ == "__main__":
204+
main()

0 commit comments

Comments
 (0)