Skip to content

Commit bd3de9f

Browse files
sharonyunyunsdmyzlp
authored andcommitted
adjusting the communication method in graph mode (#1194)
Signed-off-by: sharonyunyun <zhangying134@huawei.com> Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent c5be2fc commit bd3de9f

File tree

5 files changed

+185
-35
lines changed

5 files changed

+185
-35
lines changed

tests/multicard/test_torchair_graph_mode.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,19 @@ def test_e2e_deepseekv3_with_torchair_ms_mla(monkeypatch: pytest.MonkeyPatch):
110110
},
111111
}
112112
_deepseek_torchair_test_fixture(additional_config)
113+
114+
115+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
116+
reason="torchair graph is not supported on v0")
117+
def test_e2e_deepseekv3_with_torchair_ms_moe(monkeypatch: pytest.MonkeyPatch):
118+
with monkeypatch.context() as m:
119+
m.setenv("VLLM_USE_MODELSCOPE", "True")
120+
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
121+
122+
additional_config = {
123+
"torchair_graph_config": {
124+
"enabled": True,
125+
"enable_multistream_moe": True,
126+
},
127+
}
128+
_deepseek_torchair_test_fixture(additional_config)

vllm_ascend/attention/mla_v1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
1111
from vllm.config import get_current_vllm_config
12+
from vllm.distributed import get_tensor_model_parallel_world_size
1213
from vllm.model_executor.layers.linear import (LinearBase,
1314
UnquantizedLinearMethod)
1415
from vllm.utils import cdiv, round_down
@@ -554,6 +555,7 @@ def __init__(
554555
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
555556
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
556557
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
558+
self.tp_size = get_tensor_model_parallel_world_size()
557559

558560
ascend_config = get_ascend_config()
559561
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -572,7 +574,7 @@ def _v_up_proj_and_o_proj(self, x):
572574
x = torch.bmm(x, self.W_UV)
573575
# Convert from (N, B, V) to (B, N * V)
574576
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
575-
return self.o_proj(x)[0]
577+
return self.o_proj(x, is_prefill=False)[0]
576578

577579
# Return `ql_nope`, `q_pe`
578580
def _q_proj_and_k_up_proj(self, x):
@@ -834,12 +836,12 @@ def _forward_prefill(
834836

835837
current_ms_metadata = get_multistream_comm_context()
836838
if current_ms_metadata is None:
837-
return self.o_proj(attn_output)[0]
839+
return self.o_proj(attn_output, is_prefill=True)[0]
838840
else:
839841
current_ms_metadata.before_comm_event.record()
840842
with torch.npu.stream(current_ms_metadata.comm_stream):
841843
current_ms_metadata.before_comm_event.wait()
842-
return self.o_proj(attn_output)[0]
844+
return self.o_proj(attn_output, is_prefill=True)[0]
843845

844846
def exec_kv(
845847
self,

vllm_ascend/models/deepseek_dbo.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from vllm.model_executor.layers.layernorm import RMSNorm
4545
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
4646
ReplicatedLinear,
47-
RowParallelLinear,
4847
UnquantizedLinearMethod)
4948
from vllm.model_executor.layers.logits_processor import LogitsProcessor
5049
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -66,7 +65,8 @@
6665

6766
import vllm_ascend.envs as envs_ascend
6867
from vllm_ascend.ascend_config import get_ascend_config
69-
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
68+
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLP,
69+
CustomDeepseekV2RowParallelLinear)
7070
from vllm_ascend.multistream.base import MSEventKey
7171
from vllm_ascend.multistream.context import (
7272
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -331,11 +331,12 @@ def __init__(
331331
bias=False,
332332
quant_config=quant_config,
333333
prefix=f"{prefix}.kv_b_proj")
334-
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
335-
self.hidden_size,
336-
bias=False,
337-
quant_config=quant_config,
338-
prefix=f"{prefix}.o_proj")
334+
self.o_proj = CustomDeepseekV2RowParallelLinear(
335+
self.num_heads * self.v_head_dim,
336+
self.hidden_size,
337+
bias=False,
338+
quant_config=quant_config,
339+
prefix=f"{prefix}.o_proj")
339340

340341
if rope_scaling:
341342
rope_scaling["rope_type"] = 'deepseek_yarn'

0 commit comments

Comments
 (0)