Skip to content

Commit 78431b3

Browse files
authored
[perf]Support MOE Multi-stream in Deepseek (#947)
### What this PR does / why we need it? Support MOE inner Multi-stream for Deepseek. This feature requires graph mode with mc2 enabled. --------- Signed-off-by: David9857 <985700846@qq.com>
1 parent 908a851 commit 78431b3

File tree

6 files changed

+133
-45
lines changed

6 files changed

+133
-45
lines changed

tests/singlecard/test_ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,6 @@ def test_ascend_config_load_error():
114114
},
115115
}
116116
with VllmRunner("facebook/opt-125m",
117+
enforce_eager=False,
117118
additional_config=input_additional_config_fake_2):
118119
pass

vllm_ascend/ascend_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(self, torchair_graph_config):
5353
"graph_batch_sizes", [])
5454
self.graph_batch_sizes_init = torchair_graph_config.get(
5555
"graph_batch_sizes_init", False)
56+
self.enable_multistream_shared_expert = torchair_graph_config.get(
57+
"enable_multistream_shared_expert", False)
5658

5759
if not isinstance(self.graph_batch_sizes, list):
5860
raise TypeError("graph_batch_sizes must be list[int]")
@@ -105,7 +107,7 @@ def check_ascend_config(vllm_config, enforce_eager):
105107
ascend_config = get_ascend_config()
106108

107109
# Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode.
108-
if ascend_config.torchair_graph_config.enabled and not enforce_eager:
110+
if ascend_config.torchair_graph_config.enabled and enforce_eager:
109111
raise RuntimeError(
110112
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
111113
)

vllm_ascend/models/deepseek_v2.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def __init__(
216216

217217
ascend_config = get_ascend_config()
218218
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
219+
self.enable_multistream_shared_expert = \
220+
ascend_config.torchair_graph_config.enable_multistream_shared_expert
219221

220222
def forward(
221223
self,
@@ -238,6 +240,8 @@ def forward(
238240

239241
num_tokens, hidden_size = hidden_states.shape
240242

243+
multistream = self.enable_multistream_shared_expert and not is_prefill
244+
241245
old_hidden_states = hidden_states.clone()
242246

243247
if self.tp_size > 1:
@@ -259,13 +263,25 @@ def forward(
259263
# router_logits: (num_tokens, n_experts)
260264
router_logits, _ = self.gate(hidden_states)
261265

266+
kwargs = {}
267+
if multistream:
268+
kwargs.update({
269+
"shared_experts": self.shared_experts,
270+
"shared_hidden_states": old_hidden_states
271+
})
272+
262273
hidden_states = self.experts(
263274
hidden_states=hidden_states,
264275
router_logits=router_logits,
265276
is_prefill=is_prefill,
266277
top_k=CustomDeepseekV2MoE.top_k,
267278
enable_force_load_balance=enable_force_load_balance,
268-
) * self.routed_scaling_factor
279+
**kwargs)
280+
281+
if multistream:
282+
hidden_states, shared_output = hidden_states
283+
284+
hidden_states = hidden_states * self.routed_scaling_factor
269285

270286
if self.tp_size > 1:
271287
if self.torchair_graph_enabled:
@@ -288,7 +304,8 @@ def forward(
288304
hidden_states = hidden_states[:-num_padding_tokens]
289305

290306
if self.n_shared_experts is not None:
291-
shared_output = self.shared_experts(old_hidden_states)
307+
if not multistream:
308+
shared_output = self.shared_experts(old_hidden_states)
292309

293310
if shared_output is not None:
294311
hidden_states = hidden_states + shared_output

vllm_ascend/ops/fused_moe.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,18 @@
3939
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
4040

4141

42-
def fused_experts_with_mc2(
43-
hidden_states: torch.Tensor,
44-
w1: torch.Tensor,
45-
w2: torch.Tensor,
46-
topk_weights: torch.Tensor,
47-
topk_ids: torch.Tensor,
48-
top_k: int,
49-
expert_map: torch.Tensor = None,
50-
moe_all_to_all_group_name: Optional[str] = None,
51-
) -> torch.Tensor:
42+
def fused_experts_with_mc2(hidden_states: torch.Tensor,
43+
w1: torch.Tensor,
44+
w2: torch.Tensor,
45+
topk_weights: torch.Tensor,
46+
topk_ids: torch.Tensor,
47+
top_k: int,
48+
expert_map: torch.Tensor = None,
49+
moe_all_to_all_group_name: Optional[str] = None,
50+
**kwargs) -> torch.Tensor:
5251
global_bs = 0
5352
moe_expert_num = len(expert_map)
54-
kwargs = {
53+
kwargs_mc2 = {
5554
"x": hidden_states,
5655
"expert_ids": topk_ids,
5756
"expert_shard_type": 0,
@@ -81,9 +80,9 @@ def fused_experts_with_mc2(
8180
"tp_world_size": tp_size,
8281
"tp_rank_id": tp_rank,
8382
}
84-
kwargs.update(stage1_kwargs)
83+
kwargs_mc2.update(stage1_kwargs)
8584

86-
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
85+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
8786
# comm_stream.wait_stream(torch.npu.current_stream())
8887
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
8988
0:5]
@@ -119,7 +118,7 @@ def fused_experts_with_mc2(
119118
down_out_list = torch.cat(down_out_list, dim=0)
120119

121120
# moeCombine
122-
kwargs = {
121+
kwargs_mc2 = {
123122
"expand_x": down_out_list,
124123
"expert_ids": topk_ids,
125124
"expand_idx": expand_idx,
@@ -141,9 +140,9 @@ def fused_experts_with_mc2(
141140
"tp_world_size": tp_size,
142141
"tp_rank_id": tp_rank,
143142
}
144-
kwargs.update(stage3_kwargs)
143+
kwargs_mc2.update(stage3_kwargs)
145144

146-
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
145+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
147146

148147
return hidden_states
149148

@@ -675,7 +674,8 @@ def apply(
675674
topk_ids=topk_ids,
676675
top_k=top_k,
677676
expert_map=expert_map,
678-
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
677+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
678+
**kwargs)
679679
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
680680
return fused_experts(hidden_states=x,
681681
w1=layer.w13_weight,
@@ -772,6 +772,8 @@ def __init__(
772772

773773
ascend_config = get_ascend_config()
774774
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
775+
self.enable_multistream_shared_expert = \
776+
ascend_config.torchair_graph_config.enable_multistream_shared_expert
775777

776778
if self.scoring_func != "softmax" and not self.use_grouped_topk:
777779
raise ValueError("Only softmax scoring function is supported for "
@@ -818,7 +820,8 @@ def forward(self,
818820
router_logits: torch.Tensor,
819821
is_prefill: bool,
820822
enable_force_load_balance: bool = False,
821-
top_k=None):
823+
top_k=None,
824+
**kwargs):
822825
assert self.quant_method is not None
823826

824827
if top_k:
@@ -862,7 +865,11 @@ def forward(self,
862865
scoring_func=self.scoring_func,
863866
e_score_correction_bias=self.e_score_correction_bias,
864867
is_prefill=is_prefill,
865-
enable_force_load_balance=enable_force_load_balance)
868+
enable_force_load_balance=enable_force_load_balance,
869+
**kwargs)
870+
871+
if self.enable_multistream_shared_expert and not is_prefill:
872+
hidden_states, shared_output = hidden_states
866873

867874
if self.dp_size > 1:
868875
if VLLM_ENABLE_MC2 and not is_prefill:
@@ -886,4 +893,6 @@ def forward(self,
886893
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
887894
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
888895

896+
if self.enable_multistream_shared_expert and not is_prefill:
897+
return hidden_states, shared_output
889898
return hidden_states

vllm_ascend/quantization/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def apply(
329329
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
330330
global_num_experts, expert_map, topk_group, num_expert_group,
331331
custom_routing_function, scoring_func, e_score_correction_bias,
332-
is_prefill, enable_force_load_balance)
332+
is_prefill, enable_force_load_balance, **kwargs)
333333

334334
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
335335
if hasattr(self.quant_method, "process_weights_after_loading"):

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
23-
from vllm.distributed import GroupCoordinator
23+
import torchair as tng # type: ignore
24+
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
2425

2526
import vllm_ascend.envs as envs_ascend
2627
from vllm_ascend.ascend_config import get_ascend_config
@@ -38,7 +39,8 @@ def apply_mlp(hidden_states: torch.Tensor,
3839
w2_scale: torch.Tensor,
3940
group_list: torch.Tensor,
4041
dynamic_scale: torch.Tensor = None,
41-
group_list_type: int = 1) -> torch.Tensor:
42+
group_list_type: int = 1,
43+
**kwargs) -> torch.Tensor:
4244
"""
4345
apply MLP: gate_up_proj -> swiglu -> down_proj
4446
@@ -72,6 +74,23 @@ def apply_mlp(hidden_states: torch.Tensor,
7274
else:
7375
pertoken_scale = dynamic_scale
7476

77+
shared_experts = kwargs.get('shared_experts', None)
78+
if shared_experts:
79+
shared_gate_up = kwargs.get('shared_gate_up', None)
80+
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
81+
with tng.scope.npu_stream_switch('cv'):
82+
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
83+
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
84+
x=shared_gate_up,
85+
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
86+
activation_scale=shared_dynamic_scale,
87+
bias=None,
88+
quant_scale=None,
89+
quant_offset=None,
90+
group_index=None,
91+
activate_left=True,
92+
quant_mode=1)
93+
7594
# gmm1: gate_up_proj
7695
hidden_states = torch_npu.npu_grouped_matmul(
7796
x=[hidden_states],
@@ -100,25 +119,39 @@ def apply_mlp(hidden_states: torch.Tensor,
100119
group_type=0,
101120
group_list=group_list,
102121
output_dtype=w2_scale.dtype)[0]
122+
123+
if shared_experts:
124+
with tng.scope.npu_stream_switch('cv'):
125+
tng.scope.npu_wait_tensor(shared_x, hidden_states)
126+
shared_output = torch_npu.npu_quant_matmul(
127+
shared_x,
128+
shared_experts.down_proj.weight,
129+
shared_experts.down_proj.weight_scale,
130+
pertoken_scale=shared_dynamic_scale,
131+
output_dtype=torch.bfloat16,
132+
)
133+
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
134+
shared_output = tensor_model_parallel_all_reduce(shared_output)
135+
if shared_experts:
136+
return hidden_states, shared_output
103137
return hidden_states
104138

105139

106-
def fused_experts_with_mc2(
107-
hidden_states: torch.Tensor,
108-
w1: torch.Tensor,
109-
w2: torch.Tensor,
110-
w1_scale: torch.Tensor,
111-
w2_scale: torch.Tensor,
112-
topk_weights: torch.Tensor,
113-
topk_ids: torch.Tensor,
114-
top_k: int,
115-
expert_map: torch.Tensor = None,
116-
moe_all_to_all_group_name: str = "",
117-
) -> torch.Tensor:
140+
def fused_experts_with_mc2(hidden_states: torch.Tensor,
141+
w1: torch.Tensor,
142+
w2: torch.Tensor,
143+
w1_scale: torch.Tensor,
144+
w2_scale: torch.Tensor,
145+
topk_weights: torch.Tensor,
146+
topk_ids: torch.Tensor,
147+
top_k: int,
148+
expert_map: torch.Tensor = None,
149+
moe_all_to_all_group_name: str = "",
150+
**kwargs) -> torch.Tensor:
118151
global_bs = 0
119152
moe_expert_num = len(expert_map)
120153
# hidden_states = hidden_states.bfloat16()
121-
kwargs = {
154+
kwargs_mc2 = {
122155
"x": hidden_states,
123156
"expert_ids": topk_ids,
124157
"expert_shard_type": 0,
@@ -149,9 +182,27 @@ def fused_experts_with_mc2(
149182
"tp_world_size": tp_size,
150183
"tp_rank_id": tp_rank,
151184
}
152-
kwargs.update(stage1_kwargs)
185+
kwargs_mc2.update(stage1_kwargs)
186+
187+
shared_experts = kwargs.get('shared_experts', None)
188+
if shared_experts:
189+
shared_hidden_states = kwargs.get('shared_hidden_states', None)
190+
with tng.scope.npu_stream_switch('cv'):
191+
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
192+
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
193+
shared_hidden_states)
194+
shared_gate_up = torch_npu.npu_quant_matmul(
195+
shared_x,
196+
shared_experts.gate_up_proj.weight,
197+
shared_experts.gate_up_proj.weight_scale,
198+
output_dtype=torch.int32,
199+
)
200+
kwargs.update({
201+
"shared_gate_up": shared_gate_up,
202+
"shared_dynamic_scale": shared_dynamic_scale,
203+
})
153204

154-
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
205+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
155206
# comm_stream.wait_stream(torch.npu.current_stream())
156207
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
157208
0:5]
@@ -166,10 +217,15 @@ def fused_experts_with_mc2(
166217
w2,
167218
w2_scale,
168219
expert_token_nums,
169-
dynamic_scale=dynamic_scale)
220+
dynamic_scale=dynamic_scale,
221+
**kwargs)
222+
223+
multi_stream = isinstance(down_out_list, tuple)
224+
if multi_stream:
225+
down_out_list, shared_output = down_out_list
170226

171227
# moeCombine
172-
kwargs = {
228+
kwargs_mc2 = {
173229
"expand_x": down_out_list,
174230
"expert_ids": topk_ids,
175231
"expand_idx": expand_idx,
@@ -193,10 +249,12 @@ def fused_experts_with_mc2(
193249
"tp_world_size": tp_size,
194250
"tp_rank_id": tp_rank,
195251
}
196-
kwargs.update(stage3_kwargs)
252+
kwargs_mc2.update(stage3_kwargs)
197253

198-
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
254+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
199255

256+
if multi_stream:
257+
return hidden_states, shared_output
200258
return hidden_states
201259

202260

@@ -634,7 +692,8 @@ def apply(
634692
topk_ids=topk_ids,
635693
top_k=top_k,
636694
expert_map=expert_map,
637-
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
695+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
696+
**kwargs)
638697
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
639698
return fused_experts(hidden_states=x,
640699
w1=layer.w13_weight,

0 commit comments

Comments
 (0)