Skip to content

Commit a550e39

Browse files
committed
[MLA][Graph] Improve assertion on Graph mode with MLA
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 3442fbd commit a550e39

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

docs/source/faqs.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,13 @@ In scenarios where NPUs have limited HBM (High Bandwidth Memory) capacity, dynam
115115
- **Adjust `--gpu-memory-utilization`**: If unspecified, will use the default value of `0.9`. You can decrease this param to reserve more memory to reduce fragmentation risks. See more note in: [vLLM - Inference and Serving - Engine Arguments](https://docs.vllm.ai/en/latest/serving/engine_args.html#vllm.engine.arg_utils-_engine_args_parser-cacheconfig).
116116

117117
- **Configure `PYTORCH_NPU_ALLOC_CONF`**: Set this environment variable to optimize NPU memory management. For example, you can `export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True` to enable virtual memory feature to mitigate memory fragmentation caused by frequent dynamic memory size adjustments during runtime, see more note in: [PYTORCH_NPU_ALLOC_CONF](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html).
118+
119+
### 15. Failed to enable NPU graph mode when running DeepSeek?
120+
You may encounter the following error if running DeepSeek with NPU graph mode enabled. The allowed number of queries per kv when enabling both MLA and Graph mode only support {32, 64, 128}, **Thus this is not supported for DeepSeek-V2-Lite**, as it only has 16 attention heads. The NPU graph mode support on DeepSeek-V2-Lite will be done in the future.
121+
122+
And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tensor parallel split, num_heads / num_kv_heads in {32, 64, 128}.
123+
124+
```bash
125+
[rank0]: RuntimeError: EZ9999: Inner Error!
126+
[rank0]: EZ9999: [PID: 62938] 2025-05-27-06:52:12.455.807 numHeads / numKvHeads = 8, MLA only support {32, 64, 128}.[FUNC:CheckMlaAttrs][FILE:incre_flash_attention_tiling_check.cc][LINE:1218]
127+
```

vllm_ascend/attention/attention.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from vllm_ascend.worker.model_runner import (
4141
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
4242

43+
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
44+
4345

4446
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
4547
# Construct lower triangle matrix.
@@ -1005,6 +1007,14 @@ def __init__(
10051007
if additional_config:
10061008
self.enable_graph_mode = additional_config.get(
10071009
"enable_graph_mode", False)
1010+
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
1011+
if self.enable_graph_mode:
1012+
assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \
1013+
("The allowed number of queries per kv when enabling both MLA and Graph mode"
1014+
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
1015+
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
1016+
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
1017+
"{32, 64, 128}.")
10081018

10091019
def exec_kv(
10101020
self,

vllm_ascend/worker/multi_step_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _prepare_last_sampled_token_ids_for_tp_workers(
119119
# execute_model_req
120120
assert execute_model_req.last_sampled_token_ids is not None
121121
model_input.last_sampled_token_ids = (
122-
execute_model_req.last_sampled_token_ids.cuda())
122+
execute_model_req.last_sampled_token_ids.npu())
123123
model_input.add_sampler_output(
124124
SamplerOutput(outputs=[], sampled_token_ids=None),
125125
model_input.last_sampled_token_ids)

0 commit comments

Comments
 (0)