Skip to content

[Refactor] Use tuple as kv cache instead of tensor #1594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def forward(
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
use_kv_cache_int8 = kv_cache.numel(
) > 0 and kv_cache[0].dtype == torch.int8
use_kv_cache_int8 = len(
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
if output is None:
output = torch.empty(num_tokens,
self.num_heads,
Expand Down Expand Up @@ -315,7 +315,7 @@ def forward(
# TODO: Remove this contiguous in the future.
value = value.contiguous()

if kv_cache.numel() > 0:
if len(kv_cache) > 0:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
Expand Down
64 changes: 36 additions & 28 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,12 +659,13 @@
def _compute_prefill_context(
self,
query: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
rope_dim: int,
attn_metadata: AscendMLAMetadata,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
):
assert len(kv_c_and_k_pe_cache) > 1
Copy link
Preview

Copilot AI Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This assertion allows any length ≥2. To enforce exactly two cache parts, consider assert len(kv_c_and_k_pe_cache) == 2 for clearer intent.

Suggested change
assert len(kv_c_and_k_pe_cache) > 1
assert len(kv_c_and_k_pe_cache) == 2

Copilot uses AI. Check for mistakes.

prefill_metadata = attn_metadata.prefill
if prefill_metadata is None or prefill_metadata.chunked_context is None:
return prefix_output, prefix_lse
Expand All @@ -674,21 +675,23 @@
q_nope = query[..., :self.qk_nope_head_dim]

seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)

for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]

seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
kv_c_normed = torch.empty(toks,
kv_c_and_k_pe_cache.size(2),
num_heads,
latent_kv_dim,
dtype=query.dtype,
device=query.device)
k_pe = torch.empty(toks,
kv_c_and_k_pe_cache.size(2),
num_heads,
rope_dim,
dtype=query.dtype,
device=query.device)
Expand Down Expand Up @@ -738,10 +741,11 @@
query: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert len(kv_c_and_k_pe_cache) > 1

num_tokens = query.size(0)
attn_output = torch.empty(num_tokens,
Expand Down Expand Up @@ -769,7 +773,7 @@
vanilla_chunked_prefill_mla(
output=attn_output_torch,
query=query,
kv_cache=kv_c_and_k_pe_cache,
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
block_tables=attn_metadata.prefill.block_table,
query_lens=attn_metadata.prefill.query_lens,
context_lens=attn_metadata.prefill.context_lens,
Expand Down Expand Up @@ -938,18 +942,13 @@
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None

q = torch.cat([q_nope, q_pe], dim=-1)
num_tokens = q.size(0)
attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
num_tokens = q_nope.size(0)
if self.running_in_graph:
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
Expand Down Expand Up @@ -1008,9 +1007,16 @@
actual_seq_lengths_kv=decode_meta.seq_lens_list,
)
else:
q = torch.cat([q_nope, q_pe], dim=-1)
attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
k_cache = torch.cat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am worried about the extra NPU memory consumption this will bring

Copy link
Author

@lidenghui1110 lidenghui1110 Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It will use extra NPU memory in torch.cat, do you have any better suggestion?
I noticed that in prefill, we alse use cat to do such things, like: line:830, so the consumption should have been calcaulated in warmup stage.
And in deepseek-r1, the consumption should be per layer kv cache size, roughly about 15G/61 = 251M, maybe it is affordable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acctualy I have no good suggestions on it, either. We seems must do this concat. But I think we can remove this until the ring attention could be enabled, right? If so, I think this change is acceptable. also cc @ganyi1996ppo

Copy link
Collaborator

@ganyi1996ppo ganyi1996ppo Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be wipe out if ringmla is public accessible, and seems most of the change in this PR already contains in the PR #950 , can you refactor it after that PR merge? Or there will be lots of conflict, which may bring more barrier for the 950 to merge....... cc @wangxiyuan @Yikun

Copy link
Author

@lidenghui1110 lidenghui1110 Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ganyi1996ppo It's OK. If all the changes in this PR has been done in #950, this PR can be closed. If so, hope it can be merged asap, other works rely on kv cache will rely on it.

[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)

Check failure on line 1016 in vllm_ascend/attention/mla_v1.py

View workflow job for this annotation

GitHub Actions / lint (3.10)

Tuple index out of range [misc]

Check failure on line 1016 in vllm_ascend/attention/mla_v1.py

View workflow job for this annotation

GitHub Actions / lint (3.10)

Tuple index out of range [misc]
torch_npu._npu_paged_attention_mla(
query=q,
key_cache=kv_c_and_k_pe_cache,
key_cache=k_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
Expand All @@ -1033,7 +1039,7 @@
hidden_states_or_q_c: torch.Tensor, # query in unified attn
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
attn_metadata: M,
output: Optional[torch.Tensor] = None,
enable_multistream_mla: bool = False,
Expand Down Expand Up @@ -1153,8 +1159,11 @@
prefill_q_pe.contiguous(),
prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens)
assert len(
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
if self.torchair_graph_enabled:
if len(kv_cache) > 0 and kv_cache[0].numel(
if kv_cache[0].numel(
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
slots = attn_metadata.slot_mapping
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
Expand All @@ -1164,16 +1173,15 @@
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots)
elif kv_cache.numel() > 0:
key = torch.cat([
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
k_pe
],
dim=2)
torch_npu._npu_reshape_and_cache_siso(
key=key,
key_cache=kv_cache,
slot_indices=attn_metadata.slot_mapping.flatten())
else:
kv_c_normed = kv_c_normed.view(
[num_actual_toks, self.num_kv_heads, -1])
torch_npu._npu_reshape_and_cache(
key=kv_c_normed,
value=k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=attn_metadata.slot_mapping)
if has_prefill:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
Expand Down
27 changes: 14 additions & 13 deletions vllm_ascend/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@
def vanilla_chunked_prefill_mla(
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv)
kv_c_and_k_pe_cache: tuple[
torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim)
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
query_lens: torch.Tensor, # (batch_size)
context_lens: torch.Tensor, # (batch_size)
Expand All @@ -154,20 +155,20 @@
batch_size = block_tables.size(0)
assert query_lens.size(0) == batch_size
num_heads = query.size(1)
block_size = kv_cache.size(1)
latent_kv_dim = kv_cache.size(3) - rope_dim
max_num_blocks_per_seq = block_tables.size(1)
batch_size = query_lens.size(0)
kv_cache = kv_cache.squeeze()
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim]
cache_kv_c_pe = kv_cache[block_tables].view(
batch_size, max_num_blocks_per_seq * block_size,
latent_kv_dim + rope_dim)[:, :max_context_len, :]
# get kv_c and k_pe
cache_kv_c = kv_c_and_k_pe_cache[0].squeeze()
cache_k_pe = kv_c_and_k_pe_cache[1].squeeze()

Check failure on line 159 in vllm_ascend/ops/attention.py

View workflow job for this annotation

GitHub Actions / lint (3.10)

Tuple index out of range [misc]

Check failure on line 159 in vllm_ascend/ops/attention.py

View workflow job for this annotation

GitHub Actions / lint (3.10)

Tuple index out of range [misc]

# cached_kv_c: [batch_size, max_context_len, latent_kv]
# cached_k_pe: [batch_size, max_context_len, rope_dim]
cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim]
cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:]
block_size, latent_kv_dim = cache_kv_c.size(1), cache_kv_c.size(-1)
max_num_blocks_per_seq = block_tables.size(1)
cache_kv_c = cache_kv_c[block_tables].view(
batch_size, max_num_blocks_per_seq * block_size,
latent_kv_dim)[:, :max_context_len, :]
cache_k_pe = cache_k_pe[block_tables].view(
batch_size, max_num_blocks_per_seq * block_size,
rope_dim)[:, :max_context_len, :]

# get k_rope and v
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
# value: [batch_size, max_context_len, num_heads, v_head_dim]
Expand Down
48 changes: 28 additions & 20 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,35 +2131,43 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
if self.torchair_graph_enabled:
dtype = kv_cache_spec.dtype
if self.model_config.is_deepseek_mla:
num_blocks, block_size, num_kv_heads, head_dim = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = self.model_config.hf_text_config.kv_lora_rank
assert head_dim == rope_dim + nope_dim, \
f"head_dim({head_dim}) != rope_dim({rope_dim}) + nope_dim({nope_dim})"
layer_kv_cache_nope_shape = (num_blocks, block_size,
num_kv_heads, nope_dim)
layer_kv_cache_pe_shape = (num_blocks, block_size,
num_kv_heads, rope_dim)
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ),
dtype=self.dtype,
pin_memory=True,
layer_kv_cache_nope_shape,
dtype=dtype,
device=self.device)
layer_kv_cache_pe = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.qk_rope_head_dim,
),
dtype=self.dtype,
pin_memory=True,
layer_kv_cache_pe_shape,
dtype=dtype,
device=self.device)
kv_caches[layer_name] = (layer_kv_cache_nope,
layer_kv_cache_pe)
kv_caches[layer_name] = (
torch_npu.npu_format_cast(kv_caches[layer_name][0],
torch_npu.npu_format_cast(layer_kv_cache_nope,
acl_format),
torch_npu.npu_format_cast(kv_caches[layer_name][1],
torch_npu.npu_format_cast(layer_kv_cache_pe,
acl_format),
)
else:
kv_caches[layer_name] = torch.zeros(
kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device)
kv_caches[layer_name] = \
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
num_caches = kv_cache_shape[0]
kv_cache_list = []
for i in range(num_caches):
cache_shape = kv_cache_shape[1:]
kv_cache = torch.zeros(cache_shape,
dtype=dtype,
device=self.device)
kv_cache = torch_npu.npu_format_cast(
kv_cache, acl_format)
kv_cache_list.append(kv_cache)
kv_caches[layer_name] = kv_cache_list
Copy link
Preview

Copilot AI Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv_cache_list is a list but other parts of the code expect a tuple of (nope_cache, rope_cache). Consider converting to a tuple: kv_caches[layer_name] = tuple(kv_cache_list).

Suggested change
kv_caches[layer_name] = kv_cache_list
kv_caches[layer_name] = tuple(kv_cache_list)

Copilot uses AI. Check for mistakes.

else:
# TODO: add new branches when introducing more types of
# KV cache specs.
Expand Down
Loading