@@ -659,12 +659,13 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
659
659
def _compute_prefill_context (
660
660
self ,
661
661
query : torch .Tensor ,
662
- kv_c_and_k_pe_cache : torch .Tensor ,
662
+ kv_c_and_k_pe_cache : Tuple [ torch .Tensor ] ,
663
663
rope_dim : int ,
664
664
attn_metadata : AscendMLAMetadata ,
665
665
prefix_output : torch .Tensor ,
666
666
prefix_lse : torch .Tensor ,
667
667
):
668
+ assert len (kv_c_and_k_pe_cache ) > 1
668
669
prefill_metadata = attn_metadata .prefill
669
670
if prefill_metadata is None or prefill_metadata .chunked_context is None :
670
671
return prefix_output , prefix_lse
@@ -674,21 +675,23 @@ def _compute_prefill_context(
674
675
q_nope = query [..., :self .qk_nope_head_dim ]
675
676
676
677
seq_len1 = torch .tensor (prefill_metadata .query_lens , dtype = torch .int32 )
677
- latent_kv_dim = kv_c_and_k_pe_cache .size (3 ) - rope_dim
678
- cache_kv_c = kv_c_and_k_pe_cache [:, :, :, :latent_kv_dim ]
679
- cache_k_pe = kv_c_and_k_pe_cache [:, :, :, latent_kv_dim :]
678
+ cache_kv_c = kv_c_and_k_pe_cache [0 ]
679
+ cache_k_pe = kv_c_and_k_pe_cache [1 ]
680
+ num_heads = cache_k_pe .size (2 )
681
+ latent_kv_dim = kv_c_and_k_pe_cache [0 ].size (- 1 )
682
+
680
683
for i in range (iters ):
681
684
toks = prefill_metadata .chunked_context .seq_tot [i ]
682
685
683
686
seq_len2 = prefill_metadata .chunked_context .chunk_seq_lens [i ]
684
687
seq_len = torch .stack ([seq_len1 , seq_len2 ])
685
688
kv_c_normed = torch .empty (toks ,
686
- kv_c_and_k_pe_cache . size ( 2 ) ,
689
+ num_heads ,
687
690
latent_kv_dim ,
688
691
dtype = query .dtype ,
689
692
device = query .device )
690
693
k_pe = torch .empty (toks ,
691
- kv_c_and_k_pe_cache . size ( 2 ) ,
694
+ num_heads ,
692
695
rope_dim ,
693
696
dtype = query .dtype ,
694
697
device = query .device )
@@ -738,10 +741,11 @@ def _forward_prefill(
738
741
query : torch .Tensor ,
739
742
kv_c_normed : torch .Tensor ,
740
743
k_pe : torch .Tensor ,
741
- kv_c_and_k_pe_cache : torch .Tensor ,
744
+ kv_c_and_k_pe_cache : Tuple [ torch .Tensor ] ,
742
745
attn_metadata : AscendMLAMetadata ,
743
746
) -> torch .Tensor :
744
747
assert attn_metadata .prefill is not None
748
+ assert len (kv_c_and_k_pe_cache ) > 1
745
749
746
750
num_tokens = query .size (0 )
747
751
attn_output = torch .empty (num_tokens ,
@@ -769,7 +773,7 @@ def _forward_prefill(
769
773
vanilla_chunked_prefill_mla (
770
774
output = attn_output_torch ,
771
775
query = query ,
772
- kv_cache = kv_c_and_k_pe_cache ,
776
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache ,
773
777
block_tables = attn_metadata .prefill .block_table ,
774
778
query_lens = attn_metadata .prefill .query_lens ,
775
779
context_lens = attn_metadata .prefill .context_lens ,
@@ -938,18 +942,13 @@ def _forward_decode(
938
942
q_pe : torch .Tensor ,
939
943
k_nope : torch .Tensor ,
940
944
k_pe : torch .Tensor ,
941
- kv_c_and_k_pe_cache : torch .Tensor ,
945
+ kv_c_and_k_pe_cache : Tuple [ torch .Tensor ] ,
942
946
attn_metadata : AscendMLAMetadata ,
943
947
) -> torch .Tensor :
944
948
decode_meta = attn_metadata .decode
945
949
assert decode_meta is not None
946
950
947
- q = torch .cat ([q_nope , q_pe ], dim = - 1 )
948
- num_tokens = q .size (0 )
949
- attn_output = torch .empty (
950
- [num_tokens , self .num_heads , self .kv_lora_rank ],
951
- dtype = q .dtype ,
952
- device = q .device )
951
+ num_tokens = q_nope .size (0 )
953
952
if self .running_in_graph :
954
953
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
955
954
if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
@@ -1008,13 +1007,21 @@ def _forward_decode(
1008
1007
actual_seq_lengths_kv = decode_meta .seq_lens_list ,
1009
1008
)
1010
1009
else :
1010
+ q = torch .cat ([q_nope , q_pe ], dim = - 1 )
1011
+ attn_output = torch .empty (
1012
+ [num_tokens , self .num_heads , self .kv_lora_rank ],
1013
+ dtype = q .dtype ,
1014
+ device = q .device )
1015
+ k_cache = torch .cat (
1016
+ [kv_c_and_k_pe_cache [0 ], kv_c_and_k_pe_cache [1 ]], dim = - 1 )
1011
1017
torch_npu ._npu_paged_attention_mla (
1012
1018
query = q ,
1013
- key_cache = kv_c_and_k_pe_cache ,
1019
+ key_cache = k_cache ,
1014
1020
num_kv_heads = self .num_kv_heads ,
1015
1021
num_heads = self .num_heads ,
1016
1022
scale_value = self .scale ,
1017
- block_table = attn_metadata .decode .block_table , # type:ignore
1023
+ block_table = attn_metadata .decode .
1024
+ block_table , # type:ignore
1018
1025
context_lens = attn_metadata .decode .seq_lens , # type:ignore
1019
1026
mla_vheadsize = self .kv_lora_rank ,
1020
1027
out = attn_output )
@@ -1033,7 +1040,7 @@ def forward(
1033
1040
hidden_states_or_q_c : torch .Tensor , # query in unified attn
1034
1041
hidden_states_or_kv_c_normed : torch .Tensor , # key in unified attn
1035
1042
k_pe : torch .Tensor , # value in unified attn
1036
- kv_cache : torch .Tensor ,
1043
+ kv_cache : Tuple [ torch .Tensor ] ,
1037
1044
attn_metadata : M ,
1038
1045
output : Optional [torch .Tensor ] = None ,
1039
1046
enable_multistream_mla : bool = False ,
@@ -1153,8 +1160,11 @@ def forward(
1153
1160
prefill_q_pe .contiguous (),
1154
1161
prefill_k_pe ,
1155
1162
max_seq_len = attn_metadata .prefill .max_seq_lens )
1163
+ assert len (
1164
+ kv_cache
1165
+ ) > 1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
1156
1166
if self .torchair_graph_enabled :
1157
- if len ( kv_cache ) > 0 and kv_cache [0 ].numel (
1167
+ if kv_cache [0 ].numel (
1158
1168
) > 0 and attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1159
1169
slots = attn_metadata .slot_mapping
1160
1170
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
@@ -1164,16 +1174,15 @@ def forward(
1164
1174
key_cache = kv_cache [0 ],
1165
1175
value_cache = kv_cache [1 ],
1166
1176
slot_indices = slots )
1167
- elif kv_cache .numel () > 0 :
1168
- key = torch .cat ([
1169
- kv_c_normed .view ([num_actual_toks , self .num_kv_heads , - 1 ]),
1170
- k_pe
1171
- ],
1172
- dim = 2 )
1173
- torch_npu ._npu_reshape_and_cache_siso (
1174
- key = key ,
1175
- key_cache = kv_cache ,
1176
- slot_indices = attn_metadata .slot_mapping .flatten ())
1177
+ else :
1178
+ kv_c_normed = kv_c_normed .view (
1179
+ [num_actual_toks , self .num_kv_heads , - 1 ])
1180
+ torch_npu ._npu_reshape_and_cache (
1181
+ key = kv_c_normed ,
1182
+ value = k_pe ,
1183
+ key_cache = kv_cache [0 ],
1184
+ value_cache = kv_cache [1 ],
1185
+ slot_indices = attn_metadata .slot_mapping )
1177
1186
if has_prefill :
1178
1187
# FIX: aicore move should be also placed on the comm stream in dbo,
1179
1188
# otherwise it may affect the accuracy
0 commit comments