Skip to content

Commit e4cba5e

Browse files
Merge branch 'br_main_into_eplb' into dev_whq_eplb2
2 parents ad5e7e1 + 0bab2cd commit e4cba5e

File tree

5 files changed

+39
-28
lines changed

5 files changed

+39
-28
lines changed

vllm_ascend/eplb/core/worker/eplb_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def do_update(self):
6868
new_placement = torch.tensor(new_placement)
6969
self.check_expert_placement(old_placement, new_placement)
7070
new_expert_maps = self.local2global(new_placement)
71-
71+
self.update_expert_map(new_expert_maps)
7272
logger.debug(f"[EPLB Process new_map differs, performing D2D")
7373

7474
update_info = self.compose_expert_update_info_bipartite(new_expert_maps, self.old_expert_maps)\

vllm_ascend/eplb/eplb_updator.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ def init_eplb(self, expert_map_path):
7474
"moe_load": None,
7575
# 所有的专家表[num_layers, world_size, num_experts]
7676
"expert_maps": None,
77-
# 热度负载信息 [num_layers, world_size, local_num_experts]
78-
"load_info": None,
7977
})
8078

8179
self.eplb = EplbProcess(
@@ -141,6 +139,7 @@ def forward_end(self,dummy_run=False):
141139
load_gather_iteration, update_iteration = self.get_update_iteration()
142140
if load_gather_iteration:
143141
moe_load = self.compute_and_set_moe_load()
142+
self.get_expert_load()
144143
if update_iteration:
145144
self.wakeup_eplb_worker()
146145
self.update_in_flight = True
@@ -234,25 +233,11 @@ def unpack_update_batch(self, packed_update_info):
234233
]
235234
return recovered
236235

237-
def get_expert_load(self) -> str:
238-
239-
load_info = self.shared_dict["load_info"] # Tensor [L, W, local_experts_num]
240-
L, W, _ = load_info.shape
241-
242-
expert_load: Dict[str, List[dict]] = {}
243-
for c in range(W):
244-
layers: List[dict] = []
245-
for l in range(L):
246-
counts_1d = load_info[l, c]
247-
248-
layer_val = {
249-
f"expert_{e}": int(v)
250-
for e, v in enumerate(counts_1d.tolist())
251-
}
252-
layers.append({f"layer_{l}": layer_val})
253-
expert_load[f"card_{c}"] = layers
254-
255-
return {"expert_load": expert_load}
236+
def get_expert_load(self) -> tuple:
237+
expert_maps = self.shared_dict["expert_maps"]
238+
moe_load = self.shared_dict["moe_load"] # Tensor [L, W, global_experts_num]
239+
num_local_experts = expert_maps.max() + 1
240+
return moe_load, expert_maps, num_local_experts
256241

257242
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
258243
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations}...")

vllm_ascend/eplb/tool/eplb_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,32 @@ def global2local(cls,
8383
pt_local[g_idx, slot_idx] = k_idx
8484

8585
return pt_local
86+
87+
@classmethod
88+
def global2local_load(self,
89+
workload: torch.Tensor,
90+
placement: torch.Tensor,
91+
E_local: int
92+
) -> tuple[torch.Tensor, torch.Tensor]:
93+
L, G, _ = placement.shape
94+
device = placement.device
95+
96+
wt_local = torch.full((L, G, E_local),
97+
fill_value=-1,
98+
dtype=workload.dtype,
99+
device=device)
100+
pt_local = torch.full((L, G, E_local),
101+
fill_value=-1,
102+
dtype=torch.long,
103+
device=device)
104+
105+
valid = placement >= 0
106+
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
107+
108+
slot_idx = placement[l_idx, g_idx, k_idx]
109+
values = workload[l_idx, g_idx, k_idx]
110+
111+
wt_local[l_idx, g_idx, slot_idx] = values
112+
pt_local[l_idx, g_idx, slot_idx] = k_idx
113+
114+
return wt_local, pt_local

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1590,7 +1590,7 @@ def profile_run(self) -> None:
15901590
self.encoder_cache.clear()
15911591
gc.collect()
15921592

1593-
def do_get_expert_load(self) -> str:
1593+
def do_get_expert_load(self) -> tuple:
15941594
return self.eplb_updator.get_expert_load()
15951595

15961596
def do_update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):

vllm_ascend/worker/worker_v1.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,8 @@ def compile_or_warm_up_model(self) -> None:
209209
# the model initialization and profiling.
210210
set_random_seed(self.model_config.seed)
211211

212-
def get_expert_load(self) -> str:
213-
""" todo 一共几个worker"""
214-
moe_load = self.model_runner.do_get_expert_load()
215-
return moe_load
216-
212+
def get_expert_load(self) -> tuple:
213+
return self.model_runner.do_get_expert_load()
217214
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
218215
self.model_runner.do_update_expert_load_statistical_period(num_expert_load_gather, num_iterations)
219216

0 commit comments

Comments
 (0)