Skip to content

Commit 1ef0f68

Browse files
w00800020sdmyzlp
authored andcommitted
Offload calculation shared experts to another stream
With the expected overlaping being: | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | Shared experts will be replicated regardless of TP, to avoid AllReduce comm. Controlled by option VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT, defaulted to off. Signed-off-by: w00800020 <weijinyi3@huawei.com>
1 parent 036a36e commit 1ef0f68

File tree

4 files changed

+120
-34
lines changed

4 files changed

+120
-34
lines changed

tests/multicard/test_offline_inference_distributed.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Run `pytest tests/test_offline_inference.py`.
2222
"""
2323
import os
24+
from unittest.mock import patch
2425

2526
import vllm # noqa: F401
2627

@@ -61,3 +62,20 @@ def test_models_distributed_DeepSeek():
6162
distributed_executor_backend="mp",
6263
) as vllm_model:
6364
vllm_model.generate_greedy(example_prompts, max_tokens)
65+
66+
@patch.dict(os.environ, {"VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT": "1"})
67+
def test_models_distributed_multistream_shared_expert():
68+
example_prompts = [
69+
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
70+
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
71+
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
72+
]
73+
dtype = "half"
74+
max_tokens = 5
75+
with VllmRunner(
76+
"deepseek-ai/DeepSeek-V2-Lite",
77+
dtype=dtype,
78+
tensor_parallel_size=4,
79+
distributed_executor_backend="mp",
80+
) as vllm_model:
81+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT":
40+
lambda: bool(int(os.getenv("VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/models/deepseek_v2.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,35 @@ def __init__(
8282
hidden_act: str,
8383
quant_config: Optional[QuantizationConfig] = None,
8484
reduce_results: bool = True,
85+
force_replicate: bool = False,
8586
prefix: str = "",
8687
) -> None:
8788
super().__init__()
88-
self.gate_up_proj = MergedColumnParallelLinear(
89-
hidden_size, [intermediate_size] * 2,
90-
bias=False,
91-
quant_config=quant_config,
92-
prefix=f"{prefix}.gate_up_proj")
93-
self.down_proj = RowParallelLinear(intermediate_size,
94-
hidden_size,
95-
bias=False,
96-
quant_config=quant_config,
97-
reduce_results=reduce_results,
98-
prefix=f"{prefix}.down_proj")
89+
if not force_replicate:
90+
self.gate_up_proj = MergedColumnParallelLinear(
91+
hidden_size, [intermediate_size] * 2,
92+
bias=False,
93+
quant_config=quant_config,
94+
prefix=f"{prefix}.gate_up_proj")
95+
self.down_proj = RowParallelLinear(intermediate_size,
96+
hidden_size,
97+
bias=False,
98+
quant_config=quant_config,
99+
reduce_results=reduce_results,
100+
prefix=f"{prefix}.down_proj")
101+
else:
102+
self.gate_up_proj = ReplicatedLinear(
103+
hidden_size,
104+
intermediate_size * 2,
105+
bias=False,
106+
quant_config=quant_config,
107+
prefix=f"{prefix}.gate_up_proj")
108+
self.down_proj = ReplicatedLinear(intermediate_size,
109+
hidden_size,
110+
bias=False,
111+
quant_config=quant_config,
112+
prefix=f"{prefix}.down_proj")
113+
99114
if hidden_act != "silu":
100115
raise ValueError(f"Unsupported activation: {hidden_act}. "
101116
"Only silu is supported for now.")
@@ -202,8 +217,12 @@ def __init__(
202217
hidden_act=config.hidden_act,
203218
quant_config=quant_config,
204219
reduce_results=True,
220+
force_replicate=envs_ascend.
221+
VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT,
205222
prefix=f"{prefix}.shared_experts",
206223
)
224+
else:
225+
self.shared_experts = None # type: ignore
207226
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
208227

209228
self.dp_size = get_dp_group().world_size
@@ -224,8 +243,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224243
is_prefill = attn_metadata.num_prefills > 0
225244
enable_force_load_balance = False
226245
num_tokens, hidden_dim = hidden_states.shape
246+
use_separated_shared_expert = (
247+
self.n_shared_experts is not None
248+
and not envs_ascend.VLLM_ENABLE_MULTISTREAM_SHARED_EXPERT)
227249

228-
if self.n_shared_experts is not None:
250+
if use_separated_shared_expert:
229251
shared_output = self.shared_experts(hidden_states)
230252

231253
if self.tp_size > 1:
@@ -248,13 +270,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
248270
# router_logits: (num_tokens, n_experts)
249271
router_logits, _ = self.gate(local_hidden_states)
250272

251-
router_hidden_states = self.experts(
273+
experts_hidden_states = self.experts(
252274
hidden_states=local_hidden_states,
253275
router_logits=router_logits,
254276
is_prefill=is_prefill,
255277
top_k=CustomDeepseekV2MoE.top_k,
256278
enable_force_load_balance=enable_force_load_balance,
257-
) * self.routed_scaling_factor
279+
shared_experts=(self.shared_experts
280+
if not use_separated_shared_expert else None),
281+
)
282+
283+
if not isinstance(experts_hidden_states, tuple):
284+
router_hidden_states = experts_hidden_states * self.routed_scaling_factor
285+
else:
286+
router_hidden_states = (
287+
experts_hidden_states[0] * self.routed_scaling_factor +
288+
experts_hidden_states[1])
258289

259290
if self.tp_size > 1:
260291
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
@@ -265,7 +296,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
265296
else:
266297
final_hidden_states = router_hidden_states
267298

268-
if shared_output is not None:
299+
if use_separated_shared_expert:
269300
final_hidden_states = final_hidden_states + shared_output
270301

271302
return final_hidden_states.view(num_tokens, hidden_dim)

vllm_ascend/ops/fused_moe.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import vllm_ascend.envs as envs_ascend
3535
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
36+
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
3637

3738
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3839
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
@@ -47,6 +48,8 @@ def fused_experts_with_mc2(
4748
top_k: int,
4849
expert_map: torch.Tensor = None,
4950
moe_all_to_all_group_name: Optional[str] = None,
51+
shared_experts: Optional[torch.nn.Module] = None,
52+
graph_mode: bool = False,
5053
) -> torch.Tensor:
5154
global_bs = 0
5255
moe_expert_num = len(expert_map)
@@ -88,6 +91,10 @@ def fused_experts_with_mc2(
8891
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
8992
0:5]
9093

94+
if shared_experts is not None:
95+
with npu_stream_switch("expert_secondary", 0, enabled=graph_mode):
96+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
97+
9198
w1 = w1.transpose(1, 2)
9299
expert_token_nums = torch.cumsum(expert_token_nums,
93100
dim=0,
@@ -102,6 +109,11 @@ def fused_experts_with_mc2(
102109
group_list=group_list,
103110
)
104111

112+
if shared_experts is not None:
113+
with npu_stream_switch("expert_secondary", 0, enabled=graph_mode):
114+
npu_wait_tensor(shared_gate_up, expand_x, enabled=graph_mode)
115+
shared_act = shared_experts.act_fn(shared_gate_up)
116+
105117
# TODO: Remove this in the future.
106118
gate_up_out = torch.cat(gate_up_out_list, dim=0)
107119
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
@@ -145,7 +157,15 @@ def fused_experts_with_mc2(
145157

146158
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
147159

148-
return hidden_states
160+
if shared_experts is not None:
161+
with npu_stream_switch("expert_secondary", 0, enabled=graph_mode):
162+
npu_wait_tensor(shared_act, down_out_list, enabled=graph_mode)
163+
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
164+
165+
if shared_experts is None:
166+
return hidden_states
167+
else:
168+
return hidden_states, shared_hidden_states
149169

150170

151171
# currently expert parallelism implemented with all2all
@@ -587,6 +607,8 @@ def __init__(self, moe: MoEConfig = None):
587607
self.ep_size = ep_group.world_size
588608
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
589609
self.local_batch_size = self.global_batch_size // self.ep_size
610+
self.graph_mode = vllm_config.get("additional_config",
611+
{}).get("enable_graph_mode", False)
590612

591613
try:
592614
device_group = ep_group.device_group
@@ -624,6 +646,7 @@ def apply(
624646
scoring_func: str = "softmax",
625647
e_score_correction_bias: Optional[torch.Tensor] = None,
626648
is_prefill: bool = False,
649+
shared_experts: Optional[torch.nn.Module] = None,
627650
**kwargs,
628651
):
629652
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
@@ -664,28 +687,37 @@ def apply(
664687
topk_ids=topk_ids,
665688
top_k=top_k,
666689
expert_map=expert_map,
667-
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
690+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
691+
shared_experts=shared_experts,
692+
graph_mode=self.graph_mode,
693+
)
668694
elif get_ep_group().world_size == 1:
669-
return fused_experts(hidden_states=x,
670-
w1=layer.w13_weight,
671-
w2=layer.w2_weight,
672-
topk_weights=topk_weights,
673-
topk_ids=topk_ids,
674-
top_k=top_k,
675-
expert_map=expert_map)
695+
router_hidden_states = fused_experts(hidden_states=x,
696+
w1=layer.w13_weight,
697+
w2=layer.w2_weight,
698+
topk_weights=topk_weights,
699+
topk_ids=topk_ids,
700+
top_k=top_k,
701+
expert_map=expert_map)
676702
else:
677703
# The current implementation of deepseek moe splits hidden_states
678704
# according to tp_size before they are feed into fused_moe module.
679705
# Therefore, all2all is needed no matter how dp/tp is set so as to
680706
# dispatch/combine tokens.
681-
return fused_experts_with_all2all(hidden_states=x,
682-
w1=layer.w13_weight,
683-
w2=layer.w2_weight,
684-
topk_weights=topk_weights,
685-
topk_ids=topk_ids,
686-
top_k=top_k,
687-
expert_map=expert_map,
688-
ep_group=get_ep_group())
707+
router_hidden_states = fused_experts_with_all2all(
708+
hidden_states=x,
709+
w1=layer.w13_weight,
710+
w2=layer.w2_weight,
711+
topk_weights=topk_weights,
712+
topk_ids=topk_ids,
713+
top_k=top_k,
714+
expert_map=expert_map,
715+
ep_group=get_ep_group())
716+
717+
if shared_experts is None:
718+
return router_hidden_states
719+
else:
720+
return router_hidden_states, shared_experts(x)
689721

690722

691723
class AscendFusedMoE(FusedMoE):
@@ -815,7 +847,8 @@ def forward(self,
815847
router_logits: torch.Tensor,
816848
is_prefill: bool,
817849
enable_force_load_balance: bool = False,
818-
top_k=None):
850+
top_k: Optional[int] = None,
851+
shared_experts: Optional[torch.nn.Module] = None):
819852
assert self.quant_method is not None
820853

821854
if top_k:
@@ -842,7 +875,9 @@ def forward(self,
842875
scoring_func=self.scoring_func,
843876
e_score_correction_bias=self.e_score_correction_bias,
844877
is_prefill=is_prefill,
845-
enable_force_load_balance=enable_force_load_balance)
878+
enable_force_load_balance=enable_force_load_balance,
879+
shared_experts=shared_experts,
880+
)
846881

847882
if VLLM_ENABLE_MC2 and not is_prefill:
848883
...

0 commit comments

Comments
 (0)