Skip to content

Commit e46dc14

Browse files
authored
Enable kvcache_nz for the decode process in torchair graph mode (#1098)
What this PR does / why we need it? Enable kvcache_nz for the decode process in torchair graph mode, which reduces the time consumed by FA in long sequences. Does this PR introduce any user-facing change? If need to enable kvcache_nz, should set the additional_config.torchair_graph_config.enable_kv_nz=True How was this patch tested? 1. Tested in deepseek model: with batchsize 64 and seq_len 1k+3k, 61 layers FA total time improves 20.80ms -> 19.76ms 2. operator precision test: [aclnnFusedInferAttentionScoreV3_result.csv](https://github.com/user-attachments/files/20664138/aclnnFusedInferAttentionScoreV3_result.csv) 3. tpot test from @ttanzhiqiang, and curl one result is normal #1098 (comment) #1098 (comment) --------- Signed-off-by: chenwaner <861645847@qq.com>
1 parent 4153a50 commit e46dc14

File tree

3 files changed

+96
-47
lines changed

3 files changed

+96
-47
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ The details of each config option are as follows:
4444
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
4545
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
4646
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
47+
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout |
4748

4849
**ascend_scheduler_config**
4950

@@ -64,7 +65,8 @@ A full example of additional configuration is as follows:
6465
"use_cached_graph": true,
6566
"graph_batch_sizes": [1, 2, 4, 8],
6667
"graph_batch_sizes_init": false,
67-
"enable_multistream_moe": false
68+
"enable_multistream_moe": false,
69+
"enable_kv_nz": false
6870
},
6971
"ascend_scheduler_config": {
7072
"enabled": true,

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, torchair_graph_config):
5858
"enable_multistream_moe", False)
5959
self.enable_view_optimize = torchair_graph_config.get(
6060
"enable_view_optimize", True)
61+
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
6162

6263
if not isinstance(self.graph_batch_sizes, list):
6364
raise TypeError("graph_batch_sizes must be list[int]")

vllm_ascend/attention/mla_v1.py

Lines changed: 92 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def __init__(
480480

481481
ascend_config = get_ascend_config()
482482
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
483+
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
483484
# Adapt torch air graph mode with spec decoding.
484485
speculative_config = get_current_vllm_config().speculative_config
485486
if speculative_config is not None:
@@ -662,6 +663,7 @@ def exec_kv(
662663
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
663664
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
664665
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
666+
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
665667
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
666668
kv,
667669
self.kv_a_layernorm.weight,
@@ -671,7 +673,37 @@ def exec_kv(
671673
kv_cache[1],
672674
kv_cache[0],
673675
epsilon=self.kv_a_layernorm.variance_epsilon,
674-
cache_mode="PA",
676+
cache_mode=cache_mode,
677+
)
678+
return k_pe, k_nope
679+
680+
def exec_kv_prefill(
681+
self,
682+
hidden_states: torch.Tensor,
683+
cos: torch.Tensor,
684+
sin: torch.Tensor,
685+
kv_cache: Tuple,
686+
slots: torch.Tensor,
687+
):
688+
689+
B = hidden_states.shape[0]
690+
N = self.num_kv_heads
691+
S = 1
692+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
693+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
694+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
695+
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
696+
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
697+
kv,
698+
self.kv_a_layernorm.weight,
699+
cos,
700+
sin,
701+
slots.to(torch.int64),
702+
kv_cache[1],
703+
kv_cache[0],
704+
epsilon=self.kv_a_layernorm.variance_epsilon,
705+
cache_mode=cache_mode,
706+
is_output_kv=True,
675707
)
676708
return k_pe, k_nope
677709

@@ -709,42 +741,50 @@ def _forward_decode(
709741
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
710742
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
711743
assert num_tokens % self.spec_token_num == 0
712-
q_nope = (q_nope.view(
713-
num_tokens // (self.spec_token_num + 1),
714-
self.spec_token_num + 1,
715-
self.num_heads,
716-
-1,
717-
).transpose(1, 2).contiguous())
718-
q_pe = (q_pe.view(
719-
num_tokens // (self.spec_token_num + 1),
720-
self.spec_token_num + 1,
721-
self.num_heads,
722-
-1,
723-
).transpose(1, 2).contiguous())
744+
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
745+
self.spec_token_num + 1, self.num_heads,
746+
-1)
747+
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
748+
self.spec_token_num + 1, self.num_heads, -1)
749+
if not self.enable_kv_nz:
750+
q_nope = q_nope.transpose(1, 2).contiguous()
751+
q_pe = q_pe.transpose(1, 2).contiguous()
724752
sparse_mode = 3
725753
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
726754
else:
727-
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
728-
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
755+
if self.enable_kv_nz:
756+
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
757+
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
758+
else:
759+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
760+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
729761
sparse_mode = 0
730762
spec_attn_mask = None
731763
# shape of knope/k_pe for npu graph mode should be:
732764
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
733765
block_size = kv_c_and_k_pe_cache[0].shape[1]
734-
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
735-
self.kv_lora_rank)
736-
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
737-
self.qk_rope_head_dim)
766+
if self.enable_kv_nz:
767+
k_nope = k_nope.view(-1, self.num_kv_heads,
768+
self.kv_lora_rank // 16, block_size, 16)
769+
k_pe = k_pe.view(-1, self.num_kv_heads,
770+
self.qk_rope_head_dim // 16, block_size, 16)
771+
input_layout = "BSND"
772+
else:
773+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
774+
self.kv_lora_rank)
775+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
776+
self.qk_rope_head_dim)
777+
input_layout = "BNSD"
738778

739-
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
779+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
740780
q_nope,
741781
k_nope,
742782
k_nope,
743783
query_rope=q_pe,
744784
key_rope=k_pe,
745785
num_heads=self.num_heads,
746786
num_key_value_heads=self.num_kv_heads,
747-
input_layout="BNSD",
787+
input_layout=input_layout,
748788
atten_mask=spec_attn_mask,
749789
sparse_mode=sparse_mode,
750790
scale=self.scale,
@@ -793,10 +833,11 @@ def forward(
793833
]
794834
num_actual_toks = attn_metadata.num_actual_tokens
795835
if k_pe is None and not self.running_in_graph:
796-
kv_c, k_pe = self.kv_a_proj_with_mqa(
797-
hidden_states_or_kv_c_normed)[0].split(
798-
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
799-
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
836+
if not self.torchair_graph_enabled:
837+
kv_c, k_pe = self.kv_a_proj_with_mqa(
838+
hidden_states_or_kv_c_normed)[0].split(
839+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
840+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
800841
else:
801842
kv_c_normed = hidden_states_or_kv_c_normed
802843
assert attn_metadata.num_decodes is not None and \
@@ -809,16 +850,18 @@ def forward(
809850
# Inputs and outputs may be padded for CUDA graphs
810851
output_padded = output
811852
output = output[:num_actual_toks, ...]
812-
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
813-
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
853+
if not self.torchair_graph_enabled:
854+
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
855+
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
814856
if not self.running_in_graph:
815857
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
816-
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
817858
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
818-
k_pe = k_pe[:num_actual_toks, ...]
819-
k_pe = k_pe.unsqueeze(1)
820-
decode_k_pe = k_pe[:num_decode_tokens]
821-
prefill_k_pe = k_pe[num_decode_tokens:]
859+
if not self.torchair_graph_enabled:
860+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
861+
k_pe = k_pe[:num_actual_toks, ...]
862+
k_pe = k_pe.unsqueeze(1)
863+
decode_k_pe = k_pe[:num_decode_tokens]
864+
prefill_k_pe = k_pe[num_decode_tokens:]
822865
else:
823866
decode_hs_or_q_c = hidden_states_or_q_c
824867
if has_decode:
@@ -855,22 +898,25 @@ def forward(
855898
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
856899
if self.torchair_graph_enabled:
857900
num_tokens = prefill_hs_or_q_c.shape[0]
901+
seq_len = self.rotary_emb.max_position_embeddings
902+
cos = self.rotary_emb.cos_cached[:seq_len].to(
903+
dtype=prefill_q_pe.dtype)
904+
sin = self.rotary_emb.sin_cached[:seq_len].to(
905+
dtype=prefill_q_pe.dtype)
906+
cos = cos[attn_metadata.prefill.input_positions]
907+
sin = sin[attn_metadata.prefill.input_positions]
908+
cos = cos[:, None, None, :]
909+
sin = sin[:, None, None, :]
910+
911+
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
912+
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
913+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
914+
attn_metadata.slot_mapping)
915+
916+
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
917+
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
858918
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
859919
-1)
860-
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
861-
# NOTE: When scaling not specified
862-
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
863-
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
864-
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
865-
prefill_q_pe, prefill_k_pe = self.rotary_emb(
866-
attn_metadata.prefill.input_positions, prefill_q_pe,
867-
prefill_k_pe)
868-
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
869-
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
870-
else:
871-
prefill_q_pe, prefill_k_pe = self.rotary_emb(
872-
attn_metadata.prefill.input_positions, prefill_q_pe,
873-
prefill_k_pe)
874920
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
875921
else:
876922
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(

0 commit comments

Comments
 (0)