Skip to content

Commit 65909b2

Browse files
authored
[Perf][MoE] Improve shared experts multi-stream for w8a8 dynamic. (#1561)
This PR designs the shared expert multi-stream parallelism of w8a8-dynamic-quantized MoE stage in more detail to achieve better performance. Current multi-stream parallel for shared experts are shown in following pic: ![image](https://github.com/user-attachments/assets/89760804-94e7-4231-ae70-fe5148b9b2c2) Performance change: Before: ![image](https://github.com/user-attachments/assets/364ac7d3-e43f-44e8-86ef-c1d48c729d5f) After: ![image](https://github.com/user-attachments/assets/d34af4fa-e72c-46cd-97b7-a4a607dbc1ea) --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent e89c59d commit 65909b2

File tree

3 files changed

+133
-22
lines changed

3 files changed

+133
-22
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,10 @@ def forward(
310310
is_prefill = False
311311

312312
# router_logits: (num_tokens, n_experts)
313-
router_logits, _ = self.gate(hidden_states)
313+
if self.enable_multistream_moe:
314+
router_logits = None
315+
else:
316+
router_logits, _ = self.gate(hidden_states)
314317

315318
experts_hidden_states = self.experts(
316319
hidden_states=hidden_states,
@@ -319,6 +322,7 @@ def forward(
319322
top_k=CustomDeepseekV2MoE.top_k,
320323
enable_force_load_balance=enable_force_load_balance,
321324
shared_experts=self.shared_experts,
325+
gate=self.gate if self.enable_multistream_moe else None,
322326
)
323327

324328
hidden_states = (

vllm_ascend/ops/fused_moe.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,8 @@ def forward(self,
11251125
is_prefill: bool,
11261126
enable_force_load_balance: bool = False,
11271127
top_k: Optional[int] = None,
1128-
shared_experts: Optional[Any] = None):
1128+
shared_experts: Optional[Any] = None,
1129+
gate: Optional[Any] = None):
11291130
assert self.quant_method is not None
11301131

11311132
if top_k:
@@ -1136,6 +1137,20 @@ def forward(self,
11361137
num_tokens, hidden_size = hidden_states.shape
11371138

11381139
fused_moe_state = get_forward_context().fused_moe_state
1140+
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
1141+
quantized_x_for_share, dynamic_scale_for_share = None, None
1142+
from vllm_ascend.quantization.w8a8_dynamic import \
1143+
AscendW8A8DynamicFusedMoEMethod
1144+
if self.enable_multistream_moe:
1145+
assert gate is not None
1146+
router_logits, _ = gate(hidden_states)
1147+
if isinstance(self.quant_method.quant_method,
1148+
AscendW8A8DynamicFusedMoEMethod
1149+
) and fused_moe_state == FusedMoEState.MC2:
1150+
with npu_stream_switch("moe_secondary", 0):
1151+
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1152+
hidden_states)
1153+
11391154
if shared_experts:
11401155
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
11411156
shared_hidden_states = shared_experts(hidden_states)
@@ -1192,6 +1207,8 @@ def forward(self,
11921207
global_redundant_expert_num=self.global_redundant_expert_num,
11931208
shared_experts=shared_experts if self.torchair_graph_enabled
11941209
and self.enable_multistream_moe and not is_prefill else None,
1210+
quantized_x_for_share=quantized_x_for_share,
1211+
dynamic_scale_for_share=dynamic_scale_for_share,
11951212
)
11961213

11971214
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 110 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
import math
19-
from typing import Any, Callable, Dict, Optional, Tuple, Union
19+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2020

2121
import torch
2222
import torch.distributed as dist
@@ -32,6 +32,80 @@
3232
npu_stream_switch, npu_wait_tensor)
3333

3434

35+
def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor],
36+
w1: torch.Tensor,
37+
w1_scale: torch.Tensor,
38+
w2: torch.Tensor,
39+
w2_scale: torch.Tensor,
40+
group_list: torch.Tensor,
41+
dynamic_scale: torch.Tensor = None,
42+
group_list_type: int = 1) -> torch.Tensor:
43+
"""
44+
apply MLP: gate_up_proj -> swiglu -> down_proj
45+
Args:
46+
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
47+
w1: expert weights1 with shape
48+
(num_experts, hidden_size, intermediate_size * 2)
49+
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
50+
w2: expert weights2 with shape
51+
(num_experts, intermediate_size, hidden_size)
52+
w2_scale: weights2 scale with shape (num_experts, hidden_size)
53+
group_list: number of tokens for each expert, follow cumsum mode, and
54+
with shape (num_experts).
55+
transpose_weight:
56+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
57+
(num_experts, hidden_size, intermediate_size * 2)
58+
w2: (num_experts, hidden_size, intermediate_size) ->
59+
(num_experts, intermediate_size, hidden_size)
60+
Returns:
61+
hidden_states: output hidden states after MLP.
62+
"""
63+
64+
assert len(hidden_states_wrapper) == 1
65+
hidden_states = hidden_states_wrapper.pop()
66+
if dynamic_scale is None:
67+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
68+
hidden_states)
69+
else:
70+
pertoken_scale = dynamic_scale
71+
72+
# gmm1: gate_up_proj
73+
hidden_states = torch_npu.npu_grouped_matmul(
74+
x=[hidden_states],
75+
weight=[w1],
76+
split_item=3,
77+
group_list_type=group_list_type,
78+
group_type=0,
79+
group_list=group_list,
80+
output_dtype=torch.int32)[0]
81+
82+
# act_fn: swiglu
83+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
84+
x=hidden_states,
85+
weight_scale=w1_scale,
86+
activation_scale=pertoken_scale,
87+
bias=None,
88+
quant_scale=None,
89+
quant_offset=None,
90+
group_index=group_list,
91+
activate_left=True,
92+
quant_mode=1,
93+
)
94+
95+
# gmm2: down_proj
96+
hidden_states = torch_npu.npu_grouped_matmul(
97+
x=[hidden_states],
98+
weight=[w2],
99+
scale=[w2_scale],
100+
per_token_scale=[swiglu_out_scale],
101+
split_item=2,
102+
group_list_type=group_list_type,
103+
group_type=0,
104+
group_list=group_list,
105+
output_dtype=w2_scale.dtype)[0]
106+
return hidden_states
107+
108+
35109
def apply_mlp(hidden_states: torch.Tensor,
36110
w1: torch.Tensor,
37111
w1_scale: torch.Tensor,
@@ -138,7 +212,9 @@ def fused_experts_with_mc2(
138212
shared_experts: Optional[Any] = None,
139213
is_torchair: bool = False,
140214
w1_scale_bias: torch.Tensor = None,
141-
w2_scale_bias: torch.Tensor = None
215+
w2_scale_bias: torch.Tensor = None,
216+
quantized_x_for_share: Optional[Any] = None,
217+
dynamic_scale_for_share: Optional[Any] = None,
142218
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
143219
if log2phy:
144220
topk_ids = log2phy[topk_ids]
@@ -193,21 +269,19 @@ def fused_experts_with_mc2(
193269

194270
if shared_experts is not None:
195271
with npu_stream_switch("moe_secondary", 0):
196-
npu_wait_tensor(hidden_states, topk_weights)
197-
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
198-
npu_wait_tensor(shared_gate_up[0], expand_x)
199-
shared_act = shared_experts.act_fn(shared_gate_up)
272+
npu_wait_tensor(quantized_x_for_share, expand_x)
273+
shared_act_out = shared_experts.act_fn(
274+
(quantized_x_for_share, dynamic_scale_for_share))
275+
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
200276

201277
# `expand_x` will be disposed in the `apply_mlp` function
202-
down_out_list = apply_mlp(expand_x,
203-
w1,
204-
w1_scale,
205-
w2,
206-
w2_scale,
207-
expert_token_nums,
208-
dynamic_scale=dynamic_scale,
209-
w1_scale_bias=w1_scale_bias,
210-
w2_scale_bias=w2_scale_bias)
278+
down_out_list = apply_mlp_decode([expand_x],
279+
w1,
280+
w1_scale,
281+
w2,
282+
w2_scale,
283+
expert_token_nums,
284+
dynamic_scale=dynamic_scale)
211285

212286
# moeCombine
213287
kwargs_mc2 = {
@@ -244,8 +318,9 @@ def fused_experts_with_mc2(
244318
return hidden_states
245319
else:
246320
with npu_stream_switch("moe_secondary", 0):
247-
npu_wait_tensor(shared_act[0], down_out_list)
248-
shared_output, _ = shared_experts.down_proj(shared_act)
321+
npu_wait_tensor(shared_act, down_out_list)
322+
shared_output, _ = shared_experts.down_proj(
323+
(shared_act, swiglu_out_scale))
249324
return hidden_states, shared_output
250325

251326

@@ -661,6 +736,8 @@ def apply(
661736
log2phy: torch.Tensor = None,
662737
global_redundant_expert_num: int = 0,
663738
shared_experts: Optional[Any] = None,
739+
quantized_x_for_share: Optional[Any] = None,
740+
dynamic_scale_for_share: Optional[Any] = None,
664741
**kwargs,
665742
) -> torch.Tensor:
666743
assert router_logits.shape[
@@ -695,6 +772,16 @@ def apply(
695772
e_score_correction_bias=e_score_correction_bias,
696773
)
697774

775+
fused_moe_state = get_forward_context().fused_moe_state
776+
shared_gate_up, shared_dequant_scale = None, None
777+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
778+
with npu_stream_switch("moe_secondary", 0):
779+
npu_wait_tensor(quantized_x_for_share, router_logits)
780+
share_up_out, _ = shared_experts.gate_up_proj(
781+
(quantized_x_for_share, dynamic_scale_for_share))
782+
shared_gate_up, shared_dequant_scale = share_up_out[
783+
0], share_up_out[1]
784+
698785
# this is a naive implementation for experts load balance so as
699786
# to avoid accumulating too much tokens on a single rank.
700787
# currently it is only activated when doing profile runs.
@@ -703,13 +790,12 @@ def apply(
703790

704791
topk_weights = topk_weights.to(x.dtype)
705792

706-
fused_moe_state = get_forward_context().fused_moe_state
707793
if fused_moe_state == FusedMoEState.MC2:
708794
return fused_experts_with_mc2(
709795
hidden_states=x,
710796
w1=layer.w13_weight,
711797
w2=layer.w2_weight,
712-
w1_scale=layer.w13_weight_scale,
798+
w1_scale=layer.w13_weight_scale_fp32,
713799
w2_scale=layer.w2_weight_scale,
714800
topk_weights=topk_weights,
715801
topk_ids=topk_ids,
@@ -719,7 +805,9 @@ def apply(
719805
log2phy=log2phy,
720806
global_redundant_expert_num=global_redundant_expert_num,
721807
shared_experts=shared_experts,
722-
is_torchair=self.torchair_graph_enabled)
808+
is_torchair=self.torchair_graph_enabled,
809+
quantized_x_for_share=shared_gate_up,
810+
dynamic_scale_for_share=shared_dequant_scale)
723811
elif fused_moe_state == FusedMoEState.AllGather:
724812
return fused_experts(hidden_states=x,
725813
w1=layer.w13_weight,
@@ -764,6 +852,8 @@ def process_weights_after_loading(self, layer):
764852
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
765853
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
766854
layer.w13_weight_scale.data.shape[0], -1)
855+
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
856+
torch.float32)
767857
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
768858
layer.w13_weight_offset.data.shape[0], -1)
769859
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(

0 commit comments

Comments
 (0)