Skip to content

Commit 1074413

Browse files
committed
support moe multistream in deepseek
Signed-off-by: David9857 <985700846@qq.com> use additional_config to enable cv parallel Signed-off-by: David9857 <985700846@qq.com> rename kwargs1 in fused_experts_with_mc2 Signed-off-by: David9857 <985700846@qq.com>
1 parent 6eddbd2 commit 1074413

File tree

4 files changed

+122
-33
lines changed

4 files changed

+122
-33
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch
3131
import torch.distributed as dist
3232
import torch_npu
33+
import torchair as tng # type: ignore
3334
import vllm.envs as envs
3435
from torch import nn
3536
from transformers import PretrainedConfig
@@ -177,6 +178,12 @@ def __init__(
177178
else:
178179
self.gate.e_score_correction_bias = None
179180

181+
self.enable_cv_parallel = False
182+
additional_config = get_current_vllm_config().additional_config
183+
if additional_config:
184+
self.enable_cv_parallel = additional_config.get(
185+
"enable_cv_parallel", False)
186+
180187
self.experts = AscendFusedMoE(
181188
num_experts=config.n_routed_experts,
182189
top_k=config.num_experts_per_tok,
@@ -224,8 +231,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224231
enable_force_load_balance = False
225232
num_tokens, hidden_dim = hidden_states.shape
226233

234+
cv_parallel = self.enable_cv_parallel and not is_prefill
235+
227236
if self.n_shared_experts is not None:
228-
shared_output = self.shared_experts(hidden_states)
237+
if not cv_parallel:
238+
shared_output = self.shared_experts(hidden_states)
239+
else:
240+
shared_hidden_states = hidden_states
229241

230242
if self.tp_size > 1:
231243
# pass
@@ -247,13 +259,37 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247259
# router_logits: (num_tokens, n_experts)
248260
router_logits, _ = self.gate(local_hidden_states)
249261

250-
router_hidden_states = self.experts(
251-
hidden_states=local_hidden_states,
252-
router_logits=router_logits,
253-
is_prefill=is_prefill,
254-
top_k=CustomDeepseekV2MoE.top_k,
255-
enable_force_load_balance=enable_force_load_balance,
256-
) * self.routed_scaling_factor
262+
if self.n_shared_experts is not None and cv_parallel:
263+
with tng.scope.npu_stream_switch('cv'):
264+
tng.scope.npu_wait_tensor(shared_hidden_states, router_logits)
265+
x, dynamic_scale = torch_npu.npu_dynamic_quant(
266+
shared_hidden_states)
267+
gate_up = torch_npu.npu_quant_matmul(
268+
x,
269+
self.shared_experts.gate_up_proj.weight,
270+
self.shared_experts.gate_up_proj.weight_scale,
271+
output_dtype=torch.int32,
272+
)
273+
274+
if cv_parallel:
275+
router_hidden_states, shared_output = self.experts(
276+
hidden_states=local_hidden_states,
277+
router_logits=router_logits,
278+
is_prefill=is_prefill,
279+
top_k=CustomDeepseekV2MoE.top_k,
280+
enable_force_load_balance=enable_force_load_balance,
281+
shared_experts=self.shared_experts,
282+
shared_gate_up=gate_up,
283+
shared_dynamic_scale=dynamic_scale)
284+
router_hidden_states = router_hidden_states * self.routed_scaling_factor
285+
else:
286+
router_hidden_states = self.experts(
287+
hidden_states=local_hidden_states,
288+
router_logits=router_logits,
289+
is_prefill=is_prefill,
290+
top_k=CustomDeepseekV2MoE.top_k,
291+
enable_force_load_balance=enable_force_load_balance,
292+
) * self.routed_scaling_factor
257293

258294
if self.tp_size > 1:
259295
dist.all_gather(list(chunk_hidden_states), router_hidden_states,

vllm_ascend/ops/fused_moe.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,12 +810,18 @@ def __init__(
810810

811811
self.quant_method.create_weights(layer=self, **moe_quant_params)
812812

813+
self.enable_cv_parallel = False
814+
if vllm_config.additional_config:
815+
self.enable_cv_parallel = vllm_config.additional_config.get(
816+
"enable_cv_parallel", False)
817+
813818
def forward(self,
814819
hidden_states: torch.Tensor,
815820
router_logits: torch.Tensor,
816821
is_prefill: bool,
817822
enable_force_load_balance: bool = False,
818-
top_k=None):
823+
top_k=None,
824+
**kwargs):
819825
assert self.quant_method is not None
820826

821827
if top_k:
@@ -842,7 +848,11 @@ def forward(self,
842848
scoring_func=self.scoring_func,
843849
e_score_correction_bias=self.e_score_correction_bias,
844850
is_prefill=is_prefill,
845-
enable_force_load_balance=enable_force_load_balance)
851+
enable_force_load_balance=enable_force_load_balance,
852+
**kwargs)
853+
854+
if self.enable_cv_parallel and not is_prefill:
855+
final_hidden_states, shared_output = final_hidden_states
846856

847857
if VLLM_ENABLE_MC2 and not is_prefill:
848858
...
@@ -851,4 +861,6 @@ def forward(self,
851861
final_hidden_states = tensor_model_parallel_all_reduce(
852862
final_hidden_states)
853863

864+
if self.enable_cv_parallel and not is_prefill:
865+
return final_hidden_states, shared_output
854866
return final_hidden_states

vllm_ascend/quantization/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def apply(
329329
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
330330
global_num_experts, expert_map, topk_group, num_expert_group,
331331
custom_routing_function, scoring_func, e_score_correction_bias,
332-
is_prefill, enable_force_load_balance)
332+
is_prefill, enable_force_load_balance, **kwargs)
333333

334334
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
335335
if hasattr(self.quant_method, "process_weights_after_loading"):

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
23-
from vllm.distributed import GroupCoordinator
23+
import torchair as tng # type: ignore
24+
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
2425

2526
import vllm_ascend.envs as envs_ascend
2627
from vllm_ascend.distributed.parallel_state import get_ep_group
@@ -36,7 +37,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
3637
w2_scale: torch.Tensor,
3738
group_list: torch.Tensor,
3839
dynamic_scale: torch.Tensor = None,
39-
group_list_type: int = 1) -> torch.Tensor:
40+
group_list_type: int = 1,
41+
**kwargs) -> torch.Tensor:
4042
"""
4143
apply MLP: gate_up_proj -> swiglu -> down_proj
4244
@@ -68,6 +70,23 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
6870
else:
6971
pertoken_scale = dynamic_scale
7072

73+
shared_experts = kwargs.get('shared_experts', None)
74+
if shared_experts:
75+
shared_gate_up = kwargs.get('shared_gate_up', None)
76+
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
77+
with tng.scope.npu_stream_switch('1'):
78+
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
79+
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
80+
x=shared_gate_up,
81+
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
82+
activation_scale=shared_dynamic_scale,
83+
bias=None,
84+
quant_scale=None,
85+
quant_offset=None,
86+
group_index=None,
87+
activate_left=True,
88+
quant_mode=1)
89+
7190
# gmm1: gate_up_proj
7291
hidden_states = torch_npu.npu_grouped_matmul(
7392
x=[hidden_states],
@@ -96,25 +115,39 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
96115
group_type=0,
97116
group_list=group_list,
98117
output_dtype=w2_scale.dtype)[0]
118+
119+
if shared_experts:
120+
with tng.scope.npu_stream_switch('1'):
121+
tng.scope.npu_wait_tensor(shared_x, hidden_states)
122+
shared_output = torch_npu.npu_quant_matmul(
123+
shared_x,
124+
shared_experts.down_proj.weight,
125+
shared_experts.down_proj.weight_scale,
126+
pertoken_scale=shared_dynamic_scale,
127+
output_dtype=torch.bfloat16,
128+
)
129+
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
130+
shared_output = tensor_model_parallel_all_reduce(shared_output)
131+
if shared_experts:
132+
return hidden_states, shared_output
99133
return hidden_states
100134

101135

102-
def fused_experts_with_mc2(
103-
hidden_states: torch.Tensor,
104-
w1: torch.Tensor,
105-
w2: torch.Tensor,
106-
w1_scale: torch.Tensor,
107-
w2_scale: torch.Tensor,
108-
topk_weights: torch.Tensor,
109-
topk_ids: torch.Tensor,
110-
top_k: int,
111-
expert_map: torch.Tensor = None,
112-
moe_all_to_all_group_name: str = "",
113-
) -> torch.Tensor:
136+
def fused_experts_with_mc2(hidden_states: torch.Tensor,
137+
w1: torch.Tensor,
138+
w2: torch.Tensor,
139+
w1_scale: torch.Tensor,
140+
w2_scale: torch.Tensor,
141+
topk_weights: torch.Tensor,
142+
topk_ids: torch.Tensor,
143+
top_k: int,
144+
expert_map: torch.Tensor = None,
145+
moe_all_to_all_group_name: str = "",
146+
**kwargs) -> torch.Tensor:
114147
global_bs = 0
115148
moe_expert_num = len(expert_map)
116149
# hidden_states = hidden_states.bfloat16()
117-
kwargs = {
150+
kwargs_mc2 = {
118151
"x": hidden_states,
119152
"expert_ids": topk_ids,
120153
"expert_shard_type": 0,
@@ -145,9 +178,9 @@ def fused_experts_with_mc2(
145178
"tp_world_size": tp_size,
146179
"tp_rank_id": tp_rank,
147180
}
148-
kwargs.update(stage1_kwargs)
181+
kwargs_mc2.update(stage1_kwargs)
149182

150-
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
183+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
151184
# comm_stream.wait_stream(torch.npu.current_stream())
152185
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
153186
0:5]
@@ -165,10 +198,15 @@ def fused_experts_with_mc2(
165198
w2,
166199
w2_scale,
167200
expert_token_nums,
168-
dynamic_scale=dynamic_scale)
201+
dynamic_scale=dynamic_scale,
202+
**kwargs)
203+
204+
multi_stream = isinstance(down_out_list, tuple)
205+
if multi_stream:
206+
down_out_list, shared_output = down_out_list
169207

170208
# moeCombine
171-
kwargs = {
209+
kwargs_mc2 = {
172210
"expand_x": down_out_list,
173211
"expert_ids": topk_ids,
174212
"expand_idx": expand_idx,
@@ -192,10 +230,12 @@ def fused_experts_with_mc2(
192230
"tp_world_size": tp_size,
193231
"tp_rank_id": tp_rank,
194232
}
195-
kwargs.update(stage3_kwargs)
233+
kwargs_mc2.update(stage3_kwargs)
196234

197-
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
235+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
198236

237+
if multi_stream:
238+
return hidden_states, shared_output
199239
return hidden_states
200240

201241

@@ -633,7 +673,8 @@ def apply(
633673
topk_ids=topk_ids,
634674
top_k=top_k,
635675
expert_map=expert_map,
636-
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
676+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
677+
**kwargs)
637678
elif self.ep_group.world_size == 1:
638679
return fused_experts(hidden_states=x,
639680
w1=layer.w13_weight,

0 commit comments

Comments
 (0)