Skip to content

Commit 78a00c3

Browse files
committed
support moe multistream in deepseek
Signed-off-by: David9857 <985700846@qq.com>
1 parent 17f05b1 commit 78a00c3

File tree

5 files changed

+106
-22
lines changed

5 files changed

+106
-22
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_ENABLE_CV_PARALLEL":
40+
lambda: bool(int(os.getenv("VLLM_ENABLE_CV_PARALLEL", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/models/deepseek_v2.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch
3131
import torch.distributed as dist
3232
import torch_npu
33+
import torchair as tng
3334
import vllm.envs as envs
3435
from torch import nn
3536
from transformers import PretrainedConfig
@@ -70,6 +71,7 @@
7071
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7172

7273
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
74+
VLLM_ENABLE_CV_PARALLEL: bool = envs_ascend.VLLM_ENABLE_CV_PARALLEL
7375

7476

7577
class CustomDeepseekV2MLP(nn.Module):
@@ -224,8 +226,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224226
enable_force_load_balance = False
225227
num_tokens, hidden_dim = hidden_states.shape
226228

227-
if self.n_shared_experts is not None:
228-
shared_output = self.shared_experts(hidden_states)
229+
cv_parallel = VLLM_ENABLE_CV_PARALLEL and not is_prefill
230+
231+
if self.n_shared_experts is not None
232+
if not cv_parallel:
233+
shared_output = self.shared_experts(hidden_states)
234+
else:
235+
shared_hidden_states = hidden_states
229236

230237
if self.tp_size > 1:
231238
# pass
@@ -247,13 +254,37 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247254
# router_logits: (num_tokens, n_experts)
248255
router_logits, _ = self.gate(local_hidden_states)
249256

250-
router_hidden_states = self.experts(
251-
hidden_states=local_hidden_states,
252-
router_logits=router_logits,
253-
is_prefill=is_prefill,
254-
top_k=CustomDeepseekV2MoE.top_k,
255-
enable_force_load_balance=enable_force_load_balance,
256-
) * self.routed_scaling_factor
257+
if self.n_shared_experts is not None and cv_parallel:
258+
with tng.scope.npu_stream_switch('cv'):
259+
tng.scope.npu_wait_tensor(shared_hidden_states, router_logits)
260+
x, dynamic_scale = torch_npu.npu_dynamic_quant(shared_hidden_states)
261+
gate_up = torch_npu.npu_quant_matmul(
262+
x,
263+
self.shared_experts.gate_up_proj.weight,
264+
self.shared_experts.gate_up_proj.weight_scale,
265+
output_dtype=torch.int32,
266+
)
267+
268+
if cv_parallel:
269+
router_hidden_states, shared_output = self.experts(
270+
hidden_states=local_hidden_states,
271+
router_logits=router_logits,
272+
is_prefill=is_prefill,
273+
top_k=CustomDeepseekV2MoE.top_k,
274+
enable_force_load_balance=enable_force_load_balance,
275+
shared_experts=self.shared_experts,
276+
shared_gate_up=gate_up,
277+
shared_dynamic_scale=dynamic_scale
278+
)
279+
router_hidden_states = router_hidden_states * self.routed_scaling_factor
280+
else:
281+
router_hidden_states = self.experts(
282+
hidden_states=local_hidden_states,
283+
router_logits=router_logits,
284+
is_prefill=is_prefill,
285+
top_k=CustomDeepseekV2MoE.top_k,
286+
enable_force_load_balance=enable_force_load_balance,
287+
) * self.routed_scaling_factor
257288

258289
if self.tp_size > 1:
259290
dist.all_gather(list(chunk_hidden_states), router_hidden_states,

vllm_ascend/ops/fused_moe.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
4444
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
45+
VLLM_ENABLE_CV_PARALLEL: bool = envs_ascend.VLLM_ENABLE_CV_PARALLEL
4546

4647

4748
def fused_experts_with_mc2(
@@ -694,7 +695,8 @@ def forward(self,
694695
router_logits: torch.Tensor,
695696
is_prefill: bool,
696697
enable_force_load_balance: bool = False,
697-
top_k=None):
698+
top_k=None,
699+
**kwargs):
698700
assert self.quant_method is not None
699701

700702
if top_k:
@@ -722,7 +724,11 @@ def forward(self,
722724
e_score_correction_bias=self.e_score_correction_bias,
723725
is_prefill=is_prefill,
724726
enable_force_load_balance=enable_force_load_balance,
725-
dp_size=self.dp_size)
727+
dp_size=self.dp_size,
728+
**kwargs)
729+
730+
if VLLM_ENABLE_CV_PARALLEL and not is_prefill:
731+
final_hidden_states, shared_output = final_hidden_states
726732

727733
if VLLM_ENABLE_MC2 and not is_prefill:
728734
...
@@ -731,4 +737,6 @@ def forward(self,
731737
final_hidden_states = tensor_model_parallel_all_reduce(
732738
final_hidden_states)
733739

740+
if VLLM_ENABLE_CV_PARALLEL and not is_prefill:
741+
return final_hidden_states, shared_output
734742
return final_hidden_states

vllm_ascend/quantization/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def apply(
330330
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
331331
global_num_experts, expert_map, topk_group, num_expert_group,
332332
custom_routing_function, scoring_func, e_score_correction_bias,
333-
is_prefill, enable_force_load_balance, dp_size)
333+
is_prefill, enable_force_load_balance, dp_size, **kwargs)
334334

335335
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
336336
if hasattr(self.quant_method, "process_weights_after_loading"):

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
23-
from vllm.distributed import GroupCoordinator
23+
import torchair as tng
24+
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
2425

2526
import vllm_ascend.envs as envs_ascend
2627
from vllm_ascend.distributed.parallel_state import get_ep_group
@@ -36,7 +37,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
3637
w2_scale: torch.Tensor,
3738
group_list: torch.Tensor,
3839
dynamic_scale: torch.Tensor = None,
39-
group_list_type: int = 1) -> torch.Tensor:
40+
group_list_type: int = 1,
41+
**kwargs) -> torch.Tensor:
4042
"""
4143
apply MLP: gate_up_proj -> swiglu -> down_proj
4244
@@ -68,6 +70,23 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
6870
else:
6971
pertoken_scale = dynamic_scale
7072

73+
shared_experts = kwargs.get('shared_experts', None)
74+
if shared_experts:
75+
shared_gate_up = kwargs.get('shared_gate_up', None)
76+
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
77+
with tng.scope.npu_stream_switch('1'):
78+
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
79+
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
80+
x=shared_gate_up,
81+
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
82+
activation_scale=shared_dynamic_scale,
83+
bias=None,
84+
quant_scale=None,
85+
quant_offset=None,
86+
group_index=None,
87+
activate_left=True,
88+
quant_mode=1)
89+
7190
# gmm1: gate_up_proj
7291
hidden_states = torch_npu.npu_grouped_matmul(
7392
x=[hidden_states],
@@ -96,6 +115,21 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
96115
group_type=0,
97116
group_list=group_list,
98117
output_dtype=w2_scale.dtype)[0]
118+
119+
if shared_experts:
120+
with tng.scope.npu_stream_switch('1'):
121+
tng.scope.npu_wait_tensor(shared_x, hidden_states)
122+
shared_output = torch_npu.npu_quant_matmul(
123+
shared_x,
124+
shared_experts.down_proj.weight,
125+
shared_experts.down_proj.weight_scale,
126+
pertoken_scale=shared_dynamic_scale,
127+
output_dtype=torch.bfloat16,
128+
)
129+
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
130+
shared_output = tensor_model_parallel_all_reduce(x)
131+
if shared_experts:
132+
return hidden_states, shared_output
99133
return hidden_states
100134

101135

@@ -110,11 +144,12 @@ def fused_experts_with_mc2(
110144
top_k: int,
111145
expert_map: torch.Tensor = None,
112146
moe_all_to_all_group_name: str = "",
147+
**kwargs
113148
) -> torch.Tensor:
114149
global_bs = 0
115150
moe_expert_num = len(expert_map)
116151
# hidden_states = hidden_states.bfloat16()
117-
kwargs = {
152+
kwargs1 = {
118153
"x": hidden_states,
119154
"expert_ids": topk_ids,
120155
"expert_shard_type": 0,
@@ -145,9 +180,9 @@ def fused_experts_with_mc2(
145180
"tp_world_size": tp_size,
146181
"tp_rank_id": tp_rank,
147182
}
148-
kwargs.update(stage1_kwargs)
183+
kwargs1.update(stage1_kwargs)
149184

150-
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
185+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs1)
151186
# comm_stream.wait_stream(torch.npu.current_stream())
152187
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
153188
0:5]
@@ -165,10 +200,15 @@ def fused_experts_with_mc2(
165200
w2,
166201
w2_scale,
167202
expert_token_nums,
168-
dynamic_scale=dynamic_scale)
203+
dynamic_scale=dynamic_scale,
204+
**kwargs)
205+
206+
multi_stream = isinstance(down_out_list, tuple)
207+
if multi_stream:
208+
down_out_list, shared_output = down_out_list
169209

170210
# moeCombine
171-
kwargs = {
211+
kwargs2 = {
172212
"expand_x": down_out_list,
173213
"expert_ids": topk_ids,
174214
"expand_idx": expand_idx,
@@ -192,10 +232,12 @@ def fused_experts_with_mc2(
192232
"tp_world_size": tp_size,
193233
"tp_rank_id": tp_rank,
194234
}
195-
kwargs.update(stage3_kwargs)
235+
kwargs2.update(stage3_kwargs)
196236

197-
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
237+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs2)
198238

239+
if multi_stream:
240+
return hidden_states, shared_output
199241
return hidden_states
200242

201243

@@ -634,7 +676,8 @@ def apply(
634676
topk_ids=topk_ids,
635677
top_k=top_k,
636678
expert_map=expert_map,
637-
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
679+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
680+
**kwargs)
638681
elif dp_size == 1:
639682
return fused_experts(hidden_states=x,
640683
w1=layer.w13_weight,

0 commit comments

Comments
 (0)