Skip to content

v1: Add Whisper model support (encoder-decoder) #21088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion vllm/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
]
6 changes: 0 additions & 6 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,9 +869,6 @@ def preprocess(
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
Expand Down Expand Up @@ -903,9 +900,6 @@ async def preprocess_async(
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(prompt)
Expand Down
9 changes: 4 additions & 5 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from vllm.transformers_utils.processor import cached_get_processor

from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription, SupportsV0Only)
SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers)

Expand Down Expand Up @@ -790,7 +790,7 @@ def _get_prompt_updates(
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal, SupportsV0Only):
SupportsMultiModal):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
Expand Down Expand Up @@ -916,10 +916,9 @@ def get_language_model(self) -> torch.nn.Module:

def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
# Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs(audio_input["input_features"])
return [self.model.get_encoder_outputs(audio_input["input_features"])]

def get_input_embeddings(
self,
Expand Down
177 changes: 158 additions & 19 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0

# Begin encoder attn & enc/dec cross-attn fields...

# (batch_size + 1,). The cumulative sequence lengths of the encoder
# sequences in the batch, used to index into sequence. E.g., if the sequence
# length is [4, 6], it is [0, 4, 10].
encoder_seq_start_loc: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
cross_slot_mapping: Optional[torch.Tensor] = None

@property
def is_all_encoder_attn_metadata_set(self) -> bool:
"""
All attention metadata required for encoder attention is set.
"""
return (self.encoder_seq_start_loc is not None
and self.max_encoder_seq_len is not None)


def _get_sliding_window_configs(
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
Expand Down Expand Up @@ -207,7 +225,13 @@
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())

if (common_attn_metadata.cross_slot_mapping is not None
and common_attn_metadata.max_encoder_seq_len is not None):
# ENCODER_DECODER cross-attention
max_seq_len = common_attn_metadata.max_encoder_seq_len
else:
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
Expand Down Expand Up @@ -326,6 +350,10 @@
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
# Encoder/cross-attention fields
encoder_seq_start_loc=common_attn_metadata.encoder_seq_start_loc,
max_encoder_seq_len=common_attn_metadata.max_encoder_seq_len,
cross_slot_mapping=common_attn_metadata.cross_slot_mapping,
)
return attn_metadata

Expand Down Expand Up @@ -380,18 +408,32 @@

FlashAttentionBackend.validate_head_size(head_size)

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.attn_type = attn_type
self.use_irope = use_irope
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")

@staticmethod
def _get_causal_option(attn_type: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _is_causal_attention?

"""
Determine whether the given attention type is suitable for causal
attention mechanisms.

Args:
attn_type (AttentionType): The type of attention being evaluated

Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)
Comment on lines +433 to +435
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: isn't this attn_type == AttentionType.DECODER?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true


def forward(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -428,6 +470,14 @@
# Profiling run.
return output

# Validate attention metadata based on attention type
attn_type = self.attn_type
if (attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_DECODER,
AttentionType.ENCODER_ONLY)
Comment on lines +475 to +476
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can re-use get_causal_opt

and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")

# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
Expand All @@ -438,22 +488,40 @@
# performance to make sure it does not introduce any overhead.

num_actual_tokens = attn_metadata.num_actual_tokens

# Handle encoder attention differently - no KV cache needed
if attn_type == AttentionType.ENCODER:
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata, layer)

# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)

if self.kv_sharing_target_layer_name is None:
if (self.kv_sharing_target_layer_name is None and (key is not None)
and (value is not None)):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if attn_type == AttentionType.ENCODER_DECODER:
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
updated_slot_mapping = attn_metadata.slot_mapping

reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
updated_slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
Expand All @@ -477,7 +545,7 @@
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)

flash_attn_varlen_func(
q=query[:num_actual_tokens],
Expand All @@ -489,7 +557,7 @@
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
causal=FlashAttentionImpl._get_causal_option(attn_type),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we just add the causal flag to CommonAttentionMetadata and manipulate the slot-mapping on the CommonAttentionMetadata so we can make more of this backend agnostic?

(kinda like: #21093)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! thanks for the feedback

alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
Expand All @@ -504,7 +572,7 @@
return output

# Cascade attention (rare case).
cascade_attention(

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "causal" for "cascade_attention" [call-arg]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "softcap" for "cascade_attention" [call-arg]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "window_size" for "cascade_attention" [call-arg]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "causal" for "cascade_attention" [call-arg]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "softcap" for "cascade_attention" [call-arg]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "window_size" for "cascade_attention" [call-arg]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "causal" for "cascade_attention" [call-arg]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "softcap" for "cascade_attention" [call-arg]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "window_size" for "cascade_attention" [call-arg]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "causal" for "cascade_attention" [call-arg]
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
Expand All @@ -516,20 +584,91 @@
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output

def _forward_encoder_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.

Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))

num_kv_tokens, num_kv_heads, head_size = key.shape
key, _ = ops.scaled_fp8_quant(
key.reshape(
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
layer._k_scale)
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))

value, _ = ops.scaled_fp8_quant(
value.reshape(
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
layer._v_scale)
value = value.reshape((num_kv_tokens, num_kv_heads, head_size))

# Use encoder-specific metadata for sequence information
cu_seqlens_q = attn_metadata.encoder_seq_start_loc
cu_seqlens_k = attn_metadata.encoder_seq_start_loc
max_seqlen_q = attn_metadata.max_encoder_seq_len
max_seqlen_k = attn_metadata.max_encoder_seq_len

descale_shape = (
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
self.num_kv_heads)

# Call flash attention directly on Q, K, V tensors
flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=False, # Encoder attention is bidirectional
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)

return output


def use_cascade_attention(
common_prefix_len: int,
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor

# Encoder/cross-attention specific fields (optional)
encoder_seq_start_loc: Optional[torch.Tensor] = None
"""(batch_size + 1,), cumulative encoder sequence lengths"""
max_encoder_seq_len: Optional[int] = None
"""Maximum encoder sequence length in batch"""
cross_slot_mapping: Optional[torch.Tensor] = None
"""Slot mapping for cross-attention KV cache"""

def __post_init__(self):
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
Expand Down
Loading