@@ -55,7 +55,7 @@ class AscendMLAPrefillMetadata:
55
55
input_positions : torch .Tensor
56
56
block_table : torch .Tensor
57
57
max_query_len : int
58
- max_context_len : int
58
+ max_seq_lens : int
59
59
60
60
61
61
@dataclass
@@ -65,6 +65,7 @@ class AscendMLADecodeMetadata:
65
65
input_positions : torch .Tensor
66
66
block_table : torch .Tensor
67
67
seq_lens : torch .Tensor
68
+ max_seq_lens : int
68
69
69
70
70
71
@dataclass
@@ -131,11 +132,6 @@ def __init__(self,
131
132
self .runner = runner
132
133
scheduler_config = runner .scheduler_config
133
134
self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
134
- # self.attn_mask = None
135
- # if AscendMLAMetadataBuilder._attn_mask_builder is None:
136
- # AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
137
- # 128, self.runner.model_config.dtype
138
- # )
139
135
140
136
def reorder_batch (self , input_batch : "InputBatch" ,
141
137
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -222,12 +218,14 @@ def build(self,
222
218
num_reqs ]
223
219
seq_lens = seq_lens_cpu
224
220
max_query_len = query_lens .max ().item ()
225
- max_context_len = seq_lens .max ().item ()
221
+ max_seq_lens = seq_lens .max ().item ()
226
222
227
223
prefill_metadata = None
228
224
if self ._num_prefills > 0 :
229
225
reqs_start = self ._num_decodes # prefill_start
230
226
tokens_start = self ._num_decode_tokens
227
+ max_query_len = query_lens [tokens_start :].max ().item ()
228
+ max_seq_lens = seq_lens [tokens_start :].max ().item ()
231
229
232
230
prefill_metadata = AscendMLAPrefillMetadata (
233
231
attn_mask = self .runner .attn_mask ,
@@ -236,15 +234,17 @@ def build(self,
236
234
input_positions = input_positions [tokens_start :],
237
235
block_table = block_table [reqs_start :, ...],
238
236
max_query_len = max_query_len ,
239
- max_context_len = max_context_len ,
237
+ max_seq_lens = max_seq_lens ,
240
238
)
241
239
242
240
decode_metadata = None
243
241
if self ._num_decodes > 0 :
242
+ max_seq_lens = seq_lens [:self ._num_decodes ].max ().item ()
244
243
decode_metadata = AscendMLADecodeMetadata (
245
244
input_positions = input_positions [:self ._num_decode_tokens ],
246
245
block_table = block_table [:self ._num_decode_tokens , ...],
247
- seq_lens = seq_lens [:self ._num_decode_tokens ])
246
+ seq_lens = seq_lens [:self ._num_decode_tokens ],
247
+ max_seq_lens = max_seq_lens )
248
248
249
249
return self .metadata_cls ( # type: ignore
250
250
num_actual_tokens = num_actual_tokens ,
@@ -306,12 +306,18 @@ def __init__(
306
306
self .qk_rope_head_dim = qk_rope_head_dim
307
307
self .qk_head_dim = qk_head_dim
308
308
self .v_head_dim = v_head_dim
309
+ # TODO: below padding should be removed after kernel is ready
310
+ # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
311
+ # and slice the final result to guarantee its functionality.
312
+ self .padding_head_dim = (
313
+ (self .qk_nope_head_dim + self .qk_rope_head_dim - 1 ) // 128 +
314
+ 1 ) * 128
309
315
310
316
# Hack for V1 for now to avoid torch library overhead (since we are
311
317
# already inside an attention custom op), pull out the forward
312
318
# method from the rotary embedding and call it directly
313
319
# TODO(lucas): we should probably find a cleaner way to do this
314
- self .rotary_emb = rotary_emb . forward_native
320
+ self .rotary_emb = rotary_emb
315
321
316
322
self .q_proj = q_proj
317
323
self .kv_b_proj = kv_b_proj
@@ -409,37 +415,73 @@ def _forward_prefill(
409
415
) -> torch .Tensor :
410
416
assert attn_metadata .prefill is not None
411
417
412
- # TODO: enable this compute for flash attention computation
413
- # kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
414
- # -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
415
- # k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
416
- # key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
417
- # v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]],
418
- # value=0)
419
418
num_tokens = query .size (0 )
420
- attn_output = torch .empty (num_tokens ,
421
- self .num_heads ,
422
- self .v_head_dim ,
423
- dtype = query .dtype ,
424
- device = query .device )
425
- # current requests is chunked in prefill, disable flash attention with chunked prefill
426
- vanilla_chunked_prefill_mla (
427
- output = attn_output ,
428
- query = query ,
429
- kv_cache = kv_c_and_k_pe_cache ,
430
- block_tables = attn_metadata .prefill .block_table ,
431
- query_lens = attn_metadata .prefill .query_lens ,
432
- context_lens = attn_metadata .prefill .context_lens ,
433
- kv_b_proj = self .kv_b_proj ,
434
- max_query_len = attn_metadata .prefill .max_query_len ,
435
- max_context_len = attn_metadata .prefill .max_context_len ,
436
- nope_dim = self .qk_nope_head_dim ,
437
- rope_dim = self .qk_rope_head_dim ,
438
- v_head_dim = self .v_head_dim ,
439
- scale = self .scale ,
440
- alibi_slopes = None ,
441
- causal = True )
442
- attn_output = attn_output .view (
419
+ attn_output = None
420
+ # Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly
421
+ if attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill :
422
+ attn_output = torch .empty (num_tokens ,
423
+ self .num_heads * self .v_head_dim ,
424
+ dtype = query .dtype ,
425
+ device = query .device )
426
+ # current requests is chunked in prefill, disable flash attention with chunked prefill
427
+ vanilla_chunked_prefill_mla (
428
+ output = attn_output ,
429
+ query = query ,
430
+ kv_cache = kv_c_and_k_pe_cache ,
431
+ block_tables = attn_metadata .prefill .block_table ,
432
+ query_lens = attn_metadata .prefill .query_lens ,
433
+ context_lens = attn_metadata .prefill .context_lens ,
434
+ kv_b_proj = self .kv_b_proj ,
435
+ max_query_len = attn_metadata .prefill .max_query_len ,
436
+ max_context_len = attn_metadata .prefill .max_seq_lens ,
437
+ nope_dim = self .qk_nope_head_dim ,
438
+ rope_dim = self .qk_rope_head_dim ,
439
+ v_head_dim = self .v_head_dim ,
440
+ scale = self .scale ,
441
+ alibi_slopes = None ,
442
+ causal = True )
443
+ elif attn_metadata .attn_state == AscendAttentionState .PrefillOnly :
444
+ attn_output = torch .empty (num_tokens ,
445
+ self .num_heads ,
446
+ self .padding_head_dim ,
447
+ dtype = query .dtype ,
448
+ device = query .device )
449
+ k_nope , value = self .kv_b_proj (kv_c_normed )[0 ].view (
450
+ - 1 , self .num_heads ,
451
+ self .qk_nope_head_dim + self .v_head_dim ).split (
452
+ [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
453
+ key = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
454
+ dim = - 1 )
455
+ pad_query = torch .nn .functional .pad (query , [
456
+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
457
+ self .qk_nope_head_dim
458
+ ],
459
+ value = 0 )
460
+ pad_key = torch .nn .functional .pad (key , [
461
+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
462
+ self .qk_nope_head_dim
463
+ ],
464
+ value = 0 )
465
+ pad_value = torch .nn .functional .pad (
466
+ value , [0 , self .padding_head_dim - self .v_head_dim ], value = 0 )
467
+ torch_npu ._npu_flash_attention (
468
+ query = pad_query ,
469
+ key = pad_key ,
470
+ value = pad_value ,
471
+ mask = attn_metadata .attn_mask ,
472
+ seq_len = attn_metadata .prefill .context_lens ,
473
+ scale_value = self .scale ,
474
+ num_heads = self .num_heads ,
475
+ num_kv_heads = self .num_heads ,
476
+ out = attn_output )
477
+ attn_output = attn_output .view (
478
+ - 1 , self .num_heads ,
479
+ self .padding_head_dim )[:, :, :self .v_head_dim ]
480
+ else :
481
+ raise RuntimeError (
482
+ "Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
483
+ )
484
+ attn_output = attn_output .reshape (
443
485
[num_tokens , self .num_heads * self .v_head_dim ])
444
486
return self .o_proj (attn_output )[0 ]
445
487
@@ -457,7 +499,7 @@ def _forward_decode(
457
499
458
500
q = torch .cat ([q_nope , q_pe ], dim = - 1 )
459
501
num_tokens = q .size (0 )
460
- attn_output = torch .randn (
502
+ attn_output = torch .empty (
461
503
[num_tokens , self .num_heads , self .kv_lora_rank ],
462
504
dtype = q .dtype ,
463
505
device = q .device )
@@ -522,8 +564,10 @@ def forward(
522
564
decode_ql_nope , decode_q_pe = \
523
565
self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
524
566
decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
525
- attn_metadata .decode .input_positions , decode_q_pe .contiguous (),
526
- decode_k_pe )
567
+ attn_metadata .decode .input_positions ,
568
+ decode_q_pe .contiguous (),
569
+ decode_k_pe ,
570
+ max_seq_len = attn_metadata .decode .max_seq_lens )
527
571
528
572
if has_prefill :
529
573
assert attn_metadata .prefill is not None
@@ -533,7 +577,9 @@ def forward(
533
577
534
578
prefill_q_pe [...], prefill_k_pe [...] = self .rotary_emb (
535
579
attn_metadata .prefill .input_positions ,
536
- prefill_q_pe .contiguous (), prefill_k_pe )
580
+ prefill_q_pe .contiguous (),
581
+ prefill_k_pe ,
582
+ max_seq_len = attn_metadata .prefill .max_seq_lens )
537
583
538
584
if kv_cache .numel () > 0 :
539
585
key = torch .cat ([
0 commit comments