Skip to content

Commit 71144d6

Browse files
committed
v1: Add Whisper model support (encoder-decoder)
This brings Whisper support to V1 to close one of the remaining feature gaps with V0. Most of the changes apply to encoder-decoder models generally, though Whisper is the only one explicitly tested and is the only encoder-decoder model updated to support V1. **Whisper Model Implementation:** - Remove SupportsV0Only interface constraint to enable V1 compatibility - Update get_multimodal_embeddings() to return list format required by V1 **Flash Attention Backend:** - Add encoder attention metadata fields (encoder_seq_start_loc, max_encoder_seq_len, cross_slot_mapping) - Implement encoder self-attention support without using KV cache - Add cross-attention support for encoder-decoder models with proper KV cache handling **KV Cache Manager:** - Introduce CrossAttentionManager for handling cross-attention KV cache in encoder-decoder models - Add CrossAttentionSpec for cross-attention cache specification with encoder-based sizing - Implement allocate_slots_for_cross_attn() for static encoder-length-based allocation - Add cross-attention block allocation logic separate from decoder token growth **Scheduler:** - Disable prefix caching for encoder-decoder models - Implement cross-attention block allocation during request scheduling - Add cross-attention block tracking in state management **GPU Model Runner:** - Add encoder input extraction for audio features processing - Implement encoder attention metadata building for both self-attention and cross-attention - Add cross-attention KV cache group handling with proper slot mapping - Modify input batch creation to accommodate encoder sequence lengths - Add encoder input processing in forward pass with proper device/dtype handling - Update profiling and memory management for encoder-decoder models The implementation maintains backward compatibility while adding comprehensive encoder-decoder support, with particular focus on Whisper's audio processing pipeline and cross-attention mechanisms between encoder and decoder. Related to: - V0 deprecation: #18571 - 2025 Q3 roadmap: #20336 Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent 0f199f1 commit 71144d6

File tree

14 files changed

+660
-70
lines changed

14 files changed

+660
-70
lines changed

vllm/attention/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"AttentionMetadata",
1515
"AttentionType",
1616
"AttentionMetadataBuilder",
17-
"Attention",
1817
"AttentionState",
1918
"get_attn_backend",
2019
]

vllm/attention/backends/flash_attn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,4 @@ def _get_causal_option(attn_type: str) -> bool:
10041004
attention (i.e., not encoder, encoder-only, or encoder-decoder),
10051005
otherwise returns `False`.
10061006
"""
1007-
return not (attn_type == AttentionType.ENCODER
1008-
or attn_type == AttentionType.ENCODER_ONLY
1009-
or attn_type == AttentionType.ENCODER_DECODER)
1007+
return attn_type == AttentionType.DECODER

vllm/inputs/preprocess.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,6 @@ def preprocess(
869869
) -> ProcessorInputs:
870870
"""Preprocess the input prompt."""
871871
if self.model_config.is_encoder_decoder:
872-
assert not return_mm_hashes, (
873-
"Multimodal hashes for encoder-decoder models should not be ",
874-
"returned until they are supported on vLLM V1.")
875872
# Encoder-decoder model requires special mapping of
876873
# input prompts to encoder & decoder
877874
return self._process_encoder_decoder_prompt(
@@ -903,9 +900,6 @@ async def preprocess_async(
903900
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
904901
"""
905902
if self.model_config.is_encoder_decoder:
906-
assert not return_mm_hashes, (
907-
"Multimodal hashes for encoder-decoder models should not be ",
908-
"returned until they are supported on vLLM V1.")
909903
# Encoder-decoder model requires special mapping of
910904
# input prompts to encoder & decoder
911905
return await self._process_encoder_decoder_prompt_async(prompt)

vllm/model_executor/models/whisper.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.transformers_utils.processor import cached_get_processor
4343

4444
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
45-
SupportsTranscription, SupportsV0Only)
45+
SupportsTranscription)
4646
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
4747
make_layers)
4848

@@ -790,7 +790,7 @@ def _get_prompt_updates(
790790
info=WhisperProcessingInfo,
791791
dummy_inputs=WhisperDummyInputsBuilder)
792792
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
793-
SupportsMultiModal, SupportsV0Only):
793+
SupportsMultiModal):
794794
packed_modules_mapping = {
795795
"self_attn.qkv_proj": [
796796
"self_attn.q_proj",
@@ -916,10 +916,9 @@ def get_language_model(self) -> torch.nn.Module:
916916

917917
def get_multimodal_embeddings(self,
918918
**kwargs: object) -> MultiModalEmbeddings:
919-
# TODO: This method does not obey the interface for SupportsMultiModal.
920-
# Refactor this once encoder/decoder support is implemented in V1.
919+
# Required as part of SupportsMultiModal interface.
921920
audio_input = self._parse_and_validate_audio_input(**kwargs)
922-
return self.model.get_encoder_outputs(audio_input["input_features"])
921+
return [self.model.get_encoder_outputs(audio_input["input_features"])]
923922

924923
def get_input_embeddings(
925924
self,

vllm/v1/attention/backends/flash_attn.py

Lines changed: 155 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,24 @@ class FlashAttentionMetadata:
130130
prefix_scheduler_metadata: Optional[torch.Tensor] = None
131131
max_num_splits: int = 0
132132

133+
# Begin encoder attn & enc/dec cross-attn fields...
134+
135+
# (batch_size + 1,). The cumulative sequence lengths of the encoder
136+
# sequences in the batch, used to index into sequence. E.g., if the sequence
137+
# length is [4, 6], it is [0, 4, 10].
138+
encoder_seq_start_loc: Optional[torch.Tensor] = None
139+
# Maximum sequence length among encoder sequences
140+
max_encoder_seq_len: Optional[int] = None
141+
cross_slot_mapping: Optional[torch.Tensor] = None
142+
143+
@property
144+
def is_all_encoder_attn_metadata_set(self) -> bool:
145+
"""
146+
All attention metadata required for encoder attention is set.
147+
"""
148+
return (self.encoder_seq_start_loc is not None
149+
and self.max_encoder_seq_len is not None)
150+
133151

134152
def _get_sliding_window_configs(
135153
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
@@ -207,7 +225,13 @@ def build(self,
207225
num_reqs = common_attn_metadata.num_reqs
208226
num_actual_tokens = common_attn_metadata.num_actual_tokens
209227
max_query_len = common_attn_metadata.max_query_len
210-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
228+
229+
if (common_attn_metadata.cross_slot_mapping is not None
230+
and common_attn_metadata.max_encoder_seq_len is not None):
231+
# ENCODER_DECODER cross-attention
232+
max_seq_len = common_attn_metadata.max_encoder_seq_len
233+
else:
234+
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
211235
query_start_loc = common_attn_metadata.query_start_loc
212236
seq_lens = common_attn_metadata.seq_lens
213237
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
@@ -326,6 +350,10 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
326350
suffix_kv_lens=suffix_kv_lens,
327351
prefix_scheduler_metadata=prefix_scheduler_metadata,
328352
max_num_splits=max_num_splits,
353+
# Encoder/cross-attention fields
354+
encoder_seq_start_loc=common_attn_metadata.encoder_seq_start_loc,
355+
max_encoder_seq_len=common_attn_metadata.max_encoder_seq_len,
356+
cross_slot_mapping=common_attn_metadata.cross_slot_mapping,
329357
)
330358
return attn_metadata
331359

@@ -380,18 +408,32 @@ def __init__(
380408

381409
FlashAttentionBackend.validate_head_size(head_size)
382410

383-
if attn_type != AttentionType.DECODER:
384-
raise NotImplementedError("Encoder self-attention and "
385-
"encoder/decoder cross-attention "
386-
"are not implemented for "
387-
"FlashAttentionImpl")
411+
self.attn_type = attn_type
388412
self.use_irope = use_irope
389413
self.vllm_flash_attn_version = get_flash_attn_version()
390414
if is_quantized_kv_cache(self.kv_cache_dtype) \
391415
and not flash_attn_supports_fp8():
392416
raise NotImplementedError(
393417
"FlashAttention does not support fp8 kv-cache on this device.")
394418

419+
@staticmethod
420+
def _get_causal_option(attn_type: str) -> bool:
421+
"""
422+
Determine whether the given attention type is suitable for causal
423+
attention mechanisms.
424+
425+
Args:
426+
attn_type (AttentionType): The type of attention being evaluated
427+
428+
Returns:
429+
bool: Returns `True` if the attention type is suitable for causal
430+
attention (i.e., not encoder, encoder-only, or encoder-decoder),
431+
otherwise returns `False`.
432+
"""
433+
return not (attn_type == AttentionType.ENCODER
434+
or attn_type == AttentionType.ENCODER_ONLY
435+
or attn_type == AttentionType.ENCODER_DECODER)
436+
395437
def forward(
396438
self,
397439
layer: torch.nn.Module,
@@ -428,6 +470,14 @@ def forward(
428470
# Profiling run.
429471
return output
430472

473+
# Validate attention metadata based on attention type
474+
attn_type = self.attn_type
475+
if (attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_DECODER,
476+
AttentionType.ENCODER_ONLY)
477+
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
478+
raise AttributeError("Encoder attention requires setting "
479+
"encoder metadata attributes.")
480+
431481
# IMPORTANT!
432482
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
433483
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@@ -438,22 +488,40 @@ def forward(
438488
# performance to make sure it does not introduce any overhead.
439489

440490
num_actual_tokens = attn_metadata.num_actual_tokens
491+
492+
# Handle encoder attention differently - no KV cache needed
493+
if attn_type == AttentionType.ENCODER:
494+
# For encoder attention,
495+
# we use direct Q, K, V tensors without caching
496+
return self._forward_encoder_attention(query[:num_actual_tokens],
497+
key[:num_actual_tokens],
498+
value[:num_actual_tokens],
499+
output[:num_actual_tokens],
500+
attn_metadata, layer)
501+
502+
# For decoder and cross-attention, use KV cache as before
441503
key_cache, value_cache = kv_cache.unbind(0)
442504

443-
if self.kv_sharing_target_layer_name is None:
505+
if (self.kv_sharing_target_layer_name is None and (key is not None)
506+
and (value is not None)):
444507
# Reshape the input keys and values and store them in the cache.
445508
# Skip this if sharing KV cache with an earlier attention layer.
446509
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
447510
# not padded. However, we don't need to do key[:num_actual_tokens]
448511
# and value[:num_actual_tokens] because the reshape_and_cache_flash
449512
# op uses the slot_mapping's shape to determine the number of
450513
# actual tokens.
514+
if attn_type == AttentionType.ENCODER_DECODER:
515+
updated_slot_mapping = attn_metadata.cross_slot_mapping
516+
else:
517+
updated_slot_mapping = attn_metadata.slot_mapping
518+
451519
reshape_and_cache_flash(
452520
key,
453521
value,
454522
key_cache,
455523
value_cache,
456-
attn_metadata.slot_mapping,
524+
updated_slot_mapping,
457525
self.kv_cache_dtype,
458526
layer._k_scale,
459527
layer._v_scale,
@@ -477,7 +545,7 @@ def forward(
477545
block_table = attn_metadata.block_table
478546
scheduler_metadata = attn_metadata.scheduler_metadata
479547

480-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
548+
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
481549

482550
flash_attn_varlen_func(
483551
q=query[:num_actual_tokens],
@@ -489,7 +557,7 @@ def forward(
489557
seqused_k=seqused_k,
490558
max_seqlen_k=max_seqlen_k,
491559
softmax_scale=self.scale,
492-
causal=True,
560+
causal=FlashAttentionImpl._get_causal_option(attn_type),
493561
alibi_slopes=self.alibi_slopes,
494562
window_size=self.sliding_window,
495563
block_table=block_table,
@@ -524,12 +592,86 @@ def forward(
524592
fa_version=self.vllm_flash_attn_version,
525593
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
526594
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
527-
q_descale=layer._q_scale,
528-
k_descale=layer._k_scale,
529-
v_descale=layer._v_scale,
595+
q_descale=layer._q_scale.expand(descale_shape),
596+
k_descale=layer._k_scale.expand(descale_shape),
597+
v_descale=layer._v_scale.expand(descale_shape),
530598
)
531599
return output
532600

601+
def _forward_encoder_attention(
602+
self,
603+
query: torch.Tensor,
604+
key: torch.Tensor,
605+
value: torch.Tensor,
606+
output: torch.Tensor,
607+
attn_metadata: FlashAttentionMetadata,
608+
layer: torch.nn.Module,
609+
) -> torch.Tensor:
610+
"""Forward pass for encoder attention without KV cache.
611+
612+
Args:
613+
query: shape = [num_encoder_tokens, num_heads, head_size]
614+
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
615+
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
616+
output: shape = [num_encoder_tokens, num_heads, head_size]
617+
attn_metadata: Encoder attention metadata
618+
layer: The attention layer
619+
"""
620+
# For encoder attention, process FP8 quantization if needed
621+
if self.kv_cache_dtype.startswith("fp8"):
622+
num_tokens, num_heads, head_size = query.shape
623+
query, _ = ops.scaled_fp8_quant(
624+
query.reshape(
625+
(num_tokens, num_heads * head_size)).contiguous(),
626+
layer._q_scale)
627+
query = query.reshape((num_tokens, num_heads, head_size))
628+
629+
num_kv_tokens, num_kv_heads, head_size = key.shape
630+
key, _ = ops.scaled_fp8_quant(
631+
key.reshape(
632+
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
633+
layer._k_scale)
634+
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
635+
636+
value, _ = ops.scaled_fp8_quant(
637+
value.reshape(
638+
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
639+
layer._v_scale)
640+
value = value.reshape((num_kv_tokens, num_kv_heads, head_size))
641+
642+
# Use encoder-specific metadata for sequence information
643+
cu_seqlens_q = attn_metadata.encoder_seq_start_loc
644+
cu_seqlens_k = attn_metadata.encoder_seq_start_loc
645+
max_seqlen_q = attn_metadata.max_encoder_seq_len
646+
max_seqlen_k = attn_metadata.max_encoder_seq_len
647+
648+
descale_shape = (
649+
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
650+
self.num_kv_heads)
651+
652+
# Call flash attention directly on Q, K, V tensors
653+
flash_attn_varlen_func(
654+
q=query,
655+
k=key,
656+
v=value,
657+
out=output,
658+
cu_seqlens_q=cu_seqlens_q,
659+
cu_seqlens_k=cu_seqlens_k,
660+
max_seqlen_q=max_seqlen_q,
661+
max_seqlen_k=max_seqlen_k,
662+
softmax_scale=self.scale,
663+
causal=False, # Encoder attention is bidirectional
664+
alibi_slopes=self.alibi_slopes,
665+
window_size=self.sliding_window,
666+
softcap=self.logits_soft_cap,
667+
fa_version=self.vllm_flash_attn_version,
668+
q_descale=layer._q_scale.expand(descale_shape),
669+
k_descale=layer._k_scale.expand(descale_shape),
670+
v_descale=layer._v_scale.expand(descale_shape),
671+
)
672+
673+
return output
674+
533675

534676
def use_cascade_attention(
535677
common_prefix_len: int,

vllm/v1/attention/backends/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ class CommonAttentionMetadata:
5959
block_table_tensor: torch.Tensor
6060
slot_mapping: torch.Tensor
6161

62+
# Encoder/cross-attention specific fields (optional)
63+
encoder_seq_start_loc: Optional[torch.Tensor] = None
64+
"""(batch_size + 1,), cumulative encoder sequence lengths"""
65+
max_encoder_seq_len: Optional[int] = None
66+
"""Maximum encoder sequence length in batch"""
67+
cross_slot_mapping: Optional[torch.Tensor] = None
68+
"""Slot mapping for cross-attention KV cache"""
69+
6270
def __post_init__(self):
6371
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
6472
# mode.

0 commit comments

Comments
 (0)