Skip to content

Commit 2e824cd

Browse files
authored
Merge pull request #73 from raindaywhu/br_wjh_eplb
Extract cal_moe_load from deepseek_v2
2 parents 2403b59 + 4bda9ba commit 2e824cd

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,28 @@ def get_expert_tensor(self, layer_id, global_expert_id_to_send):
6969
return [self.param_dict["model.layers." + str(layer_id) + ".mlp.experts." + name].data[local_expert_id]
7070
for name in self.expert_weight_names]
7171

72-
def get_rank_expert_workload(self, num_moe_layers):
73-
return self.model.get_all_moe_loads(num_moe_layers, self.global_expert_num)
72+
def get_rank_expert_workload(
73+
self,
74+
num_moe_layers: int,
75+
) -> torch.Tensor:
76+
# 收集各层 topk_ids -> list of [B, K]
77+
all_topk_ids = [self.model.get_topk_ids(i) for i in range(num_moe_layers)]
78+
# stack & flatten -> ids2d: [L, B*K]
79+
stacked = torch.stack(all_topk_ids, dim=0) # [L, B, K]
80+
L, B, K = stacked.shape
81+
ids2d = stacked.view(L, B * K).to(torch.int64) # [L, N]
82+
83+
device = ids2d.device
84+
moe_load = torch.zeros((L, self.global_expert_num),
85+
dtype=torch.int64, device=device)
86+
87+
ones2d = torch.ones_like(ids2d, dtype=torch.int64)
88+
89+
assert moe_load.dim() == 2 and ids2d.dim() == 2 and ones2d.dim() == 2
90+
assert ids2d.shape == ones2d.shape
91+
92+
moe_load.scatter_add_(dim=1, index=ids2d, src=ones2d)
93+
return moe_load
7494

7595
def get_init_expert_map(self, num_moe_layers):
7696
expert_map = self.model.get_all_expert_map(num_moe_layers)

vllm_ascend/models/deepseek_v2.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -812,30 +812,6 @@ def get_all_expert_map(self,num_moe_layers):
812812

813813
def get_topk_ids(self,layer_id):
814814
return self.model.layers[layer_id+3].mlp.experts.topk_ids
815-
816-
def get_all_moe_loads(
817-
self,
818-
num_moe_layers: int,
819-
num_experts_per_layer: int
820-
) -> torch.Tensor:
821-
# 收集各层 topk_ids -> list of [B, K]
822-
all_topk_ids = [self.get_topk_ids(i) for i in range(num_moe_layers)]
823-
# stack & flatten -> ids2d: [L, B*K]
824-
stacked = torch.stack(all_topk_ids, dim=0) # [L, B, K]
825-
L, B, K = stacked.shape
826-
ids2d = stacked.view(L, B * K).to(torch.int64) # [L, N]
827-
828-
device = ids2d.device
829-
moe_load = torch.zeros((L, num_experts_per_layer),
830-
dtype=torch.int64, device=device)
831-
832-
ones2d = torch.ones_like(ids2d, dtype=torch.int64)
833-
834-
assert moe_load.dim() == 2 and ids2d.dim() == 2 and ones2d.dim() == 2
835-
assert ids2d.shape == ones2d.shape
836-
837-
moe_load.scatter_add_(dim=1, index=ids2d, src=ones2d)
838-
return moe_load
839815

840816
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
841817
pass

0 commit comments

Comments
 (0)