Skip to content

Commit d6286aa

Browse files
author
w00800020
committed
Offload calculation shared experts to another stream
With the expected overlaping being: | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | Shared experts will be replicated regardless of TP, to avoid AllReduce comm. Controlled by option VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT, defaulted to off. Signed-off-by: w00800020 <weijinyi3@huawei.com>
1 parent 1bb1fcf commit d6286aa

File tree

3 files changed

+94
-34
lines changed

3 files changed

+94
-34
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT":
40+
lambda: bool(int(os.getenv("VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/models/deepseek_v2.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,35 @@ def __init__(
8282
hidden_act: str,
8383
quant_config: Optional[QuantizationConfig] = None,
8484
reduce_results: bool = True,
85+
force_replicate: bool = False,
8586
prefix: str = "",
8687
) -> None:
8788
super().__init__()
88-
self.gate_up_proj = MergedColumnParallelLinear(
89-
hidden_size, [intermediate_size] * 2,
90-
bias=False,
91-
quant_config=quant_config,
92-
prefix=f"{prefix}.gate_up_proj")
93-
self.down_proj = RowParallelLinear(intermediate_size,
94-
hidden_size,
95-
bias=False,
96-
quant_config=quant_config,
97-
reduce_results=reduce_results,
98-
prefix=f"{prefix}.down_proj")
89+
if not force_replicate:
90+
self.gate_up_proj = MergedColumnParallelLinear(
91+
hidden_size, [intermediate_size] * 2,
92+
bias=False,
93+
quant_config=quant_config,
94+
prefix=f"{prefix}.gate_up_proj")
95+
self.down_proj = RowParallelLinear(intermediate_size,
96+
hidden_size,
97+
bias=False,
98+
quant_config=quant_config,
99+
reduce_results=reduce_results,
100+
prefix=f"{prefix}.down_proj")
101+
else:
102+
self.gate_up_proj = ReplicatedLinear(
103+
hidden_size,
104+
intermediate_size * 2,
105+
bias=False,
106+
quant_config=quant_config,
107+
prefix=f"{prefix}.gate_up_proj")
108+
self.down_proj = ReplicatedLinear(intermediate_size,
109+
hidden_size,
110+
bias=False,
111+
quant_config=quant_config,
112+
prefix=f"{prefix}.down_proj")
113+
99114
if hidden_act != "silu":
100115
raise ValueError(f"Unsupported activation: {hidden_act}. "
101116
"Only silu is supported for now.")
@@ -202,8 +217,12 @@ def __init__(
202217
hidden_act=config.hidden_act,
203218
quant_config=quant_config,
204219
reduce_results=True,
220+
force_replicate=envs_ascend.
221+
VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT,
205222
prefix=f"{prefix}.shared_experts",
206223
)
224+
else:
225+
self.shared_experts = None # type: ignore
207226
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
208227

209228
self.dp_size = get_dp_group().world_size
@@ -224,8 +243,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224243
is_prefill = attn_metadata.num_prefills > 0
225244
enable_force_load_balance = False
226245
num_tokens, hidden_dim = hidden_states.shape
246+
use_separated_shared_expert = (
247+
self.n_shared_experts is not None
248+
and not envs_ascend.VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT)
227249

228-
if self.n_shared_experts is not None:
250+
if use_separated_shared_expert:
229251
shared_output = self.shared_experts(hidden_states)
230252

231253
if self.tp_size > 1:
@@ -248,13 +270,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
248270
# router_logits: (num_tokens, n_experts)
249271
router_logits, _ = self.gate(local_hidden_states)
250272

251-
router_hidden_states = self.experts(
273+
experts_hidden_states = self.experts(
252274
hidden_states=local_hidden_states,
253275
router_logits=router_logits,
254276
is_prefill=is_prefill,
255277
top_k=CustomDeepseekV2MoE.top_k,
256278
enable_force_load_balance=enable_force_load_balance,
257-
) * self.routed_scaling_factor
279+
shared_experts=(self.shared_experts
280+
if not use_separated_shared_expert else None),
281+
)
282+
283+
if not isinstance(experts_hidden_states, tuple):
284+
router_hidden_states = experts_hidden_states * self.routed_scaling_factor
285+
else:
286+
router_hidden_states = (
287+
experts_hidden_states[0] * self.routed_scaling_factor +
288+
experts_hidden_states[1])
258289

259290
if self.tp_size > 1:
260291
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
@@ -265,7 +296,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
265296
else:
266297
final_hidden_states = router_hidden_states
267298

268-
if shared_output is not None:
299+
if use_separated_shared_expert:
269300
final_hidden_states = final_hidden_states + shared_output
270301

271302
return final_hidden_states.view(num_tokens, hidden_dim)

vllm_ascend/ops/fused_moe.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import vllm_ascend.envs as envs_ascend
3535
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
36+
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
3637

3738
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3839
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
@@ -47,6 +48,7 @@ def fused_experts_with_mc2(
4748
top_k: int,
4849
expert_map: torch.Tensor = None,
4950
moe_all_to_all_group_name: Optional[str] = None,
51+
shared_experts: Optional[torch.nn.Module] = None,
5052
) -> torch.Tensor:
5153
global_bs = 0
5254
moe_expert_num = len(expert_map)
@@ -83,11 +85,20 @@ def fused_experts_with_mc2(
8385
}
8486
kwargs.update(stage1_kwargs)
8587

88+
if shared_experts is not None:
89+
with npu_stream_switch("expert_secondary"):
90+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
91+
8692
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
8793
# comm_stream.wait_stream(torch.npu.current_stream())
8894
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
8995
0:5]
9096

97+
if shared_experts is not None:
98+
with npu_stream_switch("expert_secondary"):
99+
npu_wait_tensor(shared_gate_up, expand_x)
100+
shared_act = shared_experts.act_fn(shared_gate_up)
101+
91102
w1 = w1.transpose(1, 2)
92103
expert_token_nums = torch.cumsum(expert_token_nums,
93104
dim=0,
@@ -118,6 +129,11 @@ def fused_experts_with_mc2(
118129

119130
down_out_list = torch.cat(down_out_list, dim=0)
120131

132+
if shared_experts is not None:
133+
with npu_stream_switch("expert_secondary"):
134+
npu_wait_tensor(shared_act, down_out_list)
135+
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
136+
121137
# moeCombine
122138
kwargs = {
123139
"expand_x": down_out_list,
@@ -145,7 +161,7 @@ def fused_experts_with_mc2(
145161

146162
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
147163

148-
return hidden_states
164+
return hidden_states, shared_hidden_states if shared_experts is not None else None
149165

150166

151167
# currently expert parallelism implemented with all2all
@@ -624,6 +640,7 @@ def apply(
624640
scoring_func: str = "softmax",
625641
e_score_correction_bias: Optional[torch.Tensor] = None,
626642
is_prefill: bool = False,
643+
shared_experts: Optional[torch.nn.Module] = None,
627644
**kwargs,
628645
):
629646
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
@@ -664,28 +681,35 @@ def apply(
664681
topk_ids=topk_ids,
665682
top_k=top_k,
666683
expert_map=expert_map,
667-
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
684+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
685+
shared_experts=shared_experts,
686+
)
668687
elif get_ep_group().world_size == 1:
669-
return fused_experts(hidden_states=x,
670-
w1=layer.w13_weight,
671-
w2=layer.w2_weight,
672-
topk_weights=topk_weights,
673-
topk_ids=topk_ids,
674-
top_k=top_k,
675-
expert_map=expert_map)
688+
router_hidden_states = fused_experts(hidden_states=x,
689+
w1=layer.w13_weight,
690+
w2=layer.w2_weight,
691+
topk_weights=topk_weights,
692+
topk_ids=topk_ids,
693+
top_k=top_k,
694+
expert_map=expert_map)
676695
else:
677696
# The current implementation of deepseek moe splits hidden_states
678697
# according to tp_size before they are feed into fused_moe module.
679698
# Therefore, all2all is needed no matter how dp/tp is set so as to
680699
# dispatch/combine tokens.
681-
return fused_experts_with_all2all(hidden_states=x,
682-
w1=layer.w13_weight,
683-
w2=layer.w2_weight,
684-
topk_weights=topk_weights,
685-
topk_ids=topk_ids,
686-
top_k=top_k,
687-
expert_map=expert_map,
688-
ep_group=get_ep_group())
700+
router_hidden_states = fused_experts_with_all2all(hidden_states=x,
701+
w1=layer.w13_weight,
702+
w2=layer.w2_weight,
703+
topk_weights=topk_weights,
704+
topk_ids=topk_ids,
705+
top_k=top_k,
706+
expert_map=expert_map,
707+
ep_group=get_ep_group())
708+
709+
if shared_experts is None:
710+
return router_hidden_states
711+
else:
712+
return router_hidden_states, shared_experts(x)
689713

690714

691715
class AscendFusedMoE(FusedMoE):
@@ -815,7 +839,8 @@ def forward(self,
815839
router_logits: torch.Tensor,
816840
is_prefill: bool,
817841
enable_force_load_balance: bool = False,
818-
top_k=None):
842+
top_k: Optional[int] = None,
843+
shared_experts: Optional[torch.nn.Module] = None):
819844
assert self.quant_method is not None
820845

821846
if top_k:
@@ -842,7 +867,9 @@ def forward(self,
842867
scoring_func=self.scoring_func,
843868
e_score_correction_bias=self.e_score_correction_bias,
844869
is_prefill=is_prefill,
845-
enable_force_load_balance=enable_force_load_balance)
870+
enable_force_load_balance=enable_force_load_balance,
871+
shared_experts=shared_experts,
872+
)
846873

847874
if VLLM_ENABLE_MC2 and not is_prefill:
848875
...

0 commit comments

Comments
 (0)