1
+ """RANK_TABLE Configuration Utility Script"""
2
+ import json
3
+ import os
4
+ import socket
5
+ import sys
6
+ from argparse import ArgumentParser
7
+ from typing import Dict , Any
8
+
9
+ def parse_args ():
10
+ """
11
+ Parse command line arguments for RANK_TABLE utility
12
+
13
+ Returns:
14
+ args: Parsed command line arguments
15
+ """
16
+ parser = ArgumentParser (description = "RANK_TABLE Configuration Utility - Generate and merge RANK_TABLE config files" )
17
+
18
+ # Common arguments
19
+ subparsers = parser .add_subparsers (dest = 'command' , required = True )
20
+ subparsers .default = 'generate'
21
+
22
+ # Generate command
23
+ gen_parser = subparsers .add_parser ('generate' , help = 'Generate single RANK TABLE config file' )
24
+ gen_parser .add_argument ("--device_num" , type = str , default = "[0,16)" ,
25
+ help = "The number of the Ascend accelerators used. Must be continuous, e.g. [0,4) means using chips 0,1,2,3" )
26
+ gen_parser .add_argument ("--visible_devices" , type = str , default = "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15" ,
27
+ help = "The visible devices according to the software system" )
28
+ gen_parser .add_argument ("--server_ip" , type = str , default = "127.0.0.1" ,
29
+ help = "Set the server_ip manually, to avoid errors in auto detection" )
30
+ gen_parser .add_argument ("--instance_role" , type = str , default = "prefill" ,
31
+ help = "Set the instance role, prefill or decode" )
32
+ gen_parser .add_argument ("--instance_rank" , type = int , default = 0 ,
33
+ help = "Set the instance rank" )
34
+ gen_parser .add_argument ("--num_instances" , type = int , default = 1 ,
35
+ help = "Set the number of instances" )
36
+ gen_parser .add_argument ("--output_dir" , type = str , default = os .getcwd (),
37
+ help = "Directory to save the generated rank_table config file" )
38
+
39
+ # Merge command
40
+ merge_parser = subparsers .add_parser ('merge' , help = 'Merge multiple RANK_TABLE config files' )
41
+ merge_parser .add_argument ("file_list" , type = str , nargs = "+" , help = "RANK_TABLE file lists to merge" )
42
+ merge_parser .add_argument ("--output_dir" , type = str , default = os .getcwd (),
43
+ help = "Directory to save the merged rank_table config file" )
44
+
45
+ return parser .parse_args ()
46
+
47
+ def get_host_ip ():
48
+ """
49
+ Get host IP address
50
+
51
+ Returns:
52
+ str: Host IP address
53
+ """
54
+ try :
55
+ hostname = socket .gethostname ()
56
+ return socket .gethostbyname (hostname )
57
+ except :
58
+ return None
59
+
60
+ def generate_rank_table (args ):
61
+ """Generate single RANK_TABLE config file"""
62
+ # visible_devices
63
+ visible_devices = args .visible_devices .split (',' )
64
+
65
+ # server_id
66
+ ip = get_host_ip ()
67
+ server_id = args .server_ip if args .server_ip else ip
68
+ if not server_id :
69
+ raise ValueError ("Please input server ip!" )
70
+
71
+ # device_num
72
+ first_num = int (args .device_num .split ('[' )[1 ].split (',' )[0 ])
73
+ last_num = int (args .device_num .split (')' )[0 ].split (',' )[- 1 ])
74
+ if first_num > last_num :
75
+ raise ValueError (f"First num { first_num } of device num { args .device_num } must less than last num { last_num } !" )
76
+ device_num_list = list (range (first_num , last_num ))
77
+
78
+ assert len (visible_devices ) >= len (device_num_list )
79
+
80
+ # construct rank_table
81
+ device_ips : Dict [Any , Any ] = {}
82
+ try :
83
+ for device_id in device_num_list :
84
+ ret = os .popen (f"hccn_tool -i { device_id } -ip -g" ).readlines ()
85
+ device_ips [str (device_id )] = ret [0 ].split (":" )[1 ].replace ('\n ' , '' )
86
+ except IndexError :
87
+ try :
88
+ with open ('/etc/hccn.conf' , 'r' ) as fin :
89
+ for hccn_item in fin .readlines ():
90
+ if hccn_item .strip ().startswith ('address_' ):
91
+ device_id , device_ip = hccn_item .split ('=' )
92
+ device_id = device_id .split ('_' )[1 ]
93
+ device_ips [device_id ] = device_ip .strip ()
94
+ except OSError :
95
+ raise SystemError ("Failed to find information for rank_table" )
96
+
97
+ rank_table = {
98
+ 'version' : '1.0' ,
99
+ 'status' : 'completed' ,
100
+ 'group_id' : '0' ,
101
+ 'serve_count' : "1" ,
102
+ 'server_list' : []
103
+ }
104
+
105
+ device_list = []
106
+ rank_id = 0
107
+ for instance_id in range (len (device_num_list )):
108
+ device_id = visible_devices [instance_id ]
109
+ device_ip = device_ips [device_id ]
110
+ device = {
111
+ 'device_id' : device_id ,
112
+ 'device_ip' : device_ip ,
113
+ 'rank_id' : str (rank_id )
114
+ }
115
+ rank_id += 1
116
+ device_list .append (device )
117
+
118
+ global_instance_rank = args .num_instances + args .instance_rank
119
+ rank_table ['server_list' ].append ({
120
+ 'server_id' : f"server-{ global_instance_rank } " ,
121
+ 'server_ip' : server_id ,
122
+ 'device' : device_list ,
123
+ })
124
+
125
+ # Save rank_table to file
126
+ table_fn = os .path .join (args .output_dir ,
127
+ f'{ args .instance_role } _{ args .instance_rank } _rank_table_{ len (device_num_list )} u.json' )
128
+ with open (table_fn , 'w' ) as table_fp :
129
+ json .dump (rank_table , table_fp , indent = 4 )
130
+ print (f"Completed: rank_table file was saved in: { table_fn } " )
131
+
132
+ def merge_rank_table (args ):
133
+ """Merge multiple RANK_TABLE config files"""
134
+ prefill_jsons = []
135
+ decode_jsons = []
136
+
137
+ for f_name in args .file_list :
138
+ with open (f_name ) as f :
139
+ f_json = json .load (f )
140
+ if "prefill" in f_name :
141
+ prefill_jsons .append (f_json )
142
+ elif "decode" in f_name :
143
+ decode_jsons .append (f_json )
144
+
145
+ rank_table = {
146
+ 'version' : '1.0' ,
147
+ 'status' : "completed" ,
148
+ 'server_group_list' : [
149
+ {
150
+ "group_id" : "0" ,
151
+ "server_count" : "1" ,
152
+ "server_list" : [{
153
+ "server_id" : "router" ,
154
+ "server_ip" : "127.0.0.1"
155
+ }]
156
+ }
157
+ ]
158
+ }
159
+
160
+ if prefill_jsons :
161
+ rank_table ['server_group_list' ].append ({
162
+ "group_id" : "1" ,
163
+ "server_count" : str (len (prefill_jsons )),
164
+ "server_list" : [s for j in prefill_jsons for s in j ['server_list' ]]
165
+ })
166
+
167
+ if decode_jsons :
168
+ rank_table ['server_group_list' ].append ({
169
+ "group_id" : "2" ,
170
+ "server_count" : str (len (decode_jsons )),
171
+ "server_list" : [s for j in decode_jsons for s in j ['server_list' ]]
172
+ })
173
+
174
+ rank_id = 0
175
+ server_id_counter = 0
176
+
177
+ for group in rank_table ['server_group_list' ]:
178
+ if group ['group_id' ] == "0" :
179
+ continue
180
+
181
+ local_rank_id = 0
182
+ for server in group ['server_list' ]:
183
+ server ['server_id' ] = f"server-{ server_id_counter } "
184
+ server_id_counter += 1
185
+ local_rank_id = 0
186
+ for device in server ['device' ]:
187
+ device ['rank_id' ] = str (local_rank_id )
188
+ local_rank_id += 1
189
+ rank_id += 1
190
+
191
+ table_name = os .path .join (args .output_dir , f'global_ranktable.json' )
192
+ with open (table_name , 'w' ) as table_fp :
193
+ json .dump (rank_table , table_fp , indent = 4 )
194
+ print (f"Completed: rank_table file was saved in: { table_name } " )
195
+
196
+ def main ():
197
+ args = parse_args ()
198
+ if args .command == 'generate' :
199
+ generate_rank_table (args )
200
+ elif args .command == 'merge' :
201
+ merge_rank_table (args )
202
+
203
+ if __name__ == "__main__" :
204
+ main ()
0 commit comments