Skip to content

Commit 1b7b87b

Browse files
authored
Merge pull request #105 from raindaywhu/br_main_into_eplb_wjh
fix get_expert_load
2 parents bfa07cf + 6a0a05e commit 1b7b87b

File tree

3 files changed

+36
-6
lines changed

3 files changed

+36
-6
lines changed

vllm_ascend/eplb/core/worker/eplb_worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,13 @@ def do_update(self):
6262

6363
#根据负载信息,获取更新后的专家表
6464
load_info, old_placement = self.global2local(load_info, self.old_expert_maps, self.num_local_experts)
65-
self.shared_dict["load_info"] = load_info
6665
changed, priority, new_placement = self.calculate_rebalance_experts(load_info, old_placement)
6766

6867
if not torch.is_tensor(new_placement):
6968
new_placement = torch.tensor(new_placement)
7069
self.check_expert_placement(old_placement, new_placement)
7170
new_expert_maps = self.local2global(new_placement)
72-
71+
self.update_expert_map(new_expert_maps)
7372
logger.debug(f"[EPLB Process new_map differs, performing D2D")
7473

7574
update_info = self.compose_expert_update_info_bipartite(new_expert_maps, self.old_expert_maps)\

vllm_ascend/eplb/eplb_updator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ def init_eplb(self, expert_map_path):
7474
"moe_load": None,
7575
# 所有的专家表[num_layers, world_size, num_experts]
7676
"expert_maps": None,
77-
# 热度负载信息 [num_layers, world_size, local_num_experts]
78-
"load_info": None,
7977
})
8078

8179
self.eplb = EplbProcess(
@@ -235,9 +233,12 @@ def unpack_update_batch(self, packed_update_info):
235233
]
236234
return recovered
237235

238-
def get_expert_load(self) -> str:
236+
def get_expert_load(self):
237+
expert_maps = self.shared_dict["expert_maps"]
238+
moe_load = self.shared_dict["moe_load"] # Tensor [L, W, global_experts_num]
239+
num_local_experts = expert_maps.max() + 1
240+
load_info, _ = ExpertMapUtils.global2local_load(moe_load, expert_maps, num_local_experts)
239241

240-
load_info = self.shared_dict["load_info"] # Tensor [L, W, local_experts_num]
241242
L, W, _ = load_info.shape
242243

243244
expert_load: Dict[str, List[dict]] = {}

vllm_ascend/eplb/tool/eplb_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,33 @@ def global2local(cls,
8383
pt_local[g_idx, slot_idx] = k_idx
8484

8585
return pt_local
86+
87+
@classmethod
88+
def global2local_load(self,
89+
workload: torch.Tensor,
90+
placement: torch.Tensor,
91+
E_local: int
92+
) -> tuple[torch.Tensor, torch.Tensor]:
93+
94+
L, G, _ = placement.shape
95+
device = placement.device
96+
97+
wt_local = torch.full((L, G, E_local),
98+
fill_value=-1,
99+
dtype=workload.dtype,
100+
device=device)
101+
pt_local = torch.full((L, G, E_local),
102+
fill_value=-1,
103+
dtype=torch.long,
104+
device=device)
105+
106+
valid = placement >= 0
107+
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
108+
109+
slot_idx = placement[l_idx, g_idx, k_idx]
110+
values = workload[l_idx, g_idx, k_idx]
111+
112+
wt_local[l_idx, g_idx, slot_idx] = values
113+
pt_local[l_idx, g_idx, slot_idx] = k_idx
114+
115+
return wt_local, pt_local

0 commit comments

Comments
 (0)