Skip to content

Commit 34df77e

Browse files
committed
support cv parallel for float model
Signed-off-by: David9857 <985700846@qq.com>
1 parent 8f2e33e commit 34df77e

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
262262
if self.n_shared_experts is not None and cv_parallel:
263263
with tng.scope.npu_stream_switch('cv'):
264264
tng.scope.npu_wait_tensor(shared_hidden_states, router_logits)
265-
x, dynamic_scale = torch_npu.npu_dynamic_quant(
266-
shared_hidden_states)
267-
gate_up = torch_npu.npu_quant_matmul(
268-
x,
269-
self.shared_experts.gate_up_proj.weight,
270-
self.shared_experts.gate_up_proj.weight_scale,
271-
output_dtype=torch.int32,
272-
)
265+
dynamic_scale = None
266+
if self.shared_experts.is_dynamic_quant:
267+
x, dynamic_scale = torch_npu.npu_dynamic_quant(
268+
shared_hidden_states)
269+
gate_up = torch_npu.npu_quant_matmul(
270+
x,
271+
self.shared_experts.gate_up_proj.weight,
272+
self.shared_experts.gate_up_proj.weight_scale,
273+
output_dtype=torch.int32,
274+
)
275+
else:
276+
gate_up, _ = self.gate_up_proj(shared_hidden_states)
273277

274278
if cv_parallel:
275279
router_hidden_states, shared_output = self.experts(

vllm_ascend/ops/fused_moe.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
23+
import torchair as tng # type: ignore
2324
from vllm.config import get_current_vllm_config
2425
from vllm.distributed import (GroupCoordinator,
2526
get_tensor_model_parallel_world_size,
@@ -47,10 +48,11 @@ def fused_experts_with_mc2(
4748
top_k: int,
4849
expert_map: torch.Tensor = None,
4950
moe_all_to_all_group_name: Optional[str] = None,
51+
**kwargs
5052
) -> torch.Tensor:
5153
global_bs = 0
5254
moe_expert_num = len(expert_map)
53-
kwargs = {
55+
kwargs_mc2 = {
5456
"x": hidden_states,
5557
"expert_ids": topk_ids,
5658
"expert_shard_type": 0,
@@ -81,13 +83,20 @@ def fused_experts_with_mc2(
8183
"tp_world_size": tp_size,
8284
"tp_rank_id": tp_rank,
8385
}
84-
kwargs.update(stage1_kwargs)
86+
kwargs_mc2.update(stage1_kwargs)
8587

86-
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
88+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
8789
# comm_stream.wait_stream(torch.npu.current_stream())
8890
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
8991
0:5]
9092

93+
shared_experts = kwargs.get('shared_experts', None)
94+
if shared_experts:
95+
shared_gate_up = kwargs.get('shared_gate_up', None)
96+
with tng.scope.npu_stream_switch('cv'):
97+
tng.scope.npu_wait_tensor(shared_gate_up, expand_x)
98+
shared_x = shared_experts.act_fn(shared_gate_up)
99+
91100
w1 = w1.transpose(1, 2)
92101
expert_token_nums = torch.cumsum(expert_token_nums,
93102
dim=0,
@@ -116,10 +125,15 @@ def fused_experts_with_mc2(
116125
group_list=group_list,
117126
)
118127

128+
if shared_experts:
129+
with tng.scope.npu_stream_switch('cv'):
130+
tng.scope.npu_wait_tensor(shared_x, down_out_list)
131+
shared_output = shared_experts.down_proj(shared_x)
132+
119133
down_out_list = torch.cat(down_out_list, dim=0)
120134

121135
# moeCombine
122-
kwargs = {
136+
kwargs_mc2 = {
123137
"expand_x": down_out_list,
124138
"expert_ids": topk_ids,
125139
"expand_idx": expand_idx,
@@ -141,10 +155,12 @@ def fused_experts_with_mc2(
141155
"tp_world_size": tp_size,
142156
"tp_rank_id": tp_rank,
143157
}
144-
kwargs.update(stage3_kwargs)
158+
kwargs_mc2.update(stage3_kwargs)
145159

146-
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
160+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
147161

162+
if shared_experts:
163+
return hidden_states, shared_output
148164
return hidden_states
149165

150166

@@ -664,7 +680,8 @@ def apply(
664680
topk_ids=topk_ids,
665681
top_k=top_k,
666682
expert_map=expert_map,
667-
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
683+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
684+
**kwargs)
668685
elif get_ep_group().world_size == 1:
669686
return fused_experts(hidden_states=x,
670687
w1=layer.w13_weight,

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
7474
if shared_experts:
7575
shared_gate_up = kwargs.get('shared_gate_up', None)
7676
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
77-
with tng.scope.npu_stream_switch('1'):
77+
with tng.scope.npu_stream_switch('cv'):
7878
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
7979
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
8080
x=shared_gate_up,
@@ -117,7 +117,7 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
117117
output_dtype=w2_scale.dtype)[0]
118118

119119
if shared_experts:
120-
with tng.scope.npu_stream_switch('1'):
120+
with tng.scope.npu_stream_switch('cv'):
121121
tng.scope.npu_wait_tensor(shared_x, hidden_states)
122122
shared_output = torch_npu.npu_quant_matmul(
123123
shared_x,

0 commit comments

Comments
 (0)