From f62a66e20e5adddee3df28a19bf7394291c0cce4 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 2 Jul 2025 00:41:57 +0000 Subject: [PATCH] v1: Add Whisper model support (encoder-decoder) This brings Whisper support to V1 to close one of the remaining feature gaps with V0. Most of the changes apply to encoder-decoder models generally, though Whisper is the only one explicitly tested and is the only encoder-decoder model updated to support V1. **Whisper Model Implementation:** - Remove SupportsV0Only interface constraint to enable V1 compatibility - Update get_multimodal_embeddings() to return list format required by V1 **Flash Attention Backend:** - Add encoder attention metadata fields (encoder_seq_start_loc, max_encoder_seq_len, cross_slot_mapping) - Implement encoder self-attention support without using KV cache - Add cross-attention support for encoder-decoder models with proper KV cache handling **KV Cache Manager:** - Introduce CrossAttentionManager for handling cross-attention KV cache in encoder-decoder models - Add CrossAttentionSpec for cross-attention cache specification with encoder-based sizing - Implement allocate_slots_for_cross_attn() for static encoder-length-based allocation - Add cross-attention block allocation logic separate from decoder token growth **Scheduler:** - Disable prefix caching for encoder-decoder models - Implement cross-attention block allocation during request scheduling - Add cross-attention block tracking in state management **GPU Model Runner:** - Add encoder input extraction for audio features processing - Implement encoder attention metadata building for both self-attention and cross-attention - Add cross-attention KV cache group handling with proper slot mapping - Modify input batch creation to accommodate encoder sequence lengths - Add encoder input processing in forward pass with proper device/dtype handling - Update profiling and memory management for encoder-decoder models The implementation maintains backward compatibility while adding comprehensive encoder-decoder support, with particular focus on Whisper's audio processing pipeline and cross-attention mechanisms between encoder and decoder. Related to: - V0 deprecation: #18571 - 2025 Q3 roadmap: #20336 Signed-off-by: Russell Bryant --- vllm/attention/__init__.py | 1 - vllm/inputs/preprocess.py | 6 - vllm/model_executor/models/whisper.py | 9 +- vllm/v1/attention/backends/flash_attn.py | 177 ++++++++-- vllm/v1/attention/backends/utils.py | 8 + vllm/v1/core/kv_cache_coordinator.py | 31 +- vllm/v1/core/kv_cache_manager.py | 39 +++ vllm/v1/core/sched/scheduler.py | 33 +- vllm/v1/core/single_type_kv_cache_manager.py | 56 ++- vllm/v1/engine/processor.py | 5 - vllm/v1/kv_cache_interface.py | 18 + vllm/v1/worker/gpu_model_runner.py | 338 ++++++++++++++++++- vllm/v1/worker/utils.py | 14 +- 13 files changed, 662 insertions(+), 73 deletions(-) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 344040586a53..dcb2aa68fbee 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -14,7 +14,6 @@ "AttentionMetadata", "AttentionType", "AttentionMetadataBuilder", - "Attention", "AttentionState", "get_attn_backend", ] diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index deda9bc23daf..103776fb059d 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -869,9 +869,6 @@ def preprocess( ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( @@ -903,9 +900,6 @@ async def preprocess_async( [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. """ if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async(prompt) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index d98dab5fac0e..b04e683b23de 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -42,7 +42,7 @@ from vllm.transformers_utils.processor import cached_get_processor from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription, SupportsV0Only) + SupportsTranscription) from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, make_layers) @@ -790,7 +790,7 @@ def _get_prompt_updates( info=WhisperProcessingInfo, dummy_inputs=WhisperDummyInputsBuilder) class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, - SupportsMultiModal, SupportsV0Only): + SupportsMultiModal): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", @@ -916,10 +916,9 @@ def get_language_model(self) -> torch.nn.Module: def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: - # TODO: This method does not obey the interface for SupportsMultiModal. - # Refactor this once encoder/decoder support is implemented in V1. + # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) - return self.model.get_encoder_outputs(audio_input["input_features"]) + return [self.model.get_encoder_outputs(audio_input["input_features"])] def get_input_embeddings( self, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d5b30ac685ac..219448a3e6e8 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -130,6 +130,24 @@ class FlashAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None max_num_splits: int = 0 + # Begin encoder attn & enc/dec cross-attn fields... + + # (batch_size + 1,). The cumulative sequence lengths of the encoder + # sequences in the batch, used to index into sequence. E.g., if the sequence + # length is [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + cross_slot_mapping: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self) -> bool: + """ + All attention metadata required for encoder attention is set. + """ + return (self.encoder_seq_start_loc is not None + and self.max_encoder_seq_len is not None) + def _get_sliding_window_configs( vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: @@ -207,7 +225,13 @@ def build(self, num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + + if (common_attn_metadata.cross_slot_mapping is not None + and common_attn_metadata.max_encoder_seq_len is not None): + # ENCODER_DECODER cross-attention + max_seq_len = common_attn_metadata.max_encoder_seq_len + else: + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu @@ -326,6 +350,10 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, + # Encoder/cross-attention fields + encoder_seq_start_loc=common_attn_metadata.encoder_seq_start_loc, + max_encoder_seq_len=common_attn_metadata.max_encoder_seq_len, + cross_slot_mapping=common_attn_metadata.cross_slot_mapping, ) return attn_metadata @@ -380,11 +408,7 @@ def __init__( FlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") + self.attn_type = attn_type self.use_irope = use_irope self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -392,6 +416,24 @@ def __init__( raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") + @staticmethod + def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) + def forward( self, layer: torch.nn.Module, @@ -428,6 +470,14 @@ def forward( # Profiling run. return output + # Validate attention metadata based on attention type + attn_type = self.attn_type + if (attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_DECODER, + AttentionType.ENCODER_ONLY) + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -438,9 +488,22 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens + + # Handle encoder attention differently - no KV cache needed + if attn_type == AttentionType.ENCODER: + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention(query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, layer) + + # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + if (self.kv_sharing_target_layer_name is None and (key is not None) + and (value is not None)): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -448,12 +511,17 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. + if attn_type == AttentionType.ENCODER_DECODER: + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + updated_slot_mapping = attn_metadata.slot_mapping + reshape_and_cache_flash( key, value, key_cache, value_cache, - attn_metadata.slot_mapping, + updated_slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, @@ -477,7 +545,7 @@ def forward( block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) flash_attn_varlen_func( q=query[:num_actual_tokens], @@ -489,7 +557,7 @@ def forward( seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, - causal=True, + causal=FlashAttentionImpl._get_causal_option(attn_type), alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, @@ -516,20 +584,91 @@ def forward( suffix_kv_lens=attn_metadata.suffix_kv_lens, max_kv_len=attn_metadata.max_seq_len, softmax_scale=self.scale, + causal=True, alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window, - logits_soft_cap=self.logits_soft_cap, - block_table=attn_metadata.block_table, - common_prefix_len=attn_metadata.common_prefix_len, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, fa_version=self.vllm_flash_attn_version, - prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, - suffix_scheduler_metadata=attn_metadata.scheduler_metadata, - q_descale=layer._q_scale, - k_descale=layer._k_scale, - v_descale=layer._v_scale, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) return output + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + layer: torch.nn.Module, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache. + + Args: + query: shape = [num_encoder_tokens, num_heads, head_size] + key: shape = [num_encoder_tokens, num_kv_heads, head_size] + value: shape = [num_encoder_tokens, num_kv_heads, head_size] + output: shape = [num_encoder_tokens, num_heads, head_size] + attn_metadata: Encoder attention metadata + layer: The attention layer + """ + # For encoder attention, process FP8 quantization if needed + if self.kv_cache_dtype.startswith("fp8"): + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + + num_kv_tokens, num_kv_heads, head_size = key.shape + key, _ = ops.scaled_fp8_quant( + key.reshape( + (num_kv_tokens, num_kv_heads * head_size)).contiguous(), + layer._k_scale) + key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) + + value, _ = ops.scaled_fp8_quant( + value.reshape( + (num_kv_tokens, num_kv_heads * head_size)).contiguous(), + layer._v_scale) + value = value.reshape((num_kv_tokens, num_kv_heads, head_size)) + + # Use encoder-specific metadata for sequence information + cu_seqlens_q = attn_metadata.encoder_seq_start_loc + cu_seqlens_k = attn_metadata.encoder_seq_start_loc + max_seqlen_q = attn_metadata.max_encoder_seq_len + max_seqlen_k = attn_metadata.max_encoder_seq_len + + descale_shape = ( + cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] + self.num_kv_heads) + + # Call flash attention directly on Q, K, V tensors + flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=output, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=False, # Encoder attention is bidirectional + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + return output + def use_cascade_attention( common_prefix_len: int, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b6a06b17bca2..6179f208c05e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -59,6 +59,14 @@ class CommonAttentionMetadata: block_table_tensor: torch.Tensor slot_mapping: torch.Tensor + # Encoder/cross-attention specific fields (optional) + encoder_seq_start_loc: Optional[torch.Tensor] = None + """(batch_size + 1,), cumulative encoder sequence lengths""" + max_encoder_seq_len: Optional[int] = None + """Maximum encoder sequence length in batch""" + cross_slot_mapping: Optional[torch.Tensor] = None + """Slot mapping for cross-attention KV cache""" + def __post_init__(self): # Fill unused with -1. Needed for reshape_and_cache in full cuda graph # mode. diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index de72e60434ad..3d94ffd9cf3a 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -6,7 +6,7 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - FullAttentionManager, get_manager_for_kv_cache_spec) + CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.request import Request @@ -43,9 +43,12 @@ def __init__( ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) - def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int: + def get_num_blocks_to_allocate(self, + request_id: str, + num_tokens: int, + new_computed_blocks: tuple[ + list[KVCacheBlock], ...], + cross_attn: bool = False) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -61,8 +64,14 @@ def get_num_blocks_to_allocate( """ num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): - num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i]) + if cross_attn and isinstance(manager, CrossAttentionManager): + # For cross-attention, we issue a single static allocation + # of blocks based on the number of encoder input tokens. + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, []) + elif not cross_attn: + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks[i]) return num_blocks_to_allocate def save_new_computed_blocks( @@ -80,8 +89,11 @@ def save_new_computed_blocks( manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> tuple[list[KVCacheBlock], ...]: + def allocate_new_blocks( + self, + request_id: str, + num_tokens: int, + cross_attn: bool = False) -> tuple[list[KVCacheBlock], ...]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -95,7 +107,8 @@ def allocate_new_blocks(self, request_id: str, The new allocated blocks. """ return tuple( - manager.allocate_new_blocks(request_id, num_tokens) + (manager.allocate_new_blocks(request_id, num_tokens) if isinstance( + manager, CrossAttentionManager) == cross_attn else []) for manager in self.single_type_managers) def cache_blocks(self, request: Request, block_hashes: list[BlockHash], diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index e820a0ad6d5d..d54e64cb4232 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -307,6 +307,45 @@ def allocate_slots( return KVCacheBlocks(new_blocks) + def allocate_slots_for_cross_attn( + self, + request: Request, + num_encoder_tokens: int, + ) -> Optional[KVCacheBlocks]: + """Add slots for cross-attention blocks. + + This is separate from the main `allocate_slots` function because + cross-attention blocks are allocated based on the max encoder length, + which is a static value. The number of blocks to allocate is not + affected by the number of decoder tokens. + + Args: + request: The request to allocate slots. + num_encoder_tokens: The number of tokens sent to the encoder. + + Returns: + A list of new allocated blocks. + """ + if num_encoder_tokens == 0: + raise ValueError("num_encoder_tokens must be greater than 0") + + num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_encoder_tokens, + new_computed_blocks=tuple(), + cross_attn=True, + ) + + if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + # Cannot allocate new blocks + return None + + new_blocks = self.coordinator.allocate_new_blocks(request.request_id, + num_encoder_tokens, + cross_attn=True) + + return KVCacheBlocks(new_blocks) + def free(self, request: Request) -> None: """Free the blocks allocated for the request. We free the blocks in reverse order so that he tail blocks are evicted diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 446f98034cb8..a73c8fbf6b3c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,7 +19,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -58,6 +58,7 @@ def __init__( self.parallel_config = vllm_config.parallel_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager + self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder # include_finished_set controls whether a separate set of finished # request ids should be included in the EngineCoreOutputs returned @@ -150,11 +151,17 @@ def __init__( self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens + enable_caching = self.cache_config.enable_prefix_caching or False + if self.is_encoder_decoder: + # prefix caching for encoder-decoder models is not currently + # supported + enable_caching = False + # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - enable_caching=self.cache_config.enable_prefix_caching, + enable_caching=enable_caching, caching_hash_algo=self.cache_config.prefix_caching_hash_algo, use_eagle=self.use_eagle, log_stats=self.log_stats, @@ -399,6 +406,7 @@ def schedule(self) -> SchedulerOutput: encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget + new_cross_blocks: Optional[KVCacheBlocks] = None # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: @@ -436,6 +444,22 @@ def schedule(self) -> SchedulerOutput: if num_new_tokens == 0: # The request cannot be scheduled. break + if self.is_encoder_decoder: + # For encoder-decoder models, we allocate slots for + # the cross-attention blocks based on the max + # encoder length. This is a single static allocation + # and does not grow with the number of decoder + # tokens. + max_encoder_len = (self.vllm_config.model_config. + hf_config.max_source_positions) + new_cross_blocks = (self.kv_cache_manager. + allocate_slots_for_cross_attn( + request, + max_encoder_len, + )) + if new_cross_blocks is None: + # The request cannot be scheduled. + break new_blocks = self.kv_cache_manager.allocate_slots( request, @@ -454,9 +478,12 @@ def schedule(self) -> SchedulerOutput: # This information is used to determine if a load is # needed for this request. if self.connector is not None: + update_blocks = new_computed_blocks + new_blocks + if new_cross_blocks is not None: + update_blocks += new_cross_blocks self.connector.update_state_after_alloc( request, - new_computed_blocks + new_blocks, + update_blocks, num_external_computed_tokens, ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 1560406c9004..423a9380b582 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -8,8 +8,9 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) + CrossAttentionSpec, FullAttentionSpec, + KVCacheSpec, MambaSpec, + SlidingWindowSpec) from vllm.v1.request import Request @@ -433,11 +434,62 @@ def allocate_new_blocks(self, request_id: str, return new_blocks +class CrossAttentionManager(SingleTypeKVCacheManager): + """Manager for cross-attention KV cache in encoder-decoder models.""" + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + # We do not allocate blocks as decoder tokens are generated, so this + # method is not relevant. + pass + + def cache_blocks(self, request: Request, block_hashes: list[BlockHash], + num_tokens: int) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so this method is not relevant. + pass + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + # Cross-attention blocks contain request-specific encoder states + # and are not shared between different requests + return 0 + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> tuple[list[KVCacheBlock], ...]: + assert isinstance(kv_cache_spec, CrossAttentionSpec), ( + "CrossAttentionManager can only be used for cross-attention groups" + ) + # Cross-attention does not benefit from prefix caching since: + # 1. Encoder states are unique per request (different audio/image + # inputs) + # 2. Encoder states are computed once per request, not incrementally + # 3. No reusable prefix exists between different multimodal inputs + # Return empty blocks to indicate no cache hits + return tuple([] for _ in range(len(kv_cache_group_ids))) + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # Cross-attention blocks represent encoder states which are needed + # for the entire decoding process, so no blocks should be skipped + pass + + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, ChunkedLocalAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, MambaSpec: MambaManager, + CrossAttentionSpec: CrossAttentionManager, } diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7af4ed54a220..99ed34ddb5b7 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -232,7 +232,6 @@ def process_inputs( ) -> tuple[Optional[str], EngineCoreRequest]: # TODO(woosuk): Support pooling models. - # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) self._validate_params(params, lora_request) if trace_headers is not None: @@ -273,10 +272,6 @@ def process_inputs( encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - # TODO: Impl encoder-decoder - if encoder_inputs is not None: - raise NotImplementedError - sampling_params = None pooling_params = None if isinstance(params, SamplingParams): diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6726709955f7..1bbba80fc85f 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -198,6 +198,24 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return self.page_size_bytes +@dataclass +class CrossAttentionSpec(AttentionSpec): + """ + KV cache spec for cross-attention layers in encoder-decoder models. + """ + + @property + def type_id(self) -> str: + return f"cross_attention_{self.block_size}_{self.page_size_bytes}" + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # For cross-attention, we need to cache encoder states + # Use max_source_positions for encoder length (e.g., 1500 for Whisper) + max_encoder_len = ( + vllm_config.model_config.hf_config.max_source_positions) + return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c3eeb6c2e390..5374bfd48783 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -15,6 +15,7 @@ import vllm.envs as envs from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, @@ -51,8 +52,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, + CrossAttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -148,6 +149,13 @@ def __init__( ) self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size + if self.model_config.is_encoder_decoder: + # If specified in the model config, this attribute defines the + # maximum length of the encoder input. + self.max_encoder_len = getattr(self.model_config.hf_config, + 'max_source_positions', 0) + else: + self.max_encoder_len = 0 # Sampler self.sampler = Sampler() @@ -206,7 +214,10 @@ def __init__( # the block_sizes in the kv cache config. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + # We need to use the encoder length for encoder-decoer + # because of KV cache for cross-attention. + # TODO(russellb): Is this correct? + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -696,6 +707,24 @@ def _prepare_inputs( spec_decode_common_attn_metadata = None attn_metadata: dict[str, Any] = {} + encoder_attn_metadata: dict[str, Any] = {} + + # Prepare encoder attention metadata separately + # (encoder layers are not in KV cache groups) + # This is only necessary when there are encoder inputs + # to process. Otherwise, encoder attention won't run. + if (self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs): + encoder_attn_metadata = self._build_encoder_attn_metadata( + scheduler_output) + + # Add encoder attention metadata for all encoder layers + attention_layers = get_layers_from_vllm_config( + self.vllm_config, Attention) + for layer_name, attn_module in attention_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER: + attn_metadata[layer_name] = encoder_attn_metadata + # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -718,6 +747,12 @@ def _prepare_inputs( slot_mapping=slot_mapping, ) + is_enc_dec = isinstance(kv_cache_group_spec.kv_cache_spec, + CrossAttentionSpec) + if is_enc_dec: + encoder_attn_metadata = self._build_encoder_attn_metadata( + scheduler_output, common_attn_metadata) + if self.speculative_config and \ spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata @@ -740,10 +775,11 @@ def _prepare_inputs( builder, ) - attn_metadata_i = (builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - )) + attn_metadata_i = (encoder_attn_metadata + if is_enc_dec else builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + )) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1103,6 +1139,90 @@ def _gather_mm_embeddings( mm_embeds.append(mm_embeds_item) return mm_embeds + def _extract_encoder_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, torch.Tensor]: + """Extract encoder inputs for encoder-decoder models like Whisper. + + This method extracts audio input features and creates encoder positions + from scheduled encoder inputs. These are only needed when the encoder + needs to process new MM inputs (typically on the first processing step). + """ + input_features_list = [] + total_encoder_tokens = 0 + + for req_id, encoder_input_ids in ( + scheduler_output.scheduled_encoder_inputs.items()): + req_state = self.requests[req_id] + + for mm_input_id in encoder_input_ids: + if mm_input_id < len(req_state.mm_inputs): + mm_input = req_state.mm_inputs[mm_input_id] + # Extract input_features from MM input kwargs + if "input_features" in mm_input: + features = mm_input["input_features"] + input_features_list.append(features) + # Calculate encoder sequence length for this input + if isinstance(features, torch.Tensor): + # For Whisper: use max_source_positions from config + # which represents the encoder sequence length + encoder_seq_len = getattr( + self.model_config.hf_config, + 'max_source_positions', 1500) + total_encoder_tokens += encoder_seq_len + elif isinstance(features, list): + encoder_seq_len = getattr( + self.model_config.hf_config, + 'max_source_positions', 1500) + total_encoder_tokens += (len(features) * + encoder_seq_len) + + if not input_features_list: + return {} + + # Concatenate all input features into a single tensor + if len(input_features_list) == 1 and isinstance( + input_features_list[0], torch.Tensor): + input_features = input_features_list[0] + # Ensure we have the correct 4D shape + # [batch, channels, mel_bins, time] + if input_features.dim() == 3: + # Add batch dim: [ch, mel, time] -> [1, ch, mel, time] + input_features = input_features.unsqueeze(0) + else: + # Handle list of tensors + processed_features = [] + for feat in input_features_list: + if isinstance(feat, torch.Tensor): + # Ensure 4D shape + if feat.dim() == 3: + feat = feat.unsqueeze(0) + processed_features.append(feat) + else: + processed_features.append(torch.stack(feat)) + input_features = torch.cat(processed_features) + + # Move input_features to the correct device and dtype + input_features = input_features.to(device=self.device, + dtype=self.model_config.dtype) + + # Create encoder positions (similar to how V0 does it) + encoder_positions = torch.arange(total_encoder_tokens, + dtype=torch.long, + device=self.device) + + # Create encoder input_ids (dummy tokens for encoder) + encoder_input_ids = torch.zeros(total_encoder_tokens, + dtype=torch.long, + device=self.device) + + return { + "input_features": input_features, + "encoder_input_ids": encoder_input_ids, + "encoder_positions": encoder_positions, + } + def get_model(self) -> nn.Module: return self.model @@ -1340,14 +1460,16 @@ def execute_model( # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if self.is_multimodal_model: + if (self.is_multimodal_model + and not self.model_config.is_encoder_decoder): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - if self.is_multimodal_model and get_pp_group().is_first_rank: + if self.is_multimodal_model and get_pp_group().is_first_rank and ( + not self.model_config.is_encoder_decoder): # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -1394,11 +1516,18 @@ def execute_model( ): self.maybe_setup_kv_connector(scheduler_output) + extra_kwargs: dict = {} + if (self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs): + encoder_inputs = self._extract_encoder_inputs(scheduler_output) + extra_kwargs.update(encoder_inputs) + model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **extra_kwargs, ) self.maybe_wait_for_kv_save() @@ -2008,7 +2137,8 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model - if self.is_multimodal_model: + if (self.is_multimodal_model + and not self.model_config.is_encoder_decoder): input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: @@ -2192,9 +2322,9 @@ def _dummy_pooler_run( def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. - # TODO: handle encoder-decoder models once we support them. if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 - and self.encoder_cache_size > 0): + and self.encoder_cache_size > 0 + and not self.model_config.is_encoder_decoder): # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when @@ -2394,7 +2524,7 @@ def may_reinitialize_input_batch(self, "for more details.") self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -2625,7 +2755,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue - # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: use_local_attention = (self.attention_chunk_size is not None and attn_module.impl.use_irope) @@ -2652,12 +2781,17 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=use_mla) + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + kv_cache_spec[layer_name] = CrossAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") @@ -2688,3 +2822,175 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: page_size_padded=page_size_padded) return kv_cache_spec + + def _build_encoder_attn_metadata( + self, + scheduler_output: "SchedulerOutput", + common_attn_metadata: Optional[CommonAttentionMetadata] = None + ) -> dict[str, Any]: + """Prepare encoder attention metadata for encoder-decoder models. + + Args: + scheduler_output: Scheduler output + + Returns: + dict[str, Any]: Encoder attention metadata + """ + # Get encoder input information from scheduled encoder inputs + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + + # Calculate encoder sequence lengths and cross slot mappings + encoder_seq_lens = [] + cross_slot_mapping = [] + num_encoder_tokens = 0 + + for req_id in scheduled_encoder_inputs: + encoder_seq_len = self.max_encoder_len + encoder_seq_lens.append(encoder_seq_len) + num_encoder_tokens += encoder_seq_len + + if self.model_config.is_encoder_decoder: + # Build cross slot mapping for this request + req_state = self.requests.get(req_id) + if req_state is None: + # During memory profiling or if request not found, + # use dummy slot mapping + cross_slot_mapping.extend([PAD_SLOT_ID] * encoder_seq_len) + else: + # Find the KV cache group that uses CrossAttentionSpec + cross_attn_group_idx = None + for i, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, + CrossAttentionSpec): + cross_attn_group_idx = i + break + + if ((cross_attn_group_idx is not None) and + (cross_attn_group_idx < len(req_state.block_ids))): + # Get cross attention block IDs for this request + cross_block_ids = req_state.block_ids[ + cross_attn_group_idx] + block_size = self.kv_cache_config.kv_cache_groups[ + cross_attn_group_idx].kv_cache_spec.block_size + + # Calculate slot mapping from block IDs + for i in range(encoder_seq_len): + block_number = cross_block_ids[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + cross_slot_mapping.append(slot) + else: + # This can happen if cross-attention blocks are not + # allocated for this request. Pad with PAD_SLOT_ID. + cross_slot_mapping.extend([PAD_SLOT_ID] * + encoder_seq_len) + + # Create encoder sequence start locations (cumulative sum) + encoder_seq_start_loc = [0] + for seq_len in encoder_seq_lens: + encoder_seq_start_loc.append(encoder_seq_start_loc[-1] + seq_len) + + # Convert to tensors + encoder_seq_lens_tensor = torch.tensor(encoder_seq_lens, + dtype=torch.int32, + device=self.device) + encoder_seq_start_loc_tensor = torch.tensor(encoder_seq_start_loc, + dtype=torch.int32, + device=self.device) + + encoder_metadata = { + "encoder_seq_start_loc": encoder_seq_start_loc_tensor, + "max_encoder_seq_len": self.max_encoder_len, + } + + # Use the first attention metadata builder + # to create encoder attention metadata + builder = self.attn_metadata_builders[0] + + # Create encoder-specific common attention metadata + # If we're building metadata for cross-attention, we use the + # common_attn_metadata built from decoder details and it gets + # passed in here. + if common_attn_metadata is None: + # ENCODER self-attention + # Create dummy tensors for required fields + dummy_block_table = torch.zeros((len(encoder_seq_lens), 1), + dtype=torch.int32, + device=self.device) + dummy_slot_mapping = torch.zeros((num_encoder_tokens, ), + dtype=torch.int32, + device=self.device) + dummy_computed_tokens = torch.zeros((len(encoder_seq_lens), ), + dtype=torch.int32, + device="cpu") + + common_metadata = CommonAttentionMetadata( + query_start_loc=encoder_metadata["encoder_seq_start_loc"], + query_start_loc_cpu=encoder_metadata["encoder_seq_start_loc"]. + cpu(), + seq_lens=encoder_seq_lens_tensor, + seq_lens_cpu=encoder_seq_lens_tensor.cpu(), + num_computed_tokens_cpu=dummy_computed_tokens, + num_reqs=len(encoder_seq_lens), + num_actual_tokens=num_encoder_tokens, + max_query_len=encoder_metadata["max_encoder_seq_len"], + block_table_tensor=dummy_block_table, + slot_mapping=dummy_slot_mapping, + ) + else: + # ENCODER_DECODER cross-attention + seq_lens_tensor = torch.full( + (common_attn_metadata.num_reqs, ), + self.max_encoder_len, + dtype=torch.int32, + device=self.device, + ) + seq_lens_cpu = torch.full( + (common_attn_metadata.num_reqs, ), + self.max_encoder_len, + dtype=torch.int32, + device="cpu", + ) + common_metadata = CommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + query_start_loc_cpu=common_attn_metadata.query_start_loc_cpu, + seq_lens=seq_lens_tensor, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + max_query_len=common_attn_metadata.max_query_len, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + ) + cross_slot_mapping_tensor = torch.tensor(cross_slot_mapping, + dtype=torch.int64, + device=self.device) + encoder_metadata["cross_slot_mapping"] = cross_slot_mapping_tensor + + # Set encoder fields in common metadata and build + if common_attn_metadata is None: + # ENCODER self-attention - set encoder fields + common_metadata.encoder_seq_start_loc = encoder_metadata[ + "encoder_seq_start_loc"] + common_metadata.max_encoder_seq_len = encoder_metadata[ + "max_encoder_seq_len"] + return builder.build( + common_prefix_len=0, # No cascade for encoder + common_attn_metadata=common_metadata, + ) + else: + # ENCODER_DECODER cross-attention - set both encoder and cross + # fields + common_metadata.encoder_seq_start_loc = encoder_metadata[ + "encoder_seq_start_loc"] + common_metadata.max_encoder_seq_len = encoder_metadata[ + "max_encoder_seq_len"] + common_metadata.cross_slot_mapping = encoder_metadata.get( + "cross_slot_mapping") + return builder.build( + common_prefix_len=0, # No cascade for encoder + common_attn_metadata=common_metadata, + ) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 3ecb1d7dd656..c48a27aae44f 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -147,14 +147,14 @@ def bind_kv_cache( index2name[extract_layer_index(layer_name)].append(layer_name) for layer_index in sorted(index2name.keys()): + # Some models (like encoder-decoder models) may have multiple + # layers with the same index, so we need to append all of them. + # For an encoder-decoder model, each decoder layer has + # self-attention (AttentionType.DECODER) + # and cross-attention (AttentionType.ENCODER_DECODER). layer_names = index2name[layer_index] - if len(layer_names) > 1: - # One typical case is encoder-decoder model, e.g., bart. - # The cross attention and self attention in the same decoder layer - # has different layer_name but the same layer_index. - raise NotImplementedError - layer_name = layer_names[0] - runner_kv_caches.append(kv_caches[layer_name]) + for layer_name in layer_names: + runner_kv_caches.append(kv_caches[layer_name]) # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items():