Skip to content

Commit 25e3d2c

Browse files
committed
support moe multistream in deepseek
1 parent 17f05b1 commit 25e3d2c

File tree

4 files changed

+103
-21
lines changed

4 files changed

+103
-21
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch
3131
import torch.distributed as dist
3232
import torch_npu
33+
import torchair as tng
3334
import vllm.envs as envs
3435
from torch import nn
3536
from transformers import PretrainedConfig
@@ -210,6 +211,8 @@ def __init__(
210211
self.tp_group = get_tp_group().device_group
211212
self.tp_rank = get_tp_group().rank_in_group
212213

214+
self.enable_multi_stream = True
215+
213216
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
214217
attn_metadata = get_forward_context().attn_metadata
215218
# when profile runs, force experts to load balanced tokens
@@ -224,8 +227,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224227
enable_force_load_balance = False
225228
num_tokens, hidden_dim = hidden_states.shape
226229

227-
if self.n_shared_experts is not None:
230+
moe_multi_stream = self.enable_multi_stream and not is_prefill
231+
232+
if self.n_shared_experts is not None and not moe_multi_stream:
228233
shared_output = self.shared_experts(hidden_states)
234+
else:
235+
shared_hidden_states = hidden_states
229236

230237
if self.tp_size > 1:
231238
# pass
@@ -244,16 +251,40 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
244251
else:
245252
local_hidden_states = hidden_states
246253

254+
if self.n_shared_experts is not None and moe_multi_stream:
255+
with tng.scope.npu_stream_switch('1'):
256+
tng.scope.npu_wait_tensor(shared_hidden_states, shared_hidden_states)
257+
x, dynamic_scale = torch_npu.npu_dynamic_quant(shared_hidden_states)
258+
gate_up = torch_npu.npu_quant_matmul(
259+
x,
260+
self.shared_experts.gate_up_proj.weight,
261+
self.shared_experts.gate_up_proj.weight_scale,
262+
output_dtype=torch.int32,
263+
)
264+
247265
# router_logits: (num_tokens, n_experts)
248266
router_logits, _ = self.gate(local_hidden_states)
249267

250-
router_hidden_states = self.experts(
251-
hidden_states=local_hidden_states,
252-
router_logits=router_logits,
253-
is_prefill=is_prefill,
254-
top_k=CustomDeepseekV2MoE.top_k,
255-
enable_force_load_balance=enable_force_load_balance,
256-
) * self.routed_scaling_factor
268+
if moe_multi_stream:
269+
router_hidden_states, shared_output = self.experts(
270+
hidden_states=local_hidden_states,
271+
router_logits=router_logits,
272+
is_prefill=is_prefill,
273+
top_k=CustomDeepseekV2MoE.top_k,
274+
enable_force_load_balance=enable_force_load_balance,
275+
shared_experts=self.shared_experts,
276+
shared_gate_up=gate_up,
277+
shared_dynamic_scale=dynamic_scale
278+
)
279+
router_hidden_states = router_hidden_states * self.routed_scaling_factor
280+
else:
281+
router_hidden_states = self.experts(
282+
hidden_states=local_hidden_states,
283+
router_logits=router_logits,
284+
is_prefill=is_prefill,
285+
top_k=CustomDeepseekV2MoE.top_k,
286+
enable_force_load_balance=enable_force_load_balance,
287+
) * self.routed_scaling_factor
257288

258289
if self.tp_size > 1:
259290
dist.all_gather(list(chunk_hidden_states), router_hidden_states,

vllm_ascend/ops/fused_moe.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,8 @@ def forward(self,
694694
router_logits: torch.Tensor,
695695
is_prefill: bool,
696696
enable_force_load_balance: bool = False,
697-
top_k=None):
697+
top_k=None,
698+
**kwargs):
698699
assert self.quant_method is not None
699700

700701
if top_k:
@@ -722,7 +723,12 @@ def forward(self,
722723
e_score_correction_bias=self.e_score_correction_bias,
723724
is_prefill=is_prefill,
724725
enable_force_load_balance=enable_force_load_balance,
725-
dp_size=self.dp_size)
726+
dp_size=self.dp_size,
727+
**kwargs)
728+
729+
multi_stream = isinstance(final_hidden_states, tuple)
730+
if multi_stream:
731+
final_hidden_states, shared_output = final_hidden_states
726732

727733
if VLLM_ENABLE_MC2 and not is_prefill:
728734
...
@@ -731,4 +737,6 @@ def forward(self,
731737
final_hidden_states = tensor_model_parallel_all_reduce(
732738
final_hidden_states)
733739

740+
if multi_stream:
741+
return final_hidden_states, shared_output
734742
return final_hidden_states

vllm_ascend/quantization/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def apply(
330330
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
331331
global_num_experts, expert_map, topk_group, num_expert_group,
332332
custom_routing_function, scoring_func, e_score_correction_bias,
333-
is_prefill, enable_force_load_balance, dp_size)
333+
is_prefill, enable_force_load_balance, dp_size, **kwargs)
334334

335335
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
336336
if hasattr(self.quant_method, "process_weights_after_loading"):

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
23-
from vllm.distributed import GroupCoordinator
23+
import torchair as tng
24+
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
2425

2526
import vllm_ascend.envs as envs_ascend
2627
from vllm_ascend.distributed.parallel_state import get_ep_group
@@ -36,7 +37,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
3637
w2_scale: torch.Tensor,
3738
group_list: torch.Tensor,
3839
dynamic_scale: torch.Tensor = None,
39-
group_list_type: int = 1) -> torch.Tensor:
40+
group_list_type: int = 1,
41+
**kwargs) -> torch.Tensor:
4042
"""
4143
apply MLP: gate_up_proj -> swiglu -> down_proj
4244
@@ -68,6 +70,23 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
6870
else:
6971
pertoken_scale = dynamic_scale
7072

73+
shared_experts = kwargs.get('shared_experts', None)
74+
if shared_experts:
75+
shared_gate_up = kwargs.get('shared_gate_up', None)
76+
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
77+
with tng.scope.npu_stream_switch('1'):
78+
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
79+
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
80+
x=shared_gate_up,
81+
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
82+
activation_scale=shared_dynamic_scale,
83+
bias=None,
84+
quant_scale=None,
85+
quant_offset=None,
86+
group_index=None,
87+
activate_left=True,
88+
quant_mode=1)
89+
7190
# gmm1: gate_up_proj
7291
hidden_states = torch_npu.npu_grouped_matmul(
7392
x=[hidden_states],
@@ -96,6 +115,21 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
96115
group_type=0,
97116
group_list=group_list,
98117
output_dtype=w2_scale.dtype)[0]
118+
119+
if shared_experts:
120+
with tng.scope.npu_stream_switch('1'):
121+
tng.scope.npu_wait_tensor(shared_x, hidden_states)
122+
shared_output = torch_npu.npu_quant_matmul(
123+
shared_x,
124+
shared_experts.down_proj.weight,
125+
shared_experts.down_proj.weight_scale,
126+
pertoken_scale=shared_dynamic_scale,
127+
output_dtype=torch.bfloat16,
128+
)
129+
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
130+
shared_output = tensor_model_parallel_all_reduce(x)
131+
if shared_experts:
132+
return hidden_states, shared_output
99133
return hidden_states
100134

101135

@@ -110,11 +144,12 @@ def fused_experts_with_mc2(
110144
top_k: int,
111145
expert_map: torch.Tensor = None,
112146
moe_all_to_all_group_name: str = "",
147+
**kwargs
113148
) -> torch.Tensor:
114149
global_bs = 0
115150
moe_expert_num = len(expert_map)
116151
# hidden_states = hidden_states.bfloat16()
117-
kwargs = {
152+
kwargs1 = {
118153
"x": hidden_states,
119154
"expert_ids": topk_ids,
120155
"expert_shard_type": 0,
@@ -145,9 +180,9 @@ def fused_experts_with_mc2(
145180
"tp_world_size": tp_size,
146181
"tp_rank_id": tp_rank,
147182
}
148-
kwargs.update(stage1_kwargs)
183+
kwargs1.update(stage1_kwargs)
149184

150-
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
185+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs1)
151186
# comm_stream.wait_stream(torch.npu.current_stream())
152187
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
153188
0:5]
@@ -165,10 +200,15 @@ def fused_experts_with_mc2(
165200
w2,
166201
w2_scale,
167202
expert_token_nums,
168-
dynamic_scale=dynamic_scale)
203+
dynamic_scale=dynamic_scale,
204+
**kwargs)
205+
206+
multi_stream = isinstance(down_out_list, tuple)
207+
if multi_stream:
208+
down_out_list, shared_output = down_out_list
169209

170210
# moeCombine
171-
kwargs = {
211+
kwargs2 = {
172212
"expand_x": down_out_list,
173213
"expert_ids": topk_ids,
174214
"expand_idx": expand_idx,
@@ -192,10 +232,12 @@ def fused_experts_with_mc2(
192232
"tp_world_size": tp_size,
193233
"tp_rank_id": tp_rank,
194234
}
195-
kwargs.update(stage3_kwargs)
235+
kwargs2.update(stage3_kwargs)
196236

197-
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
237+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs2)
198238

239+
if multi_stream:
240+
return hidden_states, shared_output
199241
return hidden_states
200242

201243

@@ -634,7 +676,8 @@ def apply(
634676
topk_ids=topk_ids,
635677
top_k=top_k,
636678
expert_map=expert_map,
637-
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
679+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
680+
**kwargs)
638681
elif dp_size == 1:
639682
return fused_experts(hidden_states=x,
640683
w1=layer.w13_weight,

0 commit comments

Comments
 (0)