Skip to content

Commit c5be2fc

Browse files
committed
Handle with_prefill_across_dp for multistream mla (#1322)
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 0c99cf7 commit c5be2fc

File tree

3 files changed

+85
-57
lines changed

3 files changed

+85
-57
lines changed

tests/multicard/test_torchair_graph_mode.py

Lines changed: 72 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Run `pytest tests/multicard/test_torchair_graph_mode.py`.
2121
"""
2222
import os
23+
from typing import Dict
2324

2425
import pytest
2526

@@ -28,6 +29,55 @@
2829
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
2930

3031

32+
def _deepseek_torchair_test_fixture(
33+
additional_config: Dict,
34+
*,
35+
tensor_parallel_size=4,
36+
):
37+
example_prompts = [
38+
"Hello, my name is",
39+
"The president of the United States is",
40+
"The capital of France is",
41+
"The future of AI is",
42+
]
43+
44+
# torchair is only work without chunked-prefill now
45+
kwargs = {
46+
"ascend_scheduler_config": {
47+
"enabled": True,
48+
},
49+
"refresh": True,
50+
}
51+
additional_config.update(**kwargs)
52+
53+
with VllmRunner(
54+
"vllm-ascend/DeepSeek-V3-Pruning",
55+
dtype="half",
56+
tensor_parallel_size=tensor_parallel_size,
57+
distributed_executor_backend="mp",
58+
enforce_eager=False,
59+
additional_config=additional_config,
60+
) as vllm_model:
61+
# use greedy sampler to make sure the generated results are fix
62+
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
63+
64+
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
65+
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
66+
# inaccurate. This will only change if accuracy improves with the
67+
# official weights of DeepSeek-V3.
68+
golden_results = [
69+
'Hello, my name is feasibility伸 spazio debtor添',
70+
'The president of the United States is begg"""\n杭州风和 bestimm',
71+
'The capital of France is frequentlyশามalinkAllowed',
72+
'The future of AI is deleting俯احت怎么样了حراف',
73+
]
74+
75+
assert len(golden_results) == len(vllm_output)
76+
for i in range(len(vllm_output)):
77+
assert golden_results[i] == vllm_output[i][1]
78+
print(f"Generated text: {vllm_output[i][1]!r}")
79+
80+
3181
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
3282
reason="torchair graph is not supported on v0")
3383
@pytest.mark.parametrize("VLLM_ASCEND_ENABLE_DBO", ["0", "1"])
@@ -38,46 +88,25 @@ def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch,
3888
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
3989
m.setenv("VLLM_ASCEND_ENABLE_DBO", VLLM_ASCEND_ENABLE_DBO)
4090

41-
example_prompts = [
42-
"Hello, my name is",
43-
"The president of the United States is",
44-
"The capital of France is",
45-
"The future of AI is",
46-
]
47-
dtype = "half"
48-
max_tokens = 5
49-
# torchair is only work without chunked-prefill now
50-
with VllmRunner(
51-
"vllm-ascend/DeepSeek-V3-Pruning",
52-
dtype=dtype,
53-
tensor_parallel_size=4,
54-
distributed_executor_backend="mp",
55-
additional_config={
56-
"torchair_graph_config": {
57-
"enabled": True,
58-
},
59-
"ascend_scheduler_config": {
60-
"enabled": True,
61-
},
62-
"refresh": True,
63-
},
64-
enforce_eager=False,
65-
) as vllm_model:
66-
# use greedy sampler to make sure the generated results are fix
67-
vllm_output = vllm_model.generate_greedy(example_prompts,
68-
max_tokens)
69-
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
70-
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
71-
# inaccurate. This will only change if accuracy improves with the
72-
# official weights of DeepSeek-V3.
73-
golden_results = [
74-
'Hello, my name is feasibility伸 spazio debtor添',
75-
'The president of the United States is begg"""\n杭州风和 bestimm',
76-
'The capital of France is frequentlyশามalinkAllowed',
77-
'The future of AI is deleting俯احت怎么样了حراف',
78-
]
79-
80-
assert len(golden_results) == len(vllm_output)
81-
for i in range(len(vllm_output)):
82-
assert golden_results[i] == vllm_output[i][1]
83-
print(f"Generated text: {vllm_output[i][1]!r}")
91+
additional_config = {
92+
"torchair_graph_config": {
93+
"enabled": True,
94+
},
95+
}
96+
_deepseek_torchair_test_fixture(additional_config)
97+
98+
99+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
100+
reason="torchair graph is not supported on v0")
101+
def test_e2e_deepseekv3_with_torchair_ms_mla(monkeypatch: pytest.MonkeyPatch):
102+
with monkeypatch.context() as m:
103+
m.setenv("VLLM_USE_MODELSCOPE", "True")
104+
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
105+
106+
additional_config = {
107+
"torchair_graph_config": {
108+
"enabled": True,
109+
"enable_multistream_mla": True,
110+
},
111+
}
112+
_deepseek_torchair_test_fixture(additional_config)

vllm_ascend/attention/mla_v1.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,6 @@ def __init__(
558558
ascend_config = get_ascend_config()
559559
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
560560
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
561-
self.enable_multistream_mla = \
562-
ascend_config.torchair_graph_config.enable_multistream_mla
563561

564562
# Adapt torch air graph mode with spec decoding.
565563
speculative_config = get_current_vllm_config().speculative_config
@@ -850,6 +848,7 @@ def exec_kv(
850848
sin: torch.Tensor,
851849
kv_cache: Tuple,
852850
slots: torch.Tensor,
851+
enable_multistream_mla: bool = False,
853852
):
854853

855854
B = hidden_states.shape[0]
@@ -861,7 +860,7 @@ def exec_kv(
861860
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
862861
with npu_stream_switch("mla_secondary",
863862
0,
864-
enabled=self.enable_multistream_mla):
863+
enabled=enable_multistream_mla):
865864
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
866865
kv,
867866
self.kv_a_layernorm.weight,
@@ -1034,6 +1033,7 @@ def forward(
10341033
kv_cache: Tuple[torch.Tensor],
10351034
attn_metadata: M,
10361035
output: Optional[torch.Tensor] = None,
1036+
enable_multistream_mla: bool = False,
10371037
) -> torch.Tensor:
10381038
assert output is not None, "Output tensor must be provided."
10391039
if attn_metadata is None:
@@ -1093,22 +1093,22 @@ def forward(
10931093
# KvRmsNormRopeCache and SingleRope.
10941094
npu_wait_tensor(decode_hs_or_q_c,
10951095
cos,
1096-
enabled=self.enable_multistream_mla)
1096+
enabled=enable_multistream_mla)
10971097
npu_wait_tensor(decode_hs_or_q_c,
10981098
sin,
1099-
enabled=self.enable_multistream_mla)
1099+
enabled=enable_multistream_mla)
11001100
decode_ql_nope, decode_q_pe = \
11011101
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
11021102
if self.running_in_graph:
11031103
decode_k_pe, decode_k_nope = self.exec_kv(
11041104
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1105-
attn_metadata.slot_mapping)
1105+
attn_metadata.slot_mapping, enable_multistream_mla)
11061106
with npu_stream_switch("mla_secondary",
11071107
0,
1108-
enabled=self.enable_multistream_mla):
1108+
enabled=enable_multistream_mla):
11091109
npu_wait_tensor(decode_q_pe,
11101110
decode_k_pe,
1111-
enabled=self.enable_multistream_mla)
1111+
enabled=enable_multistream_mla)
11121112
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11131113
else:
11141114
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(

vllm_ascend/models/deepseek_v2.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,20 +466,19 @@ def forward(
466466
hidden_states: torch.Tensor,
467467
kv_cache: Optional[torch.Tensor] = None,
468468
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
469+
enable_multistream_mla = (self.enable_multistream_mla
470+
and not get_forward_context().with_prefill)
471+
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
469472
if self.q_lora_rank is not None:
470473
ckq = self.q_a_proj(hidden_states)[0]
471-
use_multistream_mla = (self.enable_multistream_mla
472-
and attn_metadata is not None
473-
and attn_metadata.num_decodes > 0)
474-
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
474+
npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
475475
with npu_stream_switch("mla_secondary",
476476
0,
477-
enabled=use_multistream_mla):
477+
enabled=enable_multistream_mla):
478478
hidden_states_or_q_c = self.q_a_layernorm(ckq)
479479
else:
480480
hidden_states_or_q_c = hidden_states
481481
if self.torchair_graph_enabled:
482-
forward_kwargs = {}
483482
if envs.VLLM_USE_V1:
484483
output_shape = hidden_states.shape
485484
output = torch.empty(output_shape,

0 commit comments

Comments
 (0)