Skip to content

Commit 53c2d58

Browse files
authored
Handle with_prefill_across_dp for multistream mla (#1322)
### What this PR does / why we need it? After #1094, decode might be executed with non-compiled mode, despite of `torchair_graph_config.enabled`, causing multistream mla to fail, which assumes torchair compiled mode for decode when `torchair_graph_config.enabled == True`. Augment that assumption to fix this. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tested both offline, and by graph mode mla e2e testcase. --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 2690697 commit 53c2d58

File tree

3 files changed

+82
-60
lines changed

3 files changed

+82
-60
lines changed

tests/e2e/multicard/test_torchair_graph_mode.py

Lines changed: 67 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Run `pytest tests/multicard/test_torchair_graph_mode.py`.
2121
"""
2222
import os
23+
from typing import Dict
2324

2425
import pytest
2526

@@ -28,53 +29,73 @@
2829
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
2930

3031

32+
def _deepseek_torchair_test_fixture(
33+
additional_config: Dict,
34+
*,
35+
tensor_parallel_size=4,
36+
):
37+
example_prompts = [
38+
"Hello, my name is",
39+
"The president of the United States is",
40+
"The capital of France is",
41+
"The future of AI is",
42+
]
43+
44+
# torchair is only work without chunked-prefill now
45+
kwargs = {
46+
"ascend_scheduler_config": {
47+
"enabled": True,
48+
},
49+
"refresh": True,
50+
}
51+
additional_config.update(**kwargs)
52+
53+
with VllmRunner(
54+
"vllm-ascend/DeepSeek-V3-Pruning",
55+
dtype="half",
56+
tensor_parallel_size=tensor_parallel_size,
57+
distributed_executor_backend="mp",
58+
enforce_eager=False,
59+
additional_config=additional_config,
60+
) as vllm_model:
61+
# use greedy sampler to make sure the generated results are fix
62+
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
63+
64+
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
65+
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
66+
# inaccurate. This will only change if accuracy improves with the
67+
# official weights of DeepSeek-V3.
68+
golden_results = [
69+
'Hello, my name is feasibility伸 spazio debtor添',
70+
'The president of the United States is begg"""\n杭州风和 bestimm',
71+
'The capital of France is frequentlyশามalinkAllowed',
72+
'The future of AI is deleting俯احت怎么样了حراف',
73+
]
74+
75+
assert len(golden_results) == len(vllm_output)
76+
for i in range(len(vllm_output)):
77+
assert golden_results[i] == vllm_output[i][1]
78+
print(f"Generated text: {vllm_output[i][1]!r}")
79+
80+
3181
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
3282
reason="torchair graph is not supported on v0")
33-
def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch):
34-
with monkeypatch.context() as m:
35-
m.setenv("VLLM_USE_MODELSCOPE", "True")
36-
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
83+
def test_e2e_deepseekv3_with_torchair():
84+
additional_config = {
85+
"torchair_graph_config": {
86+
"enabled": True,
87+
},
88+
}
89+
_deepseek_torchair_test_fixture(additional_config)
3790

38-
example_prompts = [
39-
"Hello, my name is",
40-
"The president of the United States is",
41-
"The capital of France is",
42-
"The future of AI is",
43-
]
44-
dtype = "half"
45-
max_tokens = 5
46-
# torchair is only work without chunked-prefill now
47-
with VllmRunner(
48-
"vllm-ascend/DeepSeek-V3-Pruning",
49-
dtype=dtype,
50-
tensor_parallel_size=4,
51-
distributed_executor_backend="mp",
52-
additional_config={
53-
"torchair_graph_config": {
54-
"enabled": True,
55-
},
56-
"ascend_scheduler_config": {
57-
"enabled": True,
58-
},
59-
"refresh": True,
60-
},
61-
enforce_eager=False,
62-
) as vllm_model:
63-
# use greedy sampler to make sure the generated results are fix
64-
vllm_output = vllm_model.generate_greedy(example_prompts,
65-
max_tokens)
66-
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
67-
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
68-
# inaccurate. This will only change if accuracy improves with the
69-
# official weights of DeepSeek-V3.
70-
golden_results = [
71-
'Hello, my name is feasibility伸 spazio debtor添',
72-
'The president of the United States is begg"""\n杭州风和 bestimm',
73-
'The capital of France is frequentlyশามalinkAllowed',
74-
'The future of AI is deleting俯احت怎么样了حراف',
75-
]
7691

77-
assert len(golden_results) == len(vllm_output)
78-
for i in range(len(vllm_output)):
79-
assert golden_results[i] == vllm_output[i][1]
80-
print(f"Generated text: {vllm_output[i][1]!r}")
92+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
93+
reason="torchair graph is not supported on v0")
94+
def test_e2e_deepseekv3_with_torchair_ms_mla():
95+
additional_config = {
96+
"torchair_graph_config": {
97+
"enabled": True,
98+
"enable_multistream_mla": True,
99+
},
100+
}
101+
_deepseek_torchair_test_fixture(additional_config)

vllm_ascend/attention/mla_v1.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,6 @@ def __init__(
563563
ascend_config = get_ascend_config()
564564
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
565565
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
566-
self.enable_multistream_mla = \
567-
ascend_config.torchair_graph_config.enable_multistream_mla
568566

569567
# Adapt torch air graph mode with spec decoding.
570568
speculative_config = get_current_vllm_config().speculative_config
@@ -863,6 +861,7 @@ def exec_kv(
863861
sin: torch.Tensor,
864862
kv_cache: Tuple,
865863
slots: torch.Tensor,
864+
enable_multistream_mla: bool = False,
866865
):
867866

868867
B = hidden_states.shape[0]
@@ -874,7 +873,7 @@ def exec_kv(
874873
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
875874
with npu_stream_switch("mla_secondary",
876875
0,
877-
enabled=self.enable_multistream_mla):
876+
enabled=enable_multistream_mla):
878877
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
879878
kv,
880879
self.kv_a_layernorm.weight,
@@ -1034,6 +1033,7 @@ def forward(
10341033
kv_cache: torch.Tensor,
10351034
attn_metadata: M,
10361035
output: Optional[torch.Tensor] = None,
1036+
enable_multistream_mla: bool = False,
10371037
) -> torch.Tensor:
10381038
assert output is not None, "Output tensor must be provided."
10391039
if attn_metadata is None:
@@ -1093,22 +1093,22 @@ def forward(
10931093
# KvRmsNormRopeCache and SingleRope.
10941094
npu_wait_tensor(decode_hs_or_q_c,
10951095
cos,
1096-
enabled=self.enable_multistream_mla)
1096+
enabled=enable_multistream_mla)
10971097
npu_wait_tensor(decode_hs_or_q_c,
10981098
sin,
1099-
enabled=self.enable_multistream_mla)
1099+
enabled=enable_multistream_mla)
11001100
decode_ql_nope, decode_q_pe = \
11011101
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
11021102
if self.running_in_graph:
11031103
decode_k_pe, decode_k_nope = self.exec_kv(
11041104
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1105-
attn_metadata.slot_mapping)
1105+
attn_metadata.slot_mapping, enable_multistream_mla)
11061106
with npu_stream_switch("mla_secondary",
11071107
0,
1108-
enabled=self.enable_multistream_mla):
1108+
enabled=enable_multistream_mla):
11091109
npu_wait_tensor(decode_q_pe,
11101110
decode_k_pe,
1111-
enabled=self.enable_multistream_mla)
1111+
enabled=enable_multistream_mla)
11121112
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11131113
else:
11141114
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(

vllm_ascend/models/deepseek_v2.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -555,20 +555,21 @@ def forward(
555555
hidden_states: torch.Tensor,
556556
kv_cache: Optional[torch.Tensor] = None,
557557
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
558+
enable_multistream_mla = (self.enable_multistream_mla
559+
and attn_metadata is not None
560+
and not attn_metadata.with_prefill_across_dp
561+
and attn_metadata.num_decodes > 0)
562+
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
558563
if self.q_lora_rank is not None:
559564
ckq = self.q_a_proj(hidden_states)[0]
560-
use_multistream_mla = (self.enable_multistream_mla
561-
and attn_metadata is not None
562-
and attn_metadata.num_decodes > 0)
563-
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
565+
npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
564566
with npu_stream_switch("mla_secondary",
565567
0,
566-
enabled=use_multistream_mla):
568+
enabled=enable_multistream_mla):
567569
hidden_states_or_q_c = self.q_a_layernorm(ckq)
568570
else:
569571
hidden_states_or_q_c = hidden_states
570572
if self.torchair_graph_enabled:
571-
forward_kwargs = {}
572573
if envs.VLLM_USE_V1:
573574
output_shape = hidden_states.shape
574575
output = torch.empty(output_shape,

0 commit comments

Comments
 (0)