-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
v1: Add Whisper model support (encoder-decoder) #21088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,6 +130,24 @@ | |
prefix_scheduler_metadata: Optional[torch.Tensor] = None | ||
max_num_splits: int = 0 | ||
|
||
# Begin encoder attn & enc/dec cross-attn fields... | ||
|
||
# (batch_size + 1,). The cumulative sequence lengths of the encoder | ||
# sequences in the batch, used to index into sequence. E.g., if the sequence | ||
# length is [4, 6], it is [0, 4, 10]. | ||
encoder_seq_start_loc: Optional[torch.Tensor] = None | ||
# Maximum sequence length among encoder sequences | ||
max_encoder_seq_len: Optional[int] = None | ||
cross_slot_mapping: Optional[torch.Tensor] = None | ||
|
||
@property | ||
def is_all_encoder_attn_metadata_set(self) -> bool: | ||
""" | ||
All attention metadata required for encoder attention is set. | ||
""" | ||
return (self.encoder_seq_start_loc is not None | ||
and self.max_encoder_seq_len is not None) | ||
|
||
|
||
def _get_sliding_window_configs( | ||
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: | ||
|
@@ -207,7 +225,13 @@ | |
num_reqs = common_attn_metadata.num_reqs | ||
num_actual_tokens = common_attn_metadata.num_actual_tokens | ||
max_query_len = common_attn_metadata.max_query_len | ||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) | ||
|
||
if (common_attn_metadata.cross_slot_mapping is not None | ||
and common_attn_metadata.max_encoder_seq_len is not None): | ||
# ENCODER_DECODER cross-attention | ||
max_seq_len = common_attn_metadata.max_encoder_seq_len | ||
else: | ||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) | ||
query_start_loc = common_attn_metadata.query_start_loc | ||
seq_lens = common_attn_metadata.seq_lens | ||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu | ||
|
@@ -326,6 +350,10 @@ | |
suffix_kv_lens=suffix_kv_lens, | ||
prefix_scheduler_metadata=prefix_scheduler_metadata, | ||
max_num_splits=max_num_splits, | ||
# Encoder/cross-attention fields | ||
encoder_seq_start_loc=common_attn_metadata.encoder_seq_start_loc, | ||
max_encoder_seq_len=common_attn_metadata.max_encoder_seq_len, | ||
cross_slot_mapping=common_attn_metadata.cross_slot_mapping, | ||
) | ||
return attn_metadata | ||
|
||
|
@@ -380,18 +408,32 @@ | |
|
||
FlashAttentionBackend.validate_head_size(head_size) | ||
|
||
if attn_type != AttentionType.DECODER: | ||
raise NotImplementedError("Encoder self-attention and " | ||
"encoder/decoder cross-attention " | ||
"are not implemented for " | ||
"FlashAttentionImpl") | ||
self.attn_type = attn_type | ||
self.use_irope = use_irope | ||
self.vllm_flash_attn_version = get_flash_attn_version() | ||
if is_quantized_kv_cache(self.kv_cache_dtype) \ | ||
and not flash_attn_supports_fp8(): | ||
raise NotImplementedError( | ||
"FlashAttention does not support fp8 kv-cache on this device.") | ||
|
||
@staticmethod | ||
def _get_causal_option(attn_type: str) -> bool: | ||
""" | ||
Determine whether the given attention type is suitable for causal | ||
attention mechanisms. | ||
|
||
Args: | ||
attn_type (AttentionType): The type of attention being evaluated | ||
|
||
Returns: | ||
bool: Returns `True` if the attention type is suitable for causal | ||
attention (i.e., not encoder, encoder-only, or encoder-decoder), | ||
otherwise returns `False`. | ||
""" | ||
return not (attn_type == AttentionType.ENCODER | ||
or attn_type == AttentionType.ENCODER_ONLY | ||
or attn_type == AttentionType.ENCODER_DECODER) | ||
Comment on lines
+433
to
+435
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: isn't this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. true |
||
|
||
def forward( | ||
self, | ||
layer: torch.nn.Module, | ||
|
@@ -428,6 +470,14 @@ | |
# Profiling run. | ||
return output | ||
|
||
# Validate attention metadata based on attention type | ||
attn_type = self.attn_type | ||
if (attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_DECODER, | ||
AttentionType.ENCODER_ONLY) | ||
Comment on lines
+475
to
+476
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can re-use get_causal_opt |
||
and (not attn_metadata.is_all_encoder_attn_metadata_set)): | ||
raise AttributeError("Encoder attention requires setting " | ||
"encoder metadata attributes.") | ||
|
||
# IMPORTANT! | ||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in | ||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead | ||
|
@@ -438,22 +488,40 @@ | |
# performance to make sure it does not introduce any overhead. | ||
|
||
num_actual_tokens = attn_metadata.num_actual_tokens | ||
|
||
# Handle encoder attention differently - no KV cache needed | ||
if attn_type == AttentionType.ENCODER: | ||
# For encoder attention, | ||
# we use direct Q, K, V tensors without caching | ||
return self._forward_encoder_attention(query[:num_actual_tokens], | ||
key[:num_actual_tokens], | ||
value[:num_actual_tokens], | ||
output[:num_actual_tokens], | ||
attn_metadata, layer) | ||
|
||
# For decoder and cross-attention, use KV cache as before | ||
key_cache, value_cache = kv_cache.unbind(0) | ||
|
||
if self.kv_sharing_target_layer_name is None: | ||
if (self.kv_sharing_target_layer_name is None and (key is not None) | ||
and (value is not None)): | ||
# Reshape the input keys and values and store them in the cache. | ||
# Skip this if sharing KV cache with an earlier attention layer. | ||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is | ||
# not padded. However, we don't need to do key[:num_actual_tokens] | ||
# and value[:num_actual_tokens] because the reshape_and_cache_flash | ||
# op uses the slot_mapping's shape to determine the number of | ||
# actual tokens. | ||
if attn_type == AttentionType.ENCODER_DECODER: | ||
updated_slot_mapping = attn_metadata.cross_slot_mapping | ||
else: | ||
updated_slot_mapping = attn_metadata.slot_mapping | ||
|
||
reshape_and_cache_flash( | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
updated_slot_mapping, | ||
self.kv_cache_dtype, | ||
layer._k_scale, | ||
layer._v_scale, | ||
|
@@ -477,7 +545,7 @@ | |
block_table = attn_metadata.block_table | ||
scheduler_metadata = attn_metadata.scheduler_metadata | ||
|
||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) | ||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) | ||
|
||
flash_attn_varlen_func( | ||
q=query[:num_actual_tokens], | ||
|
@@ -489,7 +557,7 @@ | |
seqused_k=seqused_k, | ||
max_seqlen_k=max_seqlen_k, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
causal=FlashAttentionImpl._get_causal_option(attn_type), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we just add the (kinda like: #21093) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes! thanks for the feedback |
||
alibi_slopes=self.alibi_slopes, | ||
window_size=self.sliding_window, | ||
block_table=block_table, | ||
|
@@ -504,7 +572,7 @@ | |
return output | ||
|
||
# Cascade attention (rare case). | ||
cascade_attention( | ||
Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py
|
||
output[:num_actual_tokens], | ||
query[:num_actual_tokens], | ||
key_cache, | ||
|
@@ -516,20 +584,91 @@ | |
suffix_kv_lens=attn_metadata.suffix_kv_lens, | ||
max_kv_len=attn_metadata.max_seq_len, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
alibi_slopes=self.alibi_slopes, | ||
sliding_window=self.sliding_window, | ||
logits_soft_cap=self.logits_soft_cap, | ||
block_table=attn_metadata.block_table, | ||
common_prefix_len=attn_metadata.common_prefix_len, | ||
window_size=self.sliding_window, | ||
softcap=self.logits_soft_cap, | ||
fa_version=self.vllm_flash_attn_version, | ||
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, | ||
suffix_scheduler_metadata=attn_metadata.scheduler_metadata, | ||
q_descale=layer._q_scale, | ||
k_descale=layer._k_scale, | ||
v_descale=layer._v_scale, | ||
q_descale=layer._q_scale.expand(descale_shape), | ||
k_descale=layer._k_scale.expand(descale_shape), | ||
v_descale=layer._v_scale.expand(descale_shape), | ||
) | ||
return output | ||
|
||
def _forward_encoder_attention( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! |
||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
output: torch.Tensor, | ||
attn_metadata: FlashAttentionMetadata, | ||
layer: torch.nn.Module, | ||
) -> torch.Tensor: | ||
"""Forward pass for encoder attention without KV cache. | ||
|
||
Args: | ||
query: shape = [num_encoder_tokens, num_heads, head_size] | ||
key: shape = [num_encoder_tokens, num_kv_heads, head_size] | ||
value: shape = [num_encoder_tokens, num_kv_heads, head_size] | ||
output: shape = [num_encoder_tokens, num_heads, head_size] | ||
attn_metadata: Encoder attention metadata | ||
layer: The attention layer | ||
""" | ||
# For encoder attention, process FP8 quantization if needed | ||
if self.kv_cache_dtype.startswith("fp8"): | ||
num_tokens, num_heads, head_size = query.shape | ||
query, _ = ops.scaled_fp8_quant( | ||
query.reshape( | ||
(num_tokens, num_heads * head_size)).contiguous(), | ||
layer._q_scale) | ||
query = query.reshape((num_tokens, num_heads, head_size)) | ||
|
||
num_kv_tokens, num_kv_heads, head_size = key.shape | ||
key, _ = ops.scaled_fp8_quant( | ||
key.reshape( | ||
(num_kv_tokens, num_kv_heads * head_size)).contiguous(), | ||
layer._k_scale) | ||
key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) | ||
|
||
value, _ = ops.scaled_fp8_quant( | ||
value.reshape( | ||
(num_kv_tokens, num_kv_heads * head_size)).contiguous(), | ||
layer._v_scale) | ||
value = value.reshape((num_kv_tokens, num_kv_heads, head_size)) | ||
|
||
# Use encoder-specific metadata for sequence information | ||
cu_seqlens_q = attn_metadata.encoder_seq_start_loc | ||
cu_seqlens_k = attn_metadata.encoder_seq_start_loc | ||
max_seqlen_q = attn_metadata.max_encoder_seq_len | ||
max_seqlen_k = attn_metadata.max_encoder_seq_len | ||
|
||
descale_shape = ( | ||
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] | ||
self.num_kv_heads) | ||
|
||
# Call flash attention directly on Q, K, V tensors | ||
flash_attn_varlen_func( | ||
q=query, | ||
k=key, | ||
v=value, | ||
out=output, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_q, | ||
max_seqlen_k=max_seqlen_k, | ||
softmax_scale=self.scale, | ||
causal=False, # Encoder attention is bidirectional | ||
alibi_slopes=self.alibi_slopes, | ||
window_size=self.sliding_window, | ||
softcap=self.logits_soft_cap, | ||
fa_version=self.vllm_flash_attn_version, | ||
q_descale=layer._q_scale.expand(descale_shape), | ||
k_descale=layer._k_scale.expand(descale_shape), | ||
v_descale=layer._v_scale.expand(descale_shape), | ||
) | ||
|
||
return output | ||
|
||
|
||
def use_cascade_attention( | ||
common_prefix_len: int, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
_is_causal_attention
?