From c8487864e7f6086cd29358edda6a619154ac5014 Mon Sep 17 00:00:00 2001 From: lidenghui Date: Wed, 2 Jul 2025 21:14:56 +0800 Subject: [PATCH 1/2] use tuple as kv cache instead of tensor Signed-off-by: lidenghui --- vllm_ascend/attention/attention_v1.py | 4 +- vllm_ascend/attention/mla_v1.py | 67 +++++++++++++++------------ vllm_ascend/ops/attention.py | 27 +++++------ vllm_ascend/worker/model_runner_v1.py | 56 +++++++++++----------- 4 files changed, 84 insertions(+), 70 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 7d7f488f47..193b03b075 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -274,7 +274,7 @@ def forward( shape = [batch_size * seq_len, num_heads, head_size] """ num_tokens = query.shape[0] - use_kv_cache_int8 = kv_cache.numel( + use_kv_cache_int8 = len(kv_cache ) > 0 and kv_cache[0].dtype == torch.int8 if output is None: output = torch.empty(num_tokens, @@ -315,7 +315,7 @@ def forward( # TODO: Remove this contiguous in the future. value = value.contiguous() - if kv_cache.numel() > 0: + if len(kv_cache) > 0: if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b9e51a3e61..2f8decb317 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -659,12 +659,13 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _compute_prefill_context( self, query: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], rope_dim: int, attn_metadata: AscendMLAMetadata, prefix_output: torch.Tensor, prefix_lse: torch.Tensor, ): + assert len(kv_c_and_k_pe_cache) > 1 prefill_metadata = attn_metadata.prefill if prefill_metadata is None or prefill_metadata.chunked_context is None: return prefix_output, prefix_lse @@ -674,21 +675,23 @@ def _compute_prefill_context( q_nope = query[..., :self.qk_nope_head_dim] seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) - latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim - cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim] - cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:] + cache_kv_c = kv_c_and_k_pe_cache[0] + cache_k_pe = kv_c_and_k_pe_cache[1] + num_heads = cache_k_pe.size(2) + latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) + for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] seq_len = torch.stack([seq_len1, seq_len2]) kv_c_normed = torch.empty(toks, - kv_c_and_k_pe_cache.size(2), + num_heads, latent_kv_dim, dtype=query.dtype, device=query.device) k_pe = torch.empty(toks, - kv_c_and_k_pe_cache.size(2), + num_heads, rope_dim, dtype=query.dtype, device=query.device) @@ -738,10 +741,11 @@ def _forward_prefill( query: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: assert attn_metadata.prefill is not None + assert len(kv_c_and_k_pe_cache) > 1 num_tokens = query.size(0) attn_output = torch.empty(num_tokens, @@ -769,7 +773,7 @@ def _forward_prefill( vanilla_chunked_prefill_mla( output=attn_output_torch, query=query, - kv_cache=kv_c_and_k_pe_cache, + kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, block_tables=attn_metadata.prefill.block_table, query_lens=attn_metadata.prefill.query_lens, context_lens=attn_metadata.prefill.context_lens, @@ -938,18 +942,13 @@ def _forward_decode( q_pe: torch.Tensor, k_nope: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None - q = torch.cat([q_nope, q_pe], dim=-1) - num_tokens = q.size(0) - attn_output = torch.empty( - [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=q.dtype, - device=q.device) + num_tokens = q_nope.size(0) if self.running_in_graph: # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: @@ -1008,13 +1007,21 @@ def _forward_decode( actual_seq_lengths_kv=decode_meta.seq_lens_list, ) else: + q = torch.cat([q_nope, q_pe], dim=-1) + attn_output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device) + k_cache = torch.cat( + [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) torch_npu._npu_paged_attention_mla( query=q, - key_cache=kv_c_and_k_pe_cache, + key_cache=k_cache, num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, scale_value=self.scale, - block_table=attn_metadata.decode.block_table, # type:ignore + block_table=attn_metadata.decode. + block_table, # type:ignore context_lens=attn_metadata.decode.seq_lens, # type:ignore mla_vheadsize=self.kv_lora_rank, out=attn_output) @@ -1033,7 +1040,7 @@ def forward( hidden_states_or_q_c: torch.Tensor, # query in unified attn hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, + kv_cache: Tuple[torch.Tensor], attn_metadata: M, output: Optional[torch.Tensor] = None, enable_multistream_mla: bool = False, @@ -1153,8 +1160,11 @@ def forward( prefill_q_pe.contiguous(), prefill_k_pe, max_seq_len=attn_metadata.prefill.max_seq_lens) + assert len( + kv_cache + ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" if self.torchair_graph_enabled: - if len(kv_cache) > 0 and kv_cache[0].numel( + if kv_cache[0].numel( ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: slots = attn_metadata.slot_mapping # NOTE: Separate the kv cache in advance to avoid OOM or other issues @@ -1164,16 +1174,15 @@ def forward( key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slots) - elif kv_cache.numel() > 0: - key = torch.cat([ - kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), - k_pe - ], - dim=2) - torch_npu._npu_reshape_and_cache_siso( - key=key, - key_cache=kv_cache, - slot_indices=attn_metadata.slot_mapping.flatten()) + else: + kv_c_normed = kv_c_normed.view( + [num_actual_toks, self.num_kv_heads, -1]) + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=attn_metadata.slot_mapping) if has_prefill: # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy diff --git a/vllm_ascend/ops/attention.py b/vllm_ascend/ops/attention.py index 8037c9545b..c0ee5b07d3 100644 --- a/vllm_ascend/ops/attention.py +++ b/vllm_ascend/ops/attention.py @@ -138,7 +138,7 @@ def vanilla_chunked_prefill( def vanilla_chunked_prefill_mla( output: torch.Tensor, # (num_tokens, num_heads, v_head_dim) query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim) - kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv) + kv_c_and_k_pe_cache: tuple[torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim) block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq) query_lens: torch.Tensor, # (batch_size) context_lens: torch.Tensor, # (batch_size) @@ -154,20 +154,21 @@ def vanilla_chunked_prefill_mla( batch_size = block_tables.size(0) assert query_lens.size(0) == batch_size num_heads = query.size(1) - block_size = kv_cache.size(1) - latent_kv_dim = kv_cache.size(3) - rope_dim - max_num_blocks_per_seq = block_tables.size(1) - batch_size = query_lens.size(0) - kv_cache = kv_cache.squeeze() - # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] - cache_kv_c_pe = kv_cache[block_tables].view( - batch_size, max_num_blocks_per_seq * block_size, - latent_kv_dim + rope_dim)[:, :max_context_len, :] - # get kv_c and k_pe + cache_kv_c = kv_c_and_k_pe_cache[0].squeeze() + cache_k_pe = kv_c_and_k_pe_cache[1].squeeze() + # cached_kv_c: [batch_size, max_context_len, latent_kv] # cached_k_pe: [batch_size, max_context_len, rope_dim] - cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim] - cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:] + batch_size = query_lens.size(0) + block_size, latent_kv_dim = cache_kv_c.size(1), cache_kv_c.size(-1) + max_num_blocks_per_seq = block_tables.size(1) + cache_kv_c = cache_kv_c[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + latent_kv_dim)[:, :max_context_len, :] + cache_k_pe = cache_k_pe[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + rope_dim)[:, :max_context_len, :] + # get k_rope and v # k_nope: [batch_size, max_context_len, num_heads, nope_dim] # value: [batch_size, max_context_len, num_heads, v_head_dim] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 50d610e94b..c61710d612 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2131,35 +2131,39 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - if self.torchair_graph_enabled: - layer_kv_cache_nope = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.kv_lora_rank, ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - layer_kv_cache_pe = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.qk_rope_head_dim, - ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - kv_caches[layer_name] = (layer_kv_cache_nope, - layer_kv_cache_pe) + dtype = kv_cache_spec.dtype + if self.model_config.is_deepseek_mla: + num_blocks, block_size, num_kv_heads, head_dim = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = self.model_config.hf_text_config.kv_lora_rank + assert head_dim == rope_dim + nope_dim, \ + f"head_dim({head_dim}) != rope_dim({rope_dim}) + nope_dim({nope_dim})" + nope_cache_shape = (num_blocks, block_size, + num_kv_heads, nope_dim) + rope_cache_shape = (num_blocks, block_size, + num_kv_heads, rope_dim) + nope_cache = torch.zeros( + nope_cache_shape, dtype=dtype, device=self.device, + pin_memory=True) + rope_cache = torch.zeros( + rope_cache_shape, dtype=dtype, device=self.device, + pin_memory=True) kv_caches[layer_name] = ( - torch_npu.npu_format_cast(kv_caches[layer_name][0], - acl_format), - torch_npu.npu_format_cast(kv_caches[layer_name][1], - acl_format), + torch_npu.npu_format_cast(nope_cache, acl_format), + torch_npu.npu_format_cast(rope_cache, acl_format), ) else: - kv_caches[layer_name] = torch.zeros( - kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device) - kv_caches[layer_name] = \ - torch_npu.npu_format_cast(kv_caches[layer_name], acl_format) + num_caches = kv_cache_shape[0] + kv_cache_list = [] + for i in range(num_caches): + cache_shape = kv_cache_shape[1:] + kv_cache = torch.zeros(cache_shape, + dtype=dtype, + device=self.device) + kv_cache = torch_npu.npu_format_cast(kv_cache, + acl_format) + kv_cache_list.append(kv_cache) + kv_caches[layer_name] = kv_cache_list else: # TODO: add new branches when introducing more types of # KV cache specs. From 70588531ab8b36fc04c0d067b5fd8d3ff20c3460 Mon Sep 17 00:00:00 2001 From: lidenghui Date: Thu, 3 Jul 2025 15:51:48 +0800 Subject: [PATCH 2/2] fix lint Signed-off-by: lidenghui --- vllm_ascend/attention/attention_v1.py | 4 ++-- vllm_ascend/attention/mla_v1.py | 3 +-- vllm_ascend/ops/attention.py | 6 ++--- vllm_ascend/worker/model_runner_v1.py | 32 +++++++++++++++------------ 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 193b03b075..91d4b8836c 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -274,8 +274,8 @@ def forward( shape = [batch_size * seq_len, num_heads, head_size] """ num_tokens = query.shape[0] - use_kv_cache_int8 = len(kv_cache - ) > 0 and kv_cache[0].dtype == torch.int8 + use_kv_cache_int8 = len( + kv_cache) > 0 and kv_cache[0].dtype == torch.int8 if output is None: output = torch.empty(num_tokens, self.num_heads, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 2f8decb317..39612d3dd8 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1020,8 +1020,7 @@ def _forward_decode( num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, scale_value=self.scale, - block_table=attn_metadata.decode. - block_table, # type:ignore + block_table=attn_metadata.decode.block_table, # type:ignore context_lens=attn_metadata.decode.seq_lens, # type:ignore mla_vheadsize=self.kv_lora_rank, out=attn_output) diff --git a/vllm_ascend/ops/attention.py b/vllm_ascend/ops/attention.py index c0ee5b07d3..c0562ee5c6 100644 --- a/vllm_ascend/ops/attention.py +++ b/vllm_ascend/ops/attention.py @@ -138,7 +138,8 @@ def vanilla_chunked_prefill( def vanilla_chunked_prefill_mla( output: torch.Tensor, # (num_tokens, num_heads, v_head_dim) query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim) - kv_c_and_k_pe_cache: tuple[torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim) + kv_c_and_k_pe_cache: tuple[ + torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim) block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq) query_lens: torch.Tensor, # (batch_size) context_lens: torch.Tensor, # (batch_size) @@ -156,10 +157,9 @@ def vanilla_chunked_prefill_mla( num_heads = query.size(1) cache_kv_c = kv_c_and_k_pe_cache[0].squeeze() cache_k_pe = kv_c_and_k_pe_cache[1].squeeze() - + # cached_kv_c: [batch_size, max_context_len, latent_kv] # cached_k_pe: [batch_size, max_context_len, rope_dim] - batch_size = query_lens.size(0) block_size, latent_kv_dim = cache_kv_c.size(1), cache_kv_c.size(-1) max_num_blocks_per_seq = block_tables.size(1) cache_kv_c = cache_kv_c[block_tables].view( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c61710d612..c65f972742 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2138,19 +2138,23 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: nope_dim = self.model_config.hf_text_config.kv_lora_rank assert head_dim == rope_dim + nope_dim, \ f"head_dim({head_dim}) != rope_dim({rope_dim}) + nope_dim({nope_dim})" - nope_cache_shape = (num_blocks, block_size, - num_kv_heads, nope_dim) - rope_cache_shape = (num_blocks, block_size, - num_kv_heads, rope_dim) - nope_cache = torch.zeros( - nope_cache_shape, dtype=dtype, device=self.device, - pin_memory=True) - rope_cache = torch.zeros( - rope_cache_shape, dtype=dtype, device=self.device, - pin_memory=True) + layer_kv_cache_nope_shape = (num_blocks, block_size, + num_kv_heads, nope_dim) + layer_kv_cache_pe_shape = (num_blocks, block_size, + num_kv_heads, rope_dim) + layer_kv_cache_nope = torch.zeros( + layer_kv_cache_nope_shape, + dtype=dtype, + device=self.device) + layer_kv_cache_pe = torch.zeros( + layer_kv_cache_pe_shape, + dtype=dtype, + device=self.device) kv_caches[layer_name] = ( - torch_npu.npu_format_cast(nope_cache, acl_format), - torch_npu.npu_format_cast(rope_cache, acl_format), + torch_npu.npu_format_cast(layer_kv_cache_nope, + acl_format), + torch_npu.npu_format_cast(layer_kv_cache_pe, + acl_format), ) else: num_caches = kv_cache_shape[0] @@ -2160,8 +2164,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache = torch.zeros(cache_shape, dtype=dtype, device=self.device) - kv_cache = torch_npu.npu_format_cast(kv_cache, - acl_format) + kv_cache = torch_npu.npu_format_cast( + kv_cache, acl_format) kv_cache_list.append(kv_cache) kv_caches[layer_name] = kv_cache_list else: