Skip to content

Commit ee40d3d

Browse files
authored
use npu_moe_gating_top_k_softmax (#1355)
### What this PR does / why we need it? The optimization solution for non-deepseek select_experts is to replace gating_topk_softmax with softmax+topk+to, which is optimized from 37us to 14us on bf16/fp16 of qwen3-235b - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@1a4f35e --------- Signed-off-by: ttanzhiqiang <389825161@qq.com>
1 parent 9d16c99 commit ee40d3d

File tree

4 files changed

+107
-14
lines changed

4 files changed

+107
-14
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
import torch
3+
import torch_npu
4+
5+
6+
@pytest.mark.parametrize(
7+
'B',
8+
[1, 16, 64, 128, 32768],
9+
)
10+
@pytest.mark.parametrize(
11+
'D',
12+
[8, 16, 32, 64, 128],
13+
)
14+
@pytest.mark.parametrize(
15+
'top_k',
16+
[1, 2, 4, 8],
17+
)
18+
@pytest.mark.parametrize(
19+
"dtype, atol, rtol",
20+
[
21+
(torch.float16, 1e-3, 1e-3),
22+
(torch.bfloat16, 1e-3, 1e-3),
23+
],
24+
)
25+
def test_quant_fpx_linear(B: int, D: int, top_k: int, dtype, atol, rtol):
26+
x = torch.rand((B, D), dtype=dtype).to("npu")
27+
# finished = torch.randint(1, size=(B,), dtype=torch.bool).to("npu")
28+
finished = None
29+
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x,
30+
finished,
31+
k=top_k)
32+
33+
topk_weights = x.softmax(dim=-1)
34+
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
35+
topk_ids = topk_ids.to(torch.int32)
36+
torch.allclose(y, topk_weights, atol=atol, rtol=rtol)
37+
torch.allclose(expert_idx, topk_ids, atol=atol, rtol=rtol)

vllm_ascend/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@
117117
# value to False to disable the optimized model.
118118
"USE_OPTIMIZED_MODEL":
119119
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
120+
# SELECT_GATING_TOPK_SOTFMAX_EXPERTS is the equivalent of select_experts in non-quantized scenarios.
121+
# In theory, it should have better performance than select_experts.
122+
# Subsequent versions will remove the SELECT_GATING_TOPK_SOTFMAX_EXPERTS tag and use it as the default mode.
123+
"SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
124+
lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
120125
# The tolerance of the kv cache size, if the difference between the
121126
# actual kv cache size and the cached kv cache size is less than this value,
122127
# then the cached kv cache size will be used.

vllm_ascend/ops/common_fused_moe.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@
2222
from vllm.model_executor.layers.fused_moe.layer import \
2323
UnquantizedFusedMoEMethod
2424

25+
import vllm_ascend.envs as envs_ascend
2526
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
26-
select_experts)
27+
select_experts,
28+
select_gating_top_k_softmax_experts)
2729
from vllm_ascend.utils import is_310p
2830

31+
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
2932
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
3033

3134

@@ -54,19 +57,27 @@ def forward_oot(
5457
apply_router_weight_on_input: bool = False,
5558
activation: str = "silu",
5659
) -> torch.Tensor:
57-
topk_weights, topk_ids = select_experts(
58-
global_num_experts=global_num_experts,
59-
hidden_states=x,
60-
router_logits=router_logits,
61-
top_k=top_k,
62-
use_grouped_topk=use_grouped_topk,
63-
renormalize=renormalize,
64-
topk_group=topk_group,
65-
num_expert_group=num_expert_group,
66-
custom_routing_function=custom_routing_function,
67-
scoring_func=scoring_func,
68-
e_score_correction_bias=e_score_correction_bias,
69-
)
60+
61+
if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
62+
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
63+
hidden_states=x,
64+
router_logits=router_logits,
65+
top_k=top_k,
66+
renormalize=renormalize)
67+
else:
68+
topk_weights, topk_ids = select_experts(
69+
global_num_experts=global_num_experts,
70+
hidden_states=x,
71+
router_logits=router_logits,
72+
top_k=top_k,
73+
use_grouped_topk=use_grouped_topk,
74+
renormalize=renormalize,
75+
topk_group=topk_group,
76+
num_expert_group=num_expert_group,
77+
custom_routing_function=custom_routing_function,
78+
scoring_func=scoring_func,
79+
e_score_correction_bias=e_score_correction_bias,
80+
)
7081

7182
if topk_ids.shape[1] < top_k or is_310p():
7283
assert global_num_experts is not None

vllm_ascend/ops/fused_moe.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
npu_stream_switch, npu_wait_tensor)
5050

5151
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
52+
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
5253

5354

5455
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
@@ -821,6 +822,39 @@ def fused_experts(
821822
return final_hidden_states
822823

823824

825+
def select_gating_top_k_softmax_experts(
826+
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
827+
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
828+
"""
829+
Select top-k experts based on router logits.
830+
only supports float16、bfloat16、float32
831+
832+
Args:
833+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
834+
router_logits: Router logits of shape (num_tokens, num_experts).
835+
top_k: Number of experts to select.
836+
renormalize: Whether to renormalize the routing weights.
837+
838+
Returns:
839+
topk_weights: Routing weights of shape (num_tokens, top_k).
840+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
841+
842+
Raises:
843+
ValueError: If an unsupported scoring function is provided.
844+
"""
845+
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
846+
router_logits, None, k=top_k)
847+
848+
# # Required by npu_moe_init_routing
849+
# topk_weights = topk_weights.to(hidden_states.dtype)
850+
# topk_ids = topk_ids.to(torch.int32)
851+
852+
if renormalize:
853+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
854+
855+
return topk_weights, topk_ids
856+
857+
824858
def native_grouped_topk(
825859
topk_weights: torch.Tensor,
826860
num_expert_group: Optional[int],
@@ -1013,6 +1047,12 @@ def apply(
10131047
# y2_flag=False, # old api; 第三个输出是否输出
10141048
routed_scaling_factor=1,
10151049
eps=float(1e-20))
1050+
elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
1051+
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
1052+
hidden_states=x,
1053+
router_logits=router_logits,
1054+
top_k=top_k,
1055+
renormalize=renormalize)
10161056
else:
10171057
topk_weights, topk_ids = select_experts(
10181058
hidden_states=x,

0 commit comments

Comments
 (0)