Skip to content

Commit 000050a

Browse files
committed
Offload shared experts of MoE to another stream
With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 9d27fcb commit 000050a

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

1818
import os
19-
from typing import Any, Callable, List, Optional
19+
from typing import Any, Callable, List, Optional, Tuple, Union
2020

2121
import torch
2222
import torch.distributed as dist
@@ -36,6 +36,7 @@
3636
from vllm_ascend.ascend_config import get_ascend_config
3737
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
3838
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
39+
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
3940

4041
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
4142
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
@@ -106,15 +107,17 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
106107
return topk_ids_pad, unpad_indices
107108

108109

109-
def fused_experts_with_mc2(hidden_states: torch.Tensor,
110-
w1: torch.Tensor,
111-
w2: torch.Tensor,
112-
topk_weights: torch.Tensor,
113-
topk_ids: torch.Tensor,
114-
top_k: int,
115-
expert_map: torch.Tensor = None,
116-
moe_all_to_all_group_name: Optional[str] = None,
117-
**kwargs) -> torch.Tensor:
110+
def fused_experts_with_mc2(
111+
hidden_states: torch.Tensor,
112+
w1: torch.Tensor,
113+
w2: torch.Tensor,
114+
topk_weights: torch.Tensor,
115+
topk_ids: torch.Tensor,
116+
top_k: int,
117+
expert_map: torch.Tensor = None,
118+
moe_all_to_all_group_name: Optional[str] = None,
119+
shared_experts: Optional[Any] = None
120+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
118121
global_bs = 0
119122
moe_expert_num = len(expert_map)
120123
kwargs_mc2 = {
@@ -154,6 +157,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
154157
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
155158
0:5]
156159

160+
if shared_experts is not None:
161+
with npu_stream_switch("moe_secondary", 0):
162+
npu_wait_tensor(hidden_states, topk_weights)
163+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
164+
npu_wait_tensor(shared_gate_up, expand_x)
165+
shared_act = shared_experts.act_fn(shared_gate_up)
166+
157167
w1 = w1.transpose(1, 2)
158168

159169
group_list = expert_token_nums.to(torch.int64)
@@ -210,7 +220,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
210220

211221
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
212222

213-
return hidden_states
223+
if shared_experts is None:
224+
return hidden_states
225+
else:
226+
with npu_stream_switch("moe_secondary", 0):
227+
npu_wait_tensor(shared_act, down_out_list)
228+
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
229+
return hidden_states, shared_hidden_states
214230

215231

216232
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
@@ -875,6 +891,7 @@ def apply(
875891
e_score_correction_bias: Optional[torch.Tensor] = None,
876892
is_prefill: bool = False,
877893
enable_force_load_balance: bool = False,
894+
shared_experts: Optional[Any] = None,
878895
**kwargs,
879896
) -> torch.Tensor:
880897

@@ -924,7 +941,7 @@ def apply(
924941
top_k=top_k,
925942
expert_map=expert_map,
926943
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
927-
**kwargs)
944+
shared_experts=shared_experts)
928945
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
929946
return fused_experts(hidden_states=x,
930947
w1=layer.w13_weight,

0 commit comments

Comments
 (0)