9
9
MLAAttentionImpl )
10
10
from vllm .attention .backends .utils import PAD_SLOT_ID
11
11
from vllm .config import get_current_vllm_config
12
- from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
13
- LinearBase , RowParallelLinear ,
12
+ from vllm .model_executor .layers .linear import (LinearBase ,
14
13
UnquantizedLinearMethod )
15
- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
16
14
17
15
from vllm_ascend .attention .attention_v1 import AscendAttentionState
18
16
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
@@ -117,6 +115,8 @@ class AscendMLAMetadata:
117
115
# For logging.
118
116
num_input_tokens : int = 0 # Number of tokens including padding.
119
117
118
+ with_prefill_across_dp : bool = False
119
+
120
120
# The dimension of the attention heads
121
121
head_dim : Optional [int ] = None
122
122
attn_mask : torch .Tensor = None
@@ -260,6 +260,10 @@ def build_dummy(self, num_reqs: int,
260
260
PAD_SLOT_ID ,
261
261
dtype = torch .int32 ,
262
262
device = device )
263
+ query_start_loc = torch .full ((num_reqs , ),
264
+ - 1 ,
265
+ dtype = torch .int32 ,
266
+ device = device )
263
267
decode_metadata = AscendMLADecodeMetadata (
264
268
input_positions = input_positions ,
265
269
block_table = block_table ,
@@ -278,15 +282,21 @@ def build_dummy(self, num_reqs: int,
278
282
attn_state = AscendAttentionState .DecodeOnly ,
279
283
prefill = None ,
280
284
decode = decode_metadata ,
285
+ query_start_loc = query_start_loc ,
286
+ seq_lens = seq_lens ,
287
+ block_tables = block_table ,
281
288
)
282
289
283
- def build (self ,
284
- num_reqs : int ,
285
- num_actual_tokens : int ,
286
- max_query_len : int ,
287
- common_attn_metadata : CommonAttentionMetadata ,
288
- common_prefix_len : Optional [int ] = None ,
289
- graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
290
+ def build (
291
+ self ,
292
+ num_reqs : int ,
293
+ num_actual_tokens : int ,
294
+ max_query_len : int ,
295
+ common_attn_metadata : CommonAttentionMetadata ,
296
+ common_prefix_len : Optional [int ] = None ,
297
+ graph_pad_size : int = - 1 ,
298
+ with_prefill_across_dp : bool = False ,
299
+ ) -> AscendMLAMetadata :
290
300
assert self ._num_decodes + self ._num_prefills == num_reqs
291
301
292
302
# Note(simon): be careful about the CPU <> GPU memory movement in this
@@ -388,6 +398,7 @@ def build(self,
388
398
query_start_loc = query_start_loc ,
389
399
block_tables = block_table ,
390
400
seq_lens = seq_lens ,
401
+ with_prefill_across_dp = with_prefill_across_dp ,
391
402
)
392
403
393
404
@@ -409,20 +420,7 @@ def __init__(
409
420
blocksparse_params : Optional [dict [str , Any ]],
410
421
logits_soft_cap : Optional [float ],
411
422
attn_type : str ,
412
- # MLA Specific Arguments
413
- q_lora_rank : Optional [int ],
414
- kv_lora_rank : int ,
415
- qk_nope_head_dim : int ,
416
- qk_rope_head_dim : int ,
417
- qk_head_dim : int ,
418
- v_head_dim : int ,
419
- rotary_emb : RotaryEmbedding ,
420
- # q_proj should be q_b_proj if q_lora_rank is not None, but from an
421
- # attention backend perspective we rely on the layer to pass in the
422
- # correct matrix
423
- q_proj : ColumnParallelLinear ,
424
- kv_b_proj : ColumnParallelLinear ,
425
- o_proj : RowParallelLinear ,
423
+ kv_sharing_target_layer_name : Optional [str ] = None ,
426
424
** kwargs ,
427
425
) -> None :
428
426
self .num_heads = num_heads
@@ -431,25 +429,20 @@ def __init__(
431
429
self .num_kv_heads = num_kv_heads
432
430
self .kv_cache_dtype = kv_cache_dtype
433
431
434
- self .q_lora_rank = q_lora_rank
435
- self .kv_lora_rank = kv_lora_rank
436
- self .qk_nope_head_dim = qk_nope_head_dim
437
- self .qk_rope_head_dim = qk_rope_head_dim
438
- self .qk_head_dim = qk_head_dim
439
- self .v_head_dim = v_head_dim
440
-
441
- # Hack for V1 for now to avoid torch library overhead (since we are
442
- # already inside an attention custom op), pull out the forward
443
- # method from the rotary embedding and call it directly
444
- # TODO(lucas): we should probably find a cleaner way to do this
445
- self .rotary_emb = rotary_emb
446
-
447
- self .q_proj = q_proj
448
- self .kv_b_proj = kv_b_proj
449
- self .o_proj = o_proj
450
-
432
+ # MLA Args
433
+ self .q_lora_rank = kwargs ['q_lora_rank' ]
434
+ self .kv_lora_rank = kwargs ['kv_lora_rank' ]
435
+ self .qk_nope_head_dim = kwargs ['qk_nope_head_dim' ]
436
+ self .qk_rope_head_dim = kwargs ['qk_rope_head_dim' ]
437
+ self .qk_head_dim = kwargs ['qk_head_dim' ]
438
+ self .v_head_dim = kwargs ['v_head_dim' ]
439
+ self .rotary_emb = kwargs ['rotary_emb' ]
440
+ self .q_proj = kwargs ['q_proj' ]
441
+ self .kv_b_proj = kwargs ['kv_b_proj' ]
442
+ self .o_proj = kwargs ['o_proj' ]
451
443
self .kv_a_proj_with_mqa = kwargs .get ('kv_a_proj_with_mqa' , None )
452
444
self .kv_a_layernorm = kwargs .get ('kv_a_layernorm' , None )
445
+
453
446
# Handle the differences between the flash_attn_varlen from flash_attn
454
447
# and the one from vllm_flash_attn. The former is used on RoCM and the
455
448
# latter has an additional parameter to control FA2 vs FA3
@@ -621,7 +614,7 @@ def exec_kv(
621
614
kv = self .kv_a_proj_with_mqa (hidden_states )[0 ]
622
615
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
623
616
kv = kv .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
624
- k_pe , k_nope , _ , _ = torch . ops . npu_inference .npu_kv_rmsnorm_rope_cache (
617
+ k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
625
618
kv ,
626
619
self .kv_a_layernorm .weight ,
627
620
cos ,
@@ -643,7 +636,7 @@ def rope_single(
643
636
B , N , D = x .shape
644
637
S = 1
645
638
x = x .view (B , N , S , D )
646
- x = torch . ops . npu_inference .npu_interleave_rope (x , cos , sin )
639
+ x = torch_npu .npu_interleave_rope (x , cos , sin )
647
640
return x .view (B , N , D )
648
641
649
642
def _forward_decode (
@@ -766,6 +759,7 @@ def forward(
766
759
sin = sin [attn_metadata .decode .input_positions ]
767
760
cos = cos [:, None , None , :]
768
761
sin = sin [:, None , None , :]
762
+
769
763
decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
770
764
decode_k_pe , decode_k_nope = self .exec_kv (
771
765
hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
0 commit comments