Skip to content

Commit 941269a

Browse files
authored
adjusting the communication method in graph mode (#1194)
### What this PR does / why we need it? Communication performance optimization: replace allreduce with reduce_scatter+all_gather in MLA layer's TP group,to remove stridedsliced and all_gather in MOE layer. when tp > 1, It is enabled during the decode phase of the graph mode when enable_multistream_moe、MLA, use_v1, and MC2 are used. According to the end-to-end RL inference test results, this PR can bring 3% gain in the decode stage. **Before Improvement** Profiling kernel_details ![image](https://github.com/user-attachments/assets/1bb5dfa1-809b-410a-90c9-c5fd23cff003) Evaluation ![image](https://github.com/user-attachments/assets/0b8ea0c7-88e7-410f-9ef4-f0cfe910cdc7) ![image](https://github.com/user-attachments/assets/94fde910-c125-4c2e-8de4-88fc3fafc057) **After Improvement** Profiling kernel_details ![image](https://github.com/user-attachments/assets/55fac0e0-11f2-4654-8fd4-287949e0b29e) Evaluation ![image](https://github.com/user-attachments/assets/e923f74b-29c4-4171-9382-40a00cf05df0) ![image](https://github.com/user-attachments/assets/5dba7967-07ea-4926-a8be-804bfd34e3e4) ### Does this PR introduce _any_ user-facing change? Users need to configure enable_multistream_moe=True ### How was this patch tested? Add e2e test cases to cover code logic Signed-off-by: sharonyunyun <zhangying134@huawei.com>
1 parent 205cb85 commit 941269a

File tree

6 files changed

+195
-37
lines changed

6 files changed

+195
-37
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ jobs:
358358
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
359359
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error.
360360
# To avoid oom, we need to run the test in a single process.
361+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
361362
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
362363
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
363364
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,32 @@ def test_models_distributed_QwQ():
4747
vllm_model.generate_greedy(example_prompts, max_tokens)
4848

4949

50+
def test_models_distributed_DeepSeek_multistream_moe():
51+
example_prompts = [
52+
"Hello, my name is",
53+
]
54+
dtype = "half"
55+
max_tokens = 5
56+
with VllmRunner(
57+
"vllm-ascend/DeepSeek-V3-Pruning",
58+
dtype=dtype,
59+
tensor_parallel_size=2,
60+
distributed_executor_backend="mp",
61+
additional_config={
62+
"torchair_graph_config": {
63+
"enabled": True,
64+
"enable_multistream_moe": True,
65+
},
66+
"ascend_scheduler_config": {
67+
"enabled": True,
68+
},
69+
"refresh": True,
70+
},
71+
enforce_eager=False,
72+
) as vllm_model:
73+
vllm_model.generate_greedy(example_prompts, max_tokens)
74+
75+
5076
def test_models_distributed_DeepSeek():
5177
example_prompts = [
5278
"Hello, my name is",

vllm_ascend/attention/mla_v1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
1111
from vllm.config import get_current_vllm_config
12+
from vllm.distributed import get_tensor_model_parallel_world_size
1213
from vllm.model_executor.layers.linear import (LinearBase,
1314
UnquantizedLinearMethod)
1415
from vllm.utils import cdiv, round_down
@@ -557,6 +558,7 @@ def __init__(
557558
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
558559
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
559560
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
561+
self.tp_size = get_tensor_model_parallel_world_size()
560562

561563
ascend_config = get_ascend_config()
562564
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -586,7 +588,7 @@ def _v_up_proj_and_o_proj(self, x):
586588
x = torch.bmm(x, self.W_UV)
587589
# Convert from (N, B, V) to (B, N * V)
588590
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
589-
return self.o_proj(x)[0]
591+
return self.o_proj(x, is_prefill=False)[0]
590592

591593
# Return `ql_nope`, `q_pe`
592594
def _q_proj_and_k_up_proj(self, x):
@@ -847,12 +849,12 @@ def _forward_prefill(
847849

848850
current_ms_metadata = get_multistream_comm_context()
849851
if current_ms_metadata is None:
850-
return self.o_proj(attn_output)[0]
852+
return self.o_proj(attn_output, is_prefill=True)[0]
851853
else:
852854
current_ms_metadata.before_comm_event.record()
853855
with torch.npu.stream(current_ms_metadata.comm_stream):
854856
current_ms_metadata.before_comm_event.wait()
855-
return self.o_proj(attn_output)[0]
857+
return self.o_proj(attn_output, is_prefill=True)[0]
856858

857859
def exec_kv(
858860
self,

vllm_ascend/models/deepseek_dbo.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@
4242
from vllm.forward_context import get_forward_context
4343
from vllm.model_executor.layers.layernorm import RMSNorm
4444
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
45-
ReplicatedLinear,
46-
RowParallelLinear)
45+
ReplicatedLinear)
4746
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4847
from vllm.model_executor.layers.quantization import QuantizationConfig
4948
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -64,7 +63,8 @@
6463

6564
import vllm_ascend.envs as envs_ascend
6665
from vllm_ascend.ascend_config import get_ascend_config
67-
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
66+
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLP,
67+
CustomDeepseekV2RowParallelLinear)
6868
from vllm_ascend.multistream.base import MSEventKey
6969
from vllm_ascend.multistream.context import (
7070
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -325,11 +325,12 @@ def __init__(
325325
bias=False,
326326
quant_config=quant_config,
327327
prefix=f"{prefix}.kv_b_proj")
328-
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
329-
self.hidden_size,
330-
bias=False,
331-
quant_config=quant_config,
332-
prefix=f"{prefix}.o_proj")
328+
self.o_proj = CustomDeepseekV2RowParallelLinear(
329+
self.num_heads * self.v_head_dim,
330+
self.hidden_size,
331+
bias=False,
332+
quant_config=quant_config,
333+
prefix=f"{prefix}.o_proj")
333334

334335
if rope_scaling:
335336
rope_scaling["rope_type"] = 'deepseek_yarn'

0 commit comments

Comments
 (0)