diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 39ae1701cb..3e65ae23f8 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -368,6 +368,7 @@ def __init__( self.ep_group = get_ep_group() self.params_dtype = torch.get_default_dtype() + self.rm_router_logits = self.experts.rm_router_logits def forward(self, hidden_states: torch.Tensor, @@ -390,7 +391,9 @@ def forward(self, is_prefill = is_prefill or attn_metadata.with_prefill_across_dp # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) + router_logits = None + if not self.rm_router_logits: + router_logits, _ = self.gate(hidden_states) experts_hidden_states = self.experts( hidden_states=hidden_states, @@ -399,6 +402,7 @@ def forward(self, top_k=CustomDeepseekV2MoE.top_k, enable_force_load_balance=enable_force_load_balance, shared_experts=self.shared_experts, + gate=self.gate, replace_allreduce=replace_allreduce) hidden_states = ( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index ff777a143e..9e6ca13cb1 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -45,7 +45,8 @@ from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.utils import (FusedMoEState, dispose_tensor, get_all_reduce_merge_state, get_fused_moe_state, - is_310p, npu_stream_switch, npu_wait_tensor) + get_rm_router_logits_state, is_310p, + npu_stream_switch, npu_wait_tensor) MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER @@ -1148,6 +1149,8 @@ def __init__( self.global_redundant_expert_num = 0 is_deepseek_v3_r1 = self.global_num_experts == 256 + self.rm_router_logits = get_rm_router_logits_state( + self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1) self.all_reduce_merge = get_all_reduce_merge_state( self.moe_parallel_config.ep_size, is_deepseek_v3_r1) @@ -1240,7 +1243,9 @@ def forward(self, enable_force_load_balance: bool = False, top_k: Optional[int] = None, shared_experts: Optional[Any] = None, + gate=None, replace_allreduce: bool = False): + assert self.quant_method is not None if top_k: @@ -1277,6 +1282,7 @@ def forward(self, tp_rank = get_tensor_model_parallel_rank() hidden_states = chunk_hidden_states[tp_rank] router_logits = chunk_router_logits[tp_rank] + if self.dp_size > 1: if fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 @@ -1289,19 +1295,27 @@ def forward(self, hidden_states, (0, 0, 0, max_num_tokens_across_dp - num_tokens)) - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, - max_num_tokens_across_dp - num_tokens)) + if not self.rm_router_logits: + router_logits = nn.functional.pad( + router_logits, + (0, 0, 0, + max_num_tokens_across_dp - num_tokens)) hidden_states = get_dp_group().all_gather(hidden_states, 0) - router_logits = get_dp_group().all_gather(router_logits, 0) + if self.rm_router_logits: + router_logits, _ = gate(hidden_states) + else: + router_logits = get_dp_group().all_gather(router_logits, 0) + elif fused_moe_state == FusedMoEState.NaiveMulticast: cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu hidden_states = self.naive_multicast(hidden_states, cu_tokens_across_dp_cpu) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_cpu) + if self.rm_router_logits: + router_logits, _ = gate(hidden_states) + else: + router_logits = self.naive_multicast( + router_logits, cu_tokens_across_dp_cpu) # Matrix multiply. e_hidden_states = self.quant_method.apply( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 2fcc4f0c15..d2e61586c1 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -425,6 +425,22 @@ class FusedMoEState(Enum): NaiveMulticast = 4 +# TODO(ttanzhiqiang): rm_router_logits +# dp>1 will trigger +# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors. +def get_rm_router_logits_state(ep_size: int, dp_size: int, + is_deepseek_v3_r1: bool): + # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep + # only supports deepseek v3/r1 + if dp_size > 1: + if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 + and is_deepseek_v3_r1): + return True + elif ep_size == 1 and is_deepseek_v3_r1: + return True + return False + + # TODO(ttanzhiqiang): all_reduce merge # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce # Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.