14
14
# limitations under the License.
15
15
# This file is a part of the vllm-ascend project.
16
16
#
17
+
17
18
import torch
19
+ from typing import Dict , List
18
20
import torch .distributed as dist
19
21
import vllm .envs as envs
20
22
from multiprocessing import Queue , Manager
21
23
22
24
from vllm .logger import logger
23
25
from vllm_ascend .eplb .core .worker .eplb_worker import EplbProcess
24
26
from vllm_ascend .eplb .core .loader .device_transfer_loader import D2DExpertWeightLoader
25
-
27
+ from vllm_ascend . eplb . tool . eplb_utils import ExpertMapUtils
26
28
27
29
class EplbUpdator :
28
30
@@ -33,6 +35,7 @@ def set_adaptor(self, adaptor):
33
35
self .adaptor = adaptor
34
36
self .eplb_loader = D2DExpertWeightLoader (eplb_adaptor = self .adaptor )
35
37
self .num_moe_layers = self .adaptor .num_moe_layers
38
+ self .global_expert_num = self .adaptor .global_expert_num
36
39
37
40
def init_eplb (self , expert_map_path ):
38
41
self .num_expert_load_gather = 10
@@ -44,7 +47,7 @@ def init_eplb(self, expert_map_path):
44
47
if not envs .VLLM_ALLOW_EXPERT_LOAD_COLLECTING :
45
48
self .num_expert_load_gather = self .num_iterations
46
49
except Exception as e :
47
- self .num_expert_load_gather = self .num_iterations
50
+ self .num_expert_load_gather = self .num_iterations
48
51
49
52
self .weight_update_counter = 0
50
53
self .expert_map_initialized = False
@@ -58,7 +61,7 @@ def init_eplb(self, expert_map_path):
58
61
self .cur_iterations : torch .int64 = 0
59
62
60
63
self .wait_worker_iterations : torch .int64 = 0
61
- self .num_wait_worker_iterations : torch .int64 = 10
64
+ self .num_wait_worker_iterations : torch .int64 = 20
62
65
63
66
self .planner_block_queue = Queue ()
64
67
self .block_update_queue = Queue (maxsize = 1 )
@@ -70,16 +73,16 @@ def init_eplb(self, expert_map_path):
70
73
# 热度负载信息 [num_layers, world_size, num_experts]
71
74
"moe_load" : None ,
72
75
# 所有的专家表[num_layers, world_size, num_experts]
73
- "expert_maps" : None
76
+ "expert_maps" : None ,
74
77
})
75
78
76
79
self .eplb = EplbProcess (
77
- shared_dict = self .shared_dict ,
78
- planner_q = self .planner_block_queue ,
79
- block_update_q = self .block_update_queue ,
80
- redundant_enable = self .redundant_enable ,
81
- policy_type = 6 ,
82
- enable_d2d = True
80
+ shared_dict = self .shared_dict ,
81
+ planner_q = self .planner_block_queue ,
82
+ block_update_q = self .block_update_queue ,
83
+ redundant_enable = self .redundant_enable ,
84
+ policy_type = 6 ,
85
+ enable_d2d = True
83
86
)
84
87
85
88
self .eplb_process = self .eplb ._launch_process ()
@@ -88,15 +91,14 @@ def init_eplb(self, expert_map_path):
88
91
89
92
def get_update_iteration (self ):
90
93
self .cur_iterations = self .cur_iterations + 1
91
- load_gather_iteration = self .cur_iterations % self .num_expert_load_gather == 0 if not self .gate_eplb else self .cur_iterations == self .num_iterations
92
- upate_iteration = self .cur_iterations % self .num_iterations == 0 if not self .gate_eplb else self .cur_iterations == self .num_iterations
94
+ load_gather_iteration = self .cur_iterations % self .num_expert_load_gather == 0 if not self .gate_eplb else self .cur_iterations == self .num_iterations
95
+ upate_iteration = self .cur_iterations % self .num_iterations == 0 if not self .gate_eplb else self .cur_iterations == self .num_iterations
93
96
return load_gather_iteration , upate_iteration
94
97
95
98
def get_init_expert_map (self ):
96
99
try :
97
100
if not self .expert_map_initialized :
98
- self .shared_dict ["expert_maps" ] = self .adaptor .get_init_expert_map_from_file (self .num_moe_layers ,
99
- self .expert_map_path )
101
+ self .shared_dict ["expert_maps" ] = self .adaptor .get_init_expert_map_from_file (self .num_moe_layers , self .expert_map_path )
100
102
self .expert_map_initialized = True
101
103
except Exception as e :
102
104
logger .warning (f"[ModelRunner] Failed to wake EPLB process: { e } " , exc_info = True )
@@ -114,32 +116,31 @@ def forward_before(self):
114
116
self .weight_loading = True
115
117
116
118
if self .update_in_flight and self .weight_loading and self .weight_update_counter < self .num_moe_layers :
117
- (expert_send_info , expert_recv_info , updated_expert_map , log2phy_map , layer_id ) = self .update_info_all .pop (
118
- 0 )
119
+ (expert_send_info , expert_recv_info , updated_expert_map , log2phy_map , layer_id ) = self .update_info_all .pop (0 )
119
120
rank_id = torch .distributed .get_rank ()
120
121
self .eplb_loader .set_log2phy_map (log2phy_map )
121
122
expert_send_info_this_rank = expert_send_info [rank_id ] if rank_id in expert_send_info else []
122
123
expert_recv_info_this_rank = expert_recv_info [rank_id ] if rank_id in expert_recv_info else []
123
- # logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
124
+ #logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
124
125
self .eplb_loader .generate_expert_d2d_transfer_task (expert_send_info_this_rank ,
125
- expert_recv_info_this_rank , updated_expert_map ,
126
- layer_id + 3 )
126
+ expert_recv_info_this_rank , updated_expert_map , layer_id + 3 )
127
127
self .weight_update_counter += 1
128
128
if self .weight_update_counter == self .num_moe_layers :
129
129
self .weight_update_counter = 0
130
130
self .update_in_flight = False
131
131
self .update_info_all = []
132
-
133
132
# set asynchronous stream for d2d expert weight update
134
133
self .reqs = []
135
134
self .eplb_loader .asyn_expert_weight_transfer (self .reqs )
136
135
137
- def forward_end (self , dummy_run = False ):
138
- self .adaptor .get_rank_expert_workload (self .num_moe_layers , dummy_run )
136
+
137
+ def forward_end (self ,dummy_run = False ):
138
+ self .adaptor .collect_topk_ids (dummy_run )
139
139
if not self .update_in_flight :
140
140
load_gather_iteration , update_iteration = self .get_update_iteration ()
141
141
if load_gather_iteration :
142
- moe_load = self .compute_and_set_moe_load (dummy_run )
142
+ moe_load = self .compute_and_set_moe_load ()
143
+ self .get_expert_load ()
143
144
if update_iteration :
144
145
self .wakeup_eplb_worker ()
145
146
self .update_in_flight = True
@@ -151,8 +152,9 @@ def forward_end(self, dummy_run=False):
151
152
152
153
self .eplb_loader .update_expert_map_and_weight (self .reqs , self .redundant_enable )
153
154
154
- def compute_and_set_moe_load (self , dummy_run = False ):
155
- local_load = self .adaptor .get_rank_expert_workload (self .num_moe_layers , dummy_run )
155
+ def compute_and_set_moe_load (self ,dummy_run = False ):
156
+ local_load = self .adaptor .get_rank_expert_workload ()
157
+
156
158
self ._gather_buffer = None
157
159
if dist .is_initialized ():
158
160
self .world_size = dist .get_world_size ()
@@ -177,7 +179,7 @@ def compute_and_set_moe_load(self, dummy_run=False):
177
179
def warm_up_eplb (self ):
178
180
179
181
self .get_init_expert_map ()
180
-
182
+ self . adaptor . collect_topk_ids ( dummy_run = False )
181
183
self .compute_and_set_moe_load ()
182
184
183
185
src_tensor = torch .empty ((1 ,), device = self .device )
@@ -197,7 +199,7 @@ def warm_up_eplb(self):
197
199
continue
198
200
comm_op_list .append (
199
201
dist .P2POp (dist .irecv , src_tensor , src_rank )
200
- )
202
+ )
201
203
if comm_op_list :
202
204
reqs = dist .batch_isend_irecv (comm_op_list )
203
205
@@ -210,7 +212,7 @@ def unpack_update_batch(self, packed_update_info):
210
212
"""
211
213
send_all , recv_all , stacked_maps , stacked_log2phy , layer_id_tensor = packed_update_info
212
214
213
- maps = stacked_maps .unbind (0 )
215
+ maps = stacked_maps .unbind (0 )
214
216
layer_ids = layer_id_tensor .tolist ()
215
217
216
218
if self .redundant_enable :
@@ -222,7 +224,7 @@ def unpack_update_batch(self, packed_update_info):
222
224
_send = send_all
223
225
_recv = recv_all
224
226
_maps = maps
225
- _l2p = log2phy_list
227
+ _l2p = log2phy_list
226
228
_lids = layer_ids
227
229
228
230
recovered = [
@@ -232,21 +234,14 @@ def unpack_update_batch(self, packed_update_info):
232
234
]
233
235
return recovered
234
236
235
- def get_expert_load (self ) -> str :
236
-
237
- # todo wjh 给到返回值
238
- # return self.shared_dict['moe_load']
239
- # mock json_str
240
- experts_load = ('{\" expert_load\" :['
241
- '{\" ip\" :\" 141.xxx.xxx.181\" ,'
242
- '\" node_0\" :'
243
- '{\" card_0\" :'
244
- '[{\" layer_4\" :{\" expert_0\" :3,\" expert_2\" :1}},{\" layer_5\" :{\" expert_0\" :3,\" expert_2\" :1}}],'
245
- '\" card_1\" :[{\" layer_4\" :{\" expert_1\" :3,\" expert_3\" :1},\" layer_5\" :{\" expert_0\" :3,\" '
246
- 'expert_2\" :1}}]}},{\" ip\" :\" 141.xxx.xxx.177\" ,\" node_0\" :{\" card_0\" :[{\" layer_4\" :'
247
- '{\" expert_0\" :3,\" expert_2\" :1}},{\" layer_5\" :{\" expert_0\" :3,\" expert_2\" :1}}],'
248
- '\" card_1\" :[{\" layer_4\" :{\" expert_1\" :3,\" expert_3\" :1}}]}}]}' )
249
- return experts_load
237
+ def get_expert_load (self ) -> torch .Tensor :
238
+ expert_maps = self .shared_dict ["expert_maps" ]
239
+ moe_load = self .shared_dict ["moe_load" ] # Tensor [L, W, global_experts_num]
240
+ if not moe_load :
241
+ return None
242
+ num_local_experts = expert_maps .max () + 1
243
+ load_info , _ = ExpertMapUtils .global2local_load (moe_load , expert_maps , num_local_experts )
244
+ return load_info
250
245
251
246
def update_expert_load_statistical_period (self , num_expert_load_gather : int , num_iterations : int ):
252
247
logger .info (f" start update { self .num_expert_load_gather = } , { self .num_iterations } ..." )
0 commit comments