Skip to content

Commit ac998d5

Browse files
MengqingCaoyangcheng (AJ)
authored andcommitted
[MLA][Graph] Improve assertion on Graph mode with MLA (vllm-project#933)
Improve assertion on Graph mode with MLA. When running deepseek with graph mode, the fused MLA op only support `numHeads / numKvHeads ∈ {32, 64, 128}`, thus we improve the assertion info here to avoid users confused with this. Adjusting tp size is required when running deepseek-v3/r1 with graph mode. deepseek-v2-lite is not supported in graph mode. Test locally as the CI machine could not run V3 due to the HBM limits. --------- Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 16058dc commit ac998d5

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

vllm_ascend/attention/attention.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from vllm_ascend.worker.model_runner import (
4141
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
4242

43+
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
44+
4345

4446
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
4547
# Construct lower triangle matrix.
@@ -1005,6 +1007,15 @@ def __init__(
10051007
ascend_config = get_ascend_config()
10061008
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
10071009

1010+
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
1011+
if self.torchair_graph_enabled:
1012+
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
1013+
("The allowed number of queries per kv when enabling both MLA and Graph mode"
1014+
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
1015+
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
1016+
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
1017+
"{32, 64, 128}.")
1018+
10081019
def exec_kv(
10091020
self,
10101021
hidden_states: torch.Tensor,

vllm_ascend/attention/mla_v1.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from vllm_ascend import envs
1717
from vllm_ascend.ascend_config import get_ascend_config
18+
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
1819
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1920
from vllm_ascend.attention.utils import \
2021
AscendCommonAttentionMetadata as CommonAttentionMetadata
@@ -585,6 +586,15 @@ def __init__(
585586
self.spec_token_num = speculative_config.num_speculative_tokens
586587
assert self.spec_token_num > 0
587588

589+
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
590+
if self.torchair_graph_enabled:
591+
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
592+
("The allowed number of queries per kv when enabling both MLA and Graph mode"
593+
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
594+
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
595+
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
596+
"{32, 64, 128}.")
597+
588598
def _v_up_proj_and_o_proj(self, x):
589599
# Convert from (B, N, L) to (N, B, L)
590600
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)

0 commit comments

Comments
 (0)