Skip to content

Commit 7863de0

Browse files
committed
fix: correct finding the kv cache shape for mha
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 78873f7 commit 7863de0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm_ascend/distributed/llmdatadist_connector_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,11 +785,11 @@ def _extract_kv_from_layer(
785785
"""
786786
if is_mla:
787787
num_heads, head_dim = kv_cache_layer.shape[
788-
2], kv_cache_layer.shape[3]
788+
-2], kv_cache_layer.shape[-1]
789789
return kv_cache_layer.view(-1, num_heads, head_dim)[slot_mapping,
790790
...]
791791

792-
num_heads, head_dim = kv_cache_layer.shape[2], kv_cache_layer.shape[3]
792+
num_heads, head_dim = kv_cache_layer.shape[-2], kv_cache_layer.shape[-1]
793793
return kv_cache_layer.view(2, -1, num_heads, head_dim)[:, slot_mapping,
794794
...]
795795

0 commit comments

Comments
 (0)