Skip to content

Commit 64a6343

Browse files
committed
use tuple as kv cache instead of tensor
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
1 parent f96100f commit 64a6343

File tree

4 files changed

+89
-50
lines changed

4 files changed

+89
-50
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(
274274
shape = [batch_size * seq_len, num_heads, head_size]
275275
"""
276276
num_tokens = query.shape[0]
277-
use_kv_cache_int8 = kv_cache.numel(
277+
use_kv_cache_int8 = len(kv_cache
278278
) > 0 and kv_cache[0].dtype == torch.int8
279279
if output is None:
280280
output = torch.empty(num_tokens,
@@ -315,7 +315,7 @@ def forward(
315315
# TODO: Remove this contiguous in the future.
316316
value = value.contiguous()
317317

318-
if kv_cache.numel() > 0:
318+
if len(kv_cache) > 0:
319319
if self.key_cache is None:
320320
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
321321
slots = attn_metadata.slot_mapping

vllm_ascend/attention/mla_v1.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -664,12 +664,13 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
664664
def _compute_prefill_context(
665665
self,
666666
query: torch.Tensor,
667-
kv_c_and_k_pe_cache: torch.Tensor,
667+
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
668668
rope_dim: int,
669669
attn_metadata: AscendMLAMetadata,
670670
prefix_output: torch.Tensor,
671671
prefix_lse: torch.Tensor,
672672
):
673+
assert len(kv_c_and_k_pe_cache) > 1
673674
prefill_metadata = attn_metadata.prefill
674675
if prefill_metadata is None or prefill_metadata.chunked_context is None:
675676
return prefix_output, prefix_lse
@@ -679,21 +680,23 @@ def _compute_prefill_context(
679680
q_nope = query[..., :self.qk_nope_head_dim]
680681

681682
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
682-
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
683-
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
684-
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
683+
cache_kv_c = kv_c_and_k_pe_cache[0]
684+
cache_k_pe = kv_c_and_k_pe_cache[1]
685+
num_heads = cache_k_pe.size(2)
686+
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
687+
685688
for i in range(iters):
686689
toks = prefill_metadata.chunked_context.seq_tot[i]
687690

688691
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
689692
seq_len = torch.stack([seq_len1, seq_len2])
690693
kv_c_normed = torch.empty(toks,
691-
kv_c_and_k_pe_cache.size(2),
694+
num_heads,
692695
latent_kv_dim,
693696
dtype=query.dtype,
694697
device=query.device)
695698
k_pe = torch.empty(toks,
696-
kv_c_and_k_pe_cache.size(2),
699+
num_heads,
697700
rope_dim,
698701
dtype=query.dtype,
699702
device=query.device)
@@ -743,10 +746,11 @@ def _forward_prefill(
743746
query: torch.Tensor,
744747
kv_c_normed: torch.Tensor,
745748
k_pe: torch.Tensor,
746-
kv_c_and_k_pe_cache: torch.Tensor,
749+
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
747750
attn_metadata: AscendMLAMetadata,
748751
) -> torch.Tensor:
749752
assert attn_metadata.prefill is not None
753+
assert len(kv_c_and_k_pe_cache) > 1
750754

751755
num_tokens = query.size(0)
752756
attn_output = torch.empty(num_tokens,
@@ -774,7 +778,7 @@ def _forward_prefill(
774778
vanilla_chunked_prefill_mla(
775779
output=attn_output_torch,
776780
query=query,
777-
kv_cache=kv_c_and_k_pe_cache,
781+
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
778782
block_tables=attn_metadata.prefill.block_table,
779783
query_lens=attn_metadata.prefill.query_lens,
780784
context_lens=attn_metadata.prefill.context_lens,
@@ -939,19 +943,14 @@ def _forward_decode(
939943
q_pe: torch.Tensor,
940944
k_nope: torch.Tensor,
941945
k_pe: torch.Tensor,
942-
kv_c_and_k_pe_cache: torch.Tensor,
946+
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
943947
attn_metadata: AscendMLAMetadata,
944948
enable_multistream_mla: bool = False,
945949
) -> torch.Tensor:
946950
decode_meta = attn_metadata.decode
947951
assert decode_meta is not None
948952

949-
q = torch.cat([q_nope, q_pe], dim=-1)
950-
num_tokens = q.size(0)
951-
attn_output = torch.empty(
952-
[num_tokens, self.num_heads, self.kv_lora_rank],
953-
dtype=q.dtype,
954-
device=q.device)
953+
num_tokens = q_nope.size(0)
955954
if self.running_in_graph:
956955
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
957956
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
@@ -1010,13 +1009,21 @@ def _forward_decode(
10101009
actual_seq_lengths_kv=decode_meta.seq_lens_list,
10111010
)
10121011
else:
1012+
q = torch.cat([q_nope, q_pe], dim=-1)
1013+
attn_output = torch.empty(
1014+
[num_tokens, self.num_heads, self.kv_lora_rank],
1015+
dtype=q.dtype,
1016+
device=q.device)
1017+
k_cache = torch.cat(
1018+
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
10131019
torch_npu._npu_paged_attention_mla(
10141020
query=q,
1015-
key_cache=kv_c_and_k_pe_cache,
1021+
key_cache=k_cache,
10161022
num_kv_heads=self.num_kv_heads,
10171023
num_heads=self.num_heads,
10181024
scale_value=self.scale,
1019-
block_table=attn_metadata.decode.block_table, # type:ignore
1025+
block_table=attn_metadata.decode.
1026+
block_table, # type:ignore
10201027
context_lens=attn_metadata.decode.seq_lens, # type:ignore
10211028
mla_vheadsize=self.kv_lora_rank,
10221029
out=attn_output)
@@ -1036,7 +1043,7 @@ def forward(
10361043
hidden_states_or_q_c: torch.Tensor, # query in unified attn
10371044
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
10381045
k_pe: torch.Tensor, # value in unified attn
1039-
kv_cache: torch.Tensor,
1046+
kv_cache: Tuple[torch.Tensor],
10401047
attn_metadata: M,
10411048
output: Optional[torch.Tensor] = None,
10421049
enable_multistream_mla: bool = False,
@@ -1167,8 +1174,11 @@ def forward(
11671174
prefill_q_pe.contiguous(),
11681175
prefill_k_pe,
11691176
max_seq_len=attn_metadata.prefill.max_seq_lens)
1177+
assert len(
1178+
kv_cache
1179+
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
11701180
if self.torchair_graph_enabled:
1171-
if len(kv_cache) > 0 and kv_cache[0].numel(
1181+
if kv_cache[0].numel(
11721182
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
11731183
slots = attn_metadata.slot_mapping
11741184
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
@@ -1178,16 +1188,15 @@ def forward(
11781188
key_cache=kv_cache[0],
11791189
value_cache=kv_cache[1],
11801190
slot_indices=slots)
1181-
elif kv_cache.numel() > 0:
1182-
key = torch.cat([
1183-
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
1184-
k_pe
1185-
],
1186-
dim=2)
1187-
torch_npu._npu_reshape_and_cache_siso(
1188-
key=key,
1189-
key_cache=kv_cache,
1190-
slot_indices=attn_metadata.slot_mapping.flatten())
1191+
else:
1192+
kv_c_normed = kv_c_normed.view(
1193+
[num_actual_toks, self.num_kv_heads, -1])
1194+
torch_npu._npu_reshape_and_cache(
1195+
key=kv_c_normed,
1196+
value=k_pe,
1197+
key_cache=kv_cache[0],
1198+
value_cache=kv_cache[1],
1199+
slot_indices=attn_metadata.slot_mapping)
11911200
if has_prefill:
11921201
# FIX: aicore move should be also placed on the comm stream in dbo,
11931202
# otherwise it may affect the accuracy

vllm_ascend/ops/attention.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def vanilla_chunked_prefill(
138138
def vanilla_chunked_prefill_mla(
139139
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
140140
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
141-
kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv)
141+
kv_c_and_k_pe_cache: tuple[torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim)
142142
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
143143
query_lens: torch.Tensor, # (batch_size)
144144
context_lens: torch.Tensor, # (batch_size)
@@ -154,20 +154,21 @@ def vanilla_chunked_prefill_mla(
154154
batch_size = block_tables.size(0)
155155
assert query_lens.size(0) == batch_size
156156
num_heads = query.size(1)
157-
block_size = kv_cache.size(1)
158-
latent_kv_dim = kv_cache.size(3) - rope_dim
159-
max_num_blocks_per_seq = block_tables.size(1)
160-
batch_size = query_lens.size(0)
161-
kv_cache = kv_cache.squeeze()
162-
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim]
163-
cache_kv_c_pe = kv_cache[block_tables].view(
164-
batch_size, max_num_blocks_per_seq * block_size,
165-
latent_kv_dim + rope_dim)[:, :max_context_len, :]
166-
# get kv_c and k_pe
157+
cache_kv_c = kv_c_and_k_pe_cache[0].squeeze()
158+
cache_k_pe = kv_c_and_k_pe_cache[1].squeeze()
159+
167160
# cached_kv_c: [batch_size, max_context_len, latent_kv]
168161
# cached_k_pe: [batch_size, max_context_len, rope_dim]
169-
cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim]
170-
cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:]
162+
batch_size = query_lens.size(0)
163+
block_size, latent_kv_dim = cache_kv_c.size(1), cache_kv_c.size(-1)
164+
max_num_blocks_per_seq = block_tables.size(1)
165+
cache_kv_c = cache_kv_c[block_tables].view(
166+
batch_size, max_num_blocks_per_seq * block_size,
167+
latent_kv_dim)[:, :max_context_len, :]
168+
cache_k_pe = cache_k_pe[block_tables].view(
169+
batch_size, max_num_blocks_per_seq * block_size,
170+
rope_dim)[:, :max_context_len, :]
171+
171172
# get k_rope and v
172173
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
173174
# value: [batch_size, max_context_len, num_heads, v_head_dim]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
19071907
num_blocks, kv_cache_spec.block_size,
19081908
kv_cache_spec.num_kv_heads,
19091909
kv_cache_spec.head_size)
1910+
<<<<<<< HEAD
19101911
if self.torchair_graph_enabled:
19111912
if len(kv_cache_shape) == 3:
19121913
# for non MLA attention backend that use torchair, we consider to pass kv_cache layout
@@ -1952,13 +1953,41 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
19521953
torch_npu.npu_format_cast(
19531954
kv_caches[layer_name][1], acl_format),
19541955
)
1956+
=======
1957+
dtype = kv_cache_spec.dtype
1958+
if self.model_config.is_deepseek_mla:
1959+
num_blocks, block_size, num_kv_heads, head_dim = kv_cache_shape
1960+
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
1961+
nope_dim = self.model_config.hf_text_config.kv_lora_rank
1962+
assert head_dim == rope_dim + nope_dim, \
1963+
f"head_dim({head_dim}) != rope_dim({rope_dim}) + nope_dim({nope_dim})"
1964+
nope_cache_shape = (num_blocks, block_size,
1965+
num_kv_heads, nope_dim)
1966+
rope_cache_shape = (num_blocks, block_size,
1967+
num_kv_heads, rope_dim)
1968+
nope_cache = torch.zeros(
1969+
nope_cache_shape, dtype=dtype, device=self.device,
1970+
pin_memory=True)
1971+
rope_cache = torch.zeros(
1972+
rope_cache_shape, dtype=dtype, device=self.device,
1973+
pin_memory=True)
1974+
kv_caches[layer_name] = (
1975+
torch_npu.npu_format_cast(nope_cache, acl_format),
1976+
torch_npu.npu_format_cast(rope_cache, acl_format),
1977+
)
1978+
>>>>>>> c848786 (use tuple as kv cache instead of tensor)
19551979
else:
1956-
kv_caches[layer_name] = torch.zeros(
1957-
kv_cache_shape,
1958-
dtype=self.kv_cache_dtype,
1959-
device=self.device)
1960-
kv_caches[layer_name] = \
1961-
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
1980+
num_caches = kv_cache_shape[0]
1981+
kv_cache_list = []
1982+
for i in range(num_caches):
1983+
cache_shape = kv_cache_shape[1:]
1984+
kv_cache = torch.zeros(cache_shape,
1985+
dtype=dtype,
1986+
device=self.device)
1987+
kv_cache = torch_npu.npu_format_cast(kv_cache,
1988+
acl_format)
1989+
kv_cache_list.append(kv_cache)
1990+
kv_caches[layer_name] = kv_cache_list
19621991
else:
19631992
# TODO: add new branches when introducing more types of
19641993
# KV cache specs.

0 commit comments

Comments
 (0)