From 5dee54d2652e698b00730817e77ea85bb3162cba Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 23 Jun 2025 11:25:27 -0300 Subject: [PATCH 01/18] Add support for encoder embedding models Signed-off-by: Max de Bayser --- .../models/language/pooling/test_embedding.py | 16 ++----- tests/models/language/pooling/test_jina.py | 8 ++++ tests/models/language/pooling/test_scoring.py | 9 ++++ tests/v1/core/test_kv_cache_utils.py | 2 +- vllm/config.py | 5 ++ vllm/engine/arg_utils.py | 3 +- vllm/model_executor/models/bert.py | 22 ++++----- vllm/model_executor/models/roberta.py | 38 +++++++-------- vllm/v1/attention/backends/flash_attn.py | 30 ++++++++++-- vllm/v1/core/kv_cache_utils.py | 1 + vllm/v1/core/sched/output.py | 2 + vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/core.py | 19 +++++++- vllm/v1/engine/processor.py | 1 + vllm/v1/kv_cache_interface.py | 1 + vllm/v1/request.py | 3 ++ vllm/v1/worker/gpu_input_batch.py | 31 +++++++++++- vllm/v1/worker/gpu_model_runner.py | 47 +++++++++++++++---- vllm/v1/worker/tpu_model_runner.py | 3 ++ 19 files changed, 178 insertions(+), 64 deletions(-) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 5ef9f768c57..c9c98357b16 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -38,19 +38,13 @@ def v1(run_with_both_engines): marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), # [Encoder-only] pytest.param("BAAI/bge-base-en-v1.5", - marks=[ - pytest.mark.core_model, pytest.mark.cpu_model, - pytest.mark.skip_v1 - ]), - pytest.param("sentence-transformers/all-MiniLM-L12-v2", - marks=[pytest.mark.skip_v1]), - pytest.param("intfloat/multilingual-e5-small", - marks=[pytest.mark.skip_v1]), + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param("sentence-transformers/all-MiniLM-L12-v2"), + pytest.param("intfloat/multilingual-e5-small"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - marks=[pytest.mark.skip_v1]), + marks=[pytest.mark.skip_v0]), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2", - marks=[pytest.mark.skip_v1]), + pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) def test_models( diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 0c44683e748..99466405501 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -26,6 +26,14 @@ ] +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index c75ff144561..1cf2cdc0132 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -23,6 +23,15 @@ "The capital of Germany is Berlin.", ] + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + DTYPE = "half" diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e80ad8a6815..97fff745647 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -916,4 +916,4 @@ def test_get_kv_cache_config(): ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) \ No newline at end of file + ]) diff --git a/vllm/config.py b/vllm/config.py index 7549c97b4fe..d112abd7156 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -716,6 +716,11 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: self.override_pooler_config = PoolerConfig( **self.override_pooler_config) + # WIP: currently cuda graphs are not working for encoder models. + logger.warning("CUDA graph is not supported for pooling yet, " + "fallback to the eager mode.") + self.enforce_eager = True + pooler_config = self.override_pooler_config or PoolerConfig() base_config = get_pooling_config(self.model, self.revision) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dd09f514906..8d01c0413e2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1664,7 +1664,8 @@ def _set_default_args_v1(self, usage_context: UsageContext, if (self.max_num_seqs is None and usage_context in default_max_num_seqs): - self.max_num_seqs = default_max_num_seqs[usage_context] + self.max_num_seqs = min(default_max_num_seqs[usage_context], + self.max_num_batched_tokens) logger.debug("Setting max_num_seqs to %d for %s usage context.", self.max_num_seqs, use_context_value) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index d6f6d9d1fb5..cb3ac791ccb 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -12,7 +12,6 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -28,7 +27,7 @@ from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) -from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .interfaces import SupportsCrossEncoding, SupportsQuant from .utils import WeightsMapper, maybe_prefix @@ -57,7 +56,6 @@ def __init__(self, config: BertConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -342,13 +340,9 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - attn_metadata = get_forward_context().attn_metadata - assert hasattr(attn_metadata, "seq_lens_tensor") - hidden_states = self.embeddings( - input_ids=input_ids, - seq_lens=attn_metadata.seq_lens_tensor, - position_ids=position_ids, - token_type_ids=token_type_ids) + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states) def load_weights(self, weights: Iterable[tuple[str, @@ -388,7 +382,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): +class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -411,11 +405,13 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, position_ids=positions, + token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) @@ -446,8 +442,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: softmax=False) -class BertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding, SupportsQuant): +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 8fa8b89798d..87251236b37 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -22,7 +22,7 @@ get_cross_encoder_activation_function) from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import SupportsCrossEncoding class RobertaEmbedding(nn.Module): @@ -52,41 +52,36 @@ def __init__(self, config: RobertaConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + input_shape = input_ids.size() inputs_embeds = self.word_embeddings(input_ids) + zero_pos = torch.where(position_ids == 0)[0] + end_pos = torch.cat((zero_pos[1:], + torch.tensor([position_ids.shape[0]], + device=zero_pos.device))) + seq_lens = end_pos - zero_pos + # Replace position ids because in RoBERTa models # they have to start at padding_idx + 1 and ignore # existing padding tokens # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + token_list = torch.split(input_ids, seq_lens.tolist()) + pos_list = [] - token_list = [] - offset = 0 - for seq_len in seq_lens: - pos_list.append(position_ids[offset:offset + seq_len]) - token_list.append(input_ids[offset:offset + seq_len]) - offset += seq_len - - new_pos_list = [] - for positions, tokens in zip(pos_list, token_list): - # Verify assumption that incoming position are - # always a sequence from 0 to N. - expected_pos = torch.arange(positions.size()[0], - dtype=torch.long, - device=inputs_embeds.device) - assert torch.equal(positions, expected_pos) - new_pos_list.append( + for tokens in token_list: + pos_list.append( create_position_ids_from_input_ids(tokens, self.padding_idx)) - position_ids = torch.cat(new_pos_list) + + corrected_positions = torch.cat(pos_list) # Position embeddings. - position_embeddings = self.position_embeddings(position_ids) + position_embeddings = self.position_embeddings(corrected_positions) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, @@ -150,8 +145,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): assert len(loaded), "Unable to load RobertaEmbeddingModel" -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsV0Only): +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4ad7178374b..117c4631b91 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -386,11 +386,13 @@ def __init__( f"Supported head sizes are: {support_head_sizes}. " "Set VLLM_USE_V1=0 to use another attention backend.") - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " + if attn_type not in [ + AttentionType.DECODER, AttentionType.ENCODER_ONLY + ]: + raise NotImplementedError("Encoder/decoder cross-attention " + "is not implemented for " "FlashAttentionImpl") + self.attn_type = attn_type self.use_irope = use_irope self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -509,7 +511,7 @@ def forward( seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, - causal=True, + causal=_get_causal_option(self.attn_type), alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, @@ -711,3 +713,21 @@ def cascade_attention( # Merge prefix and suffix outputs, and store the result in output. merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) + + +def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9489bcf433f..c52b1e5afb6 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -915,6 +915,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: dtype=spec.dtype, use_mla=spec.use_mla, sliding_window=spec.sliding_window, + attn_type=str(spec.attn_type), ) if is_hybrid(kv_cache_spec): diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 6f31031a108..2dd43a6b795 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -24,6 +24,7 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_hashes: list[str] mm_positions: list[PlaceholderRange] @@ -42,6 +43,7 @@ def from_request( return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, + token_type_ids=request.token_type_ids, mm_inputs=request.mm_inputs, mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 921ccd708cd..dcad52959b3 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -49,6 +49,7 @@ class EngineCoreRequest( request_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index da65550354d..e384dadffcb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -35,7 +35,7 @@ EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.executor.abstract import Executor -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -150,6 +150,23 @@ def _initialize_kv_caches( zip(kv_cache_specs, available_gpu_memory) ] + for kv_cache_spec_one_worker in kv_cache_specs: + for _, spec in kv_cache_spec_one_worker.items(): + if isinstance(spec, AttentionSpec) and \ + spec.attn_type != "decoder": + + logger.info("Found non-decoder layer. Disabling " + "prefix cache and chunked prefill") + self.vllm_config.cache_config.\ + enable_prefix_caching = False + self.vllm_config.scheduler_config.\ + enable_chunked_prefill = False + self.vllm_config.scheduler_config.\ + chunked_prefill_enabled = False + self.vllm_config.scheduler_config.\ + long_prefill_token_threshold = 0 + break + # Since we use a shared centralized controller, we need the # `kv_cache_config` to be consistent across all workers to make sure # all the memory operators can be applied to all workers. diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index a0b170ba55a..2ae2de6f4be 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -329,6 +329,7 @@ def process_inputs( return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], + token_type_ids=decoder_inputs.get("token_type_ids"), mm_inputs=sorted_mm_inputs, mm_hashes=sorted_mm_hashes, mm_placeholders=sorted_mm_positions, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index c48775adc9b..d8af8486ca5 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -75,6 +75,7 @@ class AttentionSpec(KVCacheSpec): head_size: int dtype: torch.dtype use_mla: bool + attn_type: str @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 9b96f4599f9..d22220e3e18 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -24,6 +24,7 @@ def __init__( self, request_id: str, prompt_token_ids: list[int], + token_type_ids: Optional[list[int]], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], multi_modal_placeholders: Optional[list[PlaceholderRange]], @@ -74,6 +75,7 @@ def __init__( "sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids + self.token_type_ids = token_type_ids self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: list[int] = [] self._all_token_ids: list[int] = self.prompt_token_ids.copy() @@ -118,6 +120,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, + token_type_ids=request.token_type_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 3a2c9ef7dfa..9514439f142 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Datastructures defining a GPU input batch +# Datastructures defining an input batch from dataclasses import dataclass from typing import Optional, cast @@ -27,6 +27,7 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] @@ -89,6 +90,8 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.token_type_ids_cpu_tensor = None + self._token_type_ids_cpu = None self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) @@ -231,6 +234,22 @@ def __init__( self.pooling_params: dict[str, PoolingParams] = {} + @property + def token_type_ids_cpu(self) -> np.ndarray: + if self._token_type_ids_cpu is None: + self.token_type_ids_cpu_tensor = torch.zeros( + self.token_ids_cpu_tensor.shape, + device="cpu", + dtype=torch.int8, + pin_memory=False, + ) + self._token_type_ids_cpu = cast( + torch.Tensor, self.token_type_ids_cpu_tensor).numpy() + return self._token_type_ids_cpu + + def has_token_types(self) -> bool: + return self._token_type_ids_cpu is not None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -261,6 +280,9 @@ def add_request( self.num_prompt_tokens[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids + if request.token_type_ids is not None: + self.token_type_ids_cpu[ + req_index, :num_prompt_tokens] = request.token_type_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, @@ -447,6 +469,10 @@ def swap_states(self, i1: int, i2: int) -> None: tmp = self.token_ids_cpu[i1, ...].copy() self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + if self.has_token_types(): + tmp2 = self.token_type_ids_cpu[i1, ...].copy() + self.token_type_ids_cpu[i1, ...] = self.token_type_ids_cpu[i2, ...] + self.token_type_ids_cpu[i2, ...] = tmp2 swap_dict_values(self.generators, i1, i2) swap_dict_values(self.min_tokens, i1, i2) @@ -503,6 +529,9 @@ def condense(self, empty_req_indices: list[int]) -> None: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] + if self.has_token_types(): + self.token_type_ids_cpu[empty_index, :num_tokens] = \ + self.token_type_ids_cpu[last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 33036600611..bf68f8fde05 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6,7 +6,7 @@ import time import weakref from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np import torch @@ -237,7 +237,7 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) - + self.token_type_ids = None # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -305,6 +305,13 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + def get_token_type_ids(self) -> Optional[torch.Tensor]: + if self.token_type_ids is None: + self.token_type_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int8, + device=self.device) + return self.token_type_ids + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: """ Update the order of requests in the batch based on the attention @@ -408,6 +415,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -627,6 +635,13 @@ def _prepare_inputs( 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) + if self.input_batch.token_type_ids_cpu_tensor is not None: + token_type_ids = torch.index_select( + self.input_batch.token_type_ids_cpu_tensor.flatten(), 0, + torch.from_numpy(token_indices)) + # Copy the tensors to the GPU. + self.get_token_type_ids()[:total_num_scheduled_tokens]\ + .copy_(token_type_ids, non_blocking=True) # Calculate the slot mapping for each KV cache group. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1299,11 +1314,17 @@ def execute_model( else: mm_embeds = [] + has_token_types = self.token_type_ids is not None + model_kwargs = {} + if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] + if has_token_types: + model_kwargs["token_type_ids"] = cast( + torch.Tensor, self.token_type_ids)[:num_scheduled_tokens] if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1319,6 +1340,9 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] + if has_token_types: + model_kwargs["token_type_ids"] = cast( + torch.Tensor, self.token_type_ids)[:num_input_tokens] inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -1352,6 +1376,7 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **model_kwargs, ) self.maybe_wait_for_kv_save() @@ -1907,7 +1932,6 @@ def _dummy_run( attn_metadata: Optional[dict[str, Any]] = None if capture_attn_cudagraph: - attn_metadata = {} query_start_loc = self.query_start_loc[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. @@ -1925,6 +1949,7 @@ def _dummy_run( max_query_len=num_tokens, ) + attn_metadata = {} for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -1972,6 +1997,7 @@ def _dummy_run( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: @@ -2482,7 +2508,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: continue # TODO: Support other attention modules, e.g., cross-attention - if attn_module.attn_type == AttentionType.DECODER: + # encoder only can also benefit from KV cache for prefix caching + if attn_module.attn_type in (AttentionType.DECODER, + AttentionType.ENCODER_ONLY): if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -2490,17 +2518,18 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=use_mla) + use_mla=use_mla, + attn_type=str(attn_module.attn_type)) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. + use_mla=use_mla, + attn_type=str(attn_module.attn_type)) + elif attn_module.attn_type == AttentionType.ENCODER: + # encoder attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 774caa1a3d9..7e7b80c0dd0 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -394,6 +394,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -494,6 +495,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=False, + attn_type=str(attn_module.attn_type), ) else: kv_cache_spec[layer_name] = FullAttentionSpec( @@ -502,6 +504,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=False, + attn_type=str(attn_module.attn_type), ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): From b430dba9a1cc235d13fecb4bf46e928dd2e0bb0f Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 23 Jun 2025 12:18:42 -0300 Subject: [PATCH 02/18] Use multi-modal support to pass token_type_ids to the model Signed-off-by: Max de Bayser --- vllm/entrypoints/openai/serving_score.py | 11 +- vllm/model_executor/models/bert.py | 217 ++++++++++++++++++++--- vllm/model_executor/models/roberta.py | 86 ++++++--- vllm/v1/core/sched/output.py | 2 - vllm/v1/engine/__init__.py | 1 - vllm/v1/engine/processor.py | 1 - vllm/v1/request.py | 3 - vllm/v1/worker/gpu_input_batch.py | 29 --- vllm/v1/worker/gpu_model_runner.py | 28 +-- vllm/v1/worker/tpu_model_runner.py | 1 - 10 files changed, 265 insertions(+), 114 deletions(-) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 9f333c02ab5..3d4095ce488 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -7,6 +7,7 @@ from fastapi import Request +import vllm.envs as envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger @@ -180,9 +181,17 @@ async def _cross_encoding_score( input_ids = prompt_inputs["input_ids"] text_token_prompt = \ self._validate_input(request, input_ids, request_prompt) + + token_type_ids = prompt_inputs.get("token_type_ids") + mm_data = None + if envs.VLLM_USE_V1 and token_type_ids is not None: + mm_data = {"token_type_ids": token_type_ids} + token_type_ids = None + engine_prompt = TokensPrompt( prompt_token_ids=text_token_prompt["prompt_token_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) + token_type_ids=token_type_ids, + multi_modal_data=mm_data) request_prompts.append(request_prompt) engine_prompts.append(engine_prompt) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index cb3ac791ccb..cb6c5ac52f2 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable -from typing import Optional +from collections.abc import Iterable, Mapping, Sequence +from typing import Optional, Union import torch from torch import nn -from transformers import BertConfig +from transformers import BatchFeature, BertConfig from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile @@ -23,11 +23,21 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalFlatField, MultiModalInputs, + MultiModalKwargs, MultiModalKwargsItem, + PlaceholderRange) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) -from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces import (MultiModalEmbeddings, SupportsCrossEncoding, + SupportsMultiModal, SupportsQuant) from .utils import WeightsMapper, maybe_prefix @@ -53,30 +63,46 @@ def __init__(self, config: BertConfig): raise ValueError("Only 'absolute' position_embedding_type" + " is supported") + def maybe_store_input_ids(self, input_ids: torch.Tensor): + pass + def forward( self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + apply_layer_norm: bool = True, ) -> torch.Tensor: - input_shape = input_ids.size() - # Input embeddings. - inputs_embeds = self.word_embeddings(input_ids) + # forward was called directly without going + # throught the multi-modal flow + if input_ids is not None and position_ids is not None \ + and inputs_embeds is None and token_type_ids is None: + token_type_ids = torch.zeros(input_ids.size(), + dtype=torch.long, + device=input_ids.device) - # Position embeddings. - position_embeddings = self.position_embeddings(position_ids) + tensors_to_add: list[torch.Tensor] = [] - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device) + if inputs_embeds is not None: + tensors_to_add.append(inputs_embeds) - token_type_embeddings = self.token_type_embeddings(token_type_ids) + if token_type_ids is not None: + tensors_to_add.append(self.token_type_embeddings(token_type_ids)) - embeddings = inputs_embeds + token_type_embeddings + position_embeddings - embeddings = self.LayerNorm(embeddings) - return embeddings + if position_ids is not None: + tensors_to_add.append(self.position_embeddings(position_ids)) + + if input_ids is not None: + tensors_to_add.append(self.word_embeddings(input_ids)) + + embeds = torch.stack(tensors_to_add, dim=0).sum(dim=0) + + if apply_layer_norm: + return self.LayerNorm(embeds) + else: + return embeds class BertPooler(nn.Module): @@ -337,12 +363,11 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embeddings(input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids) + + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds) return self.encoder(hidden_states) def load_weights(self, weights: Iterable[tuple[str, @@ -442,8 +467,143 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: softmax=False) +TOKEN_TYPES = "token_type_ids" + + +class TokenTypeMultiModalProcessor(BaseMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + raise NotImplementedError + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + raise NotImplementedError + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + + assert isinstance(prompt, list) + + mm_data_item = mm_data[TOKEN_TYPES] + if isinstance(mm_data_item, list): + mm_data_item = torch.tensor(mm_data_item) + + prompt_len = len(prompt) + + mm_placeholders = { + TOKEN_TYPES: [PlaceholderRange( + offset=0, + length=prompt_len, + )] + } + + field = MultiModalFlatField([slice(0, prompt_len)]) + mm_item = MultiModalKwargsItem.from_elems( + field.build_elems(modality=TOKEN_TYPES, + key=TOKEN_TYPES, + data=mm_data_item)) + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=prompt, + mm_kwargs=MultiModalKwargs.from_items([mm_item]), + mm_hashes=None, + mm_placeholders=mm_placeholders, + ) + + +class TokenTypeProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {TOKEN_TYPES: 1} + + +class TokenTypeInputBuilder(BaseDummyInputsBuilder[TokenTypeProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + raise NotImplementedError + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + raise NotImplementedError + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + + dummy_prompt = [0] * seq_len + dummy_mm_data = { + TOKEN_TYPES: torch.zeros(seq_len, dtype=torch.int32), + } + return ProcessorInputs(prompt=dummy_prompt, mm_data=dummy_mm_data) + + +class BertMMTokenIdsMixin(SupportsMultiModal): + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + token_type_ids = kwargs.pop(TOKEN_TYPES, None) + + if token_type_ids is None: + return [] + + if not isinstance(token_type_ids, torch.Tensor): + raise ValueError("Incorrect type token_type_ids. " + f"Got type: {type(token_type_ids)}") + + return self.get_language_model().embeddings( + token_type_ids=token_type_ids, apply_layer_norm=False) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + token_type_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + # save for forward() + self.get_language_model().embeddings.maybe_store_input_ids(input_ids) + + token_type_ids: Optional[torch.Tensor] = None + + if token_type_embeddings is not None: + assert isinstance(token_type_embeddings, list) + token_type_embeddings = torch.cat(token_type_embeddings) + else: + token_type_ids = torch.zeros(input_ids.size(), + dtype=torch.long, + device=input_ids.device) + + return self.get_language_model().embeddings( + input_ids=input_ids, + inputs_embeds=token_type_embeddings, + token_type_ids=token_type_ids, + apply_layer_norm=False) + + +@MULTIMODAL_REGISTRY.register_processor(TokenTypeMultiModalProcessor, + info=TokenTypeProcessingInfo, + dummy_inputs=TokenTypeInputBuilder) class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsQuant): + BertMMTokenIdsMixin, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -470,6 +630,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._pooler = ClassifierPooler(vllm_config.model_config, self.classifier, self.bert.pooler) + def get_language_model(self) -> torch.nn.Module: + return self.bert + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): self_weights = [] @@ -502,7 +665,7 @@ def pooler( def forward( self, input_ids: Optional[torch.Tensor], - positions: torch.Tensor, + positions: Optional[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 87251236b37..ec0938f64a2 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -14,9 +14,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel +from vllm.model_executor.models.bert import (BertEmbeddingModel, + BertMMTokenIdsMixin, BertModel, + TokenTypeInputBuilder, + TokenTypeMultiModalProcessor, + TokenTypeProcessingInfo) from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) @@ -48,17 +53,10 @@ def __init__(self, config: RobertaConfig): if self.position_embedding_type != "absolute": raise ValueError("Only 'absolute' position_embedding_type" + " is supported") + self.input_ids: Optional[torch.Tensor] = None - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - input_shape = input_ids.size() - inputs_embeds = self.word_embeddings(input_ids) - + def _correct_positions(self, input_ids: torch.Tensor, + position_ids: torch.Tensor) -> torch.Tensor: zero_pos = torch.where(position_ids == 0)[0] end_pos = torch.cat((zero_pos[1:], torch.tensor([position_ids.shape[0]], @@ -78,19 +76,56 @@ def forward( pos_list.append( create_position_ids_from_input_ids(tokens, self.padding_idx)) - corrected_positions = torch.cat(pos_list) + return torch.cat(pos_list) + + def maybe_store_input_ids(self, input_ids: torch.Tensor): + self.input_ids = input_ids + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + apply_layer_norm: bool = True, + ) -> torch.Tensor: - # Position embeddings. - position_embeddings = self.position_embeddings(corrected_positions) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, + # forward was called directly without going + # throught the multi-modal flow + if input_ids is not None and position_ids is not None \ + and inputs_embeds is None and token_type_ids is None: + token_type_ids = torch.zeros(input_ids.size(), dtype=torch.long, - device=inputs_embeds.device) + device=input_ids.device) + + tensors_to_add: list[torch.Tensor] = [] + + if inputs_embeds is not None: + tensors_to_add.append(inputs_embeds) + + if token_type_ids is not None: + tensors_to_add.append(self.token_type_embeddings(token_type_ids)) + + if position_ids is not None: + inputs = input_ids if input_ids is not None else self.input_ids + if inputs is None: # it can by during _dummy_run + corrected_positions = position_ids + else: + corrected_positions = self._correct_positions( + input_ids=inputs, position_ids=position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings + position_embeddings - embeddings = self.LayerNorm(embeddings) - return embeddings + tensors_to_add.append( + self.position_embeddings(corrected_positions)) + + if input_ids is not None: + tensors_to_add.append(self.word_embeddings(input_ids)) + + embeds = torch.stack(tensors_to_add, dim=0).sum(dim=0) + + if apply_layer_norm: + return self.LayerNorm(embeds) + else: + return embeds # Adapted from transformers @@ -145,7 +180,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): assert len(loaded), "Unable to load RobertaEmbeddingModel" -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): +@MULTIMODAL_REGISTRY.register_processor(TokenTypeMultiModalProcessor, + info=TokenTypeProcessingInfo, + dummy_inputs=TokenTypeInputBuilder) +class RobertaForSequenceClassification(nn.Module, BertMMTokenIdsMixin, + SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -185,6 +224,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._pooler = ClassifierPooler(vllm_config.model_config, self.classifier) + def get_language_model(self) -> torch.nn.Module: + return self.roberta + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): bert_weights, task_weights = roberta_task_weights_filter(weights) bert_weights = self.jina_to_vllm_mapper.apply(bert_weights) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 2dd43a6b795..6f31031a108 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -24,7 +24,6 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] - token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_hashes: list[str] mm_positions: list[PlaceholderRange] @@ -43,7 +42,6 @@ def from_request( return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, - token_type_ids=request.token_type_ids, mm_inputs=request.mm_inputs, mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index dcad52959b3..921ccd708cd 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -49,7 +49,6 @@ class EngineCoreRequest( request_id: str prompt_token_ids: list[int] - token_type_ids: Optional[list[int]] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 2ae2de6f4be..a0b170ba55a 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -329,7 +329,6 @@ def process_inputs( return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], - token_type_ids=decoder_inputs.get("token_type_ids"), mm_inputs=sorted_mm_inputs, mm_hashes=sorted_mm_hashes, mm_placeholders=sorted_mm_positions, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d22220e3e18..9b96f4599f9 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -24,7 +24,6 @@ def __init__( self, request_id: str, prompt_token_ids: list[int], - token_type_ids: Optional[list[int]], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], multi_modal_placeholders: Optional[list[PlaceholderRange]], @@ -75,7 +74,6 @@ def __init__( "sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids - self.token_type_ids = token_type_ids self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: list[int] = [] self._all_token_ids: list[int] = self.prompt_token_ids.copy() @@ -120,7 +118,6 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, - token_type_ids=request.token_type_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 9514439f142..9262083a159 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -27,7 +27,6 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] - token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] @@ -90,8 +89,6 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() - self.token_type_ids_cpu_tensor = None - self._token_type_ids_cpu = None self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) @@ -234,22 +231,6 @@ def __init__( self.pooling_params: dict[str, PoolingParams] = {} - @property - def token_type_ids_cpu(self) -> np.ndarray: - if self._token_type_ids_cpu is None: - self.token_type_ids_cpu_tensor = torch.zeros( - self.token_ids_cpu_tensor.shape, - device="cpu", - dtype=torch.int8, - pin_memory=False, - ) - self._token_type_ids_cpu = cast( - torch.Tensor, self.token_type_ids_cpu_tensor).numpy() - return self._token_type_ids_cpu - - def has_token_types(self) -> bool: - return self._token_type_ids_cpu is not None - @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -280,9 +261,6 @@ def add_request( self.num_prompt_tokens[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids - if request.token_type_ids is not None: - self.token_type_ids_cpu[ - req_index, :num_prompt_tokens] = request.token_type_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, @@ -469,10 +447,6 @@ def swap_states(self, i1: int, i2: int) -> None: tmp = self.token_ids_cpu[i1, ...].copy() self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp - if self.has_token_types(): - tmp2 = self.token_type_ids_cpu[i1, ...].copy() - self.token_type_ids_cpu[i1, ...] = self.token_type_ids_cpu[i2, ...] - self.token_type_ids_cpu[i2, ...] = tmp2 swap_dict_values(self.generators, i1, i2) swap_dict_values(self.min_tokens, i1, i2) @@ -529,9 +503,6 @@ def condense(self, empty_req_indices: list[int]) -> None: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] - if self.has_token_types(): - self.token_type_ids_cpu[empty_index, :num_tokens] = \ - self.token_type_ids_cpu[last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bf68f8fde05..61446f89ffe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6,7 +6,7 @@ import time import weakref from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import torch @@ -237,7 +237,6 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) - self.token_type_ids = None # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -305,13 +304,6 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} - def get_token_type_ids(self) -> Optional[torch.Tensor]: - if self.token_type_ids is None: - self.token_type_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int8, - device=self.device) - return self.token_type_ids - def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: """ Update the order of requests in the batch based on the attention @@ -415,7 +407,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -635,13 +626,6 @@ def _prepare_inputs( 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - if self.input_batch.token_type_ids_cpu_tensor is not None: - token_type_ids = torch.index_select( - self.input_batch.token_type_ids_cpu_tensor.flatten(), 0, - torch.from_numpy(token_indices)) - # Copy the tensors to the GPU. - self.get_token_type_ids()[:total_num_scheduled_tokens]\ - .copy_(token_type_ids, non_blocking=True) # Calculate the slot mapping for each KV cache group. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1314,17 +1298,11 @@ def execute_model( else: mm_embeds = [] - has_token_types = self.token_type_ids is not None - model_kwargs = {} - if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] - if has_token_types: - model_kwargs["token_type_ids"] = cast( - torch.Tensor, self.token_type_ids)[:num_scheduled_tokens] if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1340,9 +1318,6 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] - if has_token_types: - model_kwargs["token_type_ids"] = cast( - torch.Tensor, self.token_type_ids)[:num_input_tokens] inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -1376,7 +1351,6 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **model_kwargs, ) self.maybe_wait_for_kv_save() diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 7e7b80c0dd0..0d85cc784a1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -394,7 +394,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, From aad1052264ed20df0b8aecf9dfd584befec5071e Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 24 Jun 2025 12:41:31 -0300 Subject: [PATCH 03/18] reduce diff Signed-off-by: Max de Bayser --- vllm/v1/worker/gpu_model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 61446f89ffe..dc2faa4cb35 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1906,6 +1906,7 @@ def _dummy_run( attn_metadata: Optional[dict[str, Any]] = None if capture_attn_cudagraph: + attn_metadata = {} query_start_loc = self.query_start_loc[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. @@ -1923,7 +1924,6 @@ def _dummy_run( max_query_len=num_tokens, ) - attn_metadata = {} for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -1971,7 +1971,6 @@ def _dummy_run( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) - if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: @@ -2482,7 +2481,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: continue # TODO: Support other attention modules, e.g., cross-attention - # encoder only can also benefit from KV cache for prefix caching if attn_module.attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY): if attn_module.sliding_window is not None: From 7006e8a3ea804f0e9c392129614277945f2073ab Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 14:42:09 -0300 Subject: [PATCH 04/18] Fix cuda graphs for BERT models Signed-off-by: Max de Bayser --- vllm/config.py | 5 -- vllm/model_executor/models/bert.py | 4 +- vllm/model_executor/models/roberta.py | 96 ++++++++++++++++++--------- 3 files changed, 66 insertions(+), 39 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3e8d996e62d..6412e6e293b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -720,11 +720,6 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: self.override_pooler_config = PoolerConfig( **self.override_pooler_config) - # WIP: currently cuda graphs are not working for encoder models. - logger.warning("CUDA graph is not supported for pooling yet, " - "fallback to the eager mode.") - self.enforce_eager = True - pooler_config = self.override_pooler_config or PoolerConfig() base_config = get_pooling_config(self.model, self.revision) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index cb6c5ac52f2..76e86b53c4b 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -121,7 +121,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output -@support_torch_compile class BertEncoder(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): @@ -339,6 +338,7 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states +@support_torch_compile class BertModel(nn.Module, SupportsQuant): packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} @@ -665,7 +665,7 @@ def pooler( def forward( self, input_ids: Optional[torch.Tensor], - positions: Optional[torch.Tensor], + positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index ec0938f64a2..9dc66439b58 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -10,6 +10,7 @@ from transformers import RobertaConfig from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -55,29 +56,6 @@ def __init__(self, config: RobertaConfig): " is supported") self.input_ids: Optional[torch.Tensor] = None - def _correct_positions(self, input_ids: torch.Tensor, - position_ids: torch.Tensor) -> torch.Tensor: - zero_pos = torch.where(position_ids == 0)[0] - end_pos = torch.cat((zero_pos[1:], - torch.tensor([position_ids.shape[0]], - device=zero_pos.device))) - seq_lens = end_pos - zero_pos - - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - token_list = torch.split(input_ids, seq_lens.tolist()) - - pos_list = [] - for tokens in token_list: - pos_list.append( - create_position_ids_from_input_ids(tokens, self.padding_idx)) - - return torch.cat(pos_list) - def maybe_store_input_ids(self, input_ids: torch.Tensor): self.input_ids = input_ids @@ -107,15 +85,7 @@ def forward( tensors_to_add.append(self.token_type_embeddings(token_type_ids)) if position_ids is not None: - inputs = input_ids if input_ids is not None else self.input_ids - if inputs is None: # it can by during _dummy_run - corrected_positions = position_ids - else: - corrected_positions = self._correct_positions( - input_ids=inputs, position_ids=position_ids) - - tensors_to_add.append( - self.position_embeddings(corrected_positions)) + tensors_to_add.append(self.position_embeddings(position_ids)) if input_ids is not None: tensors_to_add.append(self.word_embeddings(input_ids)) @@ -156,6 +126,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel): _pooler: An instance of Pooler used for pooling operations. """ + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # Fix Roberta positions here outside of the CUDA graph. + # Because we need the to extract the sequences from + # input_ids the control flow is data dependent. + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) + + return self.model(input_ids=input_ids, + position_ids=positions, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> Union[BertModel, BertWithRope]: @@ -210,6 +206,7 @@ class RobertaForSequenceClassification(nn.Module, BertMMTokenIdsMixin, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id self.default_activation_function = \ get_cross_encoder_activation_function(config) @@ -257,6 +254,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) return self.roberta(input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds, @@ -288,6 +288,38 @@ def create_position_ids_from_input_ids(input_ids, return incremental_indices.long() + padding_idx +def replace_roberta_positions(input_ids: torch.Tensor, + position_ids: torch.Tensor, + padding_idx: int) -> None: + + seq_lens: Optional[torch.Tensor] = None + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: # can be None during warmup + if isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values())) + # TODO: remove "seq_lens_tensor" after V0 is removed + seq_lens = getattr(attn_metadata, "seq_lens_tensor", + getattr(attn_metadata, "seq_lens", None)) + + if seq_lens is not None: + assert isinstance(seq_lens, torch.Tensor) + + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + token_list = torch.split(input_ids, seq_lens.tolist()) + + offset = 0 + for tokens in token_list: + length = tokens.shape[0] + position_ids[offset:offset+length] = \ + create_position_ids_from_input_ids(tokens, padding_idx) + offset = offset + length + + def roberta_task_weights_filter( all_weights: Iterable[tuple[str, torch.Tensor]] ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, From c99df96506373634cbc409ac891da12c9e7fbc35 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 21:43:26 -0300 Subject: [PATCH 05/18] Add token_type_ids multi-modal to LLM._cross_encoding_score Signed-off-by: Max de Bayser --- vllm/engine/arg_utils.py | 2 +- vllm/entrypoints/llm.py | 14 ++++++++++++-- vllm/model_executor/models/roberta.py | 3 ++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 120d7ff9679..e53ee2e8e76 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1683,7 +1683,7 @@ def _set_default_args_v1(self, usage_context: UsageContext, if (self.max_num_seqs is None and usage_context in default_max_num_seqs): self.max_num_seqs = min(default_max_num_seqs[usage_context], - self.max_num_batched_tokens) + self.max_num_batched_tokens or sys.maxsize) logger.debug("Setting max_num_seqs to %d for %s usage context.", self.max_num_seqs, use_context_value) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f0404e0bc6e..96614dab6c9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -14,6 +14,7 @@ from tqdm.auto import tqdm from typing_extensions import TypeVar, deprecated +import vllm.envs as envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, create_sort_beams_key_function) @@ -1208,9 +1209,18 @@ def _cross_encoding_score( prompt_inputs = tokenizer(text=q, text_pair=t, **tokenization_kwargs) + + token_type_ids = prompt_inputs.get("token_type_ids") + mm_data = None + if envs.VLLM_USE_V1 and token_type_ids is not None: + mm_data = {"token_type_ids": token_type_ids} + token_type_ids = None + engine_prompt = TokensPrompt( - prompt_token_ids=prompt_inputs["input_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) + prompt_token_ids=prompt_inputs["prompt_token_ids"], + token_type_ids=token_type_ids, + multi_modal_data=mm_data) + parsed_prompts.append(engine_prompt) self._validate_and_add_requests( diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 9dc66439b58..b9084fc927c 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -310,7 +310,8 @@ def replace_roberta_positions(input_ids: torch.Tensor, # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - token_list = torch.split(input_ids, seq_lens.tolist()) + token_list = torch.split(input_ids[:torch.sum(seq_lens)], + seq_lens.tolist()) offset = 0 for tokens in token_list: From bbe0ea7622d5ab3f13dc47a60221b355941ee453 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 21:57:57 -0300 Subject: [PATCH 06/18] fix merge problem Signed-off-by: Max de Bayser --- vllm/model_executor/models/bert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 76e86b53c4b..43664ceeda2 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -492,6 +492,7 @@ def apply( prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, return_mm_hashes: bool = False, ) -> MultiModalInputs: From 019496a510e87dbdf7b3f2e5f92a58b1d051385a Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 22:04:39 -0300 Subject: [PATCH 07/18] fix editing mistake Signed-off-by: Max de Bayser --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 96614dab6c9..e69af02a7ff 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1217,7 +1217,7 @@ def _cross_encoding_score( token_type_ids = None engine_prompt = TokensPrompt( - prompt_token_ids=prompt_inputs["prompt_token_ids"], + prompt_token_ids=prompt_inputs["input_ids"], token_type_ids=token_type_ids, multi_modal_data=mm_data) From 6558bdd9cfe8004fceaf948467a9abe48055732b Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 22:27:08 -0300 Subject: [PATCH 08/18] fix missing input ids Signed-off-by: Max de Bayser --- vllm/model_executor/models/bert.py | 6 +++--- vllm/model_executor/models/roberta.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 43664ceeda2..077c04a7b36 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -63,9 +63,6 @@ def __init__(self, config: BertConfig): raise ValueError("Only 'absolute' position_embedding_type" + " is supported") - def maybe_store_input_ids(self, input_ids: torch.Tensor): - pass - def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -574,6 +571,9 @@ def get_multimodal_embeddings(self, return self.get_language_model().embeddings( token_type_ids=token_type_ids, apply_layer_norm=False) + def maybe_store_input_ids(self, input_ids: torch.Tensor): + pass + def get_input_embeddings( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index b9084fc927c..3553d49bc60 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -54,10 +54,6 @@ def __init__(self, config: RobertaConfig): if self.position_embedding_type != "absolute": raise ValueError("Only 'absolute' position_embedding_type" + " is supported") - self.input_ids: Optional[torch.Tensor] = None - - def maybe_store_input_ids(self, input_ids: torch.Tensor): - self.input_ids = input_ids def forward( self, @@ -220,6 +216,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._pooler = ClassifierPooler(vllm_config.model_config, self.classifier) + self.input_ids: Optional[torch.Tensor] = None + + def maybe_store_input_ids(self, input_ids: torch.Tensor): + self.input_ids = input_ids def get_language_model(self) -> torch.nn.Module: return self.roberta @@ -254,7 +254,7 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - replace_roberta_positions(input_ids=input_ids, + replace_roberta_positions(input_ids=input_ids or self.input_ids, position_ids=positions, padding_idx=self.padding_idx) return self.roberta(input_ids=input_ids, From 33bcc88a88a7f890b7a5a6f63660fed0b20126e6 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 22:31:56 -0300 Subject: [PATCH 09/18] fix mistake Signed-off-by: Max de Bayser --- vllm/model_executor/models/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 077c04a7b36..96d094d261b 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -581,7 +581,7 @@ def get_input_embeddings( ) -> torch.Tensor: # save for forward() - self.get_language_model().embeddings.maybe_store_input_ids(input_ids) + self.maybe_store_input_ids(input_ids) token_type_ids: Optional[torch.Tensor] = None From a7432685326edf0d52d0e73a5218cd9655e57be9 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 23:10:56 -0300 Subject: [PATCH 10/18] fix tensor not boolean error Signed-off-by: Max de Bayser --- vllm/model_executor/models/roberta.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 3553d49bc60..a117627959f 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -254,7 +254,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - replace_roberta_positions(input_ids=input_ids or self.input_ids, + _input_ids = input_ids if input_ids is not None else self.input_ids + replace_roberta_positions(input_ids=_input_ids, position_ids=positions, padding_idx=self.padding_idx) return self.roberta(input_ids=input_ids, From 6310f4db99a2e4661549c541758128b84cb8b0bd Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 23:37:02 -0300 Subject: [PATCH 11/18] appease mypy Signed-off-by: Max de Bayser --- vllm/entrypoints/llm.py | 3 ++- vllm/entrypoints/openai/serving_score.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e69af02a7ff..bfb9562ca2f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -39,6 +39,7 @@ from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest, LLMGuidedOptions) from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.multimodal.inputs import MultiModalDataDict from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, PoolingRequestOutput, RequestOutput, ScoringRequestOutput) @@ -1211,7 +1212,7 @@ def _cross_encoding_score( **tokenization_kwargs) token_type_ids = prompt_inputs.get("token_type_ids") - mm_data = None + mm_data: MultiModalDataDict = {} if envs.VLLM_USE_V1 and token_type_ids is not None: mm_data = {"token_type_ids": token_type_ids} token_type_ids = None diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 3d4095ce488..c39bd3b6351 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -24,6 +24,7 @@ from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal.inputs import MultiModalDataDict from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, @@ -183,7 +184,7 @@ async def _cross_encoding_score( self._validate_input(request, input_ids, request_prompt) token_type_ids = prompt_inputs.get("token_type_ids") - mm_data = None + mm_data: MultiModalDataDict = {} if envs.VLLM_USE_V1 and token_type_ids is not None: mm_data = {"token_type_ids": token_type_ids} token_type_ids = None From 3f793246fc6470d41c380f896b807fd8d2debe57 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 9 Jul 2025 15:36:01 -0300 Subject: [PATCH 12/18] Fix missing args Signed-off-by: Max de Bayser --- tests/v1/core/test_kv_cache_utils.py | 7 +++++++ tests/v1/core/test_prefix_caching.py | 18 ++++++++++++++++-- tests/v1/core/test_scheduler.py | 5 +++-- tests/v1/core/test_specialized_manager.py | 4 ++++ tests/v1/engine/test_engine_core_client.py | 4 +++- tests/v1/kv_connector/unit/utils.py | 3 ++- tests/v1/worker/test_gpu_model_runner.py | 3 ++- 7 files changed, 37 insertions(+), 7 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 97fff745647..4eeb9b79247 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -5,6 +5,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -61,6 +62,7 @@ def new_kv_cache_spec(block_size=16, head_size=head_size, dtype=dtype, use_mla=use_mla, + attn_type=AttentionType.DECODER, sliding_window=sliding_window) @@ -75,6 +77,7 @@ def new_sliding_window_spec(block_size=16, head_size=head_size, dtype=dtype, use_mla=use_mla, + attn_type=AttentionType.DECODER, sliding_window=sliding_window) @@ -534,6 +537,7 @@ def test_merge_kv_cache_spec(): head_size=full_spec.head_size, dtype=full_spec.dtype, use_mla=full_spec.use_mla, + attn_type=AttentionType.DECODER, sliding_window=1, ), ] @@ -603,6 +607,7 @@ def test_estimate_max_model_len(model_id, max_model_len, head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, ) # Estimate the maximum model length, 16384 model_len need 8GB estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, @@ -638,6 +643,7 @@ def test_get_max_concurrency_for_kv_cache_config(): head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, ) sliding_window_spec = SlidingWindowSpec( @@ -646,6 +652,7 @@ def test_get_max_concurrency_for_kv_cache_config(): head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, sliding_window=1024, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 7a42778831c..d518d76628f 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,6 +8,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -53,7 +54,12 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: kv_cache_groups=[ KVCacheGroupSpec( ["layer"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, + 1, + 1, + torch.float32, + False, + attn_type=AttentionType.DECODER), ) ], ) @@ -67,7 +73,12 @@ def make_kv_cache_config_hybrid_model(block_size: int, kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, + 1, + 1, + torch.float32, + False, + attn_type=AttentionType.DECODER), ), KVCacheGroupSpec( ["layer2"], @@ -76,6 +87,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, torch.float32, False, + attn_type=AttentionType.DECODER, sliding_window=2 * block_size), ), KVCacheGroupSpec( @@ -85,6 +97,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, torch.float32, False, + attn_type=AttentionType.DECODER, sliding_window=2 * block_size), ), ], @@ -1218,6 +1231,7 @@ def test_eagle_with_sliding_window(): dtype=torch.float32, sliding_window=block_size, use_mla=False, + attn_type=AttentionType.DECODER, ) manager = KVCacheManager( KVCacheConfig( diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 02d2c83ab15..88c574a7518 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -6,6 +6,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -104,7 +105,7 @@ def create_scheduler( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) cache_config.num_gpu_blocks = num_blocks @@ -1354,7 +1355,7 @@ def create_scheduler_with_priority( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) cache_config.num_gpu_blocks = num_blocks diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index a9e1898df93..ec3e725ae45 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -3,6 +3,7 @@ import torch +from vllm.attention import AttentionType from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, KVCacheBlock) @@ -26,6 +27,7 @@ def test_sliding_window_possible_cached_prefix(): dtype=torch.float32, sliding_window=4, use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -92,6 +94,7 @@ def test_sliding_window_remove_skipped_blocks(): dtype=torch.float32, sliding_window=4, use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -160,6 +163,7 @@ def test_get_num_blocks_to_allocate(): dtype=torch.float32, sliding_window=4, # Placeholder value, not related to test result use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 65f1da803fb..7718827d5f3 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -16,6 +16,7 @@ from tests.utils import multi_gpu_test from vllm import SamplingParams +from vllm.attention import AttentionType from vllm.distributed.kv_events import (BlockStored, KVEventBatch, ZmqEventPublisher) from vllm.engine.arg_utils import EngineArgs @@ -544,7 +545,8 @@ def create_mock_executor(vllm_config): num_kv_heads=1, head_size=64, dtype=torch.float16, - use_mla=False) + use_mla=False, + attn_type=AttentionType.DECODER) mock_executor.get_kv_cache_specs.return_value = [{ "default": mock_spec diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 983d900606f..b89e62efa51 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -5,6 +5,7 @@ import torch from vllm import SamplingParams +from vllm.attention import AttentionType from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) from vllm.v1.core.sched.scheduler import Scheduler @@ -99,7 +100,7 @@ def create_scheduler( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) vllm_config.cache_config.num_gpu_blocks = num_blocks diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index d13df553db6..99696c9887b 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -6,7 +6,7 @@ import pytest import torch -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) from vllm.platforms import current_platform @@ -38,6 +38,7 @@ def initialize_kv_cache(runner: GPUModelRunner): head_size=runner.model_config.get_head_size(), dtype=runner.kv_cache_dtype, use_mla=False, + attn_type=AttentionType.DECODER, ) tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( From f3f075ae0bb0659d1d276e1d38ab0da0f6789ed6 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 9 Jul 2025 16:58:42 -0300 Subject: [PATCH 13/18] fix mm flag in registry test Signed-off-by: Max de Bayser --- tests/models/test_registry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 01b2260abe8..66c03ff872c 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -51,9 +51,9 @@ def test_registry_imports(model_arch): ("LlamaForCausalLM", False, False, False), ("MllamaForConditionalGeneration", True, False, False), ("LlavaForConditionalGeneration", True, True, False), - ("BertForSequenceClassification", False, False, True), - ("RobertaForSequenceClassification", False, False, True), - ("XLMRobertaForSequenceClassification", False, False, True), + ("BertForSequenceClassification", True, False, True), + ("RobertaForSequenceClassification", True, False, True), + ("XLMRobertaForSequenceClassification", True, False, True), ]) def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): assert ModelRegistry.is_multimodal_model(model_arch) is is_mm From 268099b80fc4eb57f7309eb1181f9dc32772191c Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 9 Jul 2025 22:00:24 -0300 Subject: [PATCH 14/18] remove model from unsupported list Signed-off-by: Max de Bayser --- tests/v1/test_oracle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 7a7ba346a71..55e30ab1ee5 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -13,7 +13,6 @@ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder "state-spaces/mamba-130m-hf", # mamba1 - "BAAI/bge-m3", # embedding ] MODEL = "meta-llama/Llama-3.2-1B-Instruct" From 60696b4900c8f22ce43e6920566fc97cd37072b9 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Fri, 11 Jul 2025 10:06:07 -0300 Subject: [PATCH 15/18] appease linter Signed-off-by: Max de Bayser --- vllm/entrypoints/score_utils.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index c012ce55ad8..fc50396a8cb 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -196,17 +196,15 @@ def get_score_prompt( prompt_inputs = tokenizer(text=prompt_1 + prompt_2, **tokenization_kwargs) - engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"]) - - if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None: - if envs.VLLM_USE_V1: - mm_data = mm_data or {} - mm_data["token_type_ids"] = token_type_ids - else: - engine_prompt["token_type_ids"] = token_type_ids - + mm_data = mm_data or {} + if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None \ + and envs.VLLM_USE_V1: + mm_data["token_type_ids"] = token_type_ids + token_type_ids = None + + engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=token_type_ids, + multi_modal_data=mm_data) post_process_tokens(model_config, engine_prompt) - if mm_data is not None: - engine_prompt["multi_modal_data"] = mm_data return full_prompt, engine_prompt From 00bfc79127ae8da2594d7514aa58b59faaf0e9e3 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Fri, 11 Jul 2025 10:26:40 -0300 Subject: [PATCH 16/18] lazy import Signed-off-by: Max de Bayser --- vllm/entrypoints/score_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 322e800e189..35a3790470f 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -185,6 +185,7 @@ def get_score_prompt( model_config, tokenizer, ) + from vllm.model_executor.model_loader import get_model_cls model = get_model_cls(model_config) if supports_score_template(model): From 25016496ec591465554be23bc915e1cd1860b50d Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Fri, 11 Jul 2025 11:12:50 -0300 Subject: [PATCH 17/18] appease linter Signed-off-by: Max de Bayser --- vllm/entrypoints/score_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 35a3790470f..4b759d78581 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -201,10 +201,9 @@ def get_score_prompt( prompt_inputs = tokenizer(text=prompt_1 + prompt_2, **tokenization_kwargs) - mm_data = mm_data or {} if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None \ and envs.VLLM_USE_V1: - mm_data["token_type_ids"] = token_type_ids + mm_data = {"token_type_ids": token_type_ids, **(mm_data or {})} token_type_ids = None engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"], From 28fb913842c77ac46973af85e1b9943ffa57ee2a Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Fri, 11 Jul 2025 11:32:51 -0300 Subject: [PATCH 18/18] appease linter Signed-off-by: Max de Bayser --- vllm/entrypoints/score_utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 4b759d78581..229fc10391d 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -201,14 +201,16 @@ def get_score_prompt( prompt_inputs = tokenizer(text=prompt_1 + prompt_2, **tokenization_kwargs) - if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None \ - and envs.VLLM_USE_V1: - mm_data = {"token_type_ids": token_type_ids, **(mm_data or {})} - token_type_ids = None - - engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"], - token_type_ids=token_type_ids, - multi_modal_data=mm_data) + engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"]) + + if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None: + if envs.VLLM_USE_V1: + mm_data = {"token_type_ids": token_type_ids, **(mm_data or {})} + else: + engine_prompt["token_type_ids"] = token_type_ids + post_process_tokens(model_config, engine_prompt) + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data return full_prompt, engine_prompt