@@ -664,12 +664,13 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
664
664
def _compute_prefill_context (
665
665
self ,
666
666
query : torch .Tensor ,
667
- kv_c_and_k_pe_cache : torch .Tensor ,
667
+ kv_c_and_k_pe_cache : Tuple [ torch .Tensor ] ,
668
668
rope_dim : int ,
669
669
attn_metadata : AscendMLAMetadata ,
670
670
prefix_output : torch .Tensor ,
671
671
prefix_lse : torch .Tensor ,
672
672
):
673
+ assert len (kv_c_and_k_pe_cache ) > 1
673
674
prefill_metadata = attn_metadata .prefill
674
675
if prefill_metadata is None or prefill_metadata .chunked_context is None :
675
676
return prefix_output , prefix_lse
@@ -679,21 +680,23 @@ def _compute_prefill_context(
679
680
q_nope = query [..., :self .qk_nope_head_dim ]
680
681
681
682
seq_len1 = torch .tensor (prefill_metadata .query_lens , dtype = torch .int32 )
682
- latent_kv_dim = kv_c_and_k_pe_cache .size (3 ) - rope_dim
683
- cache_kv_c = kv_c_and_k_pe_cache [:, :, :, :latent_kv_dim ]
684
- cache_k_pe = kv_c_and_k_pe_cache [:, :, :, latent_kv_dim :]
683
+ cache_kv_c = kv_c_and_k_pe_cache [0 ]
684
+ cache_k_pe = kv_c_and_k_pe_cache [1 ]
685
+ num_heads = cache_k_pe .size (2 )
686
+ latent_kv_dim = kv_c_and_k_pe_cache [0 ].size (- 1 )
687
+
685
688
for i in range (iters ):
686
689
toks = prefill_metadata .chunked_context .seq_tot [i ]
687
690
688
691
seq_len2 = prefill_metadata .chunked_context .chunk_seq_lens [i ]
689
692
seq_len = torch .stack ([seq_len1 , seq_len2 ])
690
693
kv_c_normed = torch .empty (toks ,
691
- kv_c_and_k_pe_cache . size ( 2 ) ,
694
+ num_heads ,
692
695
latent_kv_dim ,
693
696
dtype = query .dtype ,
694
697
device = query .device )
695
698
k_pe = torch .empty (toks ,
696
- kv_c_and_k_pe_cache . size ( 2 ) ,
699
+ num_heads ,
697
700
rope_dim ,
698
701
dtype = query .dtype ,
699
702
device = query .device )
@@ -743,10 +746,11 @@ def _forward_prefill(
743
746
query : torch .Tensor ,
744
747
kv_c_normed : torch .Tensor ,
745
748
k_pe : torch .Tensor ,
746
- kv_c_and_k_pe_cache : torch .Tensor ,
749
+ kv_c_and_k_pe_cache : Tuple [ torch .Tensor ] ,
747
750
attn_metadata : AscendMLAMetadata ,
748
751
) -> torch .Tensor :
749
752
assert attn_metadata .prefill is not None
753
+ assert len (kv_c_and_k_pe_cache ) > 1
750
754
751
755
num_tokens = query .size (0 )
752
756
attn_output = torch .empty (num_tokens ,
@@ -774,7 +778,7 @@ def _forward_prefill(
774
778
vanilla_chunked_prefill_mla (
775
779
output = attn_output_torch ,
776
780
query = query ,
777
- kv_cache = kv_c_and_k_pe_cache ,
781
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache ,
778
782
block_tables = attn_metadata .prefill .block_table ,
779
783
query_lens = attn_metadata .prefill .query_lens ,
780
784
context_lens = attn_metadata .prefill .context_lens ,
@@ -939,19 +943,14 @@ def _forward_decode(
939
943
q_pe : torch .Tensor ,
940
944
k_nope : torch .Tensor ,
941
945
k_pe : torch .Tensor ,
942
- kv_c_and_k_pe_cache : torch .Tensor ,
946
+ kv_c_and_k_pe_cache : Tuple [ torch .Tensor ] ,
943
947
attn_metadata : AscendMLAMetadata ,
944
948
enable_multistream_mla : bool = False ,
945
949
) -> torch .Tensor :
946
950
decode_meta = attn_metadata .decode
947
951
assert decode_meta is not None
948
952
949
- q = torch .cat ([q_nope , q_pe ], dim = - 1 )
950
- num_tokens = q .size (0 )
951
- attn_output = torch .empty (
952
- [num_tokens , self .num_heads , self .kv_lora_rank ],
953
- dtype = q .dtype ,
954
- device = q .device )
953
+ num_tokens = q_nope .size (0 )
955
954
if self .running_in_graph :
956
955
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
957
956
if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
@@ -1010,13 +1009,21 @@ def _forward_decode(
1010
1009
actual_seq_lengths_kv = decode_meta .seq_lens_list ,
1011
1010
)
1012
1011
else :
1012
+ q = torch .cat ([q_nope , q_pe ], dim = - 1 )
1013
+ attn_output = torch .empty (
1014
+ [num_tokens , self .num_heads , self .kv_lora_rank ],
1015
+ dtype = q .dtype ,
1016
+ device = q .device )
1017
+ k_cache = torch .cat (
1018
+ [kv_c_and_k_pe_cache [0 ], kv_c_and_k_pe_cache [1 ]], dim = - 1 )
1013
1019
torch_npu ._npu_paged_attention_mla (
1014
1020
query = q ,
1015
- key_cache = kv_c_and_k_pe_cache ,
1021
+ key_cache = k_cache ,
1016
1022
num_kv_heads = self .num_kv_heads ,
1017
1023
num_heads = self .num_heads ,
1018
1024
scale_value = self .scale ,
1019
- block_table = attn_metadata .decode .block_table , # type:ignore
1025
+ block_table = attn_metadata .decode .
1026
+ block_table , # type:ignore
1020
1027
context_lens = attn_metadata .decode .seq_lens , # type:ignore
1021
1028
mla_vheadsize = self .kv_lora_rank ,
1022
1029
out = attn_output )
@@ -1036,7 +1043,7 @@ def forward(
1036
1043
hidden_states_or_q_c : torch .Tensor , # query in unified attn
1037
1044
hidden_states_or_kv_c_normed : torch .Tensor , # key in unified attn
1038
1045
k_pe : torch .Tensor , # value in unified attn
1039
- kv_cache : torch .Tensor ,
1046
+ kv_cache : Tuple [ torch .Tensor ] ,
1040
1047
attn_metadata : M ,
1041
1048
output : Optional [torch .Tensor ] = None ,
1042
1049
enable_multistream_mla : bool = False ,
@@ -1167,8 +1174,11 @@ def forward(
1167
1174
prefill_q_pe .contiguous (),
1168
1175
prefill_k_pe ,
1169
1176
max_seq_len = attn_metadata .prefill .max_seq_lens )
1177
+ assert len (
1178
+ kv_cache
1179
+ ) > 1 , "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
1170
1180
if self .torchair_graph_enabled :
1171
- if len ( kv_cache ) > 0 and kv_cache [0 ].numel (
1181
+ if kv_cache [0 ].numel (
1172
1182
) > 0 and attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1173
1183
slots = attn_metadata .slot_mapping
1174
1184
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
@@ -1178,16 +1188,15 @@ def forward(
1178
1188
key_cache = kv_cache [0 ],
1179
1189
value_cache = kv_cache [1 ],
1180
1190
slot_indices = slots )
1181
- elif kv_cache .numel () > 0 :
1182
- key = torch .cat ([
1183
- kv_c_normed .view ([num_actual_toks , self .num_kv_heads , - 1 ]),
1184
- k_pe
1185
- ],
1186
- dim = 2 )
1187
- torch_npu ._npu_reshape_and_cache_siso (
1188
- key = key ,
1189
- key_cache = kv_cache ,
1190
- slot_indices = attn_metadata .slot_mapping .flatten ())
1191
+ else :
1192
+ kv_c_normed = kv_c_normed .view (
1193
+ [num_actual_toks , self .num_kv_heads , - 1 ])
1194
+ torch_npu ._npu_reshape_and_cache (
1195
+ key = kv_c_normed ,
1196
+ value = k_pe ,
1197
+ key_cache = kv_cache [0 ],
1198
+ value_cache = kv_cache [1 ],
1199
+ slot_indices = attn_metadata .slot_mapping )
1191
1200
if has_prefill :
1192
1201
# FIX: aicore move should be also placed on the comm stream in dbo,
1193
1202
# otherwise it may affect the accuracy
0 commit comments