14
14
# limitations under the License.
15
15
# This file is a part of the vllm-ascend project.
16
16
#
17
+
17
18
import torch
19
+ from typing import Dict , List
18
20
import torch .distributed as dist
19
21
import vllm .envs as envs
20
22
from multiprocessing import Queue , Manager
21
23
22
24
from vllm .logger import logger
23
25
from vllm_ascend .eplb .core .worker .eplb_worker import EplbProcess
24
26
from vllm_ascend .eplb .core .loader .device_transfer_loader import D2DExpertWeightLoader
27
+ from vllm_ascend .eplb .tool .eplb_utils import ExpertMapUtils
25
28
26
29
class EplbUpdator :
27
30
@@ -32,6 +35,7 @@ def set_adaptor(self, adaptor):
32
35
self .adaptor = adaptor
33
36
self .eplb_loader = D2DExpertWeightLoader (eplb_adaptor = self .adaptor )
34
37
self .num_moe_layers = self .adaptor .num_moe_layers
38
+ self .global_expert_num = self .adaptor .global_expert_num
35
39
36
40
def init_eplb (self , expert_map_path ):
37
41
self .num_expert_load_gather = 10
@@ -69,7 +73,9 @@ def init_eplb(self, expert_map_path):
69
73
# 热度负载信息 [num_layers, world_size, num_experts]
70
74
"moe_load" : None ,
71
75
# 所有的专家表[num_layers, world_size, num_experts]
72
- "expert_maps" : None
76
+ "expert_maps" : None ,
77
+ # 热度负载信息 [num_layers, world_size, local_num_experts]
78
+ "load_info" : None ,
73
79
})
74
80
75
81
self .eplb = EplbProcess (
@@ -125,11 +131,11 @@ def forward_before(self):
125
131
self .weight_update_counter = 0
126
132
self .update_in_flight = False
127
133
self .update_info_all = []
128
-
129
134
# set asynchronous stream for d2d expert weight update
130
135
self .reqs = []
131
136
self .eplb_loader .asyn_expert_weight_transfer (self .reqs )
132
137
138
+
133
139
def forward_end (self ,dummy_run = False ):
134
140
self .adaptor .collect_topk_ids (dummy_run )
135
141
if not self .update_in_flight :
@@ -149,6 +155,7 @@ def forward_end(self,dummy_run=False):
149
155
150
156
def compute_and_set_moe_load (self ,dummy_run = False ):
151
157
local_load = self .adaptor .get_rank_expert_workload ()
158
+
152
159
self ._gather_buffer = None
153
160
if dist .is_initialized ():
154
161
self .world_size = dist .get_world_size ()
@@ -229,28 +236,31 @@ def unpack_update_batch(self, packed_update_info):
229
236
return recovered
230
237
231
238
def get_expert_load (self ) -> str :
232
-
233
- # todo wjh 给到返回值
234
- # return self.shared_dict['moe_load']
235
- # mock json_str
236
- experts_load = ('{\" expert_load\" :['
237
- '{\" ip\" :\" 141.xxx.xxx.181\" ,'
238
- '\" node_0\" :'
239
- '{\" card_0\" :'
240
- '[{\" layer_4\" :{\" expert_0\" :3,\" expert_2\" :1}},{\" layer_5\" :{\" expert_0\" :3,\" expert_2\" :1}}],'
241
- '\" card_1\" :[{\" layer_4\" :{\" expert_1\" :3,\" expert_3\" :1},\" layer_5\" :{\" expert_0\" :3,\" '
242
- 'expert_2\" :1}}]}},{\" ip\" :\" 141.xxx.xxx.177\" ,\" node_0\" :{\" card_0\" :[{\" layer_4\" :'
243
- '{\" expert_0\" :3,\" expert_2\" :1}},{\" layer_5\" :{\" expert_0\" :3,\" expert_2\" :1}}],'
244
- '\" card_1\" :[{\" layer_4\" :{\" expert_1\" :3,\" expert_3\" :1}}]}}]}' )
245
- return experts_load
239
+
240
+ load_info = self .shared_dict ["load_info" ] # Tensor [L, W, local_experts_num]
241
+ L , W , _ = load_info .shape
242
+
243
+ expert_load : Dict [str , List [dict ]] = {}
244
+ for c in range (W ):
245
+ layers : List [dict ] = []
246
+ for l in range (L ):
247
+ counts_1d = load_info [l , c ]
248
+
249
+ layer_val = {
250
+ f"expert_{ e } " : int (v )
251
+ for e , v in enumerate (counts_1d .tolist ())
252
+ }
253
+ layers .append ({f"layer_{ l } " : layer_val })
254
+ expert_load [f"card_{ c } " ] = layers
255
+
256
+ return {"expert_load" : expert_load }
246
257
247
258
def update_expert_load_statistical_period (self , num_expert_load_gather : int , num_iterations : int ):
248
259
logger .info (f" start update { self .num_expert_load_gather = } , { self .num_iterations } ..." )
249
260
self .num_expert_load_gather = num_expert_load_gather
250
261
self .num_iterations = num_iterations
251
262
logger .info (f" update { self .num_expert_load_gather = } , { self .num_iterations } success..." )
252
263
253
-
254
264
def shutdown (self ):
255
265
"""
256
266
Clean up the EPLB process.
0 commit comments