Skip to content

Commit 5177bef

Browse files
authored
support fused_moe_allgather_ep (#1335)
### What this PR does / why we need it? support fused_moe_allgather_ep ### How was this patch tested? It was tested by UT. Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
1 parent 917c6b7 commit 5177bef

File tree

5 files changed

+218
-14
lines changed

5 files changed

+218
-14
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""
18+
Execute the inference of fused_moe_allgather_ep and fused_moe_alltoall_ep.
19+
20+
Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'.
21+
"""
22+
23+
import os
24+
from unittest.mock import patch
25+
26+
from modelscope import snapshot_download # type: ignore
27+
from vllm import SamplingParams
28+
29+
from tests.conftest import VllmRunner
30+
31+
32+
@patch.dict(
33+
os.environ, {
34+
"VLLM_USE_V1": "1",
35+
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
36+
"TASK_QUEUE_ENABLE": "1",
37+
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"
38+
})
39+
def test_generate_with_allgather():
40+
example_prompts = ["Hello, my name is"]
41+
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
42+
43+
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
44+
tensor_parallel_size=16,
45+
enforce_eager=True,
46+
max_model_len=1024,
47+
dtype="auto",
48+
enable_expert_parallel=True,
49+
additional_config={
50+
"ascend_scheduler_config": {
51+
"enabled": True,
52+
"chunked_prefill_enabled": False,
53+
},
54+
"expert_tensor_parallel_size": 1
55+
}) as vllm_model:
56+
vllm_model.generate(example_prompts, sampling_params)
57+
58+
59+
@patch.dict(
60+
os.environ, {
61+
"VLLM_USE_V1": "1",
62+
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
63+
"TASK_QUEUE_ENABLE": "1"
64+
})
65+
def test_generate_with_alltoall():
66+
example_prompts = ["Hello, my name is"]
67+
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
68+
69+
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
70+
tensor_parallel_size=16,
71+
enforce_eager=True,
72+
max_model_len=1024,
73+
dtype="auto",
74+
enable_expert_parallel=True,
75+
additional_config={
76+
"ascend_scheduler_config": {
77+
"enabled": True,
78+
"chunked_prefill_enabled": False,
79+
},
80+
"expert_tensor_parallel_size": 1
81+
}) as vllm_model:
82+
vllm_model.generate(example_prompts, sampling_params)

vllm_ascend/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@
9999
# Whether to enable the trace recompiles from pytorch.
100100
"VLLM_ASCEND_TRACE_RECOMPILES":
101101
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
102+
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
103+
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
104+
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
105+
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
106+
),
102107
"VLLM_ASCEND_ENABLE_DBO":
103108
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
104109
# Whether to enable the model execute time observe profile. Disable it when

vllm_ascend/ops/fused_moe.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -988,8 +988,9 @@ def apply(
988988
**kwargs,
989989
) -> torch.Tensor:
990990

991+
is_deepseek_v3_r1 = global_num_experts == 256
991992
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
992-
if global_num_experts == 256:
993+
if is_deepseek_v3_r1:
993994
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
994995
router_logits,
995996
k=top_k, # topk当前写8
@@ -1025,7 +1026,7 @@ def apply(
10251026
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
10261027

10271028
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
1028-
is_prefill)
1029+
is_prefill, is_deepseek_v3_r1)
10291030
if fused_moe_state == FusedMoEState.MC2:
10301031
return fused_experts_with_mc2(
10311032
hidden_states=x,
@@ -1219,15 +1220,17 @@ def forward(self,
12191220
real_top_k = self.top_k
12201221

12211222
num_tokens, hidden_size = hidden_states.shape
1223+
is_deepseek_v3_r1 = self.global_num_experts == 256
12221224

12231225
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
1224-
is_prefill)
1226+
is_prefill, is_deepseek_v3_r1)
12251227
if shared_experts:
12261228
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
12271229
shared_hidden_states = shared_experts(hidden_states)
12281230

12291231
tp_size = get_tensor_model_parallel_world_size()
1230-
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
1232+
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
1233+
and fused_moe_state != FusedMoEState.AllGatherEP):
12311234
if num_tokens < tp_size:
12321235
hidden_states = nn.functional.pad(
12331236
hidden_states, (0, 0, 0, tp_size - num_tokens))
@@ -1285,7 +1288,8 @@ def forward(self,
12851288
if isinstance(e_hidden_states, tuple):
12861289
e_hidden_states, shared_hidden_states = e_hidden_states
12871290

1288-
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
1291+
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
1292+
and fused_moe_state != FusedMoEState.AllGatherEP):
12891293
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
12901294
self.tp_group)
12911295
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
@@ -1303,7 +1307,8 @@ def forward(self,
13031307
else:
13041308
final_hidden_states = e_hidden_states
13051309

1306-
if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
1310+
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
1311+
or fused_moe_state == FusedMoEState.AllGatherEP):
13071312
final_hidden_states = tensor_model_parallel_all_reduce(
13081313
final_hidden_states)
13091314

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
import torch_npu
2323
from vllm.distributed import GroupCoordinator
2424

25+
import vllm_ascend.envs as envs
2526
from vllm_ascend.ascend_config import get_ascend_config
2627
from vllm_ascend.distributed.parallel_state import get_ep_group
2728
from vllm_ascend.ops.fused_moe import select_experts
28-
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
29-
get_fused_moe_state, npu_stream_switch,
30-
npu_wait_tensor)
29+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
30+
dispose_tensor, get_fused_moe_state,
31+
npu_stream_switch, npu_wait_tensor)
3132

3233

3334
def apply_mlp(hidden_states: torch.Tensor,
@@ -346,6 +347,95 @@ def fused_experts_with_all2all(
346347
return final_hidden_states
347348

348349

350+
def fused_experts_with_allgather(hidden_states: torch.Tensor,
351+
w1: torch.Tensor,
352+
w1_scale: torch.Tensor,
353+
w2: torch.Tensor,
354+
w2_scale: torch.Tensor,
355+
topk_weights: torch.Tensor,
356+
topk_ids: torch.Tensor,
357+
top_k: int,
358+
expert_map: torch.Tensor = None):
359+
original_shape = hidden_states.shape
360+
if len(original_shape) == 3:
361+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
362+
num_tokens = hidden_states.shape[0]
363+
batch_size, hidden_size = hidden_states.shape
364+
topk_weights = topk_weights.to(hidden_states.dtype)
365+
366+
ep_group = get_ep_group().device_group
367+
ep_rank = torch.distributed.get_rank(group=ep_group)
368+
ep_size = torch.distributed.get_world_size(ep_group)
369+
370+
global_num_experts = len(expert_map)
371+
local_num_experts = global_num_experts // ep_size
372+
373+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
374+
375+
hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
376+
hidden_states,
377+
topk_ids,
378+
scale=pertoken_scale,
379+
offset=None,
380+
active_num=num_tokens * top_k,
381+
expert_num=global_num_experts,
382+
expert_tokens_num_type=1,
383+
expert_tokens_num_flag=True,
384+
active_expert_range=[
385+
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
386+
],
387+
quant_mode=-1,
388+
row_idx_type=1)
389+
group_list_type = 1
390+
391+
sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0,
392+
expanded_x_idx)
393+
row_index = expanded_x_idx // topk_ids.shape[-1]
394+
row_index = row_index.to(torch.int64)
395+
share_input = torch.zeros((batch_size, hidden_size),
396+
dtype=torch.bfloat16,
397+
device="npu")
398+
399+
hidden_states = torch_npu.npu_grouped_matmul(
400+
x=[hidden_states],
401+
weight=[w1],
402+
split_item=3,
403+
group_list_type=group_list_type,
404+
group_type=0,
405+
group_list=expert_tokens,
406+
output_dtype=torch.int32)[0]
407+
408+
# act_fn: swiglu
409+
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
410+
x=hidden_states,
411+
weight_scale=w1_scale.to(torch.float32),
412+
activation_scale=pertoken_scale,
413+
bias=None,
414+
quant_scale=None,
415+
quant_offset=None,
416+
group_index=expert_tokens,
417+
activate_left=True,
418+
quant_mode=1,
419+
)
420+
421+
final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing(
422+
hidden_states,
423+
w2,
424+
scale=w2_scale.to(torch.float32),
425+
bias=None,
426+
pertoken_scale=pertoken_scale.view(-1),
427+
group_list=expert_tokens,
428+
shared_input=share_input,
429+
logit=sorted_topk_weight.to(torch.float32),
430+
row_index=row_index,
431+
output_bs=batch_size).to(torch.bfloat16)
432+
433+
if len(original_shape) == 3:
434+
final_hidden_states = final_hidden_states.view(original_shape)
435+
436+
return final_hidden_states
437+
438+
349439
def fused_experts(hidden_states: torch.Tensor,
350440
w1: torch.Tensor,
351441
w1_scale: torch.Tensor,
@@ -623,8 +713,10 @@ def apply(
623713
assert router_logits.shape[
624714
1] == global_num_experts, "Number of global experts mismatch"
625715

716+
is_deepseek_v3_r1 = global_num_experts == 256
717+
626718
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
627-
if global_num_experts == 256:
719+
if is_deepseek_v3_r1:
628720
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
629721
router_logits,
630722
k=top_k, # topk当前写8
@@ -661,8 +753,19 @@ def apply(
661753
topk_weights = topk_weights.to(x.dtype)
662754

663755
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
664-
is_prefill)
665-
if fused_moe_state == FusedMoEState.MC2:
756+
is_prefill, is_deepseek_v3_r1)
757+
if fused_moe_state == FusedMoEState.AllGatherEP:
758+
return fused_experts_with_allgather(
759+
hidden_states=x,
760+
w1=layer.w13_weight,
761+
w1_scale=layer.w13_weight_scale,
762+
w2=layer.w2_weight,
763+
w2_scale=layer.w2_weight_scale,
764+
topk_weights=topk_weights,
765+
topk_ids=topk_ids,
766+
top_k=top_k,
767+
expert_map=expert_map)
768+
elif fused_moe_state == FusedMoEState.MC2:
666769
return fused_experts_with_mc2(
667770
hidden_states=x,
668771
w1=layer.w13_weight,
@@ -713,6 +816,8 @@ def process_weights_after_loading(self, layer):
713816
1, 2).contiguous()
714817
layer.w2_weight.data = layer.w2_weight.data.transpose(
715818
1, 2).contiguous()
819+
if envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
820+
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
716821
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
717822
layer.w13_weight_scale.data.shape[0], -1)
718823
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(

vllm_ascend/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,11 +394,18 @@ class FusedMoEState(Enum):
394394
AllGather = 0
395395
All2All = 1
396396
MC2 = 2
397+
AllGatherEP = 3
397398

398399

399400
# TODO(zzzzwwjj): add soc_version to choose branch
400-
def get_fused_moe_state(ep_size: int, with_prefill: bool):
401-
if ep_size == 1:
401+
def get_fused_moe_state(ep_size: int, with_prefill: bool,
402+
is_deepseek_v3_r1: bool):
403+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
404+
# only supports deepseek v3/r1
405+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
406+
and is_deepseek_v3_r1):
407+
return FusedMoEState.AllGatherEP
408+
elif ep_size == 1:
402409
return FusedMoEState.AllGather
403410
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
404411
elif ep_size < 16 or with_prefill:

0 commit comments

Comments
 (0)