Skip to content

Commit 594721a

Browse files
committed
add todo and faq
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 5a8c482 commit 594721a

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

docs/source/faqs.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,13 @@ In scenarios where NPUs have limited HBM (High Bandwidth Memory) capacity, dynam
119119
- **Adjust `--gpu-memory-utilization`**: If unspecified, will use the default value of `0.9`. You can decrease this param to reserve more memory to reduce fragmentation risks. See more note in: [vLLM - Inference and Serving - Engine Arguments](https://docs.vllm.ai/en/latest/serving/engine_args.html#vllm.engine.arg_utils-_engine_args_parser-cacheconfig).
120120

121121
- **Configure `PYTORCH_NPU_ALLOC_CONF`**: Set this environment variable to optimize NPU memory management. For example, you can `export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True` to enable virtual memory feature to mitigate memory fragmentation caused by frequent dynamic memory size adjustments during runtime, see more note in: [PYTORCH_NPU_ALLOC_CONF](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html).
122+
123+
### 15. Failed to enable NPU graph mode when running DeepSeek?
124+
You may encounter the following error if running DeepSeek with NPU graph mode enabled. The allowed number of queries per kv when enabling both MLA and Graph mode only support {32, 64, 128}, **Thus this is not supported for DeepSeek-V2-Lite**, as it only has 16 attention heads. The NPU graph mode support on DeepSeek-V2-Lite will be done in the future.
125+
126+
And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tensor parallel split, num_heads / num_kv_heads in {32, 64, 128}.
127+
128+
```bash
129+
[rank0]: RuntimeError: EZ9999: Inner Error!
130+
[rank0]: EZ9999: [PID: 62938] 2025-05-27-06:52:12.455.807 numHeads / numKvHeads = 8, MLA only support {32, 64, 128}.[FUNC:CheckMlaAttrs][FILE:incre_flash_attention_tiling_check.cc][LINE:1218]
131+
```

vllm_ascend/attention/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,7 @@ def __init__(
10081008
if additional_config:
10091009
self.enable_graph_mode = additional_config.get(
10101010
"enable_graph_mode", False)
1011+
# TODO: support numHeads / numKvHeads < 16 in MLA kernel
10111012
if self.enable_graph_mode:
10121013
assert self.num_queries_per_kv in ALLOWED_NUM_QUERIES_PER_KV, \
10131014
("The allowed number of queries per kv when enabling both MLA and Graph mode"

0 commit comments

Comments
 (0)