You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### What this PR does / why we need it?
The previous code is
router_logits, _ = self.gate(hidden_states)
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
I want to change the two all_gathers to one, reduce one all_gather
communication, and make it
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits, _ = self.gate(hidden_states)
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
bash examples/run_dp_attention_etp16.sh
bash examples/run_dp_attention_etp16_benmark.sh
gsm8k accuracy verification
<img width="1809" alt="截屏2025-06-24 21 53 24"
src="https://github.com/user-attachments/assets/47eace3b-a86b-41b4-9de8-773f57fea33b"
/>
- vLLM version: v0.9.2
- vLLM main:
vllm-project/vllm@77f77a9
---------
Signed-off-by: ttanzhiqiang <389825161@qq.com>
Copy file name to clipboardExpand all lines: vllm_ascend/utils.py
+16Lines changed: 16 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -439,6 +439,22 @@ class FusedMoEState(Enum):
439
439
NaiveMulticast=4
440
440
441
441
442
+
# TODO(ttanzhiqiang): rm_router_logits
443
+
# dp>1 will trigger
444
+
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
448
+
# only supports deepseek v3/r1
449
+
ifdp_size>1:
450
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EPandep_size>1
451
+
andis_deepseek_v3_r1):
452
+
returnTrue
453
+
elifep_size==1andis_deepseek_v3_r1:
454
+
returnTrue
455
+
returnFalse
456
+
457
+
442
458
# TODO(ttanzhiqiang): all_reduce merge
443
459
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
444
460
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
0 commit comments