13
13
LinearBase , RowParallelLinear ,
14
14
UnquantizedLinearMethod )
15
15
from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
16
+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
16
17
17
18
from vllm_ascend .attention .attention_v1 import AscendAttentionState
18
19
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
20
+ from vllm_ascend .utils import vllm_version_is
19
21
from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
20
22
21
23
if TYPE_CHECKING :
22
24
from vllm .v1 .core .sched .output import SchedulerOutput
23
25
from vllm .v1 .worker .gpu_input_batch import InputBatch
24
26
27
+ if vllm_version_is ("main" ):
28
+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
29
+
25
30
26
31
class AscendMLABackend (AttentionBackend ):
27
32
@@ -57,6 +62,7 @@ class AscendMLAPrefillMetadata:
57
62
seq_lens : list [int ]
58
63
context_lens : torch .Tensor
59
64
input_positions : torch .Tensor
65
+ query_start_loc : torch .Tensor
60
66
block_table : torch .Tensor
61
67
max_query_len : int
62
68
max_seq_lens : int
@@ -90,6 +96,9 @@ class AscendMLAMetadata:
90
96
91
97
num_actual_tokens : int # Number of tokens excluding padding.
92
98
slot_mapping : torch .Tensor
99
+ query_start_loc : torch .Tensor
100
+ seq_lens : torch .Tensor
101
+ block_tables : torch .Tensor
93
102
94
103
# New for MLA (compared to FlashAttention)
95
104
# For handling prefill decode split
@@ -231,6 +240,7 @@ def build(self,
231
240
num_actual_tokens : int ,
232
241
max_query_len : int ,
233
242
common_prefix_len : Optional [int ] = None ,
243
+ common_attn_metadata : CommonAttentionMetadata = None ,
234
244
graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
235
245
assert self ._num_decodes + self ._num_prefills == num_reqs
236
246
@@ -245,6 +255,7 @@ def build(self,
245
255
input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
246
256
device , non_blocking = True ).long ()
247
257
258
+ query_start_loc = common_attn_metadata .query_start_loc
248
259
seq_lens_cpu = self .runner .seq_lens_cpu [:num_reqs ]
249
260
query_lens = seq_lens_cpu - self .runner .input_batch .num_computed_tokens_cpu_tensor [:
250
261
num_reqs ]
@@ -258,6 +269,8 @@ def build(self,
258
269
tokens_start = self ._num_decode_tokens
259
270
max_query_len = query_lens [tokens_start :].max ().item ()
260
271
max_seq_lens = seq_lens [tokens_start :].max ().item ()
272
+ prefill_query_start_loc = query_start_loc [
273
+ reqs_start :] - query_start_loc [reqs_start ]
261
274
262
275
prefill_metadata = AscendMLAPrefillMetadata (
263
276
attn_mask = self .runner .attn_mask ,
@@ -268,6 +281,7 @@ def build(self,
268
281
block_table = block_table [reqs_start :, ...],
269
282
max_query_len = max_query_len ,
270
283
max_seq_lens = max_seq_lens ,
284
+ query_start_loc = prefill_query_start_loc ,
271
285
)
272
286
273
287
decode_metadata = None
@@ -324,6 +338,9 @@ def build(self,
324
338
attn_state = self .runner .attn_state ,
325
339
prefill = prefill_metadata ,
326
340
decode = decode_metadata ,
341
+ query_start_loc = query_start_loc ,
342
+ block_tables = block_table ,
343
+ seq_lens = seq_lens ,
327
344
)
328
345
329
346
@@ -373,6 +390,12 @@ def __init__(
373
390
self .qk_rope_head_dim = qk_rope_head_dim
374
391
self .qk_head_dim = qk_head_dim
375
392
self .v_head_dim = v_head_dim
393
+ # TODO: below padding should be removed after kernel is ready
394
+ # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
395
+ # and slice the final result to guarantee its functionality.
396
+ self .padding_head_dim = (
397
+ (self .qk_nope_head_dim + self .qk_rope_head_dim - 1 ) // 128 +
398
+ 1 ) * 128
376
399
377
400
# Hack for V1 for now to avoid torch library overhead (since we are
378
401
# already inside an attention custom op), pull out the forward
@@ -470,11 +493,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
470
493
[self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
471
494
472
495
# Convert from (L, N, V) to (N, L, V)
473
- self .W_UV = W_UV .transpose (0 , 1 ). contiguous ()
496
+ self .W_UV = W_UV .transpose (0 , 1 )
474
497
# Convert from (L, N, P) to (N, P, L)
475
- self .W_UK_T = W_UK .permute (1 , 2 , 0 ).contiguous ()
476
- self .W_UV .data = torch_npu .npu_format_cast (self .W_UV .data , 29 )
477
- self .W_UK_T .data = torch_npu .npu_format_cast (self .W_UK_T .data , 29 )
498
+ self .W_UK_T = W_UK .permute (1 , 2 , 0 )
478
499
479
500
def _forward_prefill (
480
501
self ,
@@ -514,7 +535,7 @@ def _forward_prefill(
514
535
elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
515
536
attn_output = torch .empty (num_tokens ,
516
537
self .num_heads ,
517
- self .v_head_dim ,
538
+ self .padding_head_dim ,
518
539
dtype = query .dtype ,
519
540
device = query .device )
520
541
k_nope , value = self .kv_b_proj (kv_c_normed )[0 ].view (
@@ -523,17 +544,31 @@ def _forward_prefill(
523
544
[self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
524
545
key = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
525
546
dim = - 1 )
547
+ pad_query = torch .nn .functional .pad (query , [
548
+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
549
+ self .qk_nope_head_dim
550
+ ],
551
+ value = 0 )
552
+ pad_key = torch .nn .functional .pad (key , [
553
+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
554
+ self .qk_nope_head_dim
555
+ ],
556
+ value = 0 )
557
+ pad_value = torch .nn .functional .pad (
558
+ value , [0 , self .padding_head_dim - self .v_head_dim ], value = 0 )
526
559
torch_npu ._npu_flash_attention (
527
- query = query ,
528
- key = key ,
529
- value = value ,
560
+ query = pad_query ,
561
+ key = pad_key ,
562
+ value = pad_value ,
530
563
mask = attn_metadata .attn_mask ,
531
564
seq_len = attn_metadata .prefill .context_lens ,
532
565
scale_value = self .scale ,
533
566
num_heads = self .num_heads ,
534
567
num_kv_heads = self .num_heads ,
535
568
out = attn_output )
536
- attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
569
+ attn_output = attn_output .view (
570
+ - 1 , self .num_heads ,
571
+ self .padding_head_dim )[:, :, :self .v_head_dim ]
537
572
else :
538
573
raise RuntimeError (
539
574
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
0 commit comments