|
15 | 15 | # This file is a part of the vllm-ascend project.
|
16 | 16 | # Adapted from vllm/tests/kernels/test_moe.py
|
17 | 17 |
|
18 |
| -from typing import Any, Callable, List, Optional |
| 18 | +from typing import Any, Callable, List, Optional, Tuple, Union |
19 | 19 |
|
20 | 20 | import torch
|
21 | 21 | import torch.distributed as dist
|
|
34 | 34 | import vllm_ascend.envs as envs_ascend
|
35 | 35 | from vllm_ascend.ascend_config import get_ascend_config
|
36 | 36 | from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
| 37 | +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor |
37 | 38 |
|
38 | 39 | VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
39 | 40 | USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
@@ -104,15 +105,17 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
104 | 105 | return topk_ids_pad, unpad_indices
|
105 | 106 |
|
106 | 107 |
|
107 |
| -def fused_experts_with_mc2(hidden_states: torch.Tensor, |
108 |
| - w1: torch.Tensor, |
109 |
| - w2: torch.Tensor, |
110 |
| - topk_weights: torch.Tensor, |
111 |
| - topk_ids: torch.Tensor, |
112 |
| - top_k: int, |
113 |
| - expert_map: torch.Tensor = None, |
114 |
| - moe_all_to_all_group_name: Optional[str] = None, |
115 |
| - **kwargs) -> torch.Tensor: |
| 108 | +def fused_experts_with_mc2( |
| 109 | + hidden_states: torch.Tensor, |
| 110 | + w1: torch.Tensor, |
| 111 | + w2: torch.Tensor, |
| 112 | + topk_weights: torch.Tensor, |
| 113 | + topk_ids: torch.Tensor, |
| 114 | + top_k: int, |
| 115 | + expert_map: torch.Tensor = None, |
| 116 | + moe_all_to_all_group_name: Optional[str] = None, |
| 117 | + shared_experts: Optional[Any] = None |
| 118 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
116 | 119 | global_bs = 0
|
117 | 120 | moe_expert_num = len(expert_map)
|
118 | 121 | kwargs_mc2 = {
|
@@ -152,6 +155,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
152 | 155 | expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
153 | 156 | 0:5]
|
154 | 157 |
|
| 158 | + if shared_experts is not None: |
| 159 | + with npu_stream_switch("moe_secondary", 0): |
| 160 | + npu_wait_tensor(hidden_states, topk_weights) |
| 161 | + shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) |
| 162 | + npu_wait_tensor(shared_gate_up, expand_x) |
| 163 | + shared_act = shared_experts.act_fn(shared_gate_up) |
| 164 | + |
155 | 165 | w1 = w1.transpose(1, 2)
|
156 | 166 |
|
157 | 167 | group_list = expert_token_nums.to(torch.int64)
|
@@ -208,7 +218,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
208 | 218 |
|
209 | 219 | hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
210 | 220 |
|
211 |
| - return hidden_states |
| 221 | + if shared_experts is None: |
| 222 | + return hidden_states |
| 223 | + else: |
| 224 | + with npu_stream_switch("moe_secondary", 0): |
| 225 | + npu_wait_tensor(shared_act, down_out_list) |
| 226 | + shared_hidden_states, _ = shared_experts.down_proj(shared_act) |
| 227 | + return hidden_states, shared_hidden_states |
212 | 228 |
|
213 | 229 |
|
214 | 230 | def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
@@ -873,6 +889,7 @@ def apply(
|
873 | 889 | e_score_correction_bias: Optional[torch.Tensor] = None,
|
874 | 890 | is_prefill: bool = False,
|
875 | 891 | enable_force_load_balance: bool = False,
|
| 892 | + shared_experts: Optional[Any] = None, |
876 | 893 | **kwargs,
|
877 | 894 | ) -> torch.Tensor:
|
878 | 895 |
|
@@ -922,7 +939,7 @@ def apply(
|
922 | 939 | top_k=top_k,
|
923 | 940 | expert_map=expert_map,
|
924 | 941 | moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
925 |
| - **kwargs) |
| 942 | + shared_experts=shared_experts) |
926 | 943 | elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
|
927 | 944 | return fused_experts(hidden_states=x,
|
928 | 945 | w1=layer.w13_weight,
|
|
0 commit comments