Skip to content

Commit e4fe832

Browse files
committed
Offload shared experts of MoE to another stream
With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent eee09d8 commit e4fe832

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# This file is a part of the vllm-ascend project.
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

18-
from typing import Any, Callable, List, Optional
18+
from typing import Any, Callable, List, Optional, Tuple, Union
1919

2020
import torch
2121
import torch.distributed as dist
@@ -34,6 +34,7 @@
3434
import vllm_ascend.envs as envs_ascend
3535
from vllm_ascend.ascend_config import get_ascend_config
3636
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
37+
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
3738

3839
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3940
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
@@ -104,15 +105,17 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
104105
return topk_ids_pad, unpad_indices
105106

106107

107-
def fused_experts_with_mc2(hidden_states: torch.Tensor,
108-
w1: torch.Tensor,
109-
w2: torch.Tensor,
110-
topk_weights: torch.Tensor,
111-
topk_ids: torch.Tensor,
112-
top_k: int,
113-
expert_map: torch.Tensor = None,
114-
moe_all_to_all_group_name: Optional[str] = None,
115-
**kwargs) -> torch.Tensor:
108+
def fused_experts_with_mc2(
109+
hidden_states: torch.Tensor,
110+
w1: torch.Tensor,
111+
w2: torch.Tensor,
112+
topk_weights: torch.Tensor,
113+
topk_ids: torch.Tensor,
114+
top_k: int,
115+
expert_map: torch.Tensor = None,
116+
moe_all_to_all_group_name: Optional[str] = None,
117+
shared_experts: Optional[Any] = None
118+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
116119
global_bs = 0
117120
moe_expert_num = len(expert_map)
118121
kwargs_mc2 = {
@@ -152,6 +155,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
152155
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
153156
0:5]
154157

158+
if shared_experts is not None:
159+
with npu_stream_switch("moe_secondary", 0):
160+
npu_wait_tensor(hidden_states, topk_weights)
161+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
162+
npu_wait_tensor(shared_gate_up, expand_x)
163+
shared_act = shared_experts.act_fn(shared_gate_up)
164+
155165
w1 = w1.transpose(1, 2)
156166

157167
group_list = expert_token_nums.to(torch.int64)
@@ -208,7 +218,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
208218

209219
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
210220

211-
return hidden_states
221+
if shared_experts is None:
222+
return hidden_states
223+
else:
224+
with npu_stream_switch("moe_secondary", 0):
225+
npu_wait_tensor(shared_act, down_out_list)
226+
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
227+
return hidden_states, shared_hidden_states
212228

213229

214230
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
@@ -873,6 +889,7 @@ def apply(
873889
e_score_correction_bias: Optional[torch.Tensor] = None,
874890
is_prefill: bool = False,
875891
enable_force_load_balance: bool = False,
892+
shared_experts: Optional[Any] = None,
876893
**kwargs,
877894
) -> torch.Tensor:
878895

@@ -922,7 +939,7 @@ def apply(
922939
top_k=top_k,
923940
expert_map=expert_map,
924941
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
925-
**kwargs)
942+
shared_experts=shared_experts)
926943
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
927944
return fused_experts(hidden_states=x,
928945
w1=layer.w13_weight,

0 commit comments

Comments
 (0)