Skip to content

Commit 5a8c482

Browse files
committed
[MLA][Graph] Improve assertion on Graph mode with MLA
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 7aa4f85 commit 5a8c482

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

vllm_ascend/attention/attention.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,9 @@ def forward(
947947
return output.view(num_tokens, self.hidden_size)
948948

949949

950+
ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
951+
952+
950953
class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
951954

952955
def __init__(
@@ -1005,6 +1008,13 @@ def __init__(
10051008
if additional_config:
10061009
self.enable_graph_mode = additional_config.get(
10071010
"enable_graph_mode", False)
1011+
if self.enable_graph_mode:
1012+
assert self.num_queries_per_kv in ALLOWED_NUM_QUERIES_PER_KV, \
1013+
("The allowed number of queries per kv when enabling both MLA and Graph mode"
1014+
" only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite,"
1015+
" as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1,"
1016+
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
1017+
"{32, 64, 128}.")
10081018

10091019
def exec_kv(
10101020
self,

vllm_ascend/worker/multi_step_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _prepare_last_sampled_token_ids_for_tp_workers(
119119
# execute_model_req
120120
assert execute_model_req.last_sampled_token_ids is not None
121121
model_input.last_sampled_token_ids = (
122-
execute_model_req.last_sampled_token_ids.cuda())
122+
execute_model_req.last_sampled_token_ids.npu())
123123
model_input.add_sampler_output(
124124
SamplerOutput(outputs=[], sampled_token_ids=None),
125125
model_input.last_sampled_token_ids)

0 commit comments

Comments
 (0)