-
Notifications
You must be signed in to change notification settings - Fork 241
[Refactor] Use tuple as kv cache instead of tensor #1594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -659,12 +659,13 @@ | |
def _compute_prefill_context( | ||
self, | ||
query: torch.Tensor, | ||
kv_c_and_k_pe_cache: torch.Tensor, | ||
kv_c_and_k_pe_cache: Tuple[torch.Tensor], | ||
rope_dim: int, | ||
attn_metadata: AscendMLAMetadata, | ||
prefix_output: torch.Tensor, | ||
prefix_lse: torch.Tensor, | ||
): | ||
assert len(kv_c_and_k_pe_cache) > 1 | ||
prefill_metadata = attn_metadata.prefill | ||
if prefill_metadata is None or prefill_metadata.chunked_context is None: | ||
return prefix_output, prefix_lse | ||
|
@@ -674,21 +675,23 @@ | |
q_nope = query[..., :self.qk_nope_head_dim] | ||
|
||
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) | ||
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim | ||
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim] | ||
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:] | ||
cache_kv_c = kv_c_and_k_pe_cache[0] | ||
cache_k_pe = kv_c_and_k_pe_cache[1] | ||
num_heads = cache_k_pe.size(2) | ||
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) | ||
|
||
for i in range(iters): | ||
toks = prefill_metadata.chunked_context.seq_tot[i] | ||
|
||
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] | ||
seq_len = torch.stack([seq_len1, seq_len2]) | ||
kv_c_normed = torch.empty(toks, | ||
kv_c_and_k_pe_cache.size(2), | ||
num_heads, | ||
latent_kv_dim, | ||
dtype=query.dtype, | ||
device=query.device) | ||
k_pe = torch.empty(toks, | ||
kv_c_and_k_pe_cache.size(2), | ||
num_heads, | ||
rope_dim, | ||
dtype=query.dtype, | ||
device=query.device) | ||
|
@@ -738,10 +741,11 @@ | |
query: torch.Tensor, | ||
kv_c_normed: torch.Tensor, | ||
k_pe: torch.Tensor, | ||
kv_c_and_k_pe_cache: torch.Tensor, | ||
kv_c_and_k_pe_cache: Tuple[torch.Tensor], | ||
attn_metadata: AscendMLAMetadata, | ||
) -> torch.Tensor: | ||
assert attn_metadata.prefill is not None | ||
assert len(kv_c_and_k_pe_cache) > 1 | ||
|
||
num_tokens = query.size(0) | ||
attn_output = torch.empty(num_tokens, | ||
|
@@ -769,7 +773,7 @@ | |
vanilla_chunked_prefill_mla( | ||
output=attn_output_torch, | ||
query=query, | ||
kv_cache=kv_c_and_k_pe_cache, | ||
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, | ||
block_tables=attn_metadata.prefill.block_table, | ||
query_lens=attn_metadata.prefill.query_lens, | ||
context_lens=attn_metadata.prefill.context_lens, | ||
|
@@ -938,18 +942,13 @@ | |
q_pe: torch.Tensor, | ||
k_nope: torch.Tensor, | ||
k_pe: torch.Tensor, | ||
kv_c_and_k_pe_cache: torch.Tensor, | ||
kv_c_and_k_pe_cache: Tuple[torch.Tensor], | ||
attn_metadata: AscendMLAMetadata, | ||
) -> torch.Tensor: | ||
decode_meta = attn_metadata.decode | ||
assert decode_meta is not None | ||
|
||
q = torch.cat([q_nope, q_pe], dim=-1) | ||
num_tokens = q.size(0) | ||
attn_output = torch.empty( | ||
[num_tokens, self.num_heads, self.kv_lora_rank], | ||
dtype=q.dtype, | ||
device=q.device) | ||
num_tokens = q_nope.size(0) | ||
if self.running_in_graph: | ||
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] | ||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: | ||
|
@@ -1008,9 +1007,16 @@ | |
actual_seq_lengths_kv=decode_meta.seq_lens_list, | ||
) | ||
else: | ||
q = torch.cat([q_nope, q_pe], dim=-1) | ||
attn_output = torch.empty( | ||
[num_tokens, self.num_heads, self.kv_lora_rank], | ||
dtype=q.dtype, | ||
device=q.device) | ||
k_cache = torch.cat( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am worried about the extra NPU memory consumption this will bring There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. It will use extra NPU memory in torch.cat, do you have any better suggestion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Acctualy I have no good suggestions on it, either. We seems must do this concat. But I think we can remove this until the ring attention could be enabled, right? If so, I think this change is acceptable. also cc @ganyi1996ppo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line can be wipe out if ringmla is public accessible, and seems most of the change in this PR already contains in the PR #950 , can you refactor it after that PR merge? Or there will be lots of conflict, which may bring more barrier for the 950 to merge....... cc @wangxiyuan @Yikun There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ganyi1996ppo It's OK. If all the changes in this PR has been done in #950, this PR can be closed. If so, hope it can be merged asap, other works rely on kv cache will rely on it. |
||
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) | ||
Check failure on line 1016 in vllm_ascend/attention/mla_v1.py
|
||
torch_npu._npu_paged_attention_mla( | ||
query=q, | ||
key_cache=kv_c_and_k_pe_cache, | ||
key_cache=k_cache, | ||
num_kv_heads=self.num_kv_heads, | ||
num_heads=self.num_heads, | ||
scale_value=self.scale, | ||
|
@@ -1033,7 +1039,7 @@ | |
hidden_states_or_q_c: torch.Tensor, # query in unified attn | ||
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn | ||
k_pe: torch.Tensor, # value in unified attn | ||
kv_cache: torch.Tensor, | ||
kv_cache: Tuple[torch.Tensor], | ||
attn_metadata: M, | ||
output: Optional[torch.Tensor] = None, | ||
enable_multistream_mla: bool = False, | ||
|
@@ -1153,8 +1159,11 @@ | |
prefill_q_pe.contiguous(), | ||
prefill_k_pe, | ||
max_seq_len=attn_metadata.prefill.max_seq_lens) | ||
assert len( | ||
kv_cache | ||
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" | ||
if self.torchair_graph_enabled: | ||
if len(kv_cache) > 0 and kv_cache[0].numel( | ||
if kv_cache[0].numel( | ||
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: | ||
slots = attn_metadata.slot_mapping | ||
# NOTE: Separate the kv cache in advance to avoid OOM or other issues | ||
|
@@ -1164,16 +1173,15 @@ | |
key_cache=kv_cache[0], | ||
value_cache=kv_cache[1], | ||
slot_indices=slots) | ||
elif kv_cache.numel() > 0: | ||
key = torch.cat([ | ||
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), | ||
k_pe | ||
], | ||
dim=2) | ||
torch_npu._npu_reshape_and_cache_siso( | ||
key=key, | ||
key_cache=kv_cache, | ||
slot_indices=attn_metadata.slot_mapping.flatten()) | ||
else: | ||
kv_c_normed = kv_c_normed.view( | ||
[num_actual_toks, self.num_kv_heads, -1]) | ||
torch_npu._npu_reshape_and_cache( | ||
key=kv_c_normed, | ||
value=k_pe, | ||
key_cache=kv_cache[0], | ||
value_cache=kv_cache[1], | ||
slot_indices=attn_metadata.slot_mapping) | ||
if has_prefill: | ||
# FIX: aicore move should be also placed on the comm stream in dbo, | ||
# otherwise it may affect the accuracy | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2131,35 +2131,43 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |||||
num_blocks, kv_cache_spec.block_size, | ||||||
kv_cache_spec.num_kv_heads, | ||||||
kv_cache_spec.head_size) | ||||||
if self.torchair_graph_enabled: | ||||||
dtype = kv_cache_spec.dtype | ||||||
if self.model_config.is_deepseek_mla: | ||||||
num_blocks, block_size, num_kv_heads, head_dim = kv_cache_shape | ||||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim | ||||||
nope_dim = self.model_config.hf_text_config.kv_lora_rank | ||||||
assert head_dim == rope_dim + nope_dim, \ | ||||||
f"head_dim({head_dim}) != rope_dim({rope_dim}) + nope_dim({nope_dim})" | ||||||
layer_kv_cache_nope_shape = (num_blocks, block_size, | ||||||
num_kv_heads, nope_dim) | ||||||
layer_kv_cache_pe_shape = (num_blocks, block_size, | ||||||
num_kv_heads, rope_dim) | ||||||
layer_kv_cache_nope = torch.zeros( | ||||||
kv_cache_shape[:-1] + | ||||||
(self.model_config.hf_text_config.kv_lora_rank, ), | ||||||
dtype=self.dtype, | ||||||
pin_memory=True, | ||||||
layer_kv_cache_nope_shape, | ||||||
dtype=dtype, | ||||||
device=self.device) | ||||||
layer_kv_cache_pe = torch.zeros( | ||||||
kv_cache_shape[:-1] + | ||||||
(self.model_config.hf_text_config.qk_rope_head_dim, | ||||||
), | ||||||
dtype=self.dtype, | ||||||
pin_memory=True, | ||||||
layer_kv_cache_pe_shape, | ||||||
dtype=dtype, | ||||||
device=self.device) | ||||||
kv_caches[layer_name] = (layer_kv_cache_nope, | ||||||
layer_kv_cache_pe) | ||||||
kv_caches[layer_name] = ( | ||||||
torch_npu.npu_format_cast(kv_caches[layer_name][0], | ||||||
torch_npu.npu_format_cast(layer_kv_cache_nope, | ||||||
acl_format), | ||||||
torch_npu.npu_format_cast(kv_caches[layer_name][1], | ||||||
torch_npu.npu_format_cast(layer_kv_cache_pe, | ||||||
acl_format), | ||||||
) | ||||||
else: | ||||||
kv_caches[layer_name] = torch.zeros( | ||||||
kv_cache_shape, | ||||||
dtype=self.kv_cache_dtype, | ||||||
device=self.device) | ||||||
kv_caches[layer_name] = \ | ||||||
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format) | ||||||
num_caches = kv_cache_shape[0] | ||||||
kv_cache_list = [] | ||||||
for i in range(num_caches): | ||||||
cache_shape = kv_cache_shape[1:] | ||||||
kv_cache = torch.zeros(cache_shape, | ||||||
dtype=dtype, | ||||||
device=self.device) | ||||||
kv_cache = torch_npu.npu_format_cast( | ||||||
kv_cache, acl_format) | ||||||
kv_cache_list.append(kv_cache) | ||||||
kv_caches[layer_name] = kv_cache_list | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kv_cache_list is a list but other parts of the code expect a tuple of (nope_cache, rope_cache). Consider converting to a tuple:
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
else: | ||||||
# TODO: add new branches when introducing more types of | ||||||
# KV cache specs. | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This assertion allows any length ≥2. To enforce exactly two cache parts, consider
assert len(kv_c_and_k_pe_cache) == 2
for clearer intent.Copilot uses AI. Check for mistakes.