16
16
#
17
17
import torch
18
18
import torch .distributed as dist
19
+ import vllm .envs as envs
19
20
from multiprocessing import Queue , Manager
20
21
21
22
from vllm .logger import logger
@@ -33,11 +34,17 @@ def set_adaptor(self, adaptor):
33
34
self .num_moe_layers = self .adaptor .num_moe_layers
34
35
35
36
def init_eplb (self , expert_map_path ):
36
-
37
+ self . num_expert_load_gather = 10
37
38
self .redundant_enable = (expert_map_path != None )
38
39
self .num_iterations : torch .int64 = 130
39
40
self .expert_map_path = expert_map_path
40
41
42
+ try :
43
+ if not envs .VLLM_ALLOW_EXPERT_LOAD_COLLECTING :
44
+ self .num_expert_load_gather = self .num_iterations
45
+ except Exception as e :
46
+ self .num_expert_load_gather = self .num_iterations
47
+
41
48
self .weight_update_counter = 0
42
49
self .expert_map_initialized = False
43
50
self .update_in_flight = False
@@ -70,7 +77,7 @@ def init_eplb(self, expert_map_path):
70
77
planner_q = self .planner_block_queue ,
71
78
block_update_q = self .block_update_queue ,
72
79
redundant_enable = self .redundant_enable ,
73
- policy_type = 2 ,
80
+ policy_type = 6 ,
74
81
enable_d2d = True
75
82
)
76
83
@@ -80,10 +87,9 @@ def init_eplb(self, expert_map_path):
80
87
81
88
def get_update_iteration (self ):
82
89
self .cur_iterations = self .cur_iterations + 1
83
- if not self .gate_eplb :
84
- return self .cur_iterations % self .num_iterations == 0
85
- else :
86
- return self .cur_iterations == self .num_iterations
90
+ load_gather_iteration = self .cur_iterations % self .num_expert_load_gather == 0 if not self .gate_eplb else self .cur_iterations == self .num_iterations
91
+ upate_iteration = self .cur_iterations % self .num_iterations == 0 if not self .gate_eplb else self .cur_iterations == self .num_iterations
92
+ return load_gather_iteration , upate_iteration
87
93
88
94
def get_init_expert_map (self ):
89
95
try :
@@ -125,12 +131,15 @@ def forward_before(self):
125
131
self .eplb_loader .asyn_expert_weight_transfer (self .reqs )
126
132
127
133
def forward_end (self ):
128
- if not self .update_in_flight and self .get_update_iteration ():
129
- moe_load = self .compute_and_set_moe_load ()
130
- self .wakeup_eplb_worker ()
131
- self .update_in_flight = True
132
- self .wait_worker_iterations = 0
133
- self .weight_loading = False
134
+ if not self .update_in_flight :
135
+ load_gather_iteration , update_iteration = self .get_update_iteration ()
136
+ if load_gather_iteration :
137
+ self .moe_load = self .compute_and_set_moe_load ()
138
+ if update_iteration :
139
+ self .wakeup_eplb_worker ()
140
+ self .update_in_flight = True
141
+ self .wait_worker_iterations = 0
142
+ self .weight_loading = False
134
143
135
144
if self .update_in_flight :
136
145
self .wait_worker_iterations = self .wait_worker_iterations + 1
@@ -220,9 +229,27 @@ def unpack_update_batch(self, packed_update_info):
220
229
return recovered
221
230
222
231
def get_expert_load (self ) -> str :
223
- """todo 确认moe_load的值是什么类型"""
224
- # return '{"a":"b"}' # mock
225
- return self .shared_dict ['moe_load' ]
232
+
233
+ # todo wjh 给到返回值
234
+ # return self.shared_dict['moe_load']
235
+ # mock json_str
236
+ experts_load = ('{\" expert_load\" :['
237
+ '{\" ip\" :\" 141.xxx.xxx.181\" ,'
238
+ '\" node_0\" :'
239
+ '{\" card_0\" :'
240
+ '[{\" layer_4\" :{\" expert_0\" :3,\" expert_2\" :1}},{\" layer_5\" :{\" expert_0\" :3,\" expert_2\" :1}}],'
241
+ '\" card_1\" :[{\" layer_4\" :{\" expert_1\" :3,\" expert_3\" :1},\" layer_5\" :{\" expert_0\" :3,\" '
242
+ 'expert_2\" :1}}]}},{\" ip\" :\" 141.xxx.xxx.177\" ,\" node_0\" :{\" card_0\" :[{\" layer_4\" :'
243
+ '{\" expert_0\" :3,\" expert_2\" :1}},{\" layer_5\" :{\" expert_0\" :3,\" expert_2\" :1}}],'
244
+ '\" card_1\" :[{\" layer_4\" :{\" expert_1\" :3,\" expert_3\" :1}}]}}]}' )
245
+ return experts_load
246
+
247
+ def update_expert_load_statistical_period (self , num_expert_load_gather : int , num_iterations : int ):
248
+ logger .info (f" start update { self .num_expert_load_gather = } , { self .num_iterations } ..." )
249
+ self .num_expert_load_gather = num_expert_load_gather
250
+ self .num_iterations = num_iterations
251
+ logger .info (f" update { self .num_expert_load_gather = } , { self .num_iterations } success..." )
252
+
226
253
227
254
def shutdown (self ):
228
255
"""
0 commit comments