Skip to content

Commit 9e099a5

Browse files
committed
Handle with_prefill_across_dp for multistream mla (#1322)
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent e1d282d commit 9e099a5

File tree

3 files changed

+85
-56
lines changed

3 files changed

+85
-56
lines changed

tests/multicard/test_torchair_graph_mode.py

Lines changed: 72 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Run `pytest tests/multicard/test_torchair_graph_mode.py`.
2121
"""
2222
import os
23+
from typing import Dict
2324

2425
import pytest
2526

@@ -28,6 +29,55 @@
2829
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
2930

3031

32+
def _deepseek_torchair_test_fixture(
33+
additional_config: Dict,
34+
*,
35+
tensor_parallel_size=4,
36+
):
37+
example_prompts = [
38+
"Hello, my name is",
39+
"The president of the United States is",
40+
"The capital of France is",
41+
"The future of AI is",
42+
]
43+
44+
# torchair is only work without chunked-prefill now
45+
kwargs = {
46+
"ascend_scheduler_config": {
47+
"enabled": True,
48+
},
49+
"refresh": True,
50+
}
51+
additional_config.update(**kwargs)
52+
53+
with VllmRunner(
54+
"vllm-ascend/DeepSeek-V3-Pruning",
55+
dtype="half",
56+
tensor_parallel_size=tensor_parallel_size,
57+
distributed_executor_backend="mp",
58+
enforce_eager=False,
59+
additional_config=additional_config,
60+
) as vllm_model:
61+
# use greedy sampler to make sure the generated results are fix
62+
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
63+
64+
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
65+
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
66+
# inaccurate. This will only change if accuracy improves with the
67+
# official weights of DeepSeek-V3.
68+
golden_results = [
69+
'Hello, my name is下载早点向前很有่อง',
70+
'The president of the United States isSender)## physiological Albany',
71+
'The capital of France is Rocky转角 hospitalizedinterval sparked',
72+
'The future of AI is её asegο BIOS一扫',
73+
]
74+
75+
assert len(golden_results) == len(vllm_output)
76+
for i in range(len(vllm_output)):
77+
assert golden_results[i] == vllm_output[i][1]
78+
print(f"Generated text: {vllm_output[i][1]!r}")
79+
80+
3181
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
3282
reason="torchair graph is not supported on v0")
3383
@pytest.mark.parametrize("VLLM_ASCEND_ENABLE_DBO", ["0", "1"])
@@ -38,46 +88,25 @@ def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch,
3888
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
3989
m.setenv("VLLM_ASCEND_ENABLE_DBO", VLLM_ASCEND_ENABLE_DBO)
4090

41-
example_prompts = [
42-
"Hello, my name is",
43-
"The president of the United States is",
44-
"The capital of France is",
45-
"The future of AI is",
46-
]
47-
dtype = "half"
48-
max_tokens = 5
49-
# torchair is only work without chunked-prefill now
50-
with VllmRunner(
51-
"vllm-ascend/DeepSeek-V3-Pruning",
52-
dtype=dtype,
53-
tensor_parallel_size=4,
54-
distributed_executor_backend="mp",
55-
additional_config={
56-
"torchair_graph_config": {
57-
"enabled": True,
58-
},
59-
"ascend_scheduler_config": {
60-
"enabled": True,
61-
},
62-
"refresh": True,
63-
},
64-
enforce_eager=False,
65-
) as vllm_model:
66-
# use greedy sampler to make sure the generated results are fix
67-
vllm_output = vllm_model.generate_greedy(example_prompts,
68-
max_tokens)
69-
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
70-
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
71-
# inaccurate. This will only change if accuracy improves with the
72-
# official weights of DeepSeek-V3.
73-
golden_results = [
74-
'Hello, my name is下载早点向前很有่อง',
75-
'The president of the United States isSender)## physiological Albany',
76-
'The capital of France is Rocky转角 hospitalizedinterval sparked',
77-
'The future of AI is её asegο BIOS一扫',
78-
]
79-
80-
assert len(golden_results) == len(vllm_output)
81-
for i in range(len(vllm_output)):
82-
assert golden_results[i] == vllm_output[i][1]
83-
print(f"Generated text: {vllm_output[i][1]!r}")
91+
additional_config = {
92+
"torchair_graph_config": {
93+
"enabled": True,
94+
},
95+
}
96+
_deepseek_torchair_test_fixture(additional_config)
97+
98+
99+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
100+
reason="torchair graph is not supported on v0")
101+
def test_e2e_deepseekv3_with_torchair_ms_mla(monkeypatch: pytest.MonkeyPatch):
102+
with monkeypatch.context() as m:
103+
m.setenv("VLLM_USE_MODELSCOPE", "True")
104+
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
105+
106+
additional_config = {
107+
"torchair_graph_config": {
108+
"enabled": True,
109+
"enable_multistream_mla": True,
110+
},
111+
}
112+
_deepseek_torchair_test_fixture(additional_config)

vllm_ascend/attention/mla_v1.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,6 @@ def __init__(
588588
ascend_config = get_ascend_config()
589589
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
590590
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
591-
self.enable_multistream_mla = \
592-
ascend_config.torchair_graph_config.enable_multistream_mla
593591

594592
# Adapt torch air graph mode with spec decoding.
595593
speculative_config = get_current_vllm_config().speculative_config
@@ -883,6 +881,7 @@ def exec_kv(
883881
sin: torch.Tensor,
884882
kv_cache: Tuple,
885883
slots: torch.Tensor,
884+
enable_multistream_mla: bool = False,
886885
):
887886

888887
B = hidden_states.shape[0]
@@ -894,7 +893,7 @@ def exec_kv(
894893
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
895894
with npu_stream_switch("mla_secondary",
896895
0,
897-
enabled=self.enable_multistream_mla):
896+
enabled=enable_multistream_mla):
898897
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
899898
kv,
900899
self.kv_a_layernorm.weight,
@@ -1066,6 +1065,7 @@ def forward(
10661065
kv_cache: Tuple[torch.Tensor],
10671066
attn_metadata: M,
10681067
output: Optional[torch.Tensor] = None,
1068+
enable_multistream_mla: bool = False,
10691069
) -> torch.Tensor:
10701070
assert output is not None, "Output tensor must be provided."
10711071
if attn_metadata is None:
@@ -1127,22 +1127,22 @@ def forward(
11271127
# KvRmsNormRopeCache and SingleRope.
11281128
npu_wait_tensor(decode_hs_or_q_c,
11291129
cos,
1130-
enabled=self.enable_multistream_mla)
1130+
enabled=enable_multistream_mla)
11311131
npu_wait_tensor(decode_hs_or_q_c,
11321132
sin,
1133-
enabled=self.enable_multistream_mla)
1133+
enabled=enable_multistream_mla)
11341134
decode_ql_nope, decode_q_pe = \
11351135
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
11361136
if self.running_in_graph:
11371137
decode_k_pe, decode_k_nope = self.exec_kv(
11381138
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1139-
attn_metadata.slot_mapping)
1139+
attn_metadata.slot_mapping, enable_multistream_mla)
11401140
with npu_stream_switch("mla_secondary",
11411141
0,
1142-
enabled=self.enable_multistream_mla):
1142+
enabled=enable_multistream_mla):
11431143
npu_wait_tensor(decode_q_pe,
11441144
decode_k_pe,
1145-
enabled=self.enable_multistream_mla)
1145+
enabled=enable_multistream_mla)
11461146
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11471147
else:
11481148
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(

vllm_ascend/models/deepseek_v2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,15 +470,15 @@ def forward(
470470
hidden_states: torch.Tensor,
471471
kv_cache: Optional[torch.Tensor] = None,
472472
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
473+
enable_multistream_mla = (self.enable_multistream_mla
474+
and not get_forward_context().with_prefill)
475+
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
473476
if self.q_lora_rank is not None:
474477
ckq = self.q_a_proj(hidden_states)[0]
475-
use_multistream_mla = (self.enable_multistream_mla
476-
and attn_metadata is not None
477-
and attn_metadata.num_decodes > 0)
478-
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
478+
npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
479479
with npu_stream_switch("mla_secondary",
480480
0,
481-
enabled=use_multistream_mla):
481+
enabled=enable_multistream_mla):
482482
hidden_states_or_q_c = self.q_a_layernorm(ckq)
483483
else:
484484
hidden_states_or_q_c = hidden_states

0 commit comments

Comments
 (0)