Skip to content

Commit 75992b9

Browse files
Merge pull request #106 from raindaywhu/dev_whq_eplb2
collect moe load after dispatch
2 parents 0bab2cd + e4cba5e commit 75992b9

File tree

7 files changed

+58
-78
lines changed

7 files changed

+58
-78
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -84,42 +84,13 @@ def init_expert_param_per_layer(self):
8484
for name in self.expert_weight_names]
8585
)
8686

87-
def collect_topk_ids(self, dummy_run=False):
88-
if dummy_run:
89-
return
90-
self.all_topk_ids.append(self.model.get_all_topk_ids(self.num_moe_layers))
87+
# def collect_topk_ids(self, dummy_run=False):
88+
# if dummy_run:
89+
# return
90+
# self.all_topk_ids.append(self.model.get_all_topk_ids(self.num_moe_layers))
9191

9292
def get_rank_expert_workload(self) -> torch.Tensor:
93-
device = self.all_topk_ids[0][0].device
94-
if not hasattr(self, "moe_load"):
95-
self.moe_load = torch.zeros(
96-
(self.num_moe_layers), self.global_expert_num,
97-
dtype=torch.int64,
98-
device=self.all_topk_ids[0][0].device,
99-
)
100-
else:
101-
self.moe_load.zero_()
102-
# pass
103-
flat_list_per_layer = [[] for _ in range(self.num_moe_layers)]
104-
105-
for period_data in self.all_topk_ids:
106-
for l in range(self.num_moe_layers):
107-
t = period_data[l]
108-
flat_list_per_layer[l].append(t.reshape(-1))
109-
110-
index_2d = torch.nn.utils.rnn.pad_sequence(
111-
[torch.cat(flat_list_per_layer[l]) for l in range(self.num_moe_layers)],
112-
batch_first=True, padding_value=-1
113-
).to(device)
114-
115-
mask = index_2d != -1
116-
index_2d = index_2d.masked_select(mask).reshape(self.num_moe_layers, -1)
117-
src_2d = torch.ones_like(index_2d, dtype=torch.int64)
118-
119-
self.moe_load.scatter_add_(dim=1, index=index_2d, src=src_2d)
120-
121-
if self.all_topk_ids:
122-
self.all_topk_ids[:] = self.all_topk_ids[-1:]
93+
self.moe_load = self.model.get_all_moe_loads()
12394
return self.moe_load
12495

12596
def get_init_expert_map(self, num_moe_layers):
@@ -224,13 +195,13 @@ def determine_expert_map_all(self):
224195

225196
for r in range(self.world_size):
226197
if r < self.world_size - 1:
227-
start = r * local_num_experts
228-
end = (r + 1) * local_num_experts
229-
local_count = local_num_experts
198+
start = r * local_num_experts
199+
end = (r + 1) * local_num_experts
200+
local_count = local_num_experts
230201
else:
231-
start = r * local_num_experts
202+
start = r * local_num_experts
232203
end = self.global_expert_num
233-
local_count = self.global_expert_num - r * local_num_experts
204+
local_count = self.global_expert_num - r * local_num_experts
234205

235206
local_ids = torch.arange(local_count, dtype=torch.int32)
236207
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(self.num_moe_layers, -1)

vllm_ascend/eplb/core/worker/eplb_worker.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def do_update(self):
6161
return
6262

6363
#根据负载信息,获取更新后的专家表
64-
load_info, old_placement = self.global2local(load_info, self.old_expert_maps, self.num_local_experts)
64+
old_placement = self.global2local(self.old_expert_maps, self.num_local_experts)
6565
changed, priority, new_placement = self.calculate_rebalance_experts(load_info, old_placement)
6666

6767
if not torch.is_tensor(new_placement):
@@ -276,18 +276,13 @@ def update_expert_map(self, expert_maps):
276276
self.shared_dict["expert_maps"] = expert_maps
277277

278278
def global2local(self,
279-
workload: torch.Tensor,
280279
placement: torch.Tensor,
281280
E_local: int
282281
) -> tuple[torch.Tensor, torch.Tensor]:
283282

284283
L, G, _ = placement.shape
285284
device = placement.device
286285

287-
wt_local = torch.full((L, G, E_local),
288-
fill_value=-1,
289-
dtype=workload.dtype,
290-
device=device)
291286
pt_local = torch.full((L, G, E_local),
292287
fill_value=-1,
293288
dtype=torch.long,
@@ -297,12 +292,10 @@ def global2local(self,
297292
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
298293

299294
slot_idx = placement[l_idx, g_idx, k_idx]
300-
values = workload[l_idx, g_idx, k_idx]
301295

302-
wt_local[l_idx, g_idx, slot_idx] = values
303296
pt_local[l_idx, g_idx, slot_idx] = k_idx
304297

305-
return wt_local, pt_local
298+
return pt_local
306299

307300

308301
def local2global(self,

vllm_ascend/eplb/eplb_updator.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def init_eplb(self, expert_map_path):
8080
shared_dict = self.shared_dict,
8181
planner_q = self.planner_block_queue,
8282
block_update_q = self.block_update_queue,
83-
redundant_enable = self.redundant_enable,
84-
policy_type = 6,
83+
redundant_enable = self.redundant_enable,
84+
policy_type = 1,
8585
enable_d2d = True
8686
)
8787

@@ -91,8 +91,8 @@ def init_eplb(self, expert_map_path):
9191

9292
def get_update_iteration(self):
9393
self.cur_iterations = self.cur_iterations + 1
94-
load_gather_iteration = self.cur_iterations % self.num_expert_load_gather == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
95-
upate_iteration = self.cur_iterations % self.num_iterations == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
94+
load_gather_iteration = self.cur_iterations % self.num_expert_load_gather == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
95+
upate_iteration = self.cur_iterations % self.num_iterations == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
9696
return load_gather_iteration, upate_iteration
9797

9898
def get_init_expert_map(self):
@@ -135,7 +135,6 @@ def forward_before(self):
135135

136136

137137
def forward_end(self,dummy_run=False):
138-
self.adaptor.collect_topk_ids(dummy_run)
139138
if not self.update_in_flight:
140139
load_gather_iteration, update_iteration = self.get_update_iteration()
141140
if load_gather_iteration:
@@ -174,12 +173,12 @@ def compute_and_set_moe_load(self,dummy_run=False):
174173
moe_load = local_load.unsqueeze(1)
175174
self.shared_dict["moe_load"] = moe_load.cpu()
176175
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
176+
self.adaptor.model.clear_all_moe_loads()
177177
return moe_load
178178

179179
def warm_up_eplb(self):
180180

181181
self.get_init_expert_map()
182-
self.adaptor.collect_topk_ids(dummy_run=False)
183182
self.compute_and_set_moe_load()
184183

185184
src_tensor = torch.empty((1,), device=self.device)
@@ -240,7 +239,6 @@ def get_expert_load(self) -> tuple:
240239
num_local_experts = expert_maps.max() + 1
241240
return moe_load, expert_maps, num_local_experts
242241

243-
244242
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
245243
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations}...")
246244
self.num_expert_load_gather = num_expert_load_gather

vllm_ascend/models/deepseek_v2.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
727727
quant_config = vllm_config.quant_config
728728
self.config = config
729729
self.quant_config = quant_config
730+
self.num_dense_layers = self.config.first_k_dense_replace
731+
self.num_moe_layers = self.config.num_hidden_layers - self.num_dense_layers
732+
730733
self.model = CustomDeepseekV2Model(vllm_config=vllm_config,
731734
prefix=maybe_prefix(
732735
prefix, "model"))
@@ -755,29 +758,31 @@ def forward(
755758
inputs_embeds)
756759
return hidden_states
757760

758-
def get_expert_map(self,layer_id):
761+
def get_expert_map(self, layer_id):
759762
return self.model.layers[layer_id].mlp.experts.get_map()
760763

761-
def get_log2phy_map(self,layer_id):
764+
def get_log2phy_map(self, layer_id):
762765
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
763766

764-
def get_all_expert_map(self,num_moe_layers):
767+
def get_all_expert_map(self, num_moe_layers):
765768
all_loads = []
766769
for layer_id in range(num_moe_layers):
767770
load_tensor = self.get_expert_map(3+layer_id) # (num_experts_per_layer,)
768771
all_loads.append(load_tensor)
769772

770773
return torch.stack(all_loads, dim=0)
771774

772-
def get_topk_ids(self,layer_id):
773-
return self.model.layers[layer_id+3].mlp.experts.topk_ids
775+
def get_all_moe_loads(self):
776+
all_moe_loads = torch.stack(
777+
[self.model.layers[layer_id + self.num_dense_layers].mlp.experts.moe_load \
778+
for layer_id in range(self.num_moe_layers)],
779+
dim=0
780+
)
781+
return all_moe_loads
774782

775-
def get_all_topk_ids(self,num_moe_layers):
776-
all_topk_id = []
777-
for layer_id in range(num_moe_layers):
778-
load_tensor = self.get_topk_ids(layer_id)
779-
all_topk_id.append(load_tensor)
780-
return all_topk_id
783+
def clear_all_moe_loads(self):
784+
for layer_id in range(self.num_moe_layers):
785+
self.model.layers[layer_id + self.num_dense_layers].mlp.experts.clear_moe_load()
781786

782787
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
783788
pass

vllm_ascend/ops/fused_moe.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,8 +1012,6 @@ def __init__(
10121012

10131013
AscendFusedMoE.moe_counter += 1
10141014
self.moe_instance_id = AscendFusedMoE.moe_counter
1015-
self.moe_load = None
1016-
self.topk_ids = None
10171015

10181016
if params_dtype is None:
10191017
params_dtype = torch.get_default_dtype()
@@ -1103,6 +1101,10 @@ def __init__(
11031101
local_num_experts = torch.sum(self.expert_map != -1) \
11041102
if self.expert_map is not None else num_experts
11051103

1104+
self.moe_load = None
1105+
if self.dynamic_eplb:
1106+
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
1107+
11061108
moe_quant_params = {
11071109
"num_experts": local_num_experts,
11081110
"hidden_size": hidden_size,
@@ -1176,7 +1178,7 @@ def forward(self,
11761178
router_logits = get_dp_group().all_gather(router_logits, 0)
11771179

11781180
# Matrix multiply.
1179-
e_hidden_states, self.topk_ids = self.quant_method.apply(
1181+
e_hidden_states, expert_token_num, group_list_type = self.quant_method.apply(
11801182
layer=self,
11811183
x=hidden_states,
11821184
router_logits=router_logits,
@@ -1198,6 +1200,10 @@ def forward(self,
11981200
and self.enable_multistream_moe and not is_prefill else None,
11991201
)
12001202

1203+
if self.dynamic_eplb:
1204+
self.moe_load += expert_token_num if group_list_type else \
1205+
torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]])
1206+
12011207
if shared_experts:
12021208
if isinstance(e_hidden_states, tuple):
12031209
e_hidden_states, shared_hidden_states = e_hidden_states
@@ -1267,3 +1273,7 @@ def get_map(self):
12671273
def get_log2phy_map(self):
12681274
return self.log2phy
12691275

1276+
def clear_moe_load(self):
1277+
self.moe_load.zero_()
1278+
1279+

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,14 @@ def fused_experts_with_mc2(
208208

209209
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
210210

211+
group_list_type = 1
211212
if shared_experts is None:
212-
return hidden_states
213+
return hidden_states, expert_token_nums, group_list_type
213214
else:
214215
with npu_stream_switch("moe_secondary", 0):
215216
npu_wait_tensor(shared_act[0], down_out_list)
216217
shared_output, _ = shared_experts.down_proj(shared_act)
217-
return hidden_states, shared_output
218+
return hidden_states, shared_output, expert_token_nums, group_list_type
218219

219220

220221
# currently expert parallelism implemented with all2all
@@ -343,7 +344,7 @@ def fused_experts_with_all2all(
343344
)
344345
if len(original_shape) == 3:
345346
final_hidden_states = final_hidden_states.view(original_shape)
346-
return final_hidden_states
347+
return final_hidden_states, expert_tokens, group_list_type
347348

348349

349350
def fused_experts(hidden_states: torch.Tensor,
@@ -457,7 +458,7 @@ def fused_experts(hidden_states: torch.Tensor,
457458

458459
if len(original_shape) == 3:
459460
final_hidden_states = final_hidden_states.view(original_shape)
460-
return final_hidden_states
461+
return final_hidden_states, expert_tokens, group_list_type
461462

462463

463464
class AscendW8A8DynamicLinearMethod:
@@ -677,7 +678,7 @@ def apply(
677678
log2phy=log2phy,
678679
global_redundant_expert_num=global_redundant_expert_num,
679680
shared_experts=shared_experts,
680-
**kwargs), topk_ids
681+
**kwargs)
681682
elif fused_moe_state == FusedMoEState.AllGather:
682683
return fused_experts(hidden_states=x,
683684
w1=layer.w13_weight,
@@ -687,7 +688,7 @@ def apply(
687688
topk_weights=topk_weights,
688689
topk_ids=topk_ids,
689690
top_k=top_k,
690-
expert_map=expert_map), topk_ids
691+
expert_map=expert_map)
691692
else:
692693
# The current implementation of deepseek moe splits hidden_states
693694
# according to tp_size before they are feed into fused_moe module.
@@ -706,7 +707,7 @@ def apply(
706707
ep_group=self.ep_group,
707708
log2phy=log2phy,
708709
global_redundant_expert_num=global_redundant_expert_num,
709-
), topk_ids
710+
)
710711

711712
def process_weights_after_loading(self, layer):
712713
if self.transpose_weight:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
366366
self.dp_size = vllm_config.parallel_config.data_parallel_size
367367
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
368368

369-
#EPLB
369+
#EPLB
370370
self.dynamic_eplb = ascend_config.dynamic_eplb
371371
if self.dynamic_eplb == True:
372372
self.eplb_adaptor = None
@@ -1240,7 +1240,7 @@ def execute_model(
12401240

12411241
if self.dynamic_eplb:
12421242
self.eplb_updator.forward_before()
1243-
1243+
12441244
(attn_metadata, hidden_states, spec_decode_metadata, positions,
12451245
num_scheduled_tokens,
12461246
sample_indices) = (self._process_reqs(scheduler_output,
@@ -1544,6 +1544,8 @@ def _dummy_run(
15441544
intermediate_tensors=intermediate_tensors,
15451545
inputs_embeds=inputs_embeds)
15461546

1547+
if is_profile_run and self.dynamic_eplb:
1548+
self.model.clear_all_moe_loads()
15471549
if not is_compile and not is_profile_run and self.dynamic_eplb:
15481550
dummy_run = True
15491551
self.eplb_updator.forward_end(dummy_run)

0 commit comments

Comments
 (0)