Skip to content

Commit e9ada68

Browse files
weijinqian0weijinqian_v1
andauthored
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization The DeepSeek V3/R1 model has 256 routing experts. During parallel inference, if the load of an EP rank is high, the overall communication and computing time is slowed down, which becomes a weakness of parallel inference because the load is unevenly distributed. However, the data volume in the prefill phase is large, and the inter-card communication time consumption/calculation time consumption and the data volume are closely related to each other. Therefore, less non-linear precision loss can be used to obtain a near-linear performance improvement. During parallel inference, global synchronization occurs during communication. As a result, the card with low load completes the calculation first and waits for the card with the highest load to complete the calculation. Therefore, if the load is unbalanced, the card with high load slows down the overall time consumption. Significant performance gains can be achieved by discarding a small number of tokens, which is unacceptable in some precision-sensitive scenarios. However, similar to quantification, it is a solution that uses an acceptable precision loss in some scenarios for performance. In addition, a trade-off between performance and precision can be achieved by configuring a proportion of discarded tokens. Perform the test on A3. The batch size is 8 (B), the prompt length is 3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2, AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15% performance gain. Plus, the next version, we'll have an alltoallv moe. --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent a2552e1 commit e9ada68

File tree

2 files changed

+273
-12
lines changed

2 files changed

+273
-12
lines changed

vllm_ascend/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@
112112
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
113113
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
114114
),
115+
# MOE_ALL2ALL_BUFFER:
116+
# 0: default, normal init.
117+
# 1: enable moe_all2all_buffer.
118+
"MOE_ALL2ALL_BUFFER":
119+
lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
115120
# VLLM_ASCEND_ACL_OP_INIT_MODE:
116121
# 0: default, normal init.
117122
# 1: delay init until launch aclops.

vllm_ascend/ops/fused_moe.py

Lines changed: 268 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# This file is a part of the vllm-ascend project.
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

18-
from typing import Callable, Optional
18+
from typing import Callable, List, Optional
1919

2020
import torch
2121
import torch.distributed as dist
@@ -37,6 +37,71 @@
3737

3838
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3939
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
40+
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
41+
42+
43+
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
44+
max_row_per_ep_rank: int, num_tokens: int,
45+
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
46+
original_total_elements = num_tokens * top_k
47+
device = topk_ids.device
48+
original_dtype = topk_ids.dtype
49+
50+
if original_total_elements == 0:
51+
output_len = ep_size * max_row_per_ep_rank
52+
topk_ids_pad = torch.full((output_len, ),
53+
expert_num,
54+
dtype=original_dtype,
55+
device=device)
56+
unpad_indices = torch.full((original_total_elements, ),
57+
-1,
58+
dtype=torch.long,
59+
device=device)
60+
return topk_ids_pad, unpad_indices
61+
62+
experts_per_ep_rank_val = expert_num // ep_size
63+
if experts_per_ep_rank_val == 0:
64+
raise ValueError(
65+
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
66+
"Ensure expert_num >= ep_size.")
67+
68+
assigned_ep_rank = (topk_ids.float() /
69+
experts_per_ep_rank_val).to(original_dtype)
70+
indices_arange = torch.arange(topk_ids.shape[0], device=device)
71+
72+
is_new_segment = torch.cat((torch.tensor([True], device=device),
73+
assigned_ep_rank[1:] != assigned_ep_rank[:-1]))
74+
temp_start_markers = torch.full_like(indices_arange,
75+
-1,
76+
dtype=indices_arange.dtype)
77+
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
78+
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
79+
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
80+
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
81+
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
82+
indices_in_rec_cond_list_for_all = cumsum_kept - 1
83+
unpad_indices = torch.where(
84+
is_kept_mask, indices_in_rec_cond_list_for_all,
85+
torch.tensor(-1, device=device, dtype=torch.long))
86+
output_len = ep_size * max_row_per_ep_rank
87+
topk_ids_pad = torch.full((output_len, ),
88+
expert_num,
89+
dtype=original_dtype,
90+
device=device)
91+
if topk_ids.shape[0] > 0:
92+
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
93+
temp_pad_buffer = torch.full((output_len + 1, ),
94+
expert_num,
95+
dtype=original_dtype,
96+
device=device)
97+
output_len_tensor = torch.tensor(output_len,
98+
dtype=torch.long,
99+
device=device)
100+
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
101+
output_len_tensor)
102+
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
103+
topk_ids_pad = temp_pad_buffer[:output_len]
104+
return topk_ids_pad, unpad_indices
40105

41106

42107
def fused_experts_with_mc2(hidden_states: torch.Tensor,
@@ -146,8 +211,62 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
146211
return hidden_states
147212

148213

149-
# currently expert parallelism implemented with all2all
150-
# is under-optimized.
214+
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
215+
w1: torch.Tensor,
216+
w2: torch.Tensor,
217+
group_list: torch.Tensor,
218+
group_list_type: int = 1) -> torch.Tensor:
219+
"""
220+
apply MLP: gate_up_proj -> swiglu -> down_proj
221+
222+
Args:
223+
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
224+
w1: expert weights1 with shape
225+
(num_experts, hidden_size, intermediate_size * 2)
226+
w2: expert weights2 with shape
227+
(num_experts, intermediate_size, hidden_size)
228+
group_list: number of tokens for each expert, follow cumsum mode, and
229+
with shape (num_experts).
230+
transpose_weight:
231+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
232+
(num_experts, hidden_size, intermediate_size * 2)
233+
w2: (num_experts, hidden_size, intermediate_size) ->
234+
(num_experts, intermediate_size, hidden_size)
235+
236+
Returns:
237+
hidden_states: output hidden states after MLP.
238+
"""
239+
240+
assert len(hidden_states_wrapper) == 1
241+
hidden_states = hidden_states_wrapper.pop()
242+
243+
w1 = w1.transpose(1, 2)
244+
hidden_states = torch_npu.npu_grouped_matmul(
245+
x=[hidden_states],
246+
weight=[w1],
247+
split_item=2,
248+
group_list_type=group_list_type,
249+
group_type=0,
250+
group_list=group_list,
251+
)
252+
253+
hidden_states = torch.cat(hidden_states, dim=0)
254+
hidden_states = torch_npu.npu_swiglu(hidden_states)
255+
256+
w2 = w2.transpose(1, 2)
257+
hidden_states = torch_npu.npu_grouped_matmul(
258+
x=[hidden_states],
259+
weight=[w2],
260+
split_item=2,
261+
group_list_type=group_list_type,
262+
group_type=0,
263+
group_list=group_list,
264+
)
265+
266+
hidden_states = torch.cat(hidden_states, dim=0)
267+
return hidden_states
268+
269+
151270
def fused_experts_with_all2all(
152271
hidden_states: torch.Tensor,
153272
w1: torch.Tensor,
@@ -283,6 +402,133 @@ def fused_experts_with_all2all(
283402
return final_hidden_states
284403

285404

405+
# currently expert parallelism implemented with all2all
406+
# is under-optimized.
407+
def fused_experts_with_all2all_buffer(
408+
hidden_states: torch.Tensor,
409+
w1: torch.Tensor,
410+
w2: torch.Tensor,
411+
topk_weights: torch.Tensor,
412+
topk_ids: torch.Tensor,
413+
top_k: int,
414+
max_model_len: int,
415+
global_batch_size: int,
416+
expert_map: torch.Tensor = None,
417+
ep_group: GroupCoordinator = None,
418+
):
419+
original_shape = hidden_states.shape
420+
if len(original_shape) == 3:
421+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
422+
423+
num_tokens, _ = hidden_states.shape
424+
device = hidden_states.device
425+
426+
global_num_experts = len(expert_map)
427+
local_num_experts = global_num_experts // ep_group.world_size
428+
row_idx_len = num_tokens * top_k
429+
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
430+
device=device).view(top_k,
431+
-1).permute(1, 0).contiguous())
432+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
433+
hidden_states,
434+
row_idx=row_idx,
435+
expert_idx=topk_ids,
436+
active_num=num_tokens)
437+
438+
max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
439+
max_model_len // ep_group.world_size +
440+
1) * top_k * 2
441+
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
442+
expanded_expert_idx, global_num_experts, ep_group.world_size,
443+
max_row_per_ep_rank, num_tokens, top_k)
444+
hidden_states_pad_idx = torch.zeros(
445+
expert_idx_buffer_scatter.shape,
446+
dtype=expert_idx_buffer_scatter.dtype,
447+
device=expert_idx_buffer_scatter.device)
448+
non_pad_len = torch.sum(
449+
(expert_idx_buffer_scatter != global_num_experts).to(torch.int32))
450+
hidden_states_pad_idx[
451+
expert_idx_buffer_scatter != global_num_experts] = torch.arange(
452+
non_pad_len,
453+
dtype=expert_idx_buffer_scatter.dtype,
454+
device=hidden_states.device)
455+
456+
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
457+
expert_idx_buffer_gather = torch.empty_like(
458+
expert_idx_buffer_scatter,
459+
dtype=expert_idx_buffer_scatter.dtype,
460+
device=expert_idx_buffer_scatter.device)
461+
hidden_states_buffer_gather = torch.empty_like(
462+
hidden_states_buffer_scatter,
463+
dtype=hidden_states_buffer_scatter.dtype,
464+
device=hidden_states_buffer_scatter.device)
465+
dist.all_to_all_single(expert_idx_buffer_gather,
466+
expert_idx_buffer_scatter,
467+
group=ep_group.device_group)
468+
dist.all_to_all_single(hidden_states_buffer_gather,
469+
hidden_states_buffer_scatter,
470+
group=ep_group.device_group)
471+
mask = expert_idx_buffer_gather != global_num_experts
472+
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
473+
global_num_experts // ep_group.world_size)
474+
hidden_states = hidden_states_buffer_gather[mask]
475+
idx_type = local_expert_idx.dtype
476+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
477+
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
478+
479+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
480+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
481+
hidden_states = hidden_states[sorted_idx]
482+
group_list_type = 0
483+
484+
hidden_states_wrapper = [hidden_states]
485+
del hidden_states
486+
487+
hidden_states = apply_mlp(hidden_states_wrapper,
488+
w1,
489+
w2,
490+
expert_tokens,
491+
group_list_type=group_list_type)
492+
493+
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
494+
hidden_states = hidden_states[resorted_idx]
495+
hidden_states_scatter = torch.zeros(
496+
(mask.shape[0], hidden_states.shape[1]),
497+
dtype=hidden_states.dtype,
498+
device=hidden_states.device)
499+
hidden_states_scatter[mask] = hidden_states
500+
hidden_states_gatter = torch.empty_like(
501+
hidden_states_scatter,
502+
dtype=hidden_states_scatter.dtype,
503+
device=hidden_states_scatter.device)
504+
dist.all_to_all_single(hidden_states_gatter,
505+
hidden_states_scatter,
506+
group=ep_group.device_group)
507+
hidden_states_gatter = hidden_states_gatter[
508+
expert_idx_buffer_scatter != global_num_experts]
509+
if hidden_states_gatter.shape[0] != row_idx_len:
510+
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
511+
dtype=hidden_states.dtype,
512+
device=hidden_states.device)
513+
hidden_states[unpad_indices != -1] = hidden_states_gatter
514+
else:
515+
# TODO: Reorder device memory 2 times here, replace the current
516+
hidden_states = hidden_states_gatter
517+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
518+
hidden_states,
519+
skip1=None,
520+
skip2=None,
521+
bias=None,
522+
scales=topk_weights,
523+
expanded_src_to_dst_row=expanded_row_idx,
524+
export_for_source_row=topk_ids,
525+
)
526+
527+
if len(original_shape) == 3:
528+
final_hidden_states = final_hidden_states.view(original_shape)
529+
return final_hidden_states
530+
531+
286532
def fused_experts(
287533
hidden_states: torch.Tensor,
288534
w1: torch.Tensor,
@@ -585,6 +831,7 @@ def __init__(self, moe: MoEConfig = None):
585831
self.ep_size = ep_group.world_size
586832
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
587833
self.local_batch_size = self.global_batch_size // self.ep_size
834+
self.max_model_len = vllm_config.model_config.max_model_len
588835

589836
ascend_config = get_ascend_config()
590837
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -613,21 +860,22 @@ def apply(
613860
self,
614861
layer: torch.nn.Module,
615862
x: torch.Tensor,
616-
use_grouped_topk: bool,
617-
top_k: int,
618863
router_logits: torch.Tensor,
864+
top_k: int,
619865
renormalize: bool,
620-
topk_group: Optional[int] = None,
621-
num_expert_group: Optional[int] = None,
866+
use_grouped_topk: bool = False,
622867
global_num_experts: int = -1,
623868
expert_map: Optional[torch.Tensor] = None,
869+
topk_group: Optional[int] = None,
870+
num_expert_group: Optional[int] = None,
624871
custom_routing_function: Optional[Callable] = None,
625872
scoring_func: str = "softmax",
626873
e_score_correction_bias: Optional[torch.Tensor] = None,
627874
is_prefill: bool = False,
628875
enable_force_load_balance: bool = False,
629876
**kwargs,
630-
):
877+
) -> torch.Tensor:
878+
631879
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
632880
if global_num_experts == 256:
633881
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
@@ -683,11 +931,19 @@ def apply(
683931
topk_ids=topk_ids,
684932
top_k=top_k,
685933
expert_map=expert_map)
934+
elif MOE_ALL2ALL_BUFFER:
935+
return fused_experts_with_all2all_buffer(
936+
hidden_states=x,
937+
w1=layer.w13_weight,
938+
w2=layer.w2_weight,
939+
topk_weights=topk_weights,
940+
topk_ids=topk_ids,
941+
top_k=top_k,
942+
max_model_len=self.max_model_len,
943+
global_batch_size=self.global_batch_size,
944+
expert_map=expert_map,
945+
ep_group=get_ep_group())
686946
else:
687-
# The current implementation of deepseek moe splits hidden_states
688-
# according to tp_size before they are feed into fused_moe module.
689-
# Therefore, all2all is needed no matter how dp/tp is set so as to
690-
# dispatch/combine tokens.
691947
return fused_experts_with_all2all(hidden_states=x,
692948
w1=layer.w13_weight,
693949
w2=layer.w2_weight,

0 commit comments

Comments
 (0)