Skip to content

Commit c59d69d

Browse files
Angazennangazenn
andauthored
[PERF]support MERRouter (#1421)
### What this PR does / why we need it? This PR introduces an expert rearrange algorithm for PanguProMoE model. Different from the original grouped topk, it filters out the top experts that are allocated more tokens. Therefore, we can load less experts when calculating gmm. We have test this algorithm for PanguProMoE-72B on 300I Duo platform and 800I A2 platform. On 300I Duo platform, we find that `num_voted_experts` set to 5 achieves both good performance and accuracy. While on 800I A2, we still set it to 8 to use original pangu grouped topk. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
1 parent 8fa1881 commit c59d69d

File tree

3 files changed

+85
-38
lines changed

3 files changed

+85
-38
lines changed

vllm_ascend/models/pangu_moe.py

Lines changed: 74 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from vllm.sequence import IntermediateTensors
5858

5959
from vllm_ascend.distributed.parallel_state import get_ep_group
60+
from vllm_ascend.utils import is_310p
6061

6162
logger = init_logger(__name__)
6263

@@ -339,41 +340,81 @@ def forward(self, x):
339340
return x
340341

341342

342-
class PanguProMoESparseMoeBlock(nn.Module):
343+
def topk_wrapper(num_voted_experts):
343344

344-
@staticmethod
345345
def pangu_group8_topk(
346346
hidden_states: torch.Tensor,
347347
gating_output: torch.Tensor,
348348
topk: int,
349-
renormalize: bool,
349+
renormalize: bool = False,
350350
num_expert_group: int = 0,
351351
topk_group: int = 0,
352352
global_num_experts: int = 0,
353353
):
354+
scores = F.softmax(gating_output, dim=1)
355+
num_tokens = scores.shape[0]
356+
router_scale = _ROUTER_SCALE.squeeze( # type: ignore
357+
)
358+
354359
ep_size = get_ep_group().world_size
355360
local_num_experts = global_num_experts // ep_size
356361
local_num_group = topk // ep_size
357-
router_scale = _ROUTER_SCALE.squeeze() # type: ignore
362+
experts_per_group = global_num_experts // topk
363+
local_group_start = get_ep_group().rank_in_group * local_num_experts
364+
local_group_end = (get_ep_group().rank_in_group +
365+
1) * local_num_experts
358366
scores = F.softmax(gating_output, dim=1)
359-
scores = scores[...,
360-
get_ep_group().rank_in_group *
361-
local_num_experts:(get_ep_group().rank_in_group + 1) *
362-
local_num_experts]
363-
364-
router_weights = router_scale[get_ep_group().rank_in_group *
365-
local_num_experts:
366-
(get_ep_group().rank_in_group + 1) *
367-
local_num_experts]
368-
topk_weights, topk_ids = torch.max(scores.view(scores.shape[0],
369-
local_num_group, -1),
370-
dim=-1)
371-
bias = torch.arange(0,
372-
local_num_experts,
373-
topk,
374-
device=scores.device,
375-
dtype=torch.int32).unsqueeze(0)
376-
topk_ids = topk_ids.to(torch.int32) + bias
367+
scores = scores[..., local_group_start:local_group_end]
368+
369+
router_weights = router_scale[local_group_start:local_group_end]
370+
371+
if num_voted_experts == 8:
372+
# use original topk
373+
topk_weights, topk_ids = torch.max(scores.view(
374+
scores.shape[0], local_num_group, -1),
375+
dim=-1)
376+
bias = torch.arange(0,
377+
local_num_experts,
378+
experts_per_group,
379+
device=scores.device,
380+
dtype=torch.int32).unsqueeze(0)
381+
topk_ids = topk_ids.to(torch.int32) + bias
382+
383+
else:
384+
group_expert_indices = torch.arange(experts_per_group,
385+
dtype=torch.int32,
386+
device=scores.device).view(
387+
1, 1, -1)
388+
group_expert_offset = (torch.arange(
389+
local_num_group, dtype=torch.int32, device=scores.device) *
390+
experts_per_group).unsqueeze(0)
391+
expert_index_range = torch.arange(experts_per_group,
392+
dtype=torch.int32,
393+
device=scores.device)
394+
395+
scores_grouped = scores.view(num_tokens, local_num_group,
396+
experts_per_group)
397+
best_expert_idx = torch.argmax(scores_grouped,
398+
dim=2) # (num_tokens, num_groups)
399+
vote_mask = (best_expert_idx.unsqueeze(-1).to(
400+
torch.int32) == group_expert_indices)
401+
402+
expert_vote_freq = vote_mask.sum(dim=0)
403+
404+
sorted_indices = torch.argsort(expert_vote_freq,
405+
dim=1,
406+
descending=True).to(torch.int32)
407+
topk_experts = sorted_indices[:, :num_voted_experts]
408+
keep_mask = ((
409+
topk_experts.unsqueeze(-1) == expert_index_range).any(
410+
dim=1)).unsqueeze(0)
411+
412+
masked_scores = torch.where(keep_mask, scores_grouped, 0)
413+
414+
topk_weights, best_pos_in_group = masked_scores.max(dim=2)
415+
best_pos_in_group = best_pos_in_group.to(torch.int32)
416+
topk_ids = (best_pos_in_group + group_expert_offset).to(
417+
torch.int32)
377418

378419
flatten_topk_ids = topk_ids.view(-1)
379420
router_weights = router_weights.index_select(0, flatten_topk_ids).view(
@@ -382,6 +423,11 @@ def pangu_group8_topk(
382423

383424
return topk_weights, topk_ids
384425

426+
return pangu_group8_topk
427+
428+
429+
class PanguProMoESparseMoeBlock(nn.Module):
430+
385431
def __init__(
386432
self,
387433
config: PretrainedConfig,
@@ -397,23 +443,23 @@ def __init__(
397443
f"Tensor parallel size {self.tp_size} is greater than "
398444
f"the number of experts {config.num_experts}.")
399445

400-
self.local_num_group = config.num_experts_per_tok // get_ep_group(
401-
).world_size
402446
self.num_experts_per_tok = config.num_experts_per_tok
403-
self.local_num_experts = config.num_experts // get_ep_group(
404-
).world_size
405447
self.router_scale = torch.nn.Parameter(
406448
torch.ones((1, self.num_experts)))
407449

450+
# on 300I Duo platform, we find that num_voted_experts set to 5 achieves
451+
# good performance without sacrifice too much accuracy. for other platform,
452+
# this is set to 8 to use original pangu grouped topk.
453+
num_voted_experts = 5 if is_310p() else 8
454+
408455
self.experts = FusedMoE(
409456
num_experts=config.num_experts,
410457
top_k=config.num_experts_per_tok,
411458
hidden_size=config.hidden_size,
412459
intermediate_size=config.moe_intermediate_size,
413460
reduce_results=False,
414461
quant_config=quant_config,
415-
custom_routing_function=PanguProMoESparseMoeBlock.
416-
pangu_group8_topk,
462+
custom_routing_function=topk_wrapper(num_voted_experts),
417463
prefix=f"{prefix}.experts",
418464
)
419465

vllm_ascend/ops/common_fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from vllm.model_executor.layers.fused_moe.layer import \
2222
UnquantizedFusedMoEMethod
2323

24-
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_310p,
24+
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
2525
select_experts)
2626
from vllm_ascend.utils import is_310p
2727

@@ -58,9 +58,9 @@ def forward_oot(
5858
e_score_correction_bias=e_score_correction_bias,
5959
)
6060

61-
if is_310p():
61+
if topk_ids.shape[1] < top_k or is_310p():
6262
assert global_num_experts is not None
63-
return fused_experts_310p(
63+
return fused_experts_moge(
6464
hidden_states=x,
6565
w1=layer.w13_weight,
6666
w2=layer.w2_weight,

vllm_ascend/ops/fused_moe.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
4040
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4141
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
42-
get_fused_moe_state, npu_stream_switch,
42+
get_fused_moe_state, is_310p, npu_stream_switch,
4343
npu_wait_tensor)
4444

4545
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -548,8 +548,7 @@ def fused_experts_with_all2all_buffer(
548548
return final_hidden_states
549549

550550

551-
# Currently, fused_experts on 310p only supports PanguProMoE.
552-
def fused_experts_310p(
551+
def fused_experts_moge(
553552
hidden_states: torch.Tensor,
554553
w1: torch.Tensor,
555554
w2: torch.Tensor,
@@ -614,8 +613,11 @@ def fused_experts_310p(
614613
group_list=group_list,
615614
)[0]
616615

617-
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
618-
torch.float16)
616+
if is_310p():
617+
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
618+
torch.float16)
619+
else:
620+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
619621
gate_up_out *= topk_scales
620622

621623
w2 = w2.transpose(1, 2)
@@ -628,8 +630,7 @@ def fused_experts_310p(
628630
group_list=group_list,
629631
)[0]
630632

631-
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(
632-
torch.int32) + torch.Tensor([0]).to(torch.int32).npu()
633+
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
633634
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
634635
final_hidden_states = unsorted_hidden_states.reshape(
635636
bsz, top_k // ep_size, -1).sum(1)

0 commit comments

Comments
 (0)