21
21
from vllm_ascend .multistream .context import get_multistream_comm_context
22
22
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
23
23
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
24
- from vllm_ascend .utils import npu_stream_switch , npu_wait_tensor
24
+ from vllm_ascend .utils import npu_prefetch , npu_stream_switch , npu_wait_tensor
25
25
from vllm_ascend .worker .npu_input_batch import InputBatch
26
26
27
27
if TYPE_CHECKING :
@@ -579,13 +579,18 @@ def __init__(
579
579
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
580
580
"{32, 64, 128}." )
581
581
582
- def _v_up_proj_and_o_proj (self , x ):
582
+ def _v_up_proj_and_o_proj (self , x , enable_multistream_mla : bool = False ):
583
583
# Convert from (B, N, L) to (N, B, L)
584
584
x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
585
585
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
586
586
x = torch .bmm (x , self .W_UV )
587
587
# Convert from (N, B, V) to (B, N * V)
588
588
x = x .transpose (0 , 1 ).reshape (- 1 , self .num_heads * self .v_head_dim )
589
+ MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
590
+ npu_prefetch (self .o_proj .weight ,
591
+ x ,
592
+ max_size = MAX_O_PROJ_PREFETCH_SIZE ,
593
+ enabled = enable_multistream_mla )
589
594
return self .o_proj (x , is_prefill = False )[0 ]
590
595
591
596
# Return `ql_nope`, `q_pe`
@@ -864,7 +869,6 @@ def exec_kv(
864
869
sin : torch .Tensor ,
865
870
kv_cache : Tuple ,
866
871
slots : torch .Tensor ,
867
- enable_multistream_mla : bool = False ,
868
872
):
869
873
870
874
B = hidden_states .shape [0 ]
@@ -874,21 +878,18 @@ def exec_kv(
874
878
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
875
879
kv = kv .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
876
880
cache_mode = "PA_NZ" if self .enable_kv_nz else "PA"
877
- with npu_stream_switch ("mla_secondary" ,
878
- 0 ,
879
- enabled = enable_multistream_mla ):
880
- k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
881
- kv ,
882
- self .kv_a_layernorm .weight ,
883
- cos ,
884
- sin ,
885
- slots .to (torch .int64 ),
886
- kv_cache [1 ],
887
- kv_cache [0 ],
888
- epsilon = self .kv_a_layernorm .variance_epsilon ,
889
- cache_mode = cache_mode ,
890
- )
891
- return k_pe , k_nope
881
+ k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
882
+ kv ,
883
+ self .kv_a_layernorm .weight ,
884
+ cos ,
885
+ sin ,
886
+ slots .to (torch .int64 ),
887
+ kv_cache [1 ],
888
+ kv_cache [0 ],
889
+ epsilon = self .kv_a_layernorm .variance_epsilon ,
890
+ cache_mode = cache_mode ,
891
+ )
892
+ return k_pe , k_nope , kv
892
893
893
894
def exec_kv_prefill (
894
895
self ,
@@ -940,6 +941,7 @@ def _forward_decode(
940
941
k_pe : torch .Tensor ,
941
942
kv_c_and_k_pe_cache : torch .Tensor ,
942
943
attn_metadata : AscendMLAMetadata ,
944
+ enable_multistream_mla : bool = False ,
943
945
) -> torch .Tensor :
944
946
decode_meta = attn_metadata .decode
945
947
assert decode_meta is not None
@@ -1020,7 +1022,8 @@ def _forward_decode(
1020
1022
out = attn_output )
1021
1023
current_ms_metadata = get_multistream_comm_context ()
1022
1024
if current_ms_metadata is None :
1023
- return self ._v_up_proj_and_o_proj (attn_output )
1025
+ return self ._v_up_proj_and_o_proj (attn_output ,
1026
+ enable_multistream_mla )
1024
1027
else :
1025
1028
current_ms_metadata .before_comm_event .record ()
1026
1029
with torch .npu .stream (current_ms_metadata .comm_stream ):
@@ -1037,6 +1040,7 @@ def forward(
1037
1040
attn_metadata : M ,
1038
1041
output : Optional [torch .Tensor ] = None ,
1039
1042
enable_multistream_mla : bool = False ,
1043
+ ckq : Optional [torch .Tensor ] = None ,
1040
1044
) -> torch .Tensor :
1041
1045
assert output is not None , "Output tensor must be provided."
1042
1046
if attn_metadata is None :
@@ -1091,6 +1095,15 @@ def forward(
1091
1095
sin = sin [attn_metadata .decode .input_positions ]
1092
1096
cos = cos [:, None , None , :]
1093
1097
sin = sin [:, None , None , :]
1098
+ with npu_stream_switch ("mla_secondary" ,
1099
+ 0 ,
1100
+ enabled = enable_multistream_mla ):
1101
+ npu_wait_tensor (hidden_states_or_kv_c_normed ,
1102
+ ckq ,
1103
+ enabled = enable_multistream_mla )
1104
+ decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1105
+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1106
+ attn_metadata .slot_mapping )
1094
1107
# Without explicitly controlling the order, IndexByTensor operations
1095
1108
# would be placed after `matmul W_KV_T` hindering the overlapping of
1096
1109
# KvRmsNormRopeCache and SingleRope.
@@ -1100,12 +1113,13 @@ def forward(
1100
1113
npu_wait_tensor (decode_hs_or_q_c ,
1101
1114
sin ,
1102
1115
enabled = enable_multistream_mla )
1116
+ npu_wait_tensor (decode_hs_or_q_c ,
1117
+ decode_kv ,
1118
+ enabled = enable_multistream_mla )
1119
+
1103
1120
decode_ql_nope , decode_q_pe = \
1104
1121
self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
1105
1122
if self .running_in_graph :
1106
- decode_k_pe , decode_k_nope = self .exec_kv (
1107
- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1108
- attn_metadata .slot_mapping , enable_multistream_mla )
1109
1123
with npu_stream_switch ("mla_secondary" ,
1110
1124
0 ,
1111
1125
enabled = enable_multistream_mla ):
@@ -1194,7 +1208,8 @@ def forward(
1194
1208
if self .running_in_graph :
1195
1209
return self ._forward_decode (decode_ql_nope , decode_q_pe ,
1196
1210
decode_k_nope , decode_k_pe ,
1197
- kv_cache , attn_metadata )
1211
+ kv_cache , attn_metadata ,
1212
+ enable_multistream_mla )
1198
1213
else :
1199
1214
output_decode = self ._forward_decode (decode_ql_nope ,
1200
1215
decode_q_pe ,
0 commit comments