Skip to content

Commit 96be9ad

Browse files
committed
v1: Add Whisper model support (encoder-decoder)
This brings Whisper support to V1 to close one of the remaining feature gaps with V0. Most of the changes apply to encoder-decoder models generally, though Whisper is the only one explicitly tested and is the only encoder-decoder model updated to support V1. **Whisper Model Implementation:** - Remove SupportsV0Only interface constraint to enable V1 compatibility - Update get_multimodal_embeddings() to return list format required by V1 **Flash Attention Backend:** - Add encoder attention metadata fields (encoder_seq_start_loc, max_encoder_seq_len, cross_slot_mapping) - Implement encoder self-attention support without using KV cache - Add cross-attention support for encoder-decoder models with proper KV cache handling **KV Cache Manager:** - Introduce CrossAttentionManager for handling cross-attention KV cache in encoder-decoder models - Add CrossAttentionSpec for cross-attention cache specification with encoder-based sizing - Implement allocate_slots_for_cross_attn() for static encoder-length-based allocation - Add cross-attention block allocation logic separate from decoder token growth **Scheduler:** - Disable prefix caching for encoder-decoder models - Implement cross-attention block allocation during request scheduling - Add cross-attention block tracking in state management **GPU Model Runner:** - Add encoder input extraction for audio features processing - Implement encoder attention metadata building for both self-attention and cross-attention - Add cross-attention KV cache group handling with proper slot mapping - Modify input batch creation to accommodate encoder sequence lengths - Add encoder input processing in forward pass with proper device/dtype handling - Update profiling and memory management for encoder-decoder models The implementation maintains backward compatibility while adding comprehensive encoder-decoder support, with particular focus on Whisper's audio processing pipeline and cross-attention mechanisms between encoder and decoder. Related to: - V0 deprecation: #18571 - 2025 Q3 roadmap: #20336 Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent c586b55 commit 96be9ad

File tree

12 files changed

+593
-90
lines changed

12 files changed

+593
-90
lines changed

vllm/attention/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"AttentionMetadata",
1515
"AttentionType",
1616
"AttentionMetadataBuilder",
17-
"Attention",
1817
"AttentionState",
1918
"get_attn_backend",
2019
]

vllm/inputs/preprocess.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,6 @@ def preprocess(
869869
) -> ProcessorInputs:
870870
"""Preprocess the input prompt."""
871871
if self.model_config.is_encoder_decoder:
872-
assert not return_mm_hashes, (
873-
"Multimodal hashes for encoder-decoder models should not be ",
874-
"returned until they are supported on vLLM V1.")
875872
# Encoder-decoder model requires special mapping of
876873
# input prompts to encoder & decoder
877874
return self._process_encoder_decoder_prompt(
@@ -903,9 +900,6 @@ async def preprocess_async(
903900
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
904901
"""
905902
if self.model_config.is_encoder_decoder:
906-
assert not return_mm_hashes, (
907-
"Multimodal hashes for encoder-decoder models should not be ",
908-
"returned until they are supported on vLLM V1.")
909903
# Encoder-decoder model requires special mapping of
910904
# input prompts to encoder & decoder
911905
return await self._process_encoder_decoder_prompt_async(prompt)

vllm/model_executor/models/whisper.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from vllm.transformers_utils.processor import cached_get_processor
4040

4141
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
42-
SupportsTranscription, SupportsV0Only)
42+
SupportsTranscription)
4343
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
4444
make_layers)
4545

@@ -757,7 +757,7 @@ def _get_prompt_updates(
757757
info=WhisperProcessingInfo,
758758
dummy_inputs=WhisperDummyInputsBuilder)
759759
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
760-
SupportsMultiModal, SupportsV0Only):
760+
SupportsMultiModal):
761761
packed_modules_mapping = {
762762
"self_attn.qkv_proj": [
763763
"self_attn.q_proj",
@@ -879,10 +879,9 @@ def get_language_model(self) -> torch.nn.Module:
879879

880880
def get_multimodal_embeddings(self,
881881
**kwargs: object) -> MultiModalEmbeddings:
882-
# TODO: This method does not obey the interface for SupportsMultiModal.
883-
# Refactor this once encoder/decoder support is implemented in V1.
882+
# Required as part of SupportsMultiModal interface.
884883
audio_input = self._parse_and_validate_audio_input(**kwargs)
885-
return self.model.get_encoder_outputs(audio_input["input_features"])
884+
return [self.model.get_encoder_outputs(audio_input["input_features"])]
886885

887886
def get_input_embeddings(
888887
self,

vllm/v1/attention/backends/flash_attn.py

Lines changed: 151 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ class FlashAttentionMetadata:
134134
prefix_scheduler_metadata: Optional[torch.Tensor] = None
135135
max_num_splits: int = 0
136136

137+
# Begin encoder attn & enc/dec cross-attn fields...
138+
139+
# (batch_size + 1,). The cumulative sequence lengths of the encoder
140+
# sequences in the batch, used to index into sequence. E.g., if the sequence
141+
# length is [4, 6], it is [0, 4, 10].
142+
encoder_seq_start_loc: Optional[torch.Tensor] = None
143+
# Maximum sequence length among encoder sequences
144+
max_encoder_seq_len: Optional[int] = None
145+
cross_slot_mapping: Optional[torch.Tensor] = None
146+
137147
# for local attention
138148
@dataclass
139149
class LocalAttentionMetadata:
@@ -146,6 +156,14 @@ class LocalAttentionMetadata:
146156

147157
local_attn_metadata: Optional[LocalAttentionMetadata] = None
148158

159+
@property
160+
def is_all_encoder_attn_metadata_set(self) -> bool:
161+
"""
162+
All attention metadata required for encoder attention is set.
163+
"""
164+
return (self.encoder_seq_start_loc is not None
165+
and self.max_encoder_seq_len is not None)
166+
149167

150168
def _get_sliding_window_configs(
151169
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
@@ -212,14 +230,22 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
212230
self.aot_sliding_window: Optional[tuple[int, int]] = None
213231

214232
def build(
215-
self, common_prefix_len: int,
216-
common_attn_metadata: CommonAttentionMetadata
217-
) -> FlashAttentionMetadata:
233+
self,
234+
common_prefix_len: int,
235+
common_attn_metadata: CommonAttentionMetadata,
236+
# Encoder/cross-attention metadata (optional)
237+
encoder_seq_start_loc: Optional[torch.Tensor] = None,
238+
max_encoder_seq_len: Optional[int] = None,
239+
cross_slot_mapping: Optional[torch.Tensor] = None):
218240
num_reqs = common_attn_metadata.num_reqs
219241
num_actual_tokens = common_attn_metadata.num_actual_tokens
220242
max_query_len = common_attn_metadata.max_query_len
221243

222-
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
244+
if cross_slot_mapping is not None and max_encoder_seq_len is not None:
245+
# ENCODER_DECODER cross-attention
246+
max_seq_len = max_encoder_seq_len
247+
else:
248+
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
223249
query_start_loc = common_attn_metadata.query_start_loc
224250
seq_lens = common_attn_metadata.seq_lens
225251
block_table = self.block_table
@@ -379,6 +405,10 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
379405
local_attn_metadata=local_attn_metadata,
380406
prefix_scheduler_metadata=prefix_scheduler_metadata,
381407
max_num_splits=max_num_splits,
408+
# Encoder/cross-attention fields
409+
encoder_seq_start_loc=encoder_seq_start_loc,
410+
max_encoder_seq_len=max_encoder_seq_len,
411+
cross_slot_mapping=cross_slot_mapping,
382412
)
383413
return attn_metadata
384414

@@ -433,18 +463,32 @@ def __init__(
433463

434464
FlashAttentionBackend.validate_head_size(head_size)
435465

436-
if attn_type != AttentionType.DECODER:
437-
raise NotImplementedError("Encoder self-attention and "
438-
"encoder/decoder cross-attention "
439-
"are not implemented for "
440-
"FlashAttentionImpl")
466+
self.attn_type = attn_type
441467
self.use_irope = use_irope
442468
self.vllm_flash_attn_version = get_flash_attn_version()
443469
if is_quantized_kv_cache(self.kv_cache_dtype) \
444470
and not flash_attn_supports_fp8():
445471
raise NotImplementedError(
446472
"FlashAttention does not support fp8 kv-cache on this device.")
447473

474+
@staticmethod
475+
def _get_causal_option(attn_type: str) -> bool:
476+
"""
477+
Determine whether the given attention type is suitable for causal
478+
attention mechanisms.
479+
480+
Args:
481+
attn_type (AttentionType): The type of attention being evaluated
482+
483+
Returns:
484+
bool: Returns `True` if the attention type is suitable for causal
485+
attention (i.e., not encoder, encoder-only, or encoder-decoder),
486+
otherwise returns `False`.
487+
"""
488+
return not (attn_type == AttentionType.ENCODER
489+
or attn_type == AttentionType.ENCODER_ONLY
490+
or attn_type == AttentionType.ENCODER_DECODER)
491+
448492
def forward(
449493
self,
450494
layer: torch.nn.Module,
@@ -481,6 +525,14 @@ def forward(
481525
# Profiling run.
482526
return output
483527

528+
# Validate attention metadata based on attention type
529+
attn_type = self.attn_type
530+
if (attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_DECODER,
531+
AttentionType.ENCODER_ONLY)
532+
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
533+
raise AttributeError("Encoder attention requires setting "
534+
"encoder metadata attributes.")
535+
484536
# IMPORTANT!
485537
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
486538
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@@ -491,22 +543,40 @@ def forward(
491543
# performance to make sure it does not introduce any overhead.
492544

493545
num_actual_tokens = attn_metadata.num_actual_tokens
546+
547+
# Handle encoder attention differently - no KV cache needed
548+
if attn_type == AttentionType.ENCODER:
549+
# For encoder attention,
550+
# we use direct Q, K, V tensors without caching
551+
return self._forward_encoder_attention(query[:num_actual_tokens],
552+
key[:num_actual_tokens],
553+
value[:num_actual_tokens],
554+
output[:num_actual_tokens],
555+
attn_metadata, layer)
556+
557+
# For decoder and cross-attention, use KV cache as before
494558
key_cache, value_cache = kv_cache.unbind(0)
495559

496-
if self.kv_sharing_target_layer_name is None:
560+
if (self.kv_sharing_target_layer_name is None and (key is not None)
561+
and (value is not None)):
497562
# Reshape the input keys and values and store them in the cache.
498563
# Skip this if sharing KV cache with an earlier attention layer.
499564
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
500565
# not padded. However, we don't need to do key[:num_actual_tokens]
501566
# and value[:num_actual_tokens] because the reshape_and_cache_flash
502567
# op uses the slot_mapping's shape to determine the number of
503568
# actual tokens.
569+
if attn_type == AttentionType.ENCODER_DECODER:
570+
updated_slot_mapping = attn_metadata.cross_slot_mapping
571+
else:
572+
updated_slot_mapping = attn_metadata.slot_mapping
573+
504574
reshape_and_cache_flash(
505575
key,
506576
value,
507577
key_cache,
508578
value_cache,
509-
attn_metadata.slot_mapping,
579+
updated_slot_mapping,
510580
self.kv_cache_dtype,
511581
layer._k_scale,
512582
layer._v_scale,
@@ -544,7 +614,7 @@ def forward(
544614
block_table = attn_metadata.block_table
545615
scheduler_metadata = attn_metadata.scheduler_metadata
546616

547-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
617+
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
548618

549619
flash_attn_varlen_func(
550620
q=query[:num_actual_tokens],
@@ -556,7 +626,7 @@ def forward(
556626
seqused_k=seqused_k,
557627
max_seqlen_k=max_seqlen_k,
558628
softmax_scale=self.scale,
559-
causal=True,
629+
causal=FlashAttentionImpl._get_causal_option(attn_type),
560630
alibi_slopes=self.alibi_slopes,
561631
window_size=self.sliding_window,
562632
block_table=block_table,
@@ -570,33 +640,78 @@ def forward(
570640
)
571641
return output
572642

573-
assert not use_local_attn, (
574-
"Cascade attention does not support local attention.")
575-
# Cascade attention (rare case).
576-
cascade_attention(
577-
output[:num_actual_tokens],
578-
query[:num_actual_tokens],
579-
key_cache,
580-
value_cache,
581-
cu_query_lens=attn_metadata.query_start_loc,
582-
max_query_len=attn_metadata.max_query_len,
583-
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
584-
prefix_kv_lens=attn_metadata.prefix_kv_lens,
585-
suffix_kv_lens=attn_metadata.suffix_kv_lens,
586-
max_kv_len=attn_metadata.max_seq_len,
643+
def _forward_encoder_attention(
644+
self,
645+
query: torch.Tensor,
646+
key: torch.Tensor,
647+
value: torch.Tensor,
648+
output: torch.Tensor,
649+
attn_metadata: FlashAttentionMetadata,
650+
layer: torch.nn.Module,
651+
) -> torch.Tensor:
652+
"""Forward pass for encoder attention without KV cache.
653+
654+
Args:
655+
query: shape = [num_encoder_tokens, num_heads, head_size]
656+
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
657+
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
658+
output: shape = [num_encoder_tokens, num_heads, head_size]
659+
attn_metadata: Encoder attention metadata
660+
layer: The attention layer
661+
"""
662+
# For encoder attention, process FP8 quantization if needed
663+
if self.kv_cache_dtype.startswith("fp8"):
664+
num_tokens, num_heads, head_size = query.shape
665+
query, _ = ops.scaled_fp8_quant(
666+
query.reshape(
667+
(num_tokens, num_heads * head_size)).contiguous(),
668+
layer._q_scale)
669+
query = query.reshape((num_tokens, num_heads, head_size))
670+
671+
num_kv_tokens, num_kv_heads, head_size = key.shape
672+
key, _ = ops.scaled_fp8_quant(
673+
key.reshape(
674+
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
675+
layer._k_scale)
676+
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
677+
678+
value, _ = ops.scaled_fp8_quant(
679+
value.reshape(
680+
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
681+
layer._v_scale)
682+
value = value.reshape((num_kv_tokens, num_kv_heads, head_size))
683+
684+
# Use encoder-specific metadata for sequence information
685+
cu_seqlens_q = attn_metadata.encoder_seq_start_loc
686+
cu_seqlens_k = attn_metadata.encoder_seq_start_loc
687+
max_seqlen_q = attn_metadata.max_encoder_seq_len
688+
max_seqlen_k = attn_metadata.max_encoder_seq_len
689+
690+
descale_shape = (
691+
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
692+
self.num_kv_heads)
693+
694+
# Call flash attention directly on Q, K, V tensors
695+
flash_attn_varlen_func(
696+
q=query,
697+
k=key,
698+
v=value,
699+
out=output,
700+
cu_seqlens_q=cu_seqlens_q,
701+
cu_seqlens_k=cu_seqlens_k,
702+
max_seqlen_q=max_seqlen_q,
703+
max_seqlen_k=max_seqlen_k,
587704
softmax_scale=self.scale,
705+
causal=False, # Encoder attention is bidirectional
588706
alibi_slopes=self.alibi_slopes,
589-
sliding_window=self.sliding_window,
590-
logits_soft_cap=self.logits_soft_cap,
591-
block_table=attn_metadata.block_table,
592-
common_prefix_len=attn_metadata.common_prefix_len,
707+
window_size=self.sliding_window,
708+
softcap=self.logits_soft_cap,
593709
fa_version=self.vllm_flash_attn_version,
594-
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
595-
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
596-
q_descale=layer._q_scale,
597-
k_descale=layer._k_scale,
598-
v_descale=layer._v_scale,
710+
q_descale=layer._q_scale.expand(descale_shape),
711+
k_descale=layer._k_scale.expand(descale_shape),
712+
v_descale=layer._v_scale.expand(descale_shape),
599713
)
714+
600715
return output
601716

602717

0 commit comments

Comments
 (0)