Skip to content

Commit 0bab2cd

Browse files
authored
Merge pull request #99 from raindaywhu/lt_expert_load
expert load collecting
2 parents 1b7b87b + 3465ad6 commit 0bab2cd

File tree

4 files changed

+8
-27
lines changed

4 files changed

+8
-27
lines changed

vllm_ascend/eplb/eplb_updator.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def forward_end(self,dummy_run=False):
140140
load_gather_iteration, update_iteration = self.get_update_iteration()
141141
if load_gather_iteration:
142142
moe_load = self.compute_and_set_moe_load()
143+
self.get_expert_load()
143144
if update_iteration:
144145
self.wakeup_eplb_worker()
145146
self.update_in_flight = True
@@ -233,28 +234,12 @@ def unpack_update_batch(self, packed_update_info):
233234
]
234235
return recovered
235236

236-
def get_expert_load(self):
237+
def get_expert_load(self) -> tuple:
237238
expert_maps = self.shared_dict["expert_maps"]
238-
moe_load = self.shared_dict["moe_load"] # Tensor [L, W, global_experts_num]
239+
moe_load = self.shared_dict["moe_load"] # Tensor [L, W, global_experts_num]
239240
num_local_experts = expert_maps.max() + 1
240-
load_info, _ = ExpertMapUtils.global2local_load(moe_load, expert_maps, num_local_experts)
241-
242-
L, W, _ = load_info.shape
243-
244-
expert_load: Dict[str, List[dict]] = {}
245-
for c in range(W):
246-
layers: List[dict] = []
247-
for l in range(L):
248-
counts_1d = load_info[l, c]
249-
250-
layer_val = {
251-
f"expert_{e}": int(v)
252-
for e, v in enumerate(counts_1d.tolist())
253-
}
254-
layers.append({f"layer_{l}": layer_val})
255-
expert_load[f"card_{c}"] = layers
256-
257-
return {"expert_load": expert_load}
241+
return moe_load, expert_maps, num_local_experts
242+
258243

259244
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
260245
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations}...")

vllm_ascend/eplb/tool/eplb_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ def global2local_load(self,
9090
placement: torch.Tensor,
9191
E_local: int
9292
) -> tuple[torch.Tensor, torch.Tensor]:
93-
9493
L, G, _ = placement.shape
9594
device = placement.device
9695

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1588,7 +1588,7 @@ def profile_run(self) -> None:
15881588
self.encoder_cache.clear()
15891589
gc.collect()
15901590

1591-
def do_get_expert_load(self) -> str:
1591+
def do_get_expert_load(self) -> tuple:
15921592
return self.eplb_updator.get_expert_load()
15931593

15941594
def do_update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):

vllm_ascend/worker/worker_v1.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,8 @@ def compile_or_warm_up_model(self) -> None:
209209
# the model initialization and profiling.
210210
set_random_seed(self.model_config.seed)
211211

212-
def get_expert_load(self) -> str:
213-
""" todo 一共几个worker"""
214-
moe_load = self.model_runner.do_get_expert_load()
215-
return moe_load
216-
212+
def get_expert_load(self) -> tuple:
213+
return self.model_runner.do_get_expert_load()
217214
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
218215
self.model_runner.do_update_expert_load_statistical_period(num_expert_load_gather, num_iterations)
219216

0 commit comments

Comments
 (0)