Skip to content

Commit 9861dc5

Browse files
MengqingCaowangxiaoxin (A)
authored andcommitted
[MLA][Graph] Improve assertion on Graph mode with MLA (vllm-project#933)
### What this PR does / why we need it? Improve assertion on Graph mode with MLA. When running deepseek with graph mode, the fused MLA op only support `numHeads / numKvHeads ∈ {32, 64, 128}`, thus we improve the assertion info here to avoid users confused with this. ### Does this PR introduce _any_ user-facing change? Adjusting tp size is required when running deepseek-v3/r1 with graph mode. deepseek-v2-lite is not supported in graph mode. ### How was this patch tested? Test locally as the CI machine could not run V3 due to the HBM limits. --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiaoxin (A) <wangxiaoxin7@huawei.com>
1 parent 7eb9f23 commit 9861dc5

File tree

4 files changed

+33
-1
lines changed

4 files changed

+33
-1
lines changed

docs/source/faqs.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,13 @@ In scenarios where NPUs have limited HBM (High Bandwidth Memory) capacity, dynam
113113
- **Adjust `--gpu-memory-utilization`**: If unspecified, will use the default value of `0.9`. You can decrease this param to reserve more memory to reduce fragmentation risks. See more note in: [vLLM - Inference and Serving - Engine Arguments](https://docs.vllm.ai/en/latest/serving/engine_args.html#vllm.engine.arg_utils-_engine_args_parser-cacheconfig).
114114

115115
- **Configure `PYTORCH_NPU_ALLOC_CONF`**: Set this environment variable to optimize NPU memory management. For example, you can `export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True` to enable virtual memory feature to mitigate memory fragmentation caused by frequent dynamic memory size adjustments during runtime, see more note in: [PYTORCH_NPU_ALLOC_CONF](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html).
116+
117+
### 15. Failed to enable NPU graph mode when running DeepSeek?
118+
You may encounter the following error if running DeepSeek with NPU graph mode enabled. The allowed number of queries per kv when enabling both MLA and Graph mode only support {32, 64, 128}, **Thus this is not supported for DeepSeek-V2-Lite**, as it only has 16 attention heads. The NPU graph mode support on DeepSeek-V2-Lite will be done in the future.
119+
120+
And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tensor parallel split, num_heads / num_kv_heads in {32, 64, 128}.
121+
122+
```bash
123+
[rank0]: RuntimeError: EZ9999: Inner Error!
124+
[rank0]: EZ9999: [PID: 62938] 2025-05-27-06:52:12.455.807 numHeads / numKvHeads = 8, MLA only support {32, 64, 128}.[FUNC:CheckMlaAttrs][FILE:incre_flash_attention_tiling_check.cc][LINE:1218]
125+
```

vllm_ascend/attention/attention.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from vllm_ascend.worker.model_runner import (
4141
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
4242

43+
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
44+
4345

4446
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
4547
# Construct lower triangle matrix.
@@ -1005,6 +1007,15 @@ def __init__(
10051007
ascend_config = get_ascend_config()
10061008
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
10071009

1010+
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
1011+
if self.torchair_graph_enabled:
1012+
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
1013+
("The allowed number of queries per kv when enabling both MLA and Graph mode"
1014+
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
1015+
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
1016+
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
1017+
"{32, 64, 128}.")
1018+
10081019
def exec_kv(
10091020
self,
10101021
hidden_states: torch.Tensor,

vllm_ascend/attention/mla_v1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.utils import cdiv, round_down
1515

1616
from vllm_ascend.ascend_config import get_ascend_config
17+
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
1718
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1819
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
1920
from vllm_ascend.multistream.context import get_multistream_comm_context
@@ -551,6 +552,7 @@ def __init__(
551552
self.o_proj = kwargs['o_proj']
552553
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
553554
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
555+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
554556

555557
ascend_config = get_ascend_config()
556558
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -560,6 +562,15 @@ def __init__(
560562
self.spec_token_num = speculative_config.num_speculative_tokens
561563
assert self.spec_token_num > 0
562564

565+
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
566+
if self.torchair_graph_enabled:
567+
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
568+
("The allowed number of queries per kv when enabling both MLA and Graph mode"
569+
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
570+
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
571+
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
572+
"{32, 64, 128}.")
573+
563574
def _v_up_proj_and_o_proj(self, x):
564575
# Convert from (B, N, L) to (N, B, L)
565576
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)

vllm_ascend/worker/multi_step_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _prepare_last_sampled_token_ids_for_tp_workers(
119119
# execute_model_req
120120
assert execute_model_req.last_sampled_token_ids is not None
121121
model_input.last_sampled_token_ids = (
122-
execute_model_req.last_sampled_token_ids.cuda())
122+
execute_model_req.last_sampled_token_ids.npu())
123123
model_input.add_sampler_output(
124124
SamplerOutput(outputs=[], sampled_token_ids=None),
125125
model_input.last_sampled_token_ids)

0 commit comments

Comments
 (0)