@@ -480,6 +480,7 @@ def __init__(
480
480
481
481
ascend_config = get_ascend_config ()
482
482
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
483
+ self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
483
484
# Adapt torch air graph mode with spec decoding.
484
485
speculative_config = get_current_vllm_config ().speculative_config
485
486
if speculative_config is not None :
@@ -662,6 +663,7 @@ def exec_kv(
662
663
kv = self .kv_a_proj_with_mqa (hidden_states )[0 ]
663
664
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
664
665
kv = kv .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
666
+ cache_mode = "PA_NZ" if self .enable_kv_nz else "PA"
665
667
k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
666
668
kv ,
667
669
self .kv_a_layernorm .weight ,
@@ -671,7 +673,37 @@ def exec_kv(
671
673
kv_cache [1 ],
672
674
kv_cache [0 ],
673
675
epsilon = self .kv_a_layernorm .variance_epsilon ,
674
- cache_mode = "PA" ,
676
+ cache_mode = cache_mode ,
677
+ )
678
+ return k_pe , k_nope
679
+
680
+ def exec_kv_prefill (
681
+ self ,
682
+ hidden_states : torch .Tensor ,
683
+ cos : torch .Tensor ,
684
+ sin : torch .Tensor ,
685
+ kv_cache : Tuple ,
686
+ slots : torch .Tensor ,
687
+ ):
688
+
689
+ B = hidden_states .shape [0 ]
690
+ N = self .num_kv_heads
691
+ S = 1
692
+ kv = self .kv_a_proj_with_mqa (hidden_states )[0 ]
693
+ # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
694
+ kv = kv .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
695
+ cache_mode = "PA_BLK_NZ" if self .enable_kv_nz else "PA"
696
+ _ , _ , k_pe , k_nope = torch_npu .npu_kv_rmsnorm_rope_cache (
697
+ kv ,
698
+ self .kv_a_layernorm .weight ,
699
+ cos ,
700
+ sin ,
701
+ slots .to (torch .int64 ),
702
+ kv_cache [1 ],
703
+ kv_cache [0 ],
704
+ epsilon = self .kv_a_layernorm .variance_epsilon ,
705
+ cache_mode = cache_mode ,
706
+ is_output_kv = True ,
675
707
)
676
708
return k_pe , k_nope
677
709
@@ -709,42 +741,50 @@ def _forward_decode(
709
741
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
710
742
if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
711
743
assert num_tokens % self .spec_token_num == 0
712
- q_nope = (q_nope .view (
713
- num_tokens // (self .spec_token_num + 1 ),
714
- self .spec_token_num + 1 ,
715
- self .num_heads ,
716
- - 1 ,
717
- ).transpose (1 , 2 ).contiguous ())
718
- q_pe = (q_pe .view (
719
- num_tokens // (self .spec_token_num + 1 ),
720
- self .spec_token_num + 1 ,
721
- self .num_heads ,
722
- - 1 ,
723
- ).transpose (1 , 2 ).contiguous ())
744
+ q_nope = q_nope .view (num_tokens // (self .spec_token_num + 1 ),
745
+ self .spec_token_num + 1 , self .num_heads ,
746
+ - 1 )
747
+ q_pe = q_pe .view (num_tokens // (self .spec_token_num + 1 ),
748
+ self .spec_token_num + 1 , self .num_heads , - 1 )
749
+ if not self .enable_kv_nz :
750
+ q_nope = q_nope .transpose (1 , 2 ).contiguous ()
751
+ q_pe = q_pe .transpose (1 , 2 ).contiguous ()
724
752
sparse_mode = 3
725
753
spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
726
754
else :
727
- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
728
- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
755
+ if self .enable_kv_nz :
756
+ q_nope = q_nope .view (num_tokens , 1 , self .num_heads , - 1 )
757
+ q_pe = q_pe .view (num_tokens , 1 , self .num_heads , - 1 )
758
+ else :
759
+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
760
+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
729
761
sparse_mode = 0
730
762
spec_attn_mask = None
731
763
# shape of knope/k_pe for npu graph mode should be:
732
764
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
733
765
block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
734
- k_nope = k_nope .view (- 1 , self .num_kv_heads , block_size ,
735
- self .kv_lora_rank )
736
- k_pe = k_pe .view (- 1 , self .num_kv_heads , block_size ,
737
- self .qk_rope_head_dim )
766
+ if self .enable_kv_nz :
767
+ k_nope = k_nope .view (- 1 , self .num_kv_heads ,
768
+ self .kv_lora_rank // 16 , block_size , 16 )
769
+ k_pe = k_pe .view (- 1 , self .num_kv_heads ,
770
+ self .qk_rope_head_dim // 16 , block_size , 16 )
771
+ input_layout = "BSND"
772
+ else :
773
+ k_nope = k_nope .view (- 1 , self .num_kv_heads , block_size ,
774
+ self .kv_lora_rank )
775
+ k_pe = k_pe .view (- 1 , self .num_kv_heads , block_size ,
776
+ self .qk_rope_head_dim )
777
+ input_layout = "BNSD"
738
778
739
- attn_output , _ = torch . ops . npu .npu_fused_infer_attention_score (
779
+ attn_output , _ = torch_npu .npu_fused_infer_attention_score (
740
780
q_nope ,
741
781
k_nope ,
742
782
k_nope ,
743
783
query_rope = q_pe ,
744
784
key_rope = k_pe ,
745
785
num_heads = self .num_heads ,
746
786
num_key_value_heads = self .num_kv_heads ,
747
- input_layout = "BNSD" ,
787
+ input_layout = input_layout ,
748
788
atten_mask = spec_attn_mask ,
749
789
sparse_mode = sparse_mode ,
750
790
scale = self .scale ,
@@ -793,10 +833,11 @@ def forward(
793
833
]
794
834
num_actual_toks = attn_metadata .num_actual_tokens
795
835
if k_pe is None and not self .running_in_graph :
796
- kv_c , k_pe = self .kv_a_proj_with_mqa (
797
- hidden_states_or_kv_c_normed )[0 ].split (
798
- [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
799
- kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
836
+ if not self .torchair_graph_enabled :
837
+ kv_c , k_pe = self .kv_a_proj_with_mqa (
838
+ hidden_states_or_kv_c_normed )[0 ].split (
839
+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
840
+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
800
841
else :
801
842
kv_c_normed = hidden_states_or_kv_c_normed
802
843
assert attn_metadata .num_decodes is not None and \
@@ -809,16 +850,18 @@ def forward(
809
850
# Inputs and outputs may be padded for CUDA graphs
810
851
output_padded = output
811
852
output = output [:num_actual_toks , ...]
812
- kv_c_normed = kv_c_normed [:num_actual_toks , ...]
813
- prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
853
+ if not self .torchair_graph_enabled :
854
+ kv_c_normed = kv_c_normed [:num_actual_toks , ...]
855
+ prefill_k_c_normed = kv_c_normed [num_decode_tokens :]
814
856
if not self .running_in_graph :
815
857
hidden_states_or_q_c = hidden_states_or_q_c [:num_actual_toks , ...]
816
- decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
817
858
prefill_hs_or_q_c = hidden_states_or_q_c [num_decode_tokens :]
818
- k_pe = k_pe [:num_actual_toks , ...]
819
- k_pe = k_pe .unsqueeze (1 )
820
- decode_k_pe = k_pe [:num_decode_tokens ]
821
- prefill_k_pe = k_pe [num_decode_tokens :]
859
+ if not self .torchair_graph_enabled :
860
+ decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
861
+ k_pe = k_pe [:num_actual_toks , ...]
862
+ k_pe = k_pe .unsqueeze (1 )
863
+ decode_k_pe = k_pe [:num_decode_tokens ]
864
+ prefill_k_pe = k_pe [num_decode_tokens :]
822
865
else :
823
866
decode_hs_or_q_c = hidden_states_or_q_c
824
867
if has_decode :
@@ -855,22 +898,25 @@ def forward(
855
898
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
856
899
if self .torchair_graph_enabled :
857
900
num_tokens = prefill_hs_or_q_c .shape [0 ]
901
+ seq_len = self .rotary_emb .max_position_embeddings
902
+ cos = self .rotary_emb .cos_cached [:seq_len ].to (
903
+ dtype = prefill_q_pe .dtype )
904
+ sin = self .rotary_emb .sin_cached [:seq_len ].to (
905
+ dtype = prefill_q_pe .dtype )
906
+ cos = cos [attn_metadata .prefill .input_positions ]
907
+ sin = sin [attn_metadata .prefill .input_positions ]
908
+ cos = cos [:, None , None , :]
909
+ sin = sin [:, None , None , :]
910
+
911
+ prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
912
+ prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
913
+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
914
+ attn_metadata .slot_mapping )
915
+
916
+ kv_c_normed = prefill_k_nope [:num_actual_toks , ...]
917
+ prefill_k_c_normed = prefill_k_nope [num_decode_tokens :]
858
918
prefill_k_pe = prefill_k_pe .view (num_tokens , self .num_kv_heads ,
859
919
- 1 )
860
- if self .rotary_emb .__class__ .__name__ == 'RotaryEmbedding' :
861
- # NOTE: When scaling not specified
862
- ori_q_pe_shape , ori_k_pe_shape = prefill_q_pe .shape , prefill_k_pe .shape
863
- prefill_q_pe = prefill_q_pe .reshape (num_tokens , - 1 )
864
- prefill_k_pe = prefill_k_pe .reshape (num_tokens , - 1 )
865
- prefill_q_pe , prefill_k_pe = self .rotary_emb (
866
- attn_metadata .prefill .input_positions , prefill_q_pe ,
867
- prefill_k_pe )
868
- prefill_q_pe = prefill_q_pe .view (ori_q_pe_shape )
869
- prefill_k_pe = prefill_k_pe .view (ori_k_pe_shape )
870
- else :
871
- prefill_q_pe , prefill_k_pe = self .rotary_emb (
872
- attn_metadata .prefill .input_positions , prefill_q_pe ,
873
- prefill_k_pe )
874
920
prefill_q = torch .cat ([prefill_q_nope , prefill_q_pe ], dim = - 1 )
875
921
else :
876
922
prefill_q_pe [...], prefill_k_pe [...] = self .rotary_emb (
0 commit comments