From 8a2c588d2097901296fed616dbe817fc543acab0 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Sun, 20 Jul 2025 16:46:24 -0300 Subject: [PATCH 1/6] Support encoder-only models without KV-Cache Add support for encoder models such as BERT which don't support a KV cache due to the non-causal attention. Since the KV Cache Spec is used to build the attention metadata for decoder models, this PR initializes the attention metadata builds for encoder-only models directly from the layers and adds a function to build the attention metadata. This PR combines elements of PRs https://github.com/vllm-project/vllm/pull/21088 and https://github.com/vllm-project/vllm/pull/19988 Summary of changes: **Flash Attention Backend:** - Implement encoder self-attention support without using KV cache **Scheduler:** - Disable chunked prefill for models without KV cache **GPU Model Runner:** - Implement encoder-only attention metadata building for self-attention Related to: - V0 deprecation: #18571 - 2025 Q3 roadmap: #20336 Signed-off-by: Max de Bayser Co-authored-by: Russell Bryant --- tests/entrypoints/openai/test_rerank.py | 2 +- .../models/language/pooling/test_embedding.py | 14 +- tests/models/language/pooling/test_jina.py | 8 + vllm/engine/arg_utils.py | 3 +- vllm/model_executor/models/bert.py | 18 +- vllm/model_executor/models/roberta.py | 91 ++++++--- vllm/v1/attention/backends/flash_attn.py | 120 ++++++++++- vllm/v1/engine/core.py | 6 + vllm/v1/worker/gpu_model_runner.py | 189 ++++++++++++++---- 9 files changed, 357 insertions(+), 94 deletions(-) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 4da97fe13691..912313ce133e 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -124,4 +124,4 @@ def test_invocations(server: RemoteOpenAIServer): invocation_output["results"]): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( - invocations_result["relevance_score"], rel=0.01) + invocations_result["relevance_score"], rel=0.05) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index cc9e4102d5b7..ba42e389fc15 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -39,17 +39,9 @@ def v1(run_with_both_engines): pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), # [Encoder-only] - pytest.param( - "BAAI/bge-base-en-v1.5", - marks=[ - # CPU only supports V1 - pytest.mark.core_model, - pytest.mark.skip_v1 - ]), - pytest.param("sentence-transformers/all-MiniLM-L12-v2", - marks=[pytest.mark.skip_v1]), - pytest.param("intfloat/multilingual-e5-small", - marks=[pytest.mark.skip_v1]), + pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), + pytest.param("sentence-transformers/all-MiniLM-L12-v2"), + pytest.param("intfloat/multilingual-e5-small"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", marks=[pytest.mark.skip_v1]), # [Cross-Encoder] diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 9bfe7411e16b..ca3dc45c32b9 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -23,6 +23,14 @@ ] +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 019ff033eda2..3075958f0dc8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1670,7 +1670,8 @@ def _set_default_args_v1(self, usage_context: UsageContext, if (self.max_num_seqs is None and usage_context in default_max_num_seqs): - self.max_num_seqs = default_max_num_seqs[usage_context] + self.max_num_seqs = min(default_max_num_seqs[usage_context], + self.max_num_batched_tokens or sys.maxsize) logger.debug("Setting max_num_seqs to %d for %s usage context.", self.max_num_seqs, use_context_value) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 006f547bb461..35fb4e57539e 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -12,7 +12,6 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -59,7 +58,6 @@ def __init__(self, config: BertConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -109,7 +107,6 @@ def forward( return pooled_output -@support_torch_compile class BertEncoder(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): @@ -327,6 +324,7 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states +@support_torch_compile class BertModel(nn.Module, SupportsQuant): is_pooling_model = True @@ -357,13 +355,9 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - attn_metadata = get_forward_context().attn_metadata - assert hasattr(attn_metadata, "seq_lens_tensor") - hidden_states = self.embeddings( - input_ids=input_ids, - seq_lens=attn_metadata.seq_lens_tensor, - position_ids=position_ids, - token_type_ids=token_type_ids) + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states) def load_weights(self, weights: Iterable[tuple[str, @@ -404,7 +398,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): +class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -429,11 +423,13 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, position_ids=positions, + token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 7d3b56ced5c4..ecfb35148685 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,6 +9,7 @@ from transformers import RobertaConfig from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -50,39 +51,12 @@ def __init__(self, config: RobertaConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() inputs_embeds = self.word_embeddings(input_ids) - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - pos_list = [] - token_list = [] - offset = 0 - for seq_len in seq_lens: - pos_list.append(position_ids[offset:offset + seq_len]) - token_list.append(input_ids[offset:offset + seq_len]) - offset += seq_len - - new_pos_list = [] - for positions, tokens in zip(pos_list, token_list): - # Verify assumption that incoming position are - # always a sequence from 0 to N. - expected_pos = torch.arange(positions.size()[0], - dtype=torch.long, - device=inputs_embeds.device) - assert torch.equal(positions, expected_pos) - new_pos_list.append( - create_position_ids_from_input_ids(tokens, self.padding_idx)) - position_ids = torch.cat(new_pos_list) - # Position embeddings. position_embeddings = self.position_embeddings(position_ids) if token_type_ids is None: @@ -124,6 +98,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel): _pooler: An instance of Pooler used for pooling operations. """ + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # Fix Roberta positions here outside of the CUDA graph. + # Because we need the to extract the sequences from + # input_ids the control flow is data dependent. + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) + + return self.model(input_ids=input_ids, + position_ids=positions, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> Union[BertModel, BertWithRope]: @@ -180,6 +180,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, @@ -206,6 +207,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) return self.roberta(input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds, @@ -235,3 +239,36 @@ def create_position_ids_from_input_ids(input_ids, past_key_values_length) * mask return incremental_indices.long() + padding_idx + + +def replace_roberta_positions(input_ids: torch.Tensor, + position_ids: torch.Tensor, + padding_idx: int) -> None: + + seq_lens: Optional[torch.Tensor] = None + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: # can be None during warmup + if isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values())) + # TODO: remove "seq_lens_tensor" after V0 is removed + seq_lens = getattr(attn_metadata, "seq_lens_tensor", + getattr(attn_metadata, "seq_lens", None)) + + if seq_lens is not None: + assert isinstance(seq_lens, torch.Tensor) + + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + token_list = torch.split(input_ids[:torch.sum(seq_lens)], + seq_lens.tolist()) + + offset = 0 + for tokens in token_list: + length = tokens.shape[0] + position_ids[offset:offset+length] = \ + create_position_ids_from_input_ids(tokens, padding_idx) + offset = offset + length diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ad414ee0a1fc..a6b393583293 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -376,11 +376,14 @@ def __init__( FlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " + if attn_type not in [ + AttentionType.DECODER, AttentionType.ENCODER_ONLY + ]: + raise NotImplementedError("Encoder/decoder cross-attention " + "is not implemented for " "FlashAttentionImpl") + + self.attn_type = attn_type self.use_irope = use_irope self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -388,6 +391,24 @@ def __init__( raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") + @staticmethod + def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) + def forward( self, layer: torch.nn.Module, @@ -424,6 +445,8 @@ def forward( # Profiling run. return output + attn_type = self.attn_type + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -434,6 +457,18 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens + + # Handle encoder attention differently - no KV cache needed + if attn_type in (AttentionType.ENCODER_ONLY, ): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention(query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, layer) + + # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: @@ -485,7 +520,7 @@ def forward( seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, - causal=True, + causal=FlashAttentionImpl._get_causal_option(attn_type), alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, @@ -526,6 +561,81 @@ def forward( ) return output + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + layer: torch.nn.Module, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache. + + Args: + query: shape = [num_encoder_tokens, num_heads, head_size] + key: shape = [num_encoder_tokens, num_kv_heads, head_size] + value: shape = [num_encoder_tokens, num_kv_heads, head_size] + output: shape = [num_encoder_tokens, num_heads, head_size] + attn_metadata: Encoder attention metadata + layer: The attention layer + """ + # For encoder attention, process FP8 quantization if needed + if self.kv_cache_dtype.startswith("fp8"): + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + + num_kv_tokens, num_kv_heads, head_size = key.shape + key, _ = ops.scaled_fp8_quant( + key.reshape( + (num_kv_tokens, num_kv_heads * head_size)).contiguous(), + layer._k_scale) + key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) + + value, _ = ops.scaled_fp8_quant( + value.reshape( + (num_kv_tokens, num_kv_heads * head_size)).contiguous(), + layer._v_scale) + value = value.reshape((num_kv_tokens, num_kv_heads, head_size)) + + # Use encoder-specific metadata for sequence information + # TODO: handle cross-encoder metadata fields + cu_seqlens_q = attn_metadata.query_start_loc + cu_seqlens_k = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_query_len + + descale_shape = ( + cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] + self.num_kv_heads) + + # Call flash attention directly on Q, K, V tensors + flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=output, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=False, # Encoder attention is bidirectional + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + return output + def use_cascade_attention( common_prefix_len: int, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ca636bf5a6f7..e7ccb0e3e8a0 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -109,6 +109,12 @@ def __init__(self, "compatibility may not be maintained.", vllm_config.scheduler_config.scheduler_cls) + if len(kv_cache_config.kv_cache_groups) == 0: + # Encoder models without KV cache don't support + # chunked prefill. But do SSM models? + logger.info("Disabling chunked prefill for model without KVCache") + vllm_config.scheduler_config.chunked_prefill_enabled = False + self.scheduler: SchedulerInterface = Scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 670e653929ce..2b74a68bd9d2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -126,6 +126,7 @@ def __init__( self.is_multimodal_model = model_config.is_multimodal_model self.is_pooling_model = model_config.pooler_config is not None + self.is_encoder_only_model = False self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -697,6 +698,21 @@ def _prepare_inputs( spec_decode_common_attn_metadata = None attn_metadata: dict[str, Any] = {} + + # Prepare encoder attention metadata separately + # (encoder layers are not in KV cache groups) + if self.is_encoder_only_model: + common_attn_metadata, encoder_attn_metdata = \ + self._build_encoder_only_attn_metadata( + scheduler_output) + + # Add encoder attention metadata for all encoder layers + attention_layers = get_layers_from_vllm_config( + self.vllm_config, Attention) + for layer_name, attn_module in attention_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_metadata[layer_name] = encoder_attn_metdata + # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -2403,6 +2419,49 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + def _initialize_single_attn_backend( + self, kv_cache_spec: KVCacheSpec + ) -> tuple[AttentionBackend, AttentionMetadataBuilder]: + if isinstance(kv_cache_spec, AttentionSpec): + attn_backend_i = get_attn_backend( + kv_cache_spec.head_size, + self.dtype, + kv_cache_spec.dtype, + kv_cache_spec.block_size, + self.model_config.is_attention_free, + use_mla=kv_cache_spec.use_mla, + ) + if attn_backend_i is None: + error_msg = (f"Error with get_attn_backend: " + f"{kv_cache_spec.head_size=}, " + f"{self.dtype=}, {kv_cache_spec.dtype=}, " + f"{kv_cache_spec.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{kv_cache_spec.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 " + "GPUModelRunner.") + elif isinstance(kv_cache_spec, MambaSpec): + attn_backend_i = Mamba2AttentionBackend + else: + raise ValueError( + f"Unknown KV cache spec type: {type(kv_cache_spec)}") + + attn_metadata_builder_i = attn_backend_i.get_builder_cls()( + kv_cache_spec, + self.vllm_config, + self.device, + ) + + if (self.full_cuda_graph + and not attn_metadata_builder_i.full_cudagraph_supported): + raise ValueError( + f"Full CUDAGraph not supported for " + f"{attn_backend_i.__name__}. Turn off CompilationConfig." + f"full_cuda_graph or use a different attention backend.") + return attn_backend_i, attn_metadata_builder_i + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. @@ -2413,48 +2472,53 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if isinstance(kv_cache_spec, AttentionSpec): - attn_backend_i = get_attn_backend( - kv_cache_spec.head_size, - self.dtype, - kv_cache_spec.dtype, - kv_cache_spec.block_size, - self.model_config.is_attention_free, - use_mla=kv_cache_spec.use_mla, - ) - if attn_backend_i is None: - error_msg = (f"Error with get_attn_backend: " - f"{kv_cache_spec.head_size=}, " - f"{self.dtype=}, {kv_cache_spec.dtype=}, " - f"{kv_cache_spec.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{kv_cache_spec.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 " - "GPUModelRunner.") - elif isinstance(kv_cache_spec, MambaSpec): - attn_backend_i = Mamba2AttentionBackend - else: - raise ValueError( - f"Unknown KV cache spec type: {type(kv_cache_spec)}") - - attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - kv_cache_spec, - self.vllm_config, - self.device, - ) - - if (self.full_cuda_graph - and not attn_metadata_builder_i.full_cudagraph_supported): - raise ValueError( - f"Full CUDAGraph not supported for " - f"{attn_backend_i.__name__}. Turn off CompilationConfig." - f"full_cuda_graph or use a different attention backend.") + attn_backend_i, attn_metadata_builder_i = \ + self._initialize_single_attn_backend(kv_cache_spec) self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) + if len(self.attn_backends) > 0: + return + + # Check if model is encoder-only + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + kv_cache_specs = list[KVCacheSpec]() + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for attn_module in attn_layers.values(): + + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + if attn_module.sliding_window is not None: + kv_cache_specs.append( + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + use_mla=use_mla)) + else: + kv_cache_specs.append( + FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla)) + else: + raise ValueError("Expected only encoder-only layers") + + if len(kv_cache_specs) > 0: + assert len(kv_cache_specs) == len(attn_layers), \ + "All or none of the layers are expected to be encoder-only" + + attn_backend, attn_metadata_builder = \ + self._initialize_single_attn_backend(kv_cache_specs[0]) + self.attn_backends.append(attn_backend) + self.attn_metadata_builders.append(attn_metadata_builder) + self.is_encoder_only_model = True + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -2771,3 +2835,52 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: page_size_padded=page_size_padded) return kv_cache_spec + + def _build_encoder_only_attn_metadata( + self, scheduler_output: "SchedulerOutput") -> \ + tuple[CommonAttentionMetadata, Any]: + """Prepare encoder attention metadata for encoder-only models. + + Args: + scheduler_output: Scheduler output + + Returns: + dict[str, Any]: Encoder attention metadata + """ + num_reqs = self.input_batch.num_reqs + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + max_num_scheduled_tokens = max(tokens) + + # Use the first attention metadata builder + # to create encoder attention metadata + builder = self.attn_metadata_builders[0] + + dummy_block_table = torch.zeros((num_reqs, 1), + dtype=torch.int32, + device=self.device) + dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ), + dtype=torch.int32, + device=self.device) + + common_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + block_table_tensor=dummy_block_table, + slot_mapping=dummy_slot_mapping, + ) + + return common_metadata, builder.build( + common_prefix_len=0, # No cascade for encoder + common_attn_metadata=common_metadata, + ) From 1f3fcc4092599231788ee32dbcb82ff56873876d Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 21 Jul 2025 15:03:49 -0300 Subject: [PATCH 2/6] address review comments Signed-off-by: Max de Bayser --- tests/entrypoints/openai/test_rerank.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 5 +---- vllm/v1/worker/gpu_model_runner.py | 12 ++++++------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 912313ce133e..4da97fe13691 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -124,4 +124,4 @@ def test_invocations(server: RemoteOpenAIServer): invocation_output["results"]): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( - invocations_result["relevance_score"], rel=0.05) + invocations_result["relevance_score"], rel=0.01) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4ea5001b5008..14f502ed3d9e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -403,9 +403,7 @@ def _get_causal_option(attn_type: str) -> bool: attention (i.e., not encoder, encoder-only, or encoder-decoder), otherwise returns `False`. """ - return not (attn_type == AttentionType.ENCODER - or attn_type == AttentionType.ENCODER_ONLY - or attn_type == AttentionType.ENCODER_DECODER) + return attn_type == AttentionType.DECODER def forward( self, @@ -601,7 +599,6 @@ def _forward_encoder_attention( value = value.reshape((num_kv_tokens, num_kv_heads, head_size)) # Use encoder-specific metadata for sequence information - # TODO: handle cross-encoder metadata fields cu_seqlens_q = attn_metadata.query_start_loc cu_seqlens_k = attn_metadata.query_start_loc max_seqlen_q = attn_metadata.max_query_len diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e49507542ed4..cc42797fada3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2476,13 +2476,13 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # Check if model is encoder-only block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla - kv_cache_specs = list[KVCacheSpec]() + attn_specs = list[AttentionSpec]() attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for attn_module in attn_layers.values(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: if attn_module.sliding_window is not None: - kv_cache_specs.append( + attn_specs.append( SlidingWindowSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, @@ -2491,7 +2491,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: sliding_window=attn_module.sliding_window, use_mla=use_mla)) else: - kv_cache_specs.append( + attn_specs.append( FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, @@ -2501,12 +2501,12 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: else: raise ValueError("Expected only encoder-only layers") - if len(kv_cache_specs) > 0: - assert len(kv_cache_specs) == len(attn_layers), \ + if len(attn_specs) > 0: + assert len(attn_specs) == len(attn_layers), \ "All or none of the layers are expected to be encoder-only" attn_backend, attn_metadata_builder = \ - self._initialize_single_attn_backend(kv_cache_specs[0]) + self._initialize_single_attn_backend(attn_specs[0]) self.attn_backends.append(attn_backend) self.attn_metadata_builders.append(attn_metadata_builder) self.is_encoder_only_model = True From 8e2cba14052293e6bdd16ad802c9183165032006 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 21 Jul 2025 22:10:48 -0300 Subject: [PATCH 3/6] remove sliding window attention case Signed-off-by: Max de Bayser --- vllm/v1/worker/gpu_model_runner.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cc42797fada3..2411dc9f5d2a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2481,23 +2481,15 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: for attn_module in attn_layers.values(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: - if attn_module.sliding_window is not None: - attn_specs.append( - SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla)) - else: - attn_specs.append( - FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla)) + assert attn_module.sliding_window is None, "Sliding " + "window attention is not supported for encoder-only models" + + attn_specs.append( + FullAttentionSpec(block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla)) else: raise ValueError("Expected only encoder-only layers") From 735761455724fa8d298065c147910879e3b8677f Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 21 Jul 2025 22:14:08 -0300 Subject: [PATCH 4/6] address review comment Signed-off-by: Max de Bayser --- vllm/v1/attention/backends/flash_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 14f502ed3d9e..1fd4ce4f5e4f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -390,7 +390,7 @@ def __init__( "FlashAttention does not support fp8 kv-cache on this device.") @staticmethod - def _get_causal_option(attn_type: str) -> bool: + def _is_causal_attention(attn_type: str) -> bool: """ Determine whether the given attention type is suitable for causal attention mechanisms. @@ -516,7 +516,7 @@ def forward( seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, - causal=FlashAttentionImpl._get_causal_option(attn_type), + causal=FlashAttentionImpl._is_causal_attention(attn_type), alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, From aa69e9246cb555c7d150daba728cb86eaed06fef Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 21 Jul 2025 22:48:53 -0300 Subject: [PATCH 5/6] make causal a flag in common attention metadata Signed-off-by: Max de Bayser --- tests/v1/attention/utils.py | 1 + vllm/v1/attention/backends/flash_attn.py | 25 ++++++------------------ vllm/v1/attention/backends/utils.py | 3 +++ vllm/v1/spec_decode/eagle.py | 1 + vllm/v1/worker/gpu_model_runner.py | 5 ++++- 5 files changed, 15 insertions(+), 20 deletions(-) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 30cfbdda5d86..f3ce33fcfa12 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -93,6 +93,7 @@ def create_common_attn_metadata( max_query_len=max_query_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, + causal=True, ) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1fd4ce4f5e4f..b8f50eb4df0f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -130,6 +130,8 @@ class FlashAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None max_num_splits: int = 0 + causal: bool = True + def _get_sliding_window_configs( vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: @@ -213,6 +215,7 @@ def build(self, seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + causal = common_attn_metadata.causal # the overhead of the aot schedule is not worth it for spec-decode aot_schedule = self.aot_schedule and not fast_build @@ -288,7 +291,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_query_len=max_query_len, seqlens=seq_lens, max_seq_len=max_seq_len, - causal=True) + causal=causal) if self.use_full_cuda_graph: assert scheduler_metadata is not None @@ -326,7 +329,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - ) + causal=causal) return attn_metadata def can_run_in_cudagraph( @@ -389,22 +392,6 @@ def __init__( raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") - @staticmethod - def _is_causal_attention(attn_type: str) -> bool: - """ - Determine whether the given attention type is suitable for causal - attention mechanisms. - - Args: - attn_type (AttentionType): The type of attention being evaluated - - Returns: - bool: Returns `True` if the attention type is suitable for causal - attention (i.e., not encoder, encoder-only, or encoder-decoder), - otherwise returns `False`. - """ - return attn_type == AttentionType.DECODER - def forward( self, layer: torch.nn.Module, @@ -516,7 +503,7 @@ def forward( seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, - causal=FlashAttentionImpl._is_causal_attention(attn_type), + causal=attn_metadata.causal, alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fc8649d587ee..3553364d3080 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -59,6 +59,8 @@ class CommonAttentionMetadata: block_table_tensor: torch.Tensor slot_mapping: torch.Tensor + causal: bool + M = TypeVar("M") @@ -395,6 +397,7 @@ def make_local_attention_virtual_batches( max_query_len=seqlens_q_local.max(), block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, + causal=True, ) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 967847c02ff2..63f6fc276189 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -330,6 +330,7 @@ def prepare_inputs( max_query_len=new_query_len_per_req.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, ) return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2411dc9f5d2a..a7f5f17d7139 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -734,6 +734,7 @@ def _prepare_inputs( max_query_len=max_num_scheduled_tokens, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, + causal=True, ) if self.speculative_config and \ @@ -2088,7 +2089,8 @@ def _dummy_run( block_table_tensor=self.input_batch.block_table[ kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. - block_table[kv_cache_group_id].slot_mapping[:num_tokens]) + block_table[kv_cache_group_id].slot_mapping[:num_tokens], + causal=True) attn_metadata_i = self.attn_metadata_builders[ kv_cache_group_id].build_for_cudagraph_capture( @@ -2861,6 +2863,7 @@ def _build_encoder_only_attn_metadata( max_query_len=max_num_scheduled_tokens, block_table_tensor=dummy_block_table, slot_mapping=dummy_slot_mapping, + causal=False, ) return common_metadata, builder.build( From d81e14329a5e87eeafa68bb3368c2376e9f57cd3 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 23 Jul 2025 11:11:15 -0300 Subject: [PATCH 6/6] fix typo Signed-off-by: Max de Bayser --- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bd86f0eb7c8b..0641078d534a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -698,7 +698,7 @@ def _prepare_inputs( # Prepare encoder attention metadata separately # (encoder layers are not in KV cache groups) if self.is_encoder_only_model: - common_attn_metadata, encoder_attn_metdata = \ + common_attn_metadata, encoder_attn_metadata = \ self._build_encoder_only_attn_metadata( scheduler_output) @@ -707,7 +707,7 @@ def _prepare_inputs( self.vllm_config, Attention) for layer_name, attn_module in attention_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: - attn_metadata[layer_name] = encoder_attn_metdata + attn_metadata[layer_name] = encoder_attn_metadata # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata.