1
+ #
1
2
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
3
#
3
4
# Licensed under the Apache License, Version 2.0 (the "License");
22
23
from vllm_ascend .eplb .core .worker .eplb_worker import EplbProcess
23
24
from vllm_ascend .eplb .core .loader .device_transfer_loader import D2DExpertWeightLoader
24
25
26
+
25
27
class EplbUpdator :
26
28
27
29
def __init__ (self , expert_map_path ):
@@ -42,7 +44,7 @@ def init_eplb(self, expert_map_path):
42
44
if not envs .VLLM_ALLOW_EXPERT_LOAD_COLLECTING :
43
45
self .num_expert_load_gather = self .num_iterations
44
46
except Exception as e :
45
- self .num_expert_load_gather = self .num_iterations
47
+ self .num_expert_load_gather = self .num_iterations
46
48
47
49
self .weight_update_counter = 0
48
50
self .expert_map_initialized = False
@@ -72,19 +74,18 @@ def init_eplb(self, expert_map_path):
72
74
})
73
75
74
76
self .eplb = EplbProcess (
75
- shared_dict = self .shared_dict ,
76
- planner_q = self .planner_block_queue ,
77
- block_update_q = self .block_update_queue ,
78
- redundant_enable = self .redundant_enable ,
79
- policy_type = 6 ,
80
- enable_d2d = True
77
+ shared_dict = self .shared_dict ,
78
+ planner_q = self .planner_block_queue ,
79
+ block_update_q = self .block_update_queue ,
80
+ redundant_enable = self .redundant_enable ,
81
+ policy_type = 6 ,
82
+ enable_d2d = True
81
83
)
82
84
83
85
self .eplb_process = self .eplb ._launch_process ()
84
86
85
87
logger .info (f"[ModelRunner] Launched EPLB process (pid={ self .eplb_process .pid } )" )
86
88
87
-
88
89
def get_update_iteration (self ):
89
90
self .cur_iterations = self .cur_iterations + 1
90
91
load_gather_iteration = self .cur_iterations % self .num_expert_load_gather == 0 if not self .gate_eplb else self .cur_iterations == self .num_iterations
@@ -94,7 +95,8 @@ def get_update_iteration(self):
94
95
def get_init_expert_map (self ):
95
96
try :
96
97
if not self .expert_map_initialized :
97
- self .shared_dict ["expert_maps" ] = self .adaptor .get_init_expert_map_from_file (self .num_moe_layers , self .expert_map_path )
98
+ self .shared_dict ["expert_maps" ] = self .adaptor .get_init_expert_map_from_file (self .num_moe_layers ,
99
+ self .expert_map_path )
98
100
self .expert_map_initialized = True
99
101
except Exception as e :
100
102
logger .warning (f"[ModelRunner] Failed to wake EPLB process: { e } " , exc_info = True )
@@ -103,6 +105,7 @@ def wakeup_eplb_worker(self):
103
105
self .planner_block_queue .put (1 )
104
106
105
107
def forward_before (self ):
108
+
106
109
# Batch after eplb process being triggered, get update info provided by eplb process
107
110
if self .update_in_flight and self .weight_update_counter == 0 and self .wait_worker_iterations == self .num_wait_worker_iterations :
108
111
self .wait_worker_iterations = 0
@@ -111,14 +114,16 @@ def forward_before(self):
111
114
self .weight_loading = True
112
115
113
116
if self .update_in_flight and self .weight_loading and self .weight_update_counter < self .num_moe_layers :
114
- (expert_send_info , expert_recv_info , updated_expert_map , log2phy_map , layer_id ) = self .update_info_all .pop (0 )
117
+ (expert_send_info , expert_recv_info , updated_expert_map , log2phy_map , layer_id ) = self .update_info_all .pop (
118
+ 0 )
115
119
rank_id = torch .distributed .get_rank ()
116
120
self .eplb_loader .set_log2phy_map (log2phy_map )
117
121
expert_send_info_this_rank = expert_send_info [rank_id ] if rank_id in expert_send_info else []
118
122
expert_recv_info_this_rank = expert_recv_info [rank_id ] if rank_id in expert_recv_info else []
119
- #logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
123
+ # logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
120
124
self .eplb_loader .generate_expert_d2d_transfer_task (expert_send_info_this_rank ,
121
- expert_recv_info_this_rank , updated_expert_map , layer_id + 3 )
125
+ expert_recv_info_this_rank , updated_expert_map ,
126
+ layer_id + 3 )
122
127
self .weight_update_counter += 1
123
128
if self .weight_update_counter == self .num_moe_layers :
124
129
self .weight_update_counter = 0
@@ -129,8 +134,8 @@ def forward_before(self):
129
134
self .reqs = []
130
135
self .eplb_loader .asyn_expert_weight_transfer (self .reqs )
131
136
132
- def forward_end (self ,dummy_run = False ):
133
- self .adaptor .get_rank_expert_workload (self .num_moe_layers ,dummy_run )
137
+ def forward_end (self , dummy_run = False ):
138
+ self .adaptor .get_rank_expert_workload (self .num_moe_layers , dummy_run )
134
139
if not self .update_in_flight :
135
140
load_gather_iteration , update_iteration = self .get_update_iteration ()
136
141
if load_gather_iteration :
@@ -146,8 +151,8 @@ def forward_end(self,dummy_run=False):
146
151
147
152
self .eplb_loader .update_expert_map_and_weight (self .reqs , self .redundant_enable )
148
153
149
- def compute_and_set_moe_load (self ,dummy_run = False ):
150
- local_load = self .adaptor .get_rank_expert_workload (self .num_moe_layers ,dummy_run )
154
+ def compute_and_set_moe_load (self , dummy_run = False ):
155
+ local_load = self .adaptor .get_rank_expert_workload (self .num_moe_layers , dummy_run )
151
156
self ._gather_buffer = None
152
157
if dist .is_initialized ():
153
158
self .world_size = dist .get_world_size ()
@@ -192,7 +197,7 @@ def warm_up_eplb(self):
192
197
continue
193
198
comm_op_list .append (
194
199
dist .P2POp (dist .irecv , src_tensor , src_rank )
195
- )
200
+ )
196
201
if comm_op_list :
197
202
reqs = dist .batch_isend_irecv (comm_op_list )
198
203
@@ -205,7 +210,7 @@ def unpack_update_batch(self, packed_update_info):
205
210
"""
206
211
send_all , recv_all , stacked_maps , stacked_log2phy , layer_id_tensor = packed_update_info
207
212
208
- maps = stacked_maps .unbind (0 )
213
+ maps = stacked_maps .unbind (0 )
209
214
layer_ids = layer_id_tensor .tolist ()
210
215
211
216
if self .redundant_enable :
@@ -217,7 +222,7 @@ def unpack_update_batch(self, packed_update_info):
217
222
_send = send_all
218
223
_recv = recv_all
219
224
_maps = maps
220
- _l2p = log2phy_list
225
+ _l2p = log2phy_list
221
226
_lids = layer_ids
222
227
223
228
recovered = [
@@ -249,7 +254,6 @@ def update_expert_load_statistical_period(self, num_expert_load_gather: int, num
249
254
self .num_iterations = num_iterations
250
255
logger .info (f" update { self .num_expert_load_gather = } , { self .num_iterations } success..." )
251
256
252
-
253
257
def shutdown (self ):
254
258
"""
255
259
Clean up the EPLB process.
0 commit comments