@@ -556,6 +556,7 @@ def __init__(
556
556
557
557
ascend_config = get_ascend_config ()
558
558
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
559
+ self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
559
560
# Adapt torch air graph mode with spec decoding.
560
561
speculative_config = get_current_vllm_config ().speculative_config
561
562
if speculative_config is not None :
@@ -859,6 +860,7 @@ def exec_kv(
859
860
kv = self .kv_a_proj_with_mqa (hidden_states )[0 ]
860
861
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
861
862
kv = kv .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
863
+ cache_mode = "PA_NZ" if self .enable_kv_nz else "PA"
862
864
k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
863
865
kv ,
864
866
self .kv_a_layernorm .weight ,
@@ -868,7 +870,37 @@ def exec_kv(
868
870
kv_cache [1 ],
869
871
kv_cache [0 ],
870
872
epsilon = self .kv_a_layernorm .variance_epsilon ,
871
- cache_mode = "PA" ,
873
+ cache_mode = cache_mode ,
874
+ )
875
+ return k_pe , k_nope
876
+
877
+ def exec_kv_prefill (
878
+ self ,
879
+ hidden_states : torch .Tensor ,
880
+ cos : torch .Tensor ,
881
+ sin : torch .Tensor ,
882
+ kv_cache : Tuple ,
883
+ slots : torch .Tensor ,
884
+ ):
885
+
886
+ B = hidden_states .shape [0 ]
887
+ N = self .num_kv_heads
888
+ S = 1
889
+ kv = self .kv_a_proj_with_mqa (hidden_states )[0 ]
890
+ # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
891
+ kv = kv .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
892
+ cache_mode = "PA_BLK_NZ" if self .enable_kv_nz else "PA"
893
+ _ , _ , k_pe , k_nope = torch_npu .npu_kv_rmsnorm_rope_cache (
894
+ kv ,
895
+ self .kv_a_layernorm .weight ,
896
+ cos ,
897
+ sin ,
898
+ slots .to (torch .int64 ),
899
+ kv_cache [1 ],
900
+ kv_cache [0 ],
901
+ epsilon = self .kv_a_layernorm .variance_epsilon ,
902
+ cache_mode = cache_mode ,
903
+ is_output_kv = True ,
872
904
)
873
905
return k_pe , k_nope
874
906
@@ -906,42 +938,50 @@ def _forward_decode(
906
938
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
907
939
if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
908
940
assert num_tokens % self .spec_token_num == 0
909
- q_nope = (q_nope .view (
910
- num_tokens // (self .spec_token_num + 1 ),
911
- self .spec_token_num + 1 ,
912
- self .num_heads ,
913
- - 1 ,
914
- ).transpose (1 , 2 ).contiguous ())
915
- q_pe = (q_pe .view (
916
- num_tokens // (self .spec_token_num + 1 ),
917
- self .spec_token_num + 1 ,
918
- self .num_heads ,
919
- - 1 ,
920
- ).transpose (1 , 2 ).contiguous ())
941
+ q_nope = q_nope .view (num_tokens // (self .spec_token_num + 1 ),
942
+ self .spec_token_num + 1 , self .num_heads ,
943
+ - 1 )
944
+ q_pe = q_pe .view (num_tokens // (self .spec_token_num + 1 ),
945
+ self .spec_token_num + 1 , self .num_heads , - 1 )
946
+ if not self .enable_kv_nz :
947
+ q_nope = q_nope .transpose (1 , 2 ).contiguous ()
948
+ q_pe = q_pe .transpose (1 , 2 ).contiguous ()
921
949
sparse_mode = 3
922
950
spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
923
951
else :
924
- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
925
- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
952
+ if self .enable_kv_nz :
953
+ q_nope = q_nope .view (num_tokens , 1 , self .num_heads , - 1 )
954
+ q_pe = q_pe .view (num_tokens , 1 , self .num_heads , - 1 )
955
+ else :
956
+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
957
+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
926
958
sparse_mode = 0
927
959
spec_attn_mask = None
928
960
# shape of knope/k_pe for npu graph mode should be:
929
961
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
930
962
block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
931
- k_nope = k_nope .view (- 1 , self .num_kv_heads , block_size ,
932
- self .kv_lora_rank )
933
- k_pe = k_pe .view (- 1 , self .num_kv_heads , block_size ,
934
- self .qk_rope_head_dim )
963
+ if self .enable_kv_nz :
964
+ k_nope = k_nope .view (- 1 , self .num_kv_heads ,
965
+ self .kv_lora_rank // 16 , block_size , 16 )
966
+ k_pe = k_pe .view (- 1 , self .num_kv_heads ,
967
+ self .qk_rope_head_dim // 16 , block_size , 16 )
968
+ input_layout = "BSND"
969
+ else :
970
+ k_nope = k_nope .view (- 1 , self .num_kv_heads , block_size ,
971
+ self .kv_lora_rank )
972
+ k_pe = k_pe .view (- 1 , self .num_kv_heads , block_size ,
973
+ self .qk_rope_head_dim )
974
+ input_layout = "BNSD"
935
975
936
- attn_output , _ = torch . ops . npu .npu_fused_infer_attention_score (
976
+ attn_output , _ = torch_npu .npu_fused_infer_attention_score (
937
977
q_nope ,
938
978
k_nope ,
939
979
k_nope ,
940
980
query_rope = q_pe ,
941
981
key_rope = k_pe ,
942
982
num_heads = self .num_heads ,
943
983
num_key_value_heads = self .num_kv_heads ,
944
- input_layout = "BNSD" ,
984
+ input_layout = input_layout ,
945
985
atten_mask = spec_attn_mask ,
946
986
sparse_mode = sparse_mode ,
947
987
scale = self .scale ,
@@ -990,10 +1030,11 @@ def forward(
990
1030
]
991
1031
num_actual_toks = attn_metadata .num_actual_tokens
992
1032
if k_pe is None and not self .running_in_graph :
993
- kv_c , k_pe = self .kv_a_proj_with_mqa (
994
- hidden_states_or_kv_c_normed )[0 ].split (
995
- [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
996
- kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
1033
+ if not self .torchair_graph_enabled :
1034
+ kv_c , k_pe = self .kv_a_proj_with_mqa (
1035
+ hidden_states_or_kv_c_normed )[0 ].split (
1036
+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1037
+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
997
1038
else :
998
1039
kv_c_normed = hidden_states_or_kv_c_normed
999
1040
assert attn_metadata .num_decodes is not None and \
@@ -1006,16 +1047,18 @@ def forward(
1006
1047
# Inputs and outputs may be padded for CUDA graphs
1007
1048
output_padded = output
1008
1049
output = output [:num_actual_toks , ...]
1009
- kv_c_normed = kv_c_normed [:num_actual_toks , ...]
1010
- prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
1050
+ if not self .torchair_graph_enabled :
1051
+ kv_c_normed = kv_c_normed [:num_actual_toks , ...]
1052
+ prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
1011
1053
if not self .running_in_graph :
1012
1054
hidden_states_or_q_c = hidden_states_or_q_c [:num_actual_toks , ...]
1013
- decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
1014
1055
prefill_hs_or_q_c = hidden_states_or_q_c [num_decode_tokens :]
1015
- k_pe = k_pe [:num_actual_toks , ...]
1016
- k_pe = k_pe .unsqueeze (1 )
1017
- decode_k_pe = k_pe [:num_decode_tokens ]
1018
- prefill_k_pe = k_pe [num_decode_tokens :]
1056
+ if not self .torchair_graph_enabled :
1057
+ decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
1058
+ k_pe = k_pe [:num_actual_toks , ...]
1059
+ k_pe = k_pe .unsqueeze (1 )
1060
+ decode_k_pe = k_pe [:num_decode_tokens ]
1061
+ prefill_k_pe = k_pe [num_decode_tokens :]
1019
1062
else :
1020
1063
decode_hs_or_q_c = hidden_states_or_q_c
1021
1064
if has_decode :
@@ -1052,22 +1095,25 @@ def forward(
1052
1095
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1053
1096
if self .torchair_graph_enabled :
1054
1097
num_tokens = prefill_hs_or_q_c .shape [0 ]
1098
+ seq_len = self .rotary_emb .max_position_embeddings
1099
+ cos = self .rotary_emb .cos_cached [:seq_len ].to (
1100
+ dtype = prefill_q_pe .dtype )
1101
+ sin = self .rotary_emb .sin_cached [:seq_len ].to (
1102
+ dtype = prefill_q_pe .dtype )
1103
+ cos = cos [attn_metadata .prefill .input_positions ]
1104
+ sin = sin [attn_metadata .prefill .input_positions ]
1105
+ cos = cos [:, None , None , :]
1106
+ sin = sin [:, None , None , :]
1107
+
1108
+ prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1109
+ prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
1110
+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1111
+ attn_metadata .slot_mapping )
1112
+
1113
+ kv_c_normed = prefill_k_nope [:num_actual_toks , ...]
1114
+ prefill_k_c_normed = prefill_k_nope [num_decode_tokens :]
1055
1115
prefill_k_pe = prefill_k_pe .view (num_tokens , self .num_kv_heads ,
1056
1116
- 1 )
1057
- if self .rotary_emb .__class__ .__name__ == 'RotaryEmbedding' :
1058
- # NOTE: When scaling not specified
1059
- ori_q_pe_shape , ori_k_pe_shape = prefill_q_pe .shape , prefill_k_pe .shape
1060
- prefill_q_pe = prefill_q_pe .reshape (num_tokens , - 1 )
1061
- prefill_k_pe = prefill_k_pe .reshape (num_tokens , - 1 )
1062
- prefill_q_pe , prefill_k_pe = self .rotary_emb (
1063
- attn_metadata .prefill .input_positions , prefill_q_pe ,
1064
- prefill_k_pe )
1065
- prefill_q_pe = prefill_q_pe .view (ori_q_pe_shape )
1066
- prefill_k_pe = prefill_k_pe .view (ori_k_pe_shape )
1067
- else :
1068
- prefill_q_pe , prefill_k_pe = self .rotary_emb (
1069
- attn_metadata .prefill .input_positions , prefill_q_pe ,
1070
- prefill_k_pe )
1071
1117
prefill_q = torch .cat ([prefill_q_nope , prefill_q_pe ], dim = - 1 )
1072
1118
else :
1073
1119
prefill_q_pe [...], prefill_k_pe [...] = self .rotary_emb (
0 commit comments