@@ -134,6 +134,16 @@ class FlashAttentionMetadata:
134
134
prefix_scheduler_metadata : Optional [torch .Tensor ] = None
135
135
max_num_splits : int = 0
136
136
137
+ # Begin encoder attn & enc/dec cross-attn fields...
138
+
139
+ # (batch_size + 1,). The cumulative sequence lengths of the encoder
140
+ # sequences in the batch, used to index into sequence. E.g., if the sequence
141
+ # length is [4, 6], it is [0, 4, 10].
142
+ encoder_seq_start_loc : Optional [torch .Tensor ] = None
143
+ # Maximum sequence length among encoder sequences
144
+ max_encoder_seq_len : Optional [int ] = None
145
+ cross_slot_mapping : Optional [torch .Tensor ] = None
146
+
137
147
# for local attention
138
148
@dataclass
139
149
class LocalAttentionMetadata :
@@ -146,6 +156,14 @@ class LocalAttentionMetadata:
146
156
147
157
local_attn_metadata : Optional [LocalAttentionMetadata ] = None
148
158
159
+ @property
160
+ def is_all_encoder_attn_metadata_set (self ) -> bool :
161
+ """
162
+ All attention metadata required for encoder attention is set.
163
+ """
164
+ return (self .encoder_seq_start_loc is not None
165
+ and self .max_encoder_seq_len is not None )
166
+
149
167
150
168
def _get_sliding_window_configs (
151
169
vllm_config : VllmConfig ) -> set [Optional [tuple [int , int ]]]:
@@ -212,14 +230,22 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
212
230
self .aot_sliding_window : Optional [tuple [int , int ]] = None
213
231
214
232
def build (
215
- self , common_prefix_len : int ,
216
- common_attn_metadata : CommonAttentionMetadata
217
- ) -> FlashAttentionMetadata :
233
+ self ,
234
+ common_prefix_len : int ,
235
+ common_attn_metadata : CommonAttentionMetadata ,
236
+ # Encoder/cross-attention metadata (optional)
237
+ encoder_seq_start_loc : Optional [torch .Tensor ] = None ,
238
+ max_encoder_seq_len : Optional [int ] = None ,
239
+ cross_slot_mapping : Optional [torch .Tensor ] = None ):
218
240
num_reqs = common_attn_metadata .num_reqs
219
241
num_actual_tokens = common_attn_metadata .num_actual_tokens
220
242
max_query_len = common_attn_metadata .max_query_len
221
243
222
- max_seq_len = int (self .runner .seq_lens_np [:num_reqs ].max ())
244
+ if cross_slot_mapping is not None and max_encoder_seq_len is not None :
245
+ # ENCODER_DECODER cross-attention
246
+ max_seq_len = max_encoder_seq_len
247
+ else :
248
+ max_seq_len = int (self .runner .seq_lens_np [:num_reqs ].max ())
223
249
query_start_loc = common_attn_metadata .query_start_loc
224
250
seq_lens = common_attn_metadata .seq_lens
225
251
block_table = self .block_table
@@ -379,6 +405,10 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
379
405
local_attn_metadata = local_attn_metadata ,
380
406
prefix_scheduler_metadata = prefix_scheduler_metadata ,
381
407
max_num_splits = max_num_splits ,
408
+ # Encoder/cross-attention fields
409
+ encoder_seq_start_loc = encoder_seq_start_loc ,
410
+ max_encoder_seq_len = max_encoder_seq_len ,
411
+ cross_slot_mapping = cross_slot_mapping ,
382
412
)
383
413
return attn_metadata
384
414
@@ -433,18 +463,32 @@ def __init__(
433
463
434
464
FlashAttentionBackend .validate_head_size (head_size )
435
465
436
- if attn_type != AttentionType .DECODER :
437
- raise NotImplementedError ("Encoder self-attention and "
438
- "encoder/decoder cross-attention "
439
- "are not implemented for "
440
- "FlashAttentionImpl" )
466
+ self .attn_type = attn_type
441
467
self .use_irope = use_irope
442
468
self .vllm_flash_attn_version = get_flash_attn_version ()
443
469
if is_quantized_kv_cache (self .kv_cache_dtype ) \
444
470
and not flash_attn_supports_fp8 ():
445
471
raise NotImplementedError (
446
472
"FlashAttention does not support fp8 kv-cache on this device." )
447
473
474
+ @staticmethod
475
+ def _get_causal_option (attn_type : str ) -> bool :
476
+ """
477
+ Determine whether the given attention type is suitable for causal
478
+ attention mechanisms.
479
+
480
+ Args:
481
+ attn_type (AttentionType): The type of attention being evaluated
482
+
483
+ Returns:
484
+ bool: Returns `True` if the attention type is suitable for causal
485
+ attention (i.e., not encoder, encoder-only, or encoder-decoder),
486
+ otherwise returns `False`.
487
+ """
488
+ return not (attn_type == AttentionType .ENCODER
489
+ or attn_type == AttentionType .ENCODER_ONLY
490
+ or attn_type == AttentionType .ENCODER_DECODER )
491
+
448
492
def forward (
449
493
self ,
450
494
layer : torch .nn .Module ,
@@ -481,6 +525,14 @@ def forward(
481
525
# Profiling run.
482
526
return output
483
527
528
+ # Validate attention metadata based on attention type
529
+ attn_type = self .attn_type
530
+ if (attn_type in (AttentionType .ENCODER , AttentionType .ENCODER_DECODER ,
531
+ AttentionType .ENCODER_ONLY )
532
+ and (not attn_metadata .is_all_encoder_attn_metadata_set )):
533
+ raise AttributeError ("Encoder attention requires setting "
534
+ "encoder metadata attributes." )
535
+
484
536
# IMPORTANT!
485
537
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
486
538
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@@ -491,22 +543,40 @@ def forward(
491
543
# performance to make sure it does not introduce any overhead.
492
544
493
545
num_actual_tokens = attn_metadata .num_actual_tokens
546
+
547
+ # Handle encoder attention differently - no KV cache needed
548
+ if attn_type == AttentionType .ENCODER :
549
+ # For encoder attention,
550
+ # we use direct Q, K, V tensors without caching
551
+ return self ._forward_encoder_attention (query [:num_actual_tokens ],
552
+ key [:num_actual_tokens ],
553
+ value [:num_actual_tokens ],
554
+ output [:num_actual_tokens ],
555
+ attn_metadata , layer )
556
+
557
+ # For decoder and cross-attention, use KV cache as before
494
558
key_cache , value_cache = kv_cache .unbind (0 )
495
559
496
- if self .kv_sharing_target_layer_name is None :
560
+ if (self .kv_sharing_target_layer_name is None and (key is not None )
561
+ and (value is not None )):
497
562
# Reshape the input keys and values and store them in the cache.
498
563
# Skip this if sharing KV cache with an earlier attention layer.
499
564
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
500
565
# not padded. However, we don't need to do key[:num_actual_tokens]
501
566
# and value[:num_actual_tokens] because the reshape_and_cache_flash
502
567
# op uses the slot_mapping's shape to determine the number of
503
568
# actual tokens.
569
+ if attn_type == AttentionType .ENCODER_DECODER :
570
+ updated_slot_mapping = attn_metadata .cross_slot_mapping
571
+ else :
572
+ updated_slot_mapping = attn_metadata .slot_mapping
573
+
504
574
reshape_and_cache_flash (
505
575
key ,
506
576
value ,
507
577
key_cache ,
508
578
value_cache ,
509
- attn_metadata . slot_mapping ,
579
+ updated_slot_mapping ,
510
580
self .kv_cache_dtype ,
511
581
layer ._k_scale ,
512
582
layer ._v_scale ,
@@ -544,7 +614,7 @@ def forward(
544
614
block_table = attn_metadata .block_table
545
615
scheduler_metadata = attn_metadata .scheduler_metadata
546
616
547
- descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key . shape [ 1 ] )
617
+ descale_shape = (cu_seqlens_q .shape [0 ] - 1 , self . num_kv_heads )
548
618
549
619
flash_attn_varlen_func (
550
620
q = query [:num_actual_tokens ],
@@ -556,7 +626,7 @@ def forward(
556
626
seqused_k = seqused_k ,
557
627
max_seqlen_k = max_seqlen_k ,
558
628
softmax_scale = self .scale ,
559
- causal = True ,
629
+ causal = FlashAttentionImpl . _get_causal_option ( attn_type ) ,
560
630
alibi_slopes = self .alibi_slopes ,
561
631
window_size = self .sliding_window ,
562
632
block_table = block_table ,
@@ -570,33 +640,78 @@ def forward(
570
640
)
571
641
return output
572
642
573
- assert not use_local_attn , (
574
- "Cascade attention does not support local attention." )
575
- # Cascade attention (rare case).
576
- cascade_attention (
577
- output [:num_actual_tokens ],
578
- query [:num_actual_tokens ],
579
- key_cache ,
580
- value_cache ,
581
- cu_query_lens = attn_metadata .query_start_loc ,
582
- max_query_len = attn_metadata .max_query_len ,
583
- cu_prefix_query_lens = attn_metadata .cu_prefix_query_lens ,
584
- prefix_kv_lens = attn_metadata .prefix_kv_lens ,
585
- suffix_kv_lens = attn_metadata .suffix_kv_lens ,
586
- max_kv_len = attn_metadata .max_seq_len ,
643
+ def _forward_encoder_attention (
644
+ self ,
645
+ query : torch .Tensor ,
646
+ key : torch .Tensor ,
647
+ value : torch .Tensor ,
648
+ output : torch .Tensor ,
649
+ attn_metadata : FlashAttentionMetadata ,
650
+ layer : torch .nn .Module ,
651
+ ) -> torch .Tensor :
652
+ """Forward pass for encoder attention without KV cache.
653
+
654
+ Args:
655
+ query: shape = [num_encoder_tokens, num_heads, head_size]
656
+ key: shape = [num_encoder_tokens, num_kv_heads, head_size]
657
+ value: shape = [num_encoder_tokens, num_kv_heads, head_size]
658
+ output: shape = [num_encoder_tokens, num_heads, head_size]
659
+ attn_metadata: Encoder attention metadata
660
+ layer: The attention layer
661
+ """
662
+ # For encoder attention, process FP8 quantization if needed
663
+ if self .kv_cache_dtype .startswith ("fp8" ):
664
+ num_tokens , num_heads , head_size = query .shape
665
+ query , _ = ops .scaled_fp8_quant (
666
+ query .reshape (
667
+ (num_tokens , num_heads * head_size )).contiguous (),
668
+ layer ._q_scale )
669
+ query = query .reshape ((num_tokens , num_heads , head_size ))
670
+
671
+ num_kv_tokens , num_kv_heads , head_size = key .shape
672
+ key , _ = ops .scaled_fp8_quant (
673
+ key .reshape (
674
+ (num_kv_tokens , num_kv_heads * head_size )).contiguous (),
675
+ layer ._k_scale )
676
+ key = key .reshape ((num_kv_tokens , num_kv_heads , head_size ))
677
+
678
+ value , _ = ops .scaled_fp8_quant (
679
+ value .reshape (
680
+ (num_kv_tokens , num_kv_heads * head_size )).contiguous (),
681
+ layer ._v_scale )
682
+ value = value .reshape ((num_kv_tokens , num_kv_heads , head_size ))
683
+
684
+ # Use encoder-specific metadata for sequence information
685
+ cu_seqlens_q = attn_metadata .encoder_seq_start_loc
686
+ cu_seqlens_k = attn_metadata .encoder_seq_start_loc
687
+ max_seqlen_q = attn_metadata .max_encoder_seq_len
688
+ max_seqlen_k = attn_metadata .max_encoder_seq_len
689
+
690
+ descale_shape = (
691
+ cu_seqlens_q .shape [0 ] - 1 , # type: ignore[union-attr]
692
+ self .num_kv_heads )
693
+
694
+ # Call flash attention directly on Q, K, V tensors
695
+ flash_attn_varlen_func (
696
+ q = query ,
697
+ k = key ,
698
+ v = value ,
699
+ out = output ,
700
+ cu_seqlens_q = cu_seqlens_q ,
701
+ cu_seqlens_k = cu_seqlens_k ,
702
+ max_seqlen_q = max_seqlen_q ,
703
+ max_seqlen_k = max_seqlen_k ,
587
704
softmax_scale = self .scale ,
705
+ causal = False , # Encoder attention is bidirectional
588
706
alibi_slopes = self .alibi_slopes ,
589
- sliding_window = self .sliding_window ,
590
- logits_soft_cap = self .logits_soft_cap ,
591
- block_table = attn_metadata .block_table ,
592
- common_prefix_len = attn_metadata .common_prefix_len ,
707
+ window_size = self .sliding_window ,
708
+ softcap = self .logits_soft_cap ,
593
709
fa_version = self .vllm_flash_attn_version ,
594
- prefix_scheduler_metadata = attn_metadata .prefix_scheduler_metadata ,
595
- suffix_scheduler_metadata = attn_metadata .scheduler_metadata ,
596
- q_descale = layer ._q_scale ,
597
- k_descale = layer ._k_scale ,
598
- v_descale = layer ._v_scale ,
710
+ q_descale = layer ._q_scale .expand (descale_shape ),
711
+ k_descale = layer ._k_scale .expand (descale_shape ),
712
+ v_descale = layer ._v_scale .expand (descale_shape ),
599
713
)
714
+
600
715
return output
601
716
602
717
0 commit comments