Skip to content

Commit 9d16c99

Browse files
authored
rm router logits Improve TTOP 3ms (#1407)
### What this PR does / why we need it? The previous code is router_logits, _ = self.gate(hidden_states) hidden_states = get_dp_group().all_gather(hidden_states, 0) router_logits = get_dp_group().all_gather(router_logits, 0) I want to change the two all_gathers to one, reduce one all_gather communication, and make it hidden_states = get_dp_group().all_gather(hidden_states, 0) router_logits, _ = self.gate(hidden_states) ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? bash examples/run_dp_attention_etp16.sh bash examples/run_dp_attention_etp16_benmark.sh gsm8k accuracy verification <img width="1809" alt="截屏2025-06-24 21 53 24" src="https://github.com/user-attachments/assets/47eace3b-a86b-41b4-9de8-773f57fea33b" /> - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@77f77a9 --------- Signed-off-by: ttanzhiqiang <389825161@qq.com>
1 parent 0fc9b56 commit 9d16c99

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def __init__(
367367
self.ep_group = get_ep_group()
368368

369369
self.params_dtype = torch.get_default_dtype()
370+
self.rm_router_logits = self.experts.rm_router_logits
370371

371372
def forward(self,
372373
hidden_states: torch.Tensor,
@@ -389,7 +390,9 @@ def forward(self,
389390
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
390391

391392
# router_logits: (num_tokens, n_experts)
392-
router_logits, _ = self.gate(hidden_states)
393+
router_logits = None
394+
if not self.rm_router_logits:
395+
router_logits, _ = self.gate(hidden_states)
393396

394397
experts_hidden_states = self.experts(
395398
hidden_states=hidden_states,
@@ -398,6 +401,7 @@ def forward(self,
398401
top_k=CustomDeepseekV2MoE.top_k,
399402
enable_force_load_balance=enable_force_load_balance,
400403
shared_experts=self.shared_experts,
404+
gate=self.gate,
401405
replace_allreduce=replace_allreduce)
402406

403407
hidden_states = (

vllm_ascend/ops/fused_moe.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4646
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
4747
get_all_reduce_merge_state, get_fused_moe_state,
48-
is_310p, npu_stream_switch, npu_wait_tensor)
48+
get_rm_router_logits_state, is_310p,
49+
npu_stream_switch, npu_wait_tensor)
4950

5051
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
5152

@@ -1148,6 +1149,8 @@ def __init__(
11481149
self.global_redundant_expert_num = 0
11491150

11501151
is_deepseek_v3_r1 = self.global_num_experts == 256
1152+
self.rm_router_logits = get_rm_router_logits_state(
1153+
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
11511154
self.all_reduce_merge = get_all_reduce_merge_state(
11521155
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
11531156

@@ -1240,7 +1243,9 @@ def forward(self,
12401243
enable_force_load_balance: bool = False,
12411244
top_k: Optional[int] = None,
12421245
shared_experts: Optional[Any] = None,
1246+
gate=None,
12431247
replace_allreduce: bool = False):
1248+
12441249
assert self.quant_method is not None
12451250

12461251
if top_k:
@@ -1277,6 +1282,7 @@ def forward(self,
12771282
tp_rank = get_tensor_model_parallel_rank()
12781283
hidden_states = chunk_hidden_states[tp_rank]
12791284
router_logits = chunk_router_logits[tp_rank]
1285+
12801286
if self.dp_size > 1:
12811287
if fused_moe_state == FusedMoEState.AllGather:
12821288
# NOTE: When in torchair graph, it has been padded in model_runner_v1
@@ -1289,19 +1295,27 @@ def forward(self,
12891295
hidden_states,
12901296
(0, 0, 0,
12911297
max_num_tokens_across_dp - num_tokens))
1292-
router_logits = nn.functional.pad(
1293-
router_logits,
1294-
(0, 0, 0,
1295-
max_num_tokens_across_dp - num_tokens))
1298+
if not self.rm_router_logits:
1299+
router_logits = nn.functional.pad(
1300+
router_logits,
1301+
(0, 0, 0,
1302+
max_num_tokens_across_dp - num_tokens))
12961303
hidden_states = get_dp_group().all_gather(hidden_states, 0)
1297-
router_logits = get_dp_group().all_gather(router_logits, 0)
1304+
if self.rm_router_logits:
1305+
router_logits, _ = gate(hidden_states)
1306+
else:
1307+
router_logits = get_dp_group().all_gather(router_logits, 0)
1308+
12981309
elif fused_moe_state == FusedMoEState.NaiveMulticast:
12991310
cu_tokens_across_dp_cpu = get_forward_context(
13001311
).dp_metadata.cu_tokens_across_dp_cpu
13011312
hidden_states = self.naive_multicast(hidden_states,
13021313
cu_tokens_across_dp_cpu)
1303-
router_logits = self.naive_multicast(router_logits,
1304-
cu_tokens_across_dp_cpu)
1314+
if self.rm_router_logits:
1315+
router_logits, _ = gate(hidden_states)
1316+
else:
1317+
router_logits = self.naive_multicast(
1318+
router_logits, cu_tokens_across_dp_cpu)
13051319

13061320
# Matrix multiply.
13071321
e_hidden_states = self.quant_method.apply(

vllm_ascend/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,22 @@ class FusedMoEState(Enum):
439439
NaiveMulticast = 4
440440

441441

442+
# TODO(ttanzhiqiang): rm_router_logits
443+
# dp>1 will trigger
444+
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
445+
def get_rm_router_logits_state(ep_size: int, dp_size: int,
446+
is_deepseek_v3_r1: bool):
447+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
448+
# only supports deepseek v3/r1
449+
if dp_size > 1:
450+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
451+
and is_deepseek_v3_r1):
452+
return True
453+
elif ep_size == 1 and is_deepseek_v3_r1:
454+
return True
455+
return False
456+
457+
442458
# TODO(ttanzhiqiang): all_reduce merge
443459
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
444460
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.

0 commit comments

Comments
 (0)