Skip to content

Commit 2c0fde8

Browse files
wanghanqingLYTyangcheng (AJ)
authored andcommitted
improve the implement of communication between main process and eplb process
1 parent 66f7388 commit 2c0fde8

File tree

4 files changed

+23
-14
lines changed

4 files changed

+23
-14
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, model, **args):
3232
self.param_dict = dict(self.model.named_parameters())
3333
self.num_dense_layers = self.model.config.first_k_dense_replace
3434
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
35-
self.global_expert_num = 256
35+
self.global_expert_num = self.model.config.n_routed_experts
3636

3737
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here
3838
self.expert_weight_names = ["w13_weight", "w2_weight", "w13_weight_scale", "w13_weight_offset",
@@ -92,7 +92,7 @@ def do_update_expert_weight(self, layer_id, expert_id_before_replace, buffer_ten
9292
expert_tensor = self.param_dict[complete_name].data[local_expert_id]
9393
expert_tensor.copy_(self.buffer_tensor_dict[name][buffer_tensor_id])
9494

95-
def generate_index_dicts(self,tensor_2d):
95+
def generate_index_dicts(self, tensor_2d):
9696
dict_list = []
9797
current_idx = 0
9898

@@ -137,7 +137,7 @@ def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
137137
rank_id = torch.distributed.get_rank()
138138
if self.log2phy_map_per_layer[layer_id] is not None:
139139
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map[rank_id])
140-
140+
141141
def global2local(self,
142142
placement: torch.Tensor,
143143
E_local: int

vllm_ascend/eplb/core/policy/policy_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ class PolicyFactory:
1010
def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy:
1111
policy = {
1212
0:MockLoadBalance , # MockLoadBalance
13-
1:DynamicEP, # When real eplb algorithm is ready, recover this
13+
1:DynamicEP,
1414
}
1515
return policy.get(policy_type, MockLoadBalance)(config)

vllm_ascend/eplb/core/worker/eplb_worker.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def compose_expert_update_info(self, updated_expert_maps, current_expert_maps):
128128

129129
if not torch.isin(torch.tensor(expert_id), experts_to_send).any():
130130
# if expert_id are not sent out from any npu, it will be copied from one npu holding this expert
131-
candidate_src_rank_indices = torch.where(current_expert_maps_this_layer[:, expert_id] != -1)
131+
candidate_src_rank_indices = torch.where(current_expert_maps_this_layer[:, expert_id] != -1)[0]
132132
else:
133133
candidate_src_rank_indices = src_rank_indices[experts_to_send == expert_id]
134134

@@ -245,7 +245,7 @@ def __init__(self, shared_dict, planner_q, block_update_q, policy_type: int = 0,
245245
self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d)
246246

247247

248-
def worker_process(self,planner_q,block_update_q):
248+
def worker_process(self, planner_q, block_update_q):
249249
"""
250250
Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete.
251251
"""
@@ -254,14 +254,17 @@ def worker_process(self,planner_q,block_update_q):
254254

255255
planner_q.get()
256256

257-
update_info = self.worker.do_update()
257+
update_info_generator = self.worker.do_update()
258+
update_info_list = []
258259

259-
for (a,b,c,d) in update_info:
260-
while True:
261-
if not block_update_q.empty():
262-
continue
263-
block_update_q.put((a,b,c,d))
264-
break
260+
for (send_info , recv_info , new_expert_map, layer_id) in update_info_generator:
261+
update_info_list.append((send_info , recv_info , new_expert_map, layer_id))
262+
263+
while True:
264+
if not block_update_q.empty():
265+
continue
266+
block_update_q.put(update_info_list)
267+
break
265268

266269
except Exception as e:
267270
logger.warning(f"[EPLB subprocess Exiting due to error: {e}", exc_info=True)

vllm_ascend/eplb/eplb_updator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def init_eplb(self):
4141
self.update_in_flight = False
4242

4343
self.reqs = []
44+
self.update_info_all = []
4445

4546
self.cur_iterations: torch.int64 = 0
4647

@@ -88,8 +89,12 @@ def wakeup_eplb_worker(self):
8889
def forward_before(self):
8990
self.get_init_expert_map()
9091

92+
# Batch after eplb process being triggered, get update info provided by eplb process
93+
if self.update_in_flight and self.weight_update_counter == 0:
94+
self.update_info_all = self.block_update_queue.get()
95+
9196
if self.update_in_flight and self.weight_update_counter < self.num_moe_layers:
92-
(expert_send_info, expert_recv_info, updated_expert_map, layer_id) = self.block_update_queue.get()
97+
(expert_send_info, expert_recv_info, updated_expert_map, layer_id) = self.update_info_all.pop(0)
9398
rank_id = torch.distributed.get_rank()
9499
expert_send_info_this_rank = expert_send_info[rank_id] if rank_id in expert_send_info else []
95100
expert_recv_info_this_rank = expert_recv_info[rank_id] if rank_id in expert_recv_info else []
@@ -100,6 +105,7 @@ def forward_before(self):
100105
if self.weight_update_counter == self.num_moe_layers:
101106
self.weight_update_counter = 0
102107
self.update_in_flight = False
108+
self.update_info_all = []
103109

104110
# set asynchronous stream for d2d expert weight update
105111
self.reqs = []

0 commit comments

Comments
 (0)