Skip to content

Commit 60519c7

Browse files
authored
shared_experts+router_experts merge all_reduce(Improve TTOP 5ms) (#1395)
### What this PR does / why we need it? When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce In prefill and decode, as long as shared_experts+router_experts are all_reduce, there will be benefits. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? bash examples/run_dp_attention_etp16.sh bash examples/run_dp_attention_etp16_benmark.sh - vLLM version: v0.9.1 - vLLM main: vllm-project/vllm@977180c --------- Signed-off-by: ttanzhiqiang <389825161@qq.com>
1 parent 997f156 commit 60519c7

File tree

5 files changed

+32
-7
lines changed

5 files changed

+32
-7
lines changed

examples/run_dp_attention_etp16.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ export TASK_QUEUE_ENABLE=1
33
source /usr/local/Ascend/ascend-toolkit/set_env.sh
44
source /usr/local/Ascend/nnal/atb/set_env.sh
55
export ASCEND_LAUNCH_BLOCKING=0
6-
export VLLM_VERSION=0.9.0
6+
export VLLM_VERSION=0.9.1
77

88
nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
9+
--served-model-name auto \
910
--quantization ascend \
1011
--trust-remote-code \
1112
--distributed-executor-backend=mp \

examples/run_dp_attention_etp16_benmark.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ for concurrency in "${concurrency_array[@]}"; do
2121
python /mnt/deepseek/vllm/benchmarks/benchmark_serving.py \
2222
--backend vllm \
2323
--trust-remote-code \
24-
--model /mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
24+
--model auto \
25+
--tokenizer /mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
2526
--dataset-name random \
2627
--random-input-len 4096 \
2728
--random-output-len 1536 \

vllm_ascend/models/deepseek_v2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,6 @@ def __init__(
303303
self.tp_size = get_tensor_model_parallel_world_size()
304304
self.routed_scaling_factor = config.routed_scaling_factor
305305
self.n_shared_experts = config.n_shared_experts
306-
self.routed_scaling_factor = config.routed_scaling_factor
307306
if self.tp_size > config.n_routed_experts:
308307
raise ValueError(
309308
f"Tensor parallel size {self.tp_size} is greater than "
@@ -345,14 +344,16 @@ def __init__(
345344
e_score_correction_bias=self.gate.e_score_correction_bias)
346345

347346
if config.n_shared_experts is not None:
347+
self.all_reduce_merge = self.experts.all_reduce_merge
348+
reduce_results = not self.all_reduce_merge
348349
intermediate_size = (config.moe_intermediate_size *
349350
config.n_shared_experts)
350351
self.shared_experts = CustomDeepseekV2MLP(
351352
hidden_size=config.hidden_size,
352353
intermediate_size=intermediate_size,
353354
hidden_act=config.hidden_act,
354355
quant_config=quant_config,
355-
reduce_results=True,
356+
reduce_results=reduce_results,
356357
force_replicate=self.enable_multistream_moe,
357358
prefix=f"{prefix}.shared_experts",
358359
)
@@ -403,6 +404,9 @@ def forward(self,
403404
hidden_states = (
404405
experts_hidden_states[0] * self.routed_scaling_factor +
405406
experts_hidden_states[1])
407+
if self.all_reduce_merge:
408+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
409+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
406410

407411
return hidden_states
408412

vllm_ascend/ops/fused_moe.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
4545
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4646
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
47-
get_fused_moe_state, is_310p, npu_stream_switch,
48-
npu_wait_tensor)
47+
get_all_reduce_merge_state, get_fused_moe_state,
48+
is_310p, npu_stream_switch, npu_wait_tensor)
4949

5050
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
5151

@@ -1146,6 +1146,10 @@ def __init__(
11461146
self.log2phy = None
11471147
self.global_redundant_expert_num = 0
11481148

1149+
is_deepseek_v3_r1 = self.global_num_experts == 256
1150+
self.all_reduce_merge = get_all_reduce_merge_state(
1151+
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
1152+
11491153
ascend_config = get_ascend_config()
11501154
expert_map_path = ascend_config.expert_map_path
11511155
if expert_map_path and os.path.exists(expert_map_path):
@@ -1250,6 +1254,7 @@ def forward(self,
12501254
is_prefill, is_deepseek_v3_r1)
12511255
if shared_experts:
12521256
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
1257+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
12531258
shared_hidden_states = shared_experts(hidden_states)
12541259

12551260
tp_size = get_tensor_model_parallel_world_size()
@@ -1351,7 +1356,7 @@ def forward(self,
13511356
else:
13521357
final_hidden_states = e_hidden_states
13531358

1354-
if tp_size > 1 and fused_moe_state in [
1359+
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
13551360
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
13561361
FusedMoEState.NaiveMulticast
13571362
]:

vllm_ascend/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,20 @@ class FusedMoEState(Enum):
425425
NaiveMulticast = 4
426426

427427

428+
# TODO(ttanzhiqiang): all_reduce merge
429+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
430+
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
431+
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
432+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
433+
# only supports deepseek v3/r1
434+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
435+
and is_deepseek_v3_r1):
436+
return True
437+
elif ep_size == 1 and is_deepseek_v3_r1:
438+
return True
439+
return False
440+
441+
428442
# TODO(zzzzwwjj): add soc_version to choose branch
429443
def get_fused_moe_state(ep_size: int, with_prefill: bool,
430444
is_deepseek_v3_r1: bool):

0 commit comments

Comments
 (0)