|
16 | 16 | # Adapted from vllm/tests/kernels/test_moe.py
|
17 | 17 |
|
18 | 18 | import os
|
19 |
| -from typing import Any, Callable, List, Optional |
| 19 | +from typing import Any, Callable, List, Optional, Tuple, Union |
20 | 20 |
|
21 | 21 | import torch
|
22 | 22 | import torch.distributed as dist
|
|
36 | 36 | from vllm_ascend.ascend_config import get_ascend_config
|
37 | 37 | from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
38 | 38 | from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
| 39 | +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor |
39 | 40 |
|
40 | 41 | VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
41 | 42 | USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
@@ -106,15 +107,17 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
106 | 107 | return topk_ids_pad, unpad_indices
|
107 | 108 |
|
108 | 109 |
|
109 |
| -def fused_experts_with_mc2(hidden_states: torch.Tensor, |
110 |
| - w1: torch.Tensor, |
111 |
| - w2: torch.Tensor, |
112 |
| - topk_weights: torch.Tensor, |
113 |
| - topk_ids: torch.Tensor, |
114 |
| - top_k: int, |
115 |
| - expert_map: torch.Tensor = None, |
116 |
| - moe_all_to_all_group_name: Optional[str] = None, |
117 |
| - **kwargs) -> torch.Tensor: |
| 110 | +def fused_experts_with_mc2( |
| 111 | + hidden_states: torch.Tensor, |
| 112 | + w1: torch.Tensor, |
| 113 | + w2: torch.Tensor, |
| 114 | + topk_weights: torch.Tensor, |
| 115 | + topk_ids: torch.Tensor, |
| 116 | + top_k: int, |
| 117 | + expert_map: torch.Tensor = None, |
| 118 | + moe_all_to_all_group_name: Optional[str] = None, |
| 119 | + shared_experts: Optional[Any] = None |
| 120 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
118 | 121 | global_bs = 0
|
119 | 122 | moe_expert_num = len(expert_map)
|
120 | 123 | kwargs_mc2 = {
|
@@ -154,6 +157,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
154 | 157 | expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
155 | 158 | 0:5]
|
156 | 159 |
|
| 160 | + if shared_experts is not None: |
| 161 | + with npu_stream_switch("moe_secondary", 0): |
| 162 | + npu_wait_tensor(hidden_states, topk_weights) |
| 163 | + shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) |
| 164 | + npu_wait_tensor(shared_gate_up, expand_x) |
| 165 | + shared_act = shared_experts.act_fn(shared_gate_up) |
| 166 | + |
157 | 167 | w1 = w1.transpose(1, 2)
|
158 | 168 |
|
159 | 169 | group_list = expert_token_nums.to(torch.int64)
|
@@ -210,7 +220,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
210 | 220 |
|
211 | 221 | hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
212 | 222 |
|
213 |
| - return hidden_states |
| 223 | + if shared_experts is None: |
| 224 | + return hidden_states |
| 225 | + else: |
| 226 | + with npu_stream_switch("moe_secondary", 0): |
| 227 | + npu_wait_tensor(shared_act, down_out_list) |
| 228 | + shared_hidden_states, _ = shared_experts.down_proj(shared_act) |
| 229 | + return hidden_states, shared_hidden_states |
214 | 230 |
|
215 | 231 |
|
216 | 232 | def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
@@ -875,6 +891,7 @@ def apply(
|
875 | 891 | e_score_correction_bias: Optional[torch.Tensor] = None,
|
876 | 892 | is_prefill: bool = False,
|
877 | 893 | enable_force_load_balance: bool = False,
|
| 894 | + shared_experts: Optional[Any] = None, |
878 | 895 | **kwargs,
|
879 | 896 | ) -> torch.Tensor:
|
880 | 897 |
|
@@ -924,7 +941,7 @@ def apply(
|
924 | 941 | top_k=top_k,
|
925 | 942 | expert_map=expert_map,
|
926 | 943 | moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
927 |
| - **kwargs) |
| 944 | + shared_experts=shared_experts) |
928 | 945 | elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
|
929 | 946 | return fused_experts(hidden_states=x,
|
930 | 947 | w1=layer.w13_weight,
|
|
0 commit comments