13
13
LinearBase , RowParallelLinear ,
14
14
UnquantizedLinearMethod )
15
15
from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
16
+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
16
17
17
18
from vllm_ascend .attention .attention_v1 import AscendAttentionState
18
19
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
@@ -57,6 +58,7 @@ class AscendMLAPrefillMetadata:
57
58
seq_lens : list [int ]
58
59
context_lens : torch .Tensor
59
60
input_positions : torch .Tensor
61
+ query_start_loc : torch .Tensor
60
62
block_table : torch .Tensor
61
63
max_query_len : int
62
64
max_seq_lens : int
@@ -90,6 +92,9 @@ class AscendMLAMetadata:
90
92
91
93
num_actual_tokens : int # Number of tokens excluding padding.
92
94
slot_mapping : torch .Tensor
95
+ query_start_loc : torch .Tensor
96
+ seq_lens : torch .Tensor
97
+ block_tables : torch .Tensor
93
98
94
99
# New for MLA (compared to FlashAttention)
95
100
# For handling prefill decode split
@@ -231,6 +236,7 @@ def build(self,
231
236
num_actual_tokens : int ,
232
237
max_query_len : int ,
233
238
common_prefix_len : Optional [int ] = None ,
239
+ common_attn_metadata : CommonAttentionMetadata = None ,
234
240
graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
235
241
assert self ._num_decodes + self ._num_prefills == num_reqs
236
242
@@ -245,6 +251,7 @@ def build(self,
245
251
input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
246
252
device , non_blocking = True ).long ()
247
253
254
+ query_start_loc = common_attn_metadata .query_start_loc
248
255
seq_lens_cpu = self .runner .seq_lens_cpu [:num_reqs ]
249
256
query_lens = seq_lens_cpu - self .runner .input_batch .num_computed_tokens_cpu_tensor [:
250
257
num_reqs ]
@@ -258,6 +265,8 @@ def build(self,
258
265
tokens_start = self ._num_decode_tokens
259
266
max_query_len = query_lens [tokens_start :].max ().item ()
260
267
max_seq_lens = seq_lens [tokens_start :].max ().item ()
268
+ prefill_query_start_loc = query_start_loc [
269
+ reqs_start :] - query_start_loc [reqs_start ]
261
270
262
271
prefill_metadata = AscendMLAPrefillMetadata (
263
272
attn_mask = self .runner .attn_mask ,
@@ -268,6 +277,7 @@ def build(self,
268
277
block_table = block_table [reqs_start :, ...],
269
278
max_query_len = max_query_len ,
270
279
max_seq_lens = max_seq_lens ,
280
+ query_start_loc = prefill_query_start_loc ,
271
281
)
272
282
273
283
decode_metadata = None
@@ -324,6 +334,9 @@ def build(self,
324
334
attn_state = self .runner .attn_state ,
325
335
prefill = prefill_metadata ,
326
336
decode = decode_metadata ,
337
+ query_start_loc = query_start_loc ,
338
+ block_tables = block_table ,
339
+ seq_lens = seq_lens ,
327
340
)
328
341
329
342
@@ -373,6 +386,12 @@ def __init__(
373
386
self .qk_rope_head_dim = qk_rope_head_dim
374
387
self .qk_head_dim = qk_head_dim
375
388
self .v_head_dim = v_head_dim
389
+ # TODO: below padding should be removed after kernel is ready
390
+ # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
391
+ # and slice the final result to guarantee its functionality.
392
+ self .padding_head_dim = (
393
+ (self .qk_nope_head_dim + self .qk_rope_head_dim - 1 ) // 128 +
394
+ 1 ) * 128
376
395
377
396
# Hack for V1 for now to avoid torch library overhead (since we are
378
397
# already inside an attention custom op), pull out the forward
@@ -470,11 +489,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
470
489
[self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
471
490
472
491
# Convert from (L, N, V) to (N, L, V)
473
- self .W_UV = W_UV .transpose (0 , 1 ). contiguous ()
492
+ self .W_UV = W_UV .transpose (0 , 1 )
474
493
# Convert from (L, N, P) to (N, P, L)
475
- self .W_UK_T = W_UK .permute (1 , 2 , 0 ).contiguous ()
476
- self .W_UV .data = torch_npu .npu_format_cast (self .W_UV .data , 29 )
477
- self .W_UK_T .data = torch_npu .npu_format_cast (self .W_UK_T .data , 29 )
494
+ self .W_UK_T = W_UK .permute (1 , 2 , 0 )
478
495
479
496
def _forward_prefill (
480
497
self ,
@@ -514,7 +531,7 @@ def _forward_prefill(
514
531
elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
515
532
attn_output = torch .empty (num_tokens ,
516
533
self .num_heads ,
517
- self .v_head_dim ,
534
+ self .padding_head_dim ,
518
535
dtype = query .dtype ,
519
536
device = query .device )
520
537
k_nope , value = self .kv_b_proj (kv_c_normed )[0 ].view (
@@ -523,17 +540,31 @@ def _forward_prefill(
523
540
[self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
524
541
key = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
525
542
dim = - 1 )
543
+ pad_query = torch .nn .functional .pad (query , [
544
+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
545
+ self .qk_nope_head_dim
546
+ ],
547
+ value = 0 )
548
+ pad_key = torch .nn .functional .pad (key , [
549
+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
550
+ self .qk_nope_head_dim
551
+ ],
552
+ value = 0 )
553
+ pad_value = torch .nn .functional .pad (
554
+ value , [0 , self .padding_head_dim - self .v_head_dim ], value = 0 )
526
555
torch_npu ._npu_flash_attention (
527
- query = query ,
528
- key = key ,
529
- value = value ,
556
+ query = pad_query ,
557
+ key = pad_key ,
558
+ value = pad_value ,
530
559
mask = attn_metadata .attn_mask ,
531
560
seq_len = attn_metadata .prefill .context_lens ,
532
561
scale_value = self .scale ,
533
562
num_heads = self .num_heads ,
534
563
num_kv_heads = self .num_heads ,
535
564
out = attn_output )
536
- attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
565
+ attn_output = attn_output .view (
566
+ - 1 , self .num_heads ,
567
+ self .padding_head_dim )[:, :, :self .v_head_dim ]
537
568
else :
538
569
raise RuntimeError (
539
570
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
0 commit comments