@@ -130,6 +130,24 @@ class FlashAttentionMetadata:
130
130
prefix_scheduler_metadata : Optional [torch .Tensor ] = None
131
131
max_num_splits : int = 0
132
132
133
+ # Begin encoder attn & enc/dec cross-attn fields...
134
+
135
+ # (batch_size + 1,). The cumulative sequence lengths of the encoder
136
+ # sequences in the batch, used to index into sequence. E.g., if the sequence
137
+ # length is [4, 6], it is [0, 4, 10].
138
+ encoder_seq_start_loc : Optional [torch .Tensor ] = None
139
+ # Maximum sequence length among encoder sequences
140
+ max_encoder_seq_len : Optional [int ] = None
141
+ cross_slot_mapping : Optional [torch .Tensor ] = None
142
+
143
+ @property
144
+ def is_all_encoder_attn_metadata_set (self ) -> bool :
145
+ """
146
+ All attention metadata required for encoder attention is set.
147
+ """
148
+ return (self .encoder_seq_start_loc is not None
149
+ and self .max_encoder_seq_len is not None )
150
+
133
151
134
152
def _get_sliding_window_configs (
135
153
vllm_config : VllmConfig ) -> set [Optional [tuple [int , int ]]]:
@@ -207,7 +225,13 @@ def build(self,
207
225
num_reqs = common_attn_metadata .num_reqs
208
226
num_actual_tokens = common_attn_metadata .num_actual_tokens
209
227
max_query_len = common_attn_metadata .max_query_len
210
- max_seq_len = int (common_attn_metadata .seq_lens_cpu .max ())
228
+
229
+ if (common_attn_metadata .cross_slot_mapping is not None
230
+ and common_attn_metadata .max_encoder_seq_len is not None ):
231
+ # ENCODER_DECODER cross-attention
232
+ max_seq_len = common_attn_metadata .max_encoder_seq_len
233
+ else :
234
+ max_seq_len = int (common_attn_metadata .seq_lens_cpu .max ())
211
235
query_start_loc = common_attn_metadata .query_start_loc
212
236
seq_lens = common_attn_metadata .seq_lens
213
237
seq_lens_cpu = common_attn_metadata .seq_lens_cpu
@@ -326,6 +350,10 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
326
350
suffix_kv_lens = suffix_kv_lens ,
327
351
prefix_scheduler_metadata = prefix_scheduler_metadata ,
328
352
max_num_splits = max_num_splits ,
353
+ # Encoder/cross-attention fields
354
+ encoder_seq_start_loc = common_attn_metadata .encoder_seq_start_loc ,
355
+ max_encoder_seq_len = common_attn_metadata .max_encoder_seq_len ,
356
+ cross_slot_mapping = common_attn_metadata .cross_slot_mapping ,
329
357
)
330
358
return attn_metadata
331
359
@@ -380,18 +408,32 @@ def __init__(
380
408
381
409
FlashAttentionBackend .validate_head_size (head_size )
382
410
383
- if attn_type != AttentionType .DECODER :
384
- raise NotImplementedError ("Encoder self-attention and "
385
- "encoder/decoder cross-attention "
386
- "are not implemented for "
387
- "FlashAttentionImpl" )
411
+ self .attn_type = attn_type
388
412
self .use_irope = use_irope
389
413
self .vllm_flash_attn_version = get_flash_attn_version ()
390
414
if is_quantized_kv_cache (self .kv_cache_dtype ) \
391
415
and not flash_attn_supports_fp8 ():
392
416
raise NotImplementedError (
393
417
"FlashAttention does not support fp8 kv-cache on this device." )
394
418
419
+ @staticmethod
420
+ def _get_causal_option (attn_type : str ) -> bool :
421
+ """
422
+ Determine whether the given attention type is suitable for causal
423
+ attention mechanisms.
424
+
425
+ Args:
426
+ attn_type (AttentionType): The type of attention being evaluated
427
+
428
+ Returns:
429
+ bool: Returns `True` if the attention type is suitable for causal
430
+ attention (i.e., not encoder, encoder-only, or encoder-decoder),
431
+ otherwise returns `False`.
432
+ """
433
+ return not (attn_type == AttentionType .ENCODER
434
+ or attn_type == AttentionType .ENCODER_ONLY
435
+ or attn_type == AttentionType .ENCODER_DECODER )
436
+
395
437
def forward (
396
438
self ,
397
439
layer : torch .nn .Module ,
@@ -428,6 +470,14 @@ def forward(
428
470
# Profiling run.
429
471
return output
430
472
473
+ # Validate attention metadata based on attention type
474
+ attn_type = self .attn_type
475
+ if (attn_type in (AttentionType .ENCODER , AttentionType .ENCODER_DECODER ,
476
+ AttentionType .ENCODER_ONLY )
477
+ and (not attn_metadata .is_all_encoder_attn_metadata_set )):
478
+ raise AttributeError ("Encoder attention requires setting "
479
+ "encoder metadata attributes." )
480
+
431
481
# IMPORTANT!
432
482
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
433
483
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@@ -438,22 +488,40 @@ def forward(
438
488
# performance to make sure it does not introduce any overhead.
439
489
440
490
num_actual_tokens = attn_metadata .num_actual_tokens
491
+
492
+ # Handle encoder attention differently - no KV cache needed
493
+ if attn_type == AttentionType .ENCODER :
494
+ # For encoder attention,
495
+ # we use direct Q, K, V tensors without caching
496
+ return self ._forward_encoder_attention (query [:num_actual_tokens ],
497
+ key [:num_actual_tokens ],
498
+ value [:num_actual_tokens ],
499
+ output [:num_actual_tokens ],
500
+ attn_metadata , layer )
501
+
502
+ # For decoder and cross-attention, use KV cache as before
441
503
key_cache , value_cache = kv_cache .unbind (0 )
442
504
443
- if self .kv_sharing_target_layer_name is None :
505
+ if (self .kv_sharing_target_layer_name is None and (key is not None )
506
+ and (value is not None )):
444
507
# Reshape the input keys and values and store them in the cache.
445
508
# Skip this if sharing KV cache with an earlier attention layer.
446
509
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
447
510
# not padded. However, we don't need to do key[:num_actual_tokens]
448
511
# and value[:num_actual_tokens] because the reshape_and_cache_flash
449
512
# op uses the slot_mapping's shape to determine the number of
450
513
# actual tokens.
514
+ if attn_type == AttentionType .ENCODER_DECODER :
515
+ updated_slot_mapping = attn_metadata .cross_slot_mapping
516
+ else :
517
+ updated_slot_mapping = attn_metadata .slot_mapping
518
+
451
519
reshape_and_cache_flash (
452
520
key ,
453
521
value ,
454
522
key_cache ,
455
523
value_cache ,
456
- attn_metadata . slot_mapping ,
524
+ updated_slot_mapping ,
457
525
self .kv_cache_dtype ,
458
526
layer ._k_scale ,
459
527
layer ._v_scale ,
@@ -477,7 +545,7 @@ def forward(
477
545
block_table = attn_metadata .block_table
478
546
scheduler_metadata = attn_metadata .scheduler_metadata
479
547
480
- descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key . shape [ 1 ] )
548
+ descale_shape = (cu_seqlens_q .shape [0 ] - 1 , self . num_kv_heads )
481
549
482
550
flash_attn_varlen_func (
483
551
q = query [:num_actual_tokens ],
@@ -489,7 +557,7 @@ def forward(
489
557
seqused_k = seqused_k ,
490
558
max_seqlen_k = max_seqlen_k ,
491
559
softmax_scale = self .scale ,
492
- causal = True ,
560
+ causal = FlashAttentionImpl . _get_causal_option ( attn_type ) ,
493
561
alibi_slopes = self .alibi_slopes ,
494
562
window_size = self .sliding_window ,
495
563
block_table = block_table ,
@@ -524,12 +592,86 @@ def forward(
524
592
fa_version = self .vllm_flash_attn_version ,
525
593
prefix_scheduler_metadata = attn_metadata .prefix_scheduler_metadata ,
526
594
suffix_scheduler_metadata = attn_metadata .scheduler_metadata ,
527
- q_descale = layer ._q_scale ,
528
- k_descale = layer ._k_scale ,
529
- v_descale = layer ._v_scale ,
595
+ q_descale = layer ._q_scale . expand ( descale_shape ) ,
596
+ k_descale = layer ._k_scale . expand ( descale_shape ) ,
597
+ v_descale = layer ._v_scale . expand ( descale_shape ) ,
530
598
)
531
599
return output
532
600
601
+ def _forward_encoder_attention (
602
+ self ,
603
+ query : torch .Tensor ,
604
+ key : torch .Tensor ,
605
+ value : torch .Tensor ,
606
+ output : torch .Tensor ,
607
+ attn_metadata : FlashAttentionMetadata ,
608
+ layer : torch .nn .Module ,
609
+ ) -> torch .Tensor :
610
+ """Forward pass for encoder attention without KV cache.
611
+
612
+ Args:
613
+ query: shape = [num_encoder_tokens, num_heads, head_size]
614
+ key: shape = [num_encoder_tokens, num_kv_heads, head_size]
615
+ value: shape = [num_encoder_tokens, num_kv_heads, head_size]
616
+ output: shape = [num_encoder_tokens, num_heads, head_size]
617
+ attn_metadata: Encoder attention metadata
618
+ layer: The attention layer
619
+ """
620
+ # For encoder attention, process FP8 quantization if needed
621
+ if self .kv_cache_dtype .startswith ("fp8" ):
622
+ num_tokens , num_heads , head_size = query .shape
623
+ query , _ = ops .scaled_fp8_quant (
624
+ query .reshape (
625
+ (num_tokens , num_heads * head_size )).contiguous (),
626
+ layer ._q_scale )
627
+ query = query .reshape ((num_tokens , num_heads , head_size ))
628
+
629
+ num_kv_tokens , num_kv_heads , head_size = key .shape
630
+ key , _ = ops .scaled_fp8_quant (
631
+ key .reshape (
632
+ (num_kv_tokens , num_kv_heads * head_size )).contiguous (),
633
+ layer ._k_scale )
634
+ key = key .reshape ((num_kv_tokens , num_kv_heads , head_size ))
635
+
636
+ value , _ = ops .scaled_fp8_quant (
637
+ value .reshape (
638
+ (num_kv_tokens , num_kv_heads * head_size )).contiguous (),
639
+ layer ._v_scale )
640
+ value = value .reshape ((num_kv_tokens , num_kv_heads , head_size ))
641
+
642
+ # Use encoder-specific metadata for sequence information
643
+ cu_seqlens_q = attn_metadata .encoder_seq_start_loc
644
+ cu_seqlens_k = attn_metadata .encoder_seq_start_loc
645
+ max_seqlen_q = attn_metadata .max_encoder_seq_len
646
+ max_seqlen_k = attn_metadata .max_encoder_seq_len
647
+
648
+ descale_shape = (
649
+ cu_seqlens_q .shape [0 ] - 1 , # type: ignore[union-attr]
650
+ self .num_kv_heads )
651
+
652
+ # Call flash attention directly on Q, K, V tensors
653
+ flash_attn_varlen_func (
654
+ q = query ,
655
+ k = key ,
656
+ v = value ,
657
+ out = output ,
658
+ cu_seqlens_q = cu_seqlens_q ,
659
+ cu_seqlens_k = cu_seqlens_k ,
660
+ max_seqlen_q = max_seqlen_q ,
661
+ max_seqlen_k = max_seqlen_k ,
662
+ softmax_scale = self .scale ,
663
+ causal = False , # Encoder attention is bidirectional
664
+ alibi_slopes = self .alibi_slopes ,
665
+ window_size = self .sliding_window ,
666
+ softcap = self .logits_soft_cap ,
667
+ fa_version = self .vllm_flash_attn_version ,
668
+ q_descale = layer ._q_scale .expand (descale_shape ),
669
+ k_descale = layer ._k_scale .expand (descale_shape ),
670
+ v_descale = layer ._v_scale .expand (descale_shape ),
671
+ )
672
+
673
+ return output
674
+
533
675
534
676
def use_cascade_attention (
535
677
common_prefix_len : int ,
0 commit comments