22
22
from vllm_ascend .multistream .context import get_multistream_comm_context
23
23
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
24
24
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
25
- from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , npu_stream_switch ,
26
- npu_wait_tensor )
25
+ from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , npu_prefetch ,
26
+ npu_stream_switch , npu_wait_tensor )
27
27
28
28
if TYPE_CHECKING :
29
29
from vllm .v1 .core .sched .output import SchedulerOutput
@@ -627,22 +627,25 @@ def __init__(
627
627
ascend_config = get_ascend_config ()
628
628
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
629
629
self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
630
- self .enable_multistream_mla = \
631
- ascend_config .torchair_graph_config .enable_multistream_mla
632
630
633
631
# Adapt torch air graph mode with spec decoding.
634
632
speculative_config = get_current_vllm_config ().speculative_config
635
633
if speculative_config is not None :
636
634
self .spec_token_num = speculative_config .num_speculative_tokens
637
635
assert self .spec_token_num > 0
638
636
639
- def _v_up_proj_and_o_proj (self , x ):
637
+ def _v_up_proj_and_o_proj (self , x , enable_multistream_mla : bool = False ):
640
638
# Convert from (B, N, L) to (N, B, L)
641
639
x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
642
640
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
643
641
x = torch .bmm (x , self .W_UV )
644
642
# Convert from (N, B, V) to (B, N * V)
645
643
x = x .transpose (0 , 1 ).reshape (- 1 , self .num_heads * self .v_head_dim )
644
+ MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
645
+ npu_prefetch (self .o_proj .weight ,
646
+ x ,
647
+ max_size = MAX_O_PROJ_PREFETCH_SIZE ,
648
+ enabled = enable_multistream_mla )
646
649
return self .o_proj (x )[0 ]
647
650
648
651
# Return `ql_nope`, `q_pe`
@@ -933,20 +936,17 @@ def exec_kv(
933
936
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
934
937
kv = kv .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
935
938
cache_mode = "PA_NZ" if self .enable_kv_nz else "PA"
936
- with npu_stream_switch ("mla_secondary" ,
937
- 0 ,
938
- enabled = self .enable_multistream_mla ):
939
- k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
940
- kv ,
941
- self .kv_a_layernorm .weight ,
942
- cos ,
943
- sin ,
944
- slots .to (torch .int64 ),
945
- kv_cache [1 ],
946
- kv_cache [0 ],
947
- epsilon = self .kv_a_layernorm .variance_epsilon ,
948
- cache_mode = cache_mode ,
949
- )
939
+ k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
940
+ kv ,
941
+ self .kv_a_layernorm .weight ,
942
+ cos ,
943
+ sin ,
944
+ slots .to (torch .int64 ),
945
+ kv_cache [1 ],
946
+ kv_cache [0 ],
947
+ epsilon = self .kv_a_layernorm .variance_epsilon ,
948
+ cache_mode = cache_mode ,
949
+ )
950
950
return k_pe , k_nope
951
951
952
952
def exec_kv_prefill (
@@ -999,6 +999,7 @@ def _forward_decode(
999
999
k_pe : torch .Tensor ,
1000
1000
kv_c_and_k_pe_cache : Tuple [torch .Tensor ],
1001
1001
attn_metadata : AscendMLAMetadata ,
1002
+ enable_multistream_mla : bool = False ,
1002
1003
) -> torch .Tensor :
1003
1004
decode_meta = attn_metadata .decode
1004
1005
assert decode_meta is not None
@@ -1093,7 +1094,8 @@ def _forward_decode(
1093
1094
out = attn_output )
1094
1095
current_ms_metadata = get_multistream_comm_context ()
1095
1096
if current_ms_metadata is None :
1096
- return self ._v_up_proj_and_o_proj (attn_output )
1097
+ return self ._v_up_proj_and_o_proj (attn_output ,
1098
+ enable_multistream_mla )
1097
1099
else :
1098
1100
current_ms_metadata .before_comm_event .record ()
1099
1101
with torch .npu .stream (current_ms_metadata .comm_stream ):
@@ -1109,6 +1111,7 @@ def forward(
1109
1111
kv_cache : Tuple [torch .Tensor ],
1110
1112
attn_metadata : M ,
1111
1113
output : Optional [torch .Tensor ] = None ,
1114
+ enable_multistream_mla = False ,
1112
1115
) -> torch .Tensor :
1113
1116
assert output is not None , "Output tensor must be provided."
1114
1117
if attn_metadata is None :
@@ -1158,27 +1161,21 @@ def forward(
1158
1161
if self .running_in_graph :
1159
1162
cos = attn_metadata .decode .cos
1160
1163
sin = attn_metadata .decode .sin
1161
- # Without explicitly controlling the order, IndexByTensor operations
1162
- # would be placed after `matmul W_KV_T` hindering the overlapping of
1163
- # KvRmsNormRopeCache and SingleRope.
1164
- npu_wait_tensor (decode_hs_or_q_c ,
1165
- cos ,
1166
- enabled = self .enable_multistream_mla )
1167
- npu_wait_tensor (decode_hs_or_q_c ,
1168
- sin ,
1169
- enabled = self .enable_multistream_mla )
1164
+ with npu_stream_switch ("mla_secondary" ,
1165
+ 0 ,
1166
+ enabled = enable_multistream_mla ):
1167
+ decode_k_pe , decode_k_nope = self .exec_kv (
1168
+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1169
+ attn_metadata .slot_mapping )
1170
1170
decode_ql_nope , decode_q_pe = \
1171
1171
self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
1172
1172
if self .running_in_graph :
1173
- decode_k_pe , decode_k_nope = self .exec_kv (
1174
- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1175
- attn_metadata .slot_mapping )
1176
1173
with npu_stream_switch ("mla_secondary" ,
1177
1174
0 ,
1178
- enabled = self . enable_multistream_mla ):
1175
+ enabled = enable_multistream_mla ):
1179
1176
npu_wait_tensor (decode_q_pe ,
1180
1177
decode_k_pe ,
1181
- enabled = self . enable_multistream_mla )
1178
+ enabled = enable_multistream_mla )
1182
1179
decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
1183
1180
else :
1184
1181
decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
@@ -1253,7 +1250,8 @@ def forward(
1253
1250
if self .running_in_graph :
1254
1251
return self ._forward_decode (decode_ql_nope , decode_q_pe ,
1255
1252
decode_k_nope , decode_k_pe ,
1256
- kv_cache , attn_metadata )
1253
+ kv_cache , attn_metadata ,
1254
+ enable_multistream_mla )
1257
1255
else :
1258
1256
output_decode = self ._forward_decode (decode_ql_nope ,
1259
1257
decode_q_pe ,
0 commit comments