@@ -17,7 +17,6 @@ def parse_args():
17
17
18
18
# Common arguments
19
19
subparsers = parser .add_subparsers (dest = 'command' , required = True )
20
- subparsers .default = 'generate'
21
20
22
21
# Generate command
23
22
gen_parser = subparsers .add_parser ('generate' , help = 'Generate single RANK TABLE config file' )
@@ -31,8 +30,6 @@ def parse_args():
31
30
help = "Set the instance role, prefill or decode" )
32
31
gen_parser .add_argument ("--instance_rank" , type = int , default = 0 ,
33
32
help = "Set the instance rank" )
34
- gen_parser .add_argument ("--num_instances" , type = int , default = 1 ,
35
- help = "Set the number of instances" )
36
33
gen_parser .add_argument ("--output_dir" , type = str , default = os .getcwd (),
37
34
help = "Directory to save the generated rank_table config file" )
38
35
@@ -64,8 +61,8 @@ def generate_rank_table(args):
64
61
65
62
# server_id
66
63
ip = get_host_ip ()
67
- server_id = args . server_ip if args . server_ip else ip
68
- if not server_id :
64
+ server_ip = ip if ip else args . server_ip
65
+ if not server_ip :
69
66
raise ValueError ("Please input server ip!" )
70
67
71
68
# device_num
@@ -80,19 +77,24 @@ def generate_rank_table(args):
80
77
# construct rank_table
81
78
device_ips : Dict [Any , Any ] = {}
82
79
try :
83
- for device_id in device_num_list :
84
- ret = os .popen (f"hccn_tool -i { device_id } -ip -g" ).readlines ()
85
- device_ips [str (device_id )] = ret [0 ].split (":" )[1 ].replace ('\n ' , '' )
86
- except IndexError :
80
+ # make sure the /etc/hccn.conf configuration file is correct
81
+ with open ('/etc/hccn.conf' , 'r' ) as fin :
82
+ for hccn_item in fin .readlines ():
83
+ if hccn_item .strip ().startswith ('address_' ):
84
+ device_id , device_ip = hccn_item .split ('=' )
85
+ device_id = device_id .split ('_' )[1 ]
86
+ device_ips [device_id ] = device_ip .strip ()
87
+ except OSError :
87
88
try :
88
- with open ('/etc/hccn.conf' , 'r' ) as fin :
89
- for hccn_item in fin .readlines ():
90
- if hccn_item .strip ().startswith ('address_' ):
91
- device_id , device_ip = hccn_item .split ('=' )
92
- device_id = device_id .split ('_' )[1 ]
93
- device_ips [device_id ] = device_ip .strip ()
94
- except OSError :
95
- raise SystemError ("Failed to find information for rank_table" )
89
+ for device_id in device_num_list :
90
+ ret = os .popen (f"hccn_tool -i { device_id } -ip -g" ).readlines ()
91
+ device_ips [str (device_id )] = ret [0 ].split (":" )[1 ].replace ('\n ' , '' )
92
+ except :
93
+ raise SystemError (
94
+ "Failed to get device IPs. Need either:\n "
95
+ "1. hccn_tool in PATH\n "
96
+ "2. /etc/hccn.conf configuration file"
97
+ )
96
98
97
99
rank_table = {
98
100
'version' : '1.0' ,
@@ -115,10 +117,9 @@ def generate_rank_table(args):
115
117
rank_id += 1
116
118
device_list .append (device )
117
119
118
- global_instance_rank = args .num_instances + args .instance_rank
119
120
rank_table ['server_list' ].append ({
120
- 'server_id' : f"server-{ global_instance_rank } " ,
121
- 'server_ip' : server_id ,
121
+ 'server_id' : f"server-{ args . instance_rank } " ,
122
+ 'server_ip' : server_ip ,
122
123
'device' : device_list ,
123
124
})
124
125
@@ -158,18 +159,26 @@ def merge_rank_table(args):
158
159
}
159
160
160
161
if prefill_jsons :
162
+ prefill_servers = []
163
+ for j in prefill_jsons :
164
+ prefill_servers .extend (j ['server_list' ])
165
+
161
166
rank_table ['server_group_list' ].append ({
162
167
"group_id" : "1" ,
163
168
"server_count" : str (len (prefill_jsons )),
164
- "server_list" : [ s for j in prefill_jsons for s in j [ 'server_list' ]]
169
+ "server_list" : prefill_servers
165
170
})
166
171
167
172
if decode_jsons :
173
+ decode_servers = []
174
+ for j in decode_jsons :
175
+ decode_servers .extend (j ['server_list' ])
176
+
168
177
rank_table ['server_group_list' ].append ({
169
178
"group_id" : "2" ,
170
179
"server_count" : str (len (decode_jsons )),
171
- "server_list" : [ s for j in decode_jsons for s in j [ 'server_list' ]]
172
- })
180
+ "server_list" : decode_servers
181
+ })
173
182
174
183
rank_id = 0
175
184
server_id_counter = 0
0 commit comments