Skip to content

Commit d893a91

Browse files
committed
use tuple as kv cache instead of tensor
1 parent 9fb3d55 commit d893a91

File tree

4 files changed

+84
-70
lines changed

4 files changed

+84
-70
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(
274274
shape = [batch_size * seq_len, num_heads, head_size]
275275
"""
276276
num_tokens = query.shape[0]
277-
use_kv_cache_int8 = kv_cache.numel(
277+
use_kv_cache_int8 = len(kv_cache
278278
) > 0 and kv_cache[0].dtype == torch.int8
279279
if output is None:
280280
output = torch.empty(num_tokens,
@@ -315,7 +315,7 @@ def forward(
315315
# TODO: Remove this contiguous in the future.
316316
value = value.contiguous()
317317

318-
if kv_cache.numel() > 0:
318+
if len(kv_cache) > 0:
319319
if self.key_cache is None:
320320
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
321321
slots = attn_metadata.slot_mapping

vllm_ascend/attention/mla_v1.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -659,12 +659,13 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
659659
def _compute_prefill_context(
660660
self,
661661
query: torch.Tensor,
662-
kv_c_and_k_pe_cache: torch.Tensor,
662+
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
663663
rope_dim: int,
664664
attn_metadata: AscendMLAMetadata,
665665
prefix_output: torch.Tensor,
666666
prefix_lse: torch.Tensor,
667667
):
668+
assert len(kv_c_and_k_pe_cache) > 1
668669
prefill_metadata = attn_metadata.prefill
669670
if prefill_metadata is None or prefill_metadata.chunked_context is None:
670671
return prefix_output, prefix_lse
@@ -674,21 +675,23 @@ def _compute_prefill_context(
674675
q_nope = query[..., :self.qk_nope_head_dim]
675676

676677
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
677-
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
678-
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
679-
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
678+
cache_kv_c = kv_c_and_k_pe_cache[0]
679+
cache_k_pe = kv_c_and_k_pe_cache[1]
680+
num_heads = cache_k_pe.size(2)
681+
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
682+
680683
for i in range(iters):
681684
toks = prefill_metadata.chunked_context.seq_tot[i]
682685

683686
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
684687
seq_len = torch.stack([seq_len1, seq_len2])
685688
kv_c_normed = torch.empty(toks,
686-
kv_c_and_k_pe_cache.size(2),
689+
num_heads,
687690
latent_kv_dim,
688691
dtype=query.dtype,
689692
device=query.device)
690693
k_pe = torch.empty(toks,
691-
kv_c_and_k_pe_cache.size(2),
694+
num_heads,
692695
rope_dim,
693696
dtype=query.dtype,
694697
device=query.device)
@@ -738,10 +741,11 @@ def _forward_prefill(
738741
query: torch.Tensor,
739742
kv_c_normed: torch.Tensor,
740743
k_pe: torch.Tensor,
741-
kv_c_and_k_pe_cache: torch.Tensor,
744+
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
742745
attn_metadata: AscendMLAMetadata,
743746
) -> torch.Tensor:
744747
assert attn_metadata.prefill is not None
748+
assert len(kv_c_and_k_pe_cache) > 1
745749

746750
num_tokens = query.size(0)
747751
attn_output = torch.empty(num_tokens,
@@ -769,7 +773,7 @@ def _forward_prefill(
769773
vanilla_chunked_prefill_mla(
770774
output=attn_output_torch,
771775
query=query,
772-
kv_cache=kv_c_and_k_pe_cache,
776+
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
773777
block_tables=attn_metadata.prefill.block_table,
774778
query_lens=attn_metadata.prefill.query_lens,
775779
context_lens=attn_metadata.prefill.context_lens,
@@ -938,18 +942,13 @@ def _forward_decode(
938942
q_pe: torch.Tensor,
939943
k_nope: torch.Tensor,
940944
k_pe: torch.Tensor,
941-
kv_c_and_k_pe_cache: torch.Tensor,
945+
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
942946
attn_metadata: AscendMLAMetadata,
943947
) -> torch.Tensor:
944948
decode_meta = attn_metadata.decode
945949
assert decode_meta is not None
946950

947-
q = torch.cat([q_nope, q_pe], dim=-1)
948-
num_tokens = q.size(0)
949-
attn_output = torch.empty(
950-
[num_tokens, self.num_heads, self.kv_lora_rank],
951-
dtype=q.dtype,
952-
device=q.device)
951+
num_tokens = q_nope.size(0)
953952
if self.running_in_graph:
954953
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
955954
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
@@ -1008,13 +1007,21 @@ def _forward_decode(
10081007
actual_seq_lengths_kv=decode_meta.seq_lens_list,
10091008
)
10101009
else:
1010+
q = torch.cat([q_nope, q_pe], dim=-1)
1011+
attn_output = torch.empty(
1012+
[num_tokens, self.num_heads, self.kv_lora_rank],
1013+
dtype=q.dtype,
1014+
device=q.device)
1015+
k_cache = torch.cat(
1016+
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
10111017
torch_npu._npu_paged_attention_mla(
10121018
query=q,
1013-
key_cache=kv_c_and_k_pe_cache,
1019+
key_cache=k_cache,
10141020
num_kv_heads=self.num_kv_heads,
10151021
num_heads=self.num_heads,
10161022
scale_value=self.scale,
1017-
block_table=attn_metadata.decode.block_table, # type:ignore
1023+
block_table=attn_metadata.decode.
1024+
block_table, # type:ignore
10181025
context_lens=attn_metadata.decode.seq_lens, # type:ignore
10191026
mla_vheadsize=self.kv_lora_rank,
10201027
out=attn_output)
@@ -1033,7 +1040,7 @@ def forward(
10331040
hidden_states_or_q_c: torch.Tensor, # query in unified attn
10341041
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
10351042
k_pe: torch.Tensor, # value in unified attn
1036-
kv_cache: torch.Tensor,
1043+
kv_cache: Tuple[torch.Tensor],
10371044
attn_metadata: M,
10381045
output: Optional[torch.Tensor] = None,
10391046
enable_multistream_mla: bool = False,
@@ -1153,8 +1160,11 @@ def forward(
11531160
prefill_q_pe.contiguous(),
11541161
prefill_k_pe,
11551162
max_seq_len=attn_metadata.prefill.max_seq_lens)
1163+
assert len(
1164+
kv_cache
1165+
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
11561166
if self.torchair_graph_enabled:
1157-
if len(kv_cache) > 0 and kv_cache[0].numel(
1167+
if kv_cache[0].numel(
11581168
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
11591169
slots = attn_metadata.slot_mapping
11601170
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
@@ -1164,16 +1174,15 @@ def forward(
11641174
key_cache=kv_cache[0],
11651175
value_cache=kv_cache[1],
11661176
slot_indices=slots)
1167-
elif kv_cache.numel() > 0:
1168-
key = torch.cat([
1169-
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
1170-
k_pe
1171-
],
1172-
dim=2)
1173-
torch_npu._npu_reshape_and_cache_siso(
1174-
key=key,
1175-
key_cache=kv_cache,
1176-
slot_indices=attn_metadata.slot_mapping.flatten())
1177+
else:
1178+
kv_c_normed = kv_c_normed.view(
1179+
[num_actual_toks, self.num_kv_heads, -1])
1180+
torch_npu._npu_reshape_and_cache(
1181+
key=kv_c_normed,
1182+
value=k_pe,
1183+
key_cache=kv_cache[0],
1184+
value_cache=kv_cache[1],
1185+
slot_indices=attn_metadata.slot_mapping)
11771186
if has_prefill:
11781187
# FIX: aicore move should be also placed on the comm stream in dbo,
11791188
# otherwise it may affect the accuracy

vllm_ascend/ops/attention.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def vanilla_chunked_prefill(
138138
def vanilla_chunked_prefill_mla(
139139
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
140140
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
141-
kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv)
141+
kv_c_and_k_pe_cache: tuple[torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim)
142142
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
143143
query_lens: torch.Tensor, # (batch_size)
144144
context_lens: torch.Tensor, # (batch_size)
@@ -154,20 +154,21 @@ def vanilla_chunked_prefill_mla(
154154
batch_size = block_tables.size(0)
155155
assert query_lens.size(0) == batch_size
156156
num_heads = query.size(1)
157-
block_size = kv_cache.size(1)
158-
latent_kv_dim = kv_cache.size(3) - rope_dim
159-
max_num_blocks_per_seq = block_tables.size(1)
160-
batch_size = query_lens.size(0)
161-
kv_cache = kv_cache.squeeze()
162-
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim]
163-
cache_kv_c_pe = kv_cache[block_tables].view(
164-
batch_size, max_num_blocks_per_seq * block_size,
165-
latent_kv_dim + rope_dim)[:, :max_context_len, :]
166-
# get kv_c and k_pe
157+
cache_kv_c = kv_c_and_k_pe_cache[0].squeeze()
158+
cache_k_pe = kv_c_and_k_pe_cache[1].squeeze()
159+
167160
# cached_kv_c: [batch_size, max_context_len, latent_kv]
168161
# cached_k_pe: [batch_size, max_context_len, rope_dim]
169-
cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim]
170-
cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:]
162+
batch_size = query_lens.size(0)
163+
block_size, latent_kv_dim = cache_kv_c.size(1), cache_kv_c.size(-1)
164+
max_num_blocks_per_seq = block_tables.size(1)
165+
cache_kv_c = cache_kv_c[block_tables].view(
166+
batch_size, max_num_blocks_per_seq * block_size,
167+
latent_kv_dim)[:, :max_context_len, :]
168+
cache_k_pe = cache_k_pe[block_tables].view(
169+
batch_size, max_num_blocks_per_seq * block_size,
170+
rope_dim)[:, :max_context_len, :]
171+
171172
# get k_rope and v
172173
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
173174
# value: [batch_size, max_context_len, num_heads, v_head_dim]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,35 +2131,39 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
21312131
num_blocks, kv_cache_spec.block_size,
21322132
kv_cache_spec.num_kv_heads,
21332133
kv_cache_spec.head_size)
2134-
if self.torchair_graph_enabled:
2135-
layer_kv_cache_nope = torch.zeros(
2136-
kv_cache_shape[:-1] +
2137-
(self.model_config.hf_text_config.kv_lora_rank, ),
2138-
dtype=self.dtype,
2139-
pin_memory=True,
2140-
device=self.device)
2141-
layer_kv_cache_pe = torch.zeros(
2142-
kv_cache_shape[:-1] +
2143-
(self.model_config.hf_text_config.qk_rope_head_dim,
2144-
),
2145-
dtype=self.dtype,
2146-
pin_memory=True,
2147-
device=self.device)
2148-
kv_caches[layer_name] = (layer_kv_cache_nope,
2149-
layer_kv_cache_pe)
2134+
dtype = kv_cache_spec.dtype
2135+
if self.model_config.is_deepseek_mla:
2136+
num_blocks, block_size, num_kv_heads, head_dim = kv_cache_shape
2137+
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
2138+
nope_dim = self.model_config.hf_text_config.kv_lora_rank
2139+
assert head_dim == rope_dim + nope_dim, \
2140+
f"head_dim({head_dim}) != rope_dim({rope_dim}) + nope_dim({nope_dim})"
2141+
nope_cache_shape = (num_blocks, block_size,
2142+
num_kv_heads, nope_dim)
2143+
rope_cache_shape = (num_blocks, block_size,
2144+
num_kv_heads, rope_dim)
2145+
nope_cache = torch.zeros(
2146+
nope_cache_shape, dtype=dtype, device=self.device,
2147+
pin_memory=True)
2148+
rope_cache = torch.zeros(
2149+
rope_cache_shape, dtype=dtype, device=self.device,
2150+
pin_memory=True)
21502151
kv_caches[layer_name] = (
2151-
torch_npu.npu_format_cast(kv_caches[layer_name][0],
2152-
acl_format),
2153-
torch_npu.npu_format_cast(kv_caches[layer_name][1],
2154-
acl_format),
2152+
torch_npu.npu_format_cast(nope_cache, acl_format),
2153+
torch_npu.npu_format_cast(rope_cache, acl_format),
21552154
)
21562155
else:
2157-
kv_caches[layer_name] = torch.zeros(
2158-
kv_cache_shape,
2159-
dtype=self.kv_cache_dtype,
2160-
device=self.device)
2161-
kv_caches[layer_name] = \
2162-
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
2156+
num_caches = kv_cache_shape[0]
2157+
kv_cache_list = []
2158+
for i in range(num_caches):
2159+
cache_shape = kv_cache_shape[1:]
2160+
kv_cache = torch.zeros(cache_shape,
2161+
dtype=dtype,
2162+
device=self.device)
2163+
kv_cache = torch_npu.npu_format_cast(kv_cache,
2164+
acl_format)
2165+
kv_cache_list.append(kv_cache)
2166+
kv_caches[layer_name] = kv_cache_list
21632167
else:
21642168
# TODO: add new branches when introducing more types of
21652169
# KV cache specs.

0 commit comments

Comments
 (0)