Skip to content

Commit cfbe8b1

Browse files
Merge pull request #107 from raindaywhu/dev_whq_eplb2
modify serialization of eplb process
2 parents 75992b9 + 89bcf04 commit cfbe8b1

File tree

2 files changed

+18
-25
lines changed

2 files changed

+18
-25
lines changed

vllm_ascend/eplb/core/worker/eplb_worker.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -336,30 +336,22 @@ def pack_update_info(self, update_info_generator):
336336

337337
for send_info, recv_info, new_expert_map, layer_id in update_info_generator:
338338

339-
send_all.append(send_info)
340-
recv_all.append(recv_info)
339+
send_info_this_rank = send_info[self.rank_id] if self.rank_id in send_info else []
340+
recv_info_this_rank = recv_info[self.rank_id] if self.rank_id in recv_info else []
341+
send_all.append(send_info_this_rank)
342+
recv_all.append(recv_info_this_rank)
341343

342-
maps.append(new_expert_map[self.rank_id])
344+
maps.append(new_expert_map[self.rank_id].numpy().tolist())
343345

344-
if self.redundant_enable is not None:
346+
if self.redundant_enable:
345347
log2phy_map = ExpertMapUtils.generate_log2phy_map(new_expert_map)
346-
log2phy_all.append(log2phy_map)
348+
log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist())
349+
else:
350+
log2phy_all.append([])
347351

348352
layer_ids.append(layer_id)
349353

350-
# 把 list of Tensor 堆成一个大 Tensor
351-
stacked_maps = torch.stack(maps, dim=0)
352-
layer_id_tensor = torch.as_tensor(layer_ids, dtype=torch.int64)
353-
stacked_maps.share_memory_()
354-
layer_id_tensor.share_memory_()
355-
356-
if self.redundant_enable:
357-
stacked_log2phy = torch.stack(log2phy_all, dim=0)
358-
stacked_log2phy.share_memory_()
359-
else:
360-
stacked_log2phy = None
361-
362-
return send_all, recv_all, stacked_maps, stacked_log2phy, layer_id_tensor
354+
return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids))
363355

364356
class EplbProcess:
365357
def __init__(self, shared_dict, planner_q, block_update_q, redundant_enable, policy_type: int = 0, enable_d2d: bool = True):

vllm_ascend/eplb/eplb_updator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
import torch
19+
import numpy
1920
from typing import Dict, List
2021
import torch.distributed as dist
2122
import vllm.envs as envs
@@ -111,19 +112,19 @@ def forward_before(self):
111112
# Batch after eplb process being triggered, get update info provided by eplb process
112113
if self.update_in_flight and self.weight_update_counter == 0 and self.wait_worker_iterations == self.num_wait_worker_iterations:
113114
self.wait_worker_iterations = 0
114-
packed_update_info = self.block_update_queue.get()
115-
self.update_info_all = self.unpack_update_batch(packed_update_info)
115+
self.update_info_all = self.block_update_queue.get()
116116
self.weight_loading = True
117117

118118
if self.update_in_flight and self.weight_loading and self.weight_update_counter < self.num_moe_layers:
119119
(expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(0)
120120
rank_id = torch.distributed.get_rank()
121-
self.eplb_loader.set_log2phy_map(log2phy_map)
122-
expert_send_info_this_rank = expert_send_info[rank_id] if rank_id in expert_send_info else []
123-
expert_recv_info_this_rank = expert_recv_info[rank_id] if rank_id in expert_recv_info else []
121+
if self.redundant_enable:
122+
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
123+
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
124+
updated_expert_map_this_rank = torch.from_numpy(numpy.array(updated_expert_map))
124125
#logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
125-
self.eplb_loader.generate_expert_d2d_transfer_task(expert_send_info_this_rank,
126-
expert_recv_info_this_rank, updated_expert_map, layer_id + 3)
126+
self.eplb_loader.generate_expert_d2d_transfer_task(expert_send_info, expert_recv_info,
127+
updated_expert_map_this_rank, layer_id + self.adaptor.num_dense_layers)
127128
self.weight_update_counter += 1
128129
if self.weight_update_counter == self.num_moe_layers:
129130
self.weight_update_counter = 0

0 commit comments

Comments
 (0)