13
13
LinearBase , RowParallelLinear ,
14
14
UnquantizedLinearMethod )
15
15
from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
16
+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
16
17
17
18
from vllm_ascend .attention .attention_v1 import AscendAttentionState
18
19
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
19
- from vllm_ascend .utils import vllm_version_is
20
+ from vllm_ascend .utils import vllm_major_version_is , vllm_version_is
20
21
from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
21
22
22
23
if TYPE_CHECKING :
23
24
from vllm .v1 .core .sched .output import SchedulerOutput
24
25
from vllm .v1 .worker .gpu_input_batch import InputBatch
25
26
27
+ if vllm_major_version_is ("0.9.0" ):
28
+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
29
+
26
30
27
31
class AscendMLABackend (AttentionBackend ):
28
32
@@ -58,6 +62,7 @@ class AscendMLAPrefillMetadata:
58
62
seq_lens : list [int ]
59
63
context_lens : torch .Tensor
60
64
input_positions : torch .Tensor
65
+ query_start_loc : torch .Tensor
61
66
block_table : torch .Tensor
62
67
max_query_len : int
63
68
max_seq_lens : int
@@ -91,6 +96,9 @@ class AscendMLAMetadata:
91
96
92
97
num_actual_tokens : int # Number of tokens excluding padding.
93
98
slot_mapping : torch .Tensor
99
+ query_start_loc : torch .Tensor
100
+ seq_lens : torch .Tensor
101
+ block_tables : torch .Tensor
94
102
95
103
# New for MLA (compared to FlashAttention)
96
104
# For handling prefill decode split
@@ -232,6 +240,7 @@ def build(self,
232
240
num_actual_tokens : int ,
233
241
max_query_len : int ,
234
242
common_prefix_len : Optional [int ] = None ,
243
+ common_attn_metadata : CommonAttentionMetadata = None ,
235
244
graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
236
245
assert self ._num_decodes + self ._num_prefills == num_reqs
237
246
@@ -243,15 +252,14 @@ def build(self,
243
252
block_table = (self .runner .input_batch .block_table .
244
253
get_device_tensor ()[:num_reqs ])
245
254
else :
246
- block_table = self .runner .input_batch .block_table [
247
- 0 ].get_device_tensor ()
248
- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
249
- block_table [:num_reqs ])
255
+ block_table = (self .runner .input_batch .block_table [0 ].
256
+ get_device_tensor ()[:num_reqs ])
250
257
slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
251
258
device , non_blocking = True )
252
259
input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
253
260
device , non_blocking = True ).long ()
254
261
262
+ query_start_loc = common_attn_metadata .query_start_loc
255
263
seq_lens_cpu = self .runner .seq_lens_cpu [:num_reqs ]
256
264
query_lens = seq_lens_cpu - self .runner .input_batch .num_computed_tokens_cpu_tensor [:
257
265
num_reqs ]
@@ -265,6 +273,8 @@ def build(self,
265
273
tokens_start = self ._num_decode_tokens
266
274
max_query_len = query_lens [tokens_start :].max ().item ()
267
275
max_seq_lens = seq_lens [tokens_start :].max ().item ()
276
+ prefill_query_start_loc = query_start_loc [
277
+ reqs_start :] - query_start_loc [reqs_start ]
268
278
269
279
prefill_metadata = AscendMLAPrefillMetadata (
270
280
attn_mask = self .runner .attn_mask ,
@@ -275,6 +285,7 @@ def build(self,
275
285
block_table = block_table [reqs_start :, ...],
276
286
max_query_len = max_query_len ,
277
287
max_seq_lens = max_seq_lens ,
288
+ query_start_loc = prefill_query_start_loc ,
278
289
)
279
290
280
291
decode_metadata = None
@@ -331,6 +342,9 @@ def build(self,
331
342
attn_state = self .runner .attn_state ,
332
343
prefill = prefill_metadata ,
333
344
decode = decode_metadata ,
345
+ query_start_loc = query_start_loc ,
346
+ block_tables = block_table ,
347
+ seq_lens = seq_lens ,
334
348
)
335
349
336
350
@@ -380,6 +394,12 @@ def __init__(
380
394
self .qk_rope_head_dim = qk_rope_head_dim
381
395
self .qk_head_dim = qk_head_dim
382
396
self .v_head_dim = v_head_dim
397
+ # TODO: below padding should be removed after kernel is ready
398
+ # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
399
+ # and slice the final result to guarantee its functionality.
400
+ self .padding_head_dim = (
401
+ (self .qk_nope_head_dim + self .qk_rope_head_dim - 1 ) // 128 +
402
+ 1 ) * 128
383
403
384
404
# Hack for V1 for now to avoid torch library overhead (since we are
385
405
# already inside an attention custom op), pull out the forward
@@ -477,11 +497,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
477
497
[self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
478
498
479
499
# Convert from (L, N, V) to (N, L, V)
480
- self .W_UV = W_UV .transpose (0 , 1 ). contiguous ()
500
+ self .W_UV = W_UV .transpose (0 , 1 )
481
501
# Convert from (L, N, P) to (N, P, L)
482
- self .W_UK_T = W_UK .permute (1 , 2 , 0 ).contiguous ()
483
- self .W_UV .data = torch_npu .npu_format_cast (self .W_UV .data , 29 )
484
- self .W_UK_T .data = torch_npu .npu_format_cast (self .W_UK_T .data , 29 )
502
+ self .W_UK_T = W_UK .permute (1 , 2 , 0 )
485
503
486
504
def _forward_prefill (
487
505
self ,
@@ -521,7 +539,7 @@ def _forward_prefill(
521
539
elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
522
540
attn_output = torch .empty (num_tokens ,
523
541
self .num_heads ,
524
- self .v_head_dim ,
542
+ self .padding_head_dim ,
525
543
dtype = query .dtype ,
526
544
device = query .device )
527
545
k_nope , value = self .kv_b_proj (kv_c_normed )[0 ].view (
@@ -530,17 +548,31 @@ def _forward_prefill(
530
548
[self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
531
549
key = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
532
550
dim = - 1 )
551
+ pad_query = torch .nn .functional .pad (query , [
552
+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
553
+ self .qk_nope_head_dim
554
+ ],
555
+ value = 0 )
556
+ pad_key = torch .nn .functional .pad (key , [
557
+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
558
+ self .qk_nope_head_dim
559
+ ],
560
+ value = 0 )
561
+ pad_value = torch .nn .functional .pad (
562
+ value , [0 , self .padding_head_dim - self .v_head_dim ], value = 0 )
533
563
torch_npu ._npu_flash_attention (
534
- query = query ,
535
- key = key ,
536
- value = value ,
564
+ query = pad_query ,
565
+ key = pad_key ,
566
+ value = pad_value ,
537
567
mask = attn_metadata .attn_mask ,
538
568
seq_len = attn_metadata .prefill .context_lens ,
539
569
scale_value = self .scale ,
540
570
num_heads = self .num_heads ,
541
571
num_kv_heads = self .num_heads ,
542
572
out = attn_output )
543
- attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
573
+ attn_output = attn_output .view (
574
+ - 1 , self .num_heads ,
575
+ self .padding_head_dim )[:, :, :self .v_head_dim ]
544
576
else :
545
577
raise RuntimeError (
546
578
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
0 commit comments