From 01168864c95a433c415fdf6c7d0b5ea81515004f Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Mon, 9 Jun 2025 08:51:20 +0000 Subject: [PATCH 1/2] [MLA][Graph] Improve assertion on Graph mode with MLA Signed-off-by: MengqingCao --- docs/source/faqs.md | 10 ++++++++++ vllm_ascend/attention/attention.py | 11 +++++++++++ vllm_ascend/attention/mla_v1.py | 11 +++++++++++ vllm_ascend/worker/multi_step_worker.py | 2 +- 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/docs/source/faqs.md b/docs/source/faqs.md index 2a8ba9042d..6ab04b2206 100644 --- a/docs/source/faqs.md +++ b/docs/source/faqs.md @@ -115,3 +115,13 @@ In scenarios where NPUs have limited HBM (High Bandwidth Memory) capacity, dynam - **Adjust `--gpu-memory-utilization`**: If unspecified, will use the default value of `0.9`. You can decrease this param to reserve more memory to reduce fragmentation risks. See more note in: [vLLM - Inference and Serving - Engine Arguments](https://docs.vllm.ai/en/latest/serving/engine_args.html#vllm.engine.arg_utils-_engine_args_parser-cacheconfig). - **Configure `PYTORCH_NPU_ALLOC_CONF`**: Set this environment variable to optimize NPU memory management. For example, you can `export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True` to enable virtual memory feature to mitigate memory fragmentation caused by frequent dynamic memory size adjustments during runtime, see more note in: [PYTORCH_NPU_ALLOC_CONF](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html). + +### 15. Failed to enable NPU graph mode when running DeepSeek? +You may encounter the following error if running DeepSeek with NPU graph mode enabled. The allowed number of queries per kv when enabling both MLA and Graph mode only support {32, 64, 128}, **Thus this is not supported for DeepSeek-V2-Lite**, as it only has 16 attention heads. The NPU graph mode support on DeepSeek-V2-Lite will be done in the future. + +And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tensor parallel split, num_heads / num_kv_heads in {32, 64, 128}. + +```bash +[rank0]: RuntimeError: EZ9999: Inner Error! +[rank0]: EZ9999: [PID: 62938] 2025-05-27-06:52:12.455.807 numHeads / numKvHeads = 8, MLA only support {32, 64, 128}.[FUNC:CheckMlaAttrs][FILE:incre_flash_attention_tiling_check.cc][LINE:1218] +``` diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 8f130e4241..a567cc5306 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -40,6 +40,8 @@ from vllm_ascend.worker.model_runner import ( ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) +_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128] + def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None): # Construct lower triangle matrix. @@ -1005,6 +1007,15 @@ def __init__( ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + # TODO: support numHeads / numKvHeads < 16 in MLA kernel + if self.torchair_graph_enabled: + assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \ + ("The allowed number of queries per kv when enabling both MLA and Graph mode" + " only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite," + " as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1," + " please make sure after the tensor parallel split, num_heads / num_kv_heads in " + "{32, 64, 128}.") + def exec_kv( self, hidden_states: torch.Tensor, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 91ddf43888..164d428a5a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,6 +13,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn @@ -458,10 +459,20 @@ def __init__( self.o_proj = kwargs['o_proj'] self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + # TODO: support numHeads / numKvHeads < 16 in MLA kernel + if self.torchair_graph_enabled: + assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \ + ("The allowed number of queries per kv when enabling both MLA and Graph mode" + " only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite," + " as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1," + " please make sure after the tensor parallel split, num_heads / num_kv_heads in " + "{32, 64, 128}.") + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) diff --git a/vllm_ascend/worker/multi_step_worker.py b/vllm_ascend/worker/multi_step_worker.py index ba83f6b962..6d092805d5 100644 --- a/vllm_ascend/worker/multi_step_worker.py +++ b/vllm_ascend/worker/multi_step_worker.py @@ -119,7 +119,7 @@ def _prepare_last_sampled_token_ids_for_tp_workers( # execute_model_req assert execute_model_req.last_sampled_token_ids is not None model_input.last_sampled_token_ids = ( - execute_model_req.last_sampled_token_ids.cuda()) + execute_model_req.last_sampled_token_ids.npu()) model_input.add_sampler_output( SamplerOutput(outputs=[], sampled_token_ids=None), model_input.last_sampled_token_ids) From 297fc88dce96a584b9f47d69f2aa460e71ab1527 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Mon, 9 Jun 2025 09:55:08 +0000 Subject: [PATCH 2/2] fix ruff Signed-off-by: MengqingCao --- vllm_ascend/attention/mla_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 164d428a5a..0fd17e352d 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -12,8 +12,8 @@ UnquantizedLinearMethod) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn