Skip to content

Commit 2dba24d

Browse files
author
lt
committed
Merge branch 'br_main_into_eplb' of https://github.com/raindaywhu/vllm-ascend into br_main_into_eplb
2 parents 83f2d51 + 5225f3c commit 2dba24d

File tree

3 files changed

+31
-19
lines changed

3 files changed

+31
-19
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,25 +85,36 @@ def init_expert_param_per_layer(self):
8585
def get_rank_expert_workload(
8686
self,
8787
num_moe_layers: int,
88+
dummy_run = False
8889
) -> torch.Tensor:
89-
# 收集各层 topk_ids -> list of [B, K]
90+
9091
all_topk_ids = [self.model.get_topk_ids(i) for i in range(num_moe_layers)]
91-
# stack & flatten -> ids2d: [L, B*K]
92-
stacked = torch.stack(all_topk_ids, dim=0) # [L, B, K]
92+
stacked = torch.stack(all_topk_ids, dim=0)
9393
L, B, K = stacked.shape
94-
ids2d = stacked.view(L, B * K).to(torch.int64) # [L, N]
94+
N = B * K
95+
device = stacked.device
96+
G = self.global_expert_num
97+
98+
if not hasattr(self, "cum_moe_load") or self.cum_moe_load is None:
99+
self.cum_moe_load = torch.zeros((L, G),
100+
dtype=torch.int64,
101+
device=device)
102+
103+
if dummy_run:
104+
return self.cum_moe_load
105+
106+
ids1d = stacked.view(-1).to(torch.int64)
95107

96-
device = ids2d.device
97-
moe_load = torch.zeros((L, self.global_expert_num),
98-
dtype=torch.int64, device=device)
108+
row_idx = torch.arange(L, device=device).repeat_interleave(N)
99109

100-
ones2d = torch.ones_like(ids2d, dtype=torch.int64)
110+
combined = row_idx * G + ids1d
101111

102-
assert moe_load.dim() == 2 and ids2d.dim() == 2 and ones2d.dim() == 2
103-
assert ids2d.shape == ones2d.shape
112+
counts = torch.bincount(combined, minlength=L * G)
113+
workload = counts.view(L, G)
104114

105-
moe_load.scatter_add_(dim=1, index=ids2d, src=ones2d)
106-
return moe_load
115+
self.cum_moe_load.add_(workload)
116+
117+
return self.cum_moe_load
107118

108119
def get_init_expert_map(self, num_moe_layers):
109120
expert_map = self.model.get_all_expert_map(num_moe_layers)

vllm_ascend/eplb/eplb_updator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,12 @@ def forward_before(self):
130130
self.reqs = []
131131
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
132132

133-
def forward_end(self):
133+
def forward_end(self,dummy_run=False):
134+
self.adaptor.get_rank_expert_workload(self.num_moe_layers,dummy_run)
134135
if not self.update_in_flight:
135136
load_gather_iteration, update_iteration = self.get_update_iteration()
136137
if load_gather_iteration:
137-
self.moe_load = self.compute_and_set_moe_load()
138+
moe_load = self.compute_and_set_moe_load(dummy_run)
138139
if update_iteration:
139140
self.wakeup_eplb_worker()
140141
self.update_in_flight = True
@@ -146,9 +147,8 @@ def forward_end(self):
146147

147148
self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable)
148149

149-
def compute_and_set_moe_load(self):
150-
local_load = self.adaptor.get_rank_expert_workload(self.num_moe_layers)
151-
150+
def compute_and_set_moe_load(self,dummy_run=False):
151+
local_load = self.adaptor.get_rank_expert_workload(self.num_moe_layers,dummy_run)
152152
self._gather_buffer = None
153153
if dist.is_initialized():
154154
self.world_size = dist.get_world_size()
@@ -161,7 +161,7 @@ def compute_and_set_moe_load(self):
161161

162162
dist.all_gather_into_tensor(self._gather_buffer, local_load)
163163

164-
moe_load = self._gather_buffer.permute(1, 0, 2).contiguous()
164+
moe_load = self._gather_buffer.permute(1, 0, 2)
165165
self.shared_dict["moe_load"] = moe_load.cpu()
166166
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
167167
else:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1544,7 +1544,8 @@ def _dummy_run(
15441544
inputs_embeds=inputs_embeds)
15451545

15461546
if not is_compile and not is_profile_run and self.dynamic_eplb:
1547-
self.eplb_updator.forward_end()
1547+
dummy_run = True
1548+
self.eplb_updator.forward_end(dummy_run)
15481549
return hidden_states
15491550

15501551
def profile_run(self) -> None:

0 commit comments

Comments
 (0)