From 5dee54d2652e698b00730817e77ea85bb3162cba Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 23 Jun 2025 11:25:27 -0300 Subject: [PATCH 01/14] Add support for encoder embedding models Signed-off-by: Max de Bayser --- .../models/language/pooling/test_embedding.py | 16 ++----- tests/models/language/pooling/test_jina.py | 8 ++++ tests/models/language/pooling/test_scoring.py | 9 ++++ tests/v1/core/test_kv_cache_utils.py | 2 +- vllm/config.py | 5 ++ vllm/engine/arg_utils.py | 3 +- vllm/model_executor/models/bert.py | 22 ++++----- vllm/model_executor/models/roberta.py | 38 +++++++-------- vllm/v1/attention/backends/flash_attn.py | 30 ++++++++++-- vllm/v1/core/kv_cache_utils.py | 1 + vllm/v1/core/sched/output.py | 2 + vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/core.py | 19 +++++++- vllm/v1/engine/processor.py | 1 + vllm/v1/kv_cache_interface.py | 1 + vllm/v1/request.py | 3 ++ vllm/v1/worker/gpu_input_batch.py | 31 +++++++++++- vllm/v1/worker/gpu_model_runner.py | 47 +++++++++++++++---- vllm/v1/worker/tpu_model_runner.py | 3 ++ 19 files changed, 178 insertions(+), 64 deletions(-) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 5ef9f768c57..c9c98357b16 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -38,19 +38,13 @@ def v1(run_with_both_engines): marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), # [Encoder-only] pytest.param("BAAI/bge-base-en-v1.5", - marks=[ - pytest.mark.core_model, pytest.mark.cpu_model, - pytest.mark.skip_v1 - ]), - pytest.param("sentence-transformers/all-MiniLM-L12-v2", - marks=[pytest.mark.skip_v1]), - pytest.param("intfloat/multilingual-e5-small", - marks=[pytest.mark.skip_v1]), + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param("sentence-transformers/all-MiniLM-L12-v2"), + pytest.param("intfloat/multilingual-e5-small"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - marks=[pytest.mark.skip_v1]), + marks=[pytest.mark.skip_v0]), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2", - marks=[pytest.mark.skip_v1]), + pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) def test_models( diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 0c44683e748..99466405501 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -26,6 +26,14 @@ ] +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index c75ff144561..1cf2cdc0132 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -23,6 +23,15 @@ "The capital of Germany is Berlin.", ] + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + DTYPE = "half" diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e80ad8a6815..97fff745647 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -916,4 +916,4 @@ def test_get_kv_cache_config(): ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) \ No newline at end of file + ]) diff --git a/vllm/config.py b/vllm/config.py index 7549c97b4fe..d112abd7156 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -716,6 +716,11 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: self.override_pooler_config = PoolerConfig( **self.override_pooler_config) + # WIP: currently cuda graphs are not working for encoder models. + logger.warning("CUDA graph is not supported for pooling yet, " + "fallback to the eager mode.") + self.enforce_eager = True + pooler_config = self.override_pooler_config or PoolerConfig() base_config = get_pooling_config(self.model, self.revision) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dd09f514906..8d01c0413e2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1664,7 +1664,8 @@ def _set_default_args_v1(self, usage_context: UsageContext, if (self.max_num_seqs is None and usage_context in default_max_num_seqs): - self.max_num_seqs = default_max_num_seqs[usage_context] + self.max_num_seqs = min(default_max_num_seqs[usage_context], + self.max_num_batched_tokens) logger.debug("Setting max_num_seqs to %d for %s usage context.", self.max_num_seqs, use_context_value) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index d6f6d9d1fb5..cb3ac791ccb 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -12,7 +12,6 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -28,7 +27,7 @@ from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) -from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .interfaces import SupportsCrossEncoding, SupportsQuant from .utils import WeightsMapper, maybe_prefix @@ -57,7 +56,6 @@ def __init__(self, config: BertConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -342,13 +340,9 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - attn_metadata = get_forward_context().attn_metadata - assert hasattr(attn_metadata, "seq_lens_tensor") - hidden_states = self.embeddings( - input_ids=input_ids, - seq_lens=attn_metadata.seq_lens_tensor, - position_ids=position_ids, - token_type_ids=token_type_ids) + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states) def load_weights(self, weights: Iterable[tuple[str, @@ -388,7 +382,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): +class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -411,11 +405,13 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, position_ids=positions, + token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) @@ -446,8 +442,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: softmax=False) -class BertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding, SupportsQuant): +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 8fa8b89798d..87251236b37 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -22,7 +22,7 @@ get_cross_encoder_activation_function) from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import SupportsCrossEncoding class RobertaEmbedding(nn.Module): @@ -52,41 +52,36 @@ def __init__(self, config: RobertaConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + input_shape = input_ids.size() inputs_embeds = self.word_embeddings(input_ids) + zero_pos = torch.where(position_ids == 0)[0] + end_pos = torch.cat((zero_pos[1:], + torch.tensor([position_ids.shape[0]], + device=zero_pos.device))) + seq_lens = end_pos - zero_pos + # Replace position ids because in RoBERTa models # they have to start at padding_idx + 1 and ignore # existing padding tokens # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + token_list = torch.split(input_ids, seq_lens.tolist()) + pos_list = [] - token_list = [] - offset = 0 - for seq_len in seq_lens: - pos_list.append(position_ids[offset:offset + seq_len]) - token_list.append(input_ids[offset:offset + seq_len]) - offset += seq_len - - new_pos_list = [] - for positions, tokens in zip(pos_list, token_list): - # Verify assumption that incoming position are - # always a sequence from 0 to N. - expected_pos = torch.arange(positions.size()[0], - dtype=torch.long, - device=inputs_embeds.device) - assert torch.equal(positions, expected_pos) - new_pos_list.append( + for tokens in token_list: + pos_list.append( create_position_ids_from_input_ids(tokens, self.padding_idx)) - position_ids = torch.cat(new_pos_list) + + corrected_positions = torch.cat(pos_list) # Position embeddings. - position_embeddings = self.position_embeddings(position_ids) + position_embeddings = self.position_embeddings(corrected_positions) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, @@ -150,8 +145,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): assert len(loaded), "Unable to load RobertaEmbeddingModel" -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsV0Only): +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4ad7178374b..117c4631b91 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -386,11 +386,13 @@ def __init__( f"Supported head sizes are: {support_head_sizes}. " "Set VLLM_USE_V1=0 to use another attention backend.") - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " + if attn_type not in [ + AttentionType.DECODER, AttentionType.ENCODER_ONLY + ]: + raise NotImplementedError("Encoder/decoder cross-attention " + "is not implemented for " "FlashAttentionImpl") + self.attn_type = attn_type self.use_irope = use_irope self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -509,7 +511,7 @@ def forward( seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, - causal=True, + causal=_get_causal_option(self.attn_type), alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, @@ -711,3 +713,21 @@ def cascade_attention( # Merge prefix and suffix outputs, and store the result in output. merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) + + +def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9489bcf433f..c52b1e5afb6 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -915,6 +915,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: dtype=spec.dtype, use_mla=spec.use_mla, sliding_window=spec.sliding_window, + attn_type=str(spec.attn_type), ) if is_hybrid(kv_cache_spec): diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 6f31031a108..2dd43a6b795 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -24,6 +24,7 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_hashes: list[str] mm_positions: list[PlaceholderRange] @@ -42,6 +43,7 @@ def from_request( return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, + token_type_ids=request.token_type_ids, mm_inputs=request.mm_inputs, mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 921ccd708cd..dcad52959b3 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -49,6 +49,7 @@ class EngineCoreRequest( request_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index da65550354d..e384dadffcb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -35,7 +35,7 @@ EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.executor.abstract import Executor -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -150,6 +150,23 @@ def _initialize_kv_caches( zip(kv_cache_specs, available_gpu_memory) ] + for kv_cache_spec_one_worker in kv_cache_specs: + for _, spec in kv_cache_spec_one_worker.items(): + if isinstance(spec, AttentionSpec) and \ + spec.attn_type != "decoder": + + logger.info("Found non-decoder layer. Disabling " + "prefix cache and chunked prefill") + self.vllm_config.cache_config.\ + enable_prefix_caching = False + self.vllm_config.scheduler_config.\ + enable_chunked_prefill = False + self.vllm_config.scheduler_config.\ + chunked_prefill_enabled = False + self.vllm_config.scheduler_config.\ + long_prefill_token_threshold = 0 + break + # Since we use a shared centralized controller, we need the # `kv_cache_config` to be consistent across all workers to make sure # all the memory operators can be applied to all workers. diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index a0b170ba55a..2ae2de6f4be 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -329,6 +329,7 @@ def process_inputs( return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], + token_type_ids=decoder_inputs.get("token_type_ids"), mm_inputs=sorted_mm_inputs, mm_hashes=sorted_mm_hashes, mm_placeholders=sorted_mm_positions, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index c48775adc9b..d8af8486ca5 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -75,6 +75,7 @@ class AttentionSpec(KVCacheSpec): head_size: int dtype: torch.dtype use_mla: bool + attn_type: str @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 9b96f4599f9..d22220e3e18 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -24,6 +24,7 @@ def __init__( self, request_id: str, prompt_token_ids: list[int], + token_type_ids: Optional[list[int]], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], multi_modal_placeholders: Optional[list[PlaceholderRange]], @@ -74,6 +75,7 @@ def __init__( "sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids + self.token_type_ids = token_type_ids self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: list[int] = [] self._all_token_ids: list[int] = self.prompt_token_ids.copy() @@ -118,6 +120,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, + token_type_ids=request.token_type_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 3a2c9ef7dfa..9514439f142 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Datastructures defining a GPU input batch +# Datastructures defining an input batch from dataclasses import dataclass from typing import Optional, cast @@ -27,6 +27,7 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] @@ -89,6 +90,8 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.token_type_ids_cpu_tensor = None + self._token_type_ids_cpu = None self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) @@ -231,6 +234,22 @@ def __init__( self.pooling_params: dict[str, PoolingParams] = {} + @property + def token_type_ids_cpu(self) -> np.ndarray: + if self._token_type_ids_cpu is None: + self.token_type_ids_cpu_tensor = torch.zeros( + self.token_ids_cpu_tensor.shape, + device="cpu", + dtype=torch.int8, + pin_memory=False, + ) + self._token_type_ids_cpu = cast( + torch.Tensor, self.token_type_ids_cpu_tensor).numpy() + return self._token_type_ids_cpu + + def has_token_types(self) -> bool: + return self._token_type_ids_cpu is not None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -261,6 +280,9 @@ def add_request( self.num_prompt_tokens[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids + if request.token_type_ids is not None: + self.token_type_ids_cpu[ + req_index, :num_prompt_tokens] = request.token_type_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, @@ -447,6 +469,10 @@ def swap_states(self, i1: int, i2: int) -> None: tmp = self.token_ids_cpu[i1, ...].copy() self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + if self.has_token_types(): + tmp2 = self.token_type_ids_cpu[i1, ...].copy() + self.token_type_ids_cpu[i1, ...] = self.token_type_ids_cpu[i2, ...] + self.token_type_ids_cpu[i2, ...] = tmp2 swap_dict_values(self.generators, i1, i2) swap_dict_values(self.min_tokens, i1, i2) @@ -503,6 +529,9 @@ def condense(self, empty_req_indices: list[int]) -> None: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] + if self.has_token_types(): + self.token_type_ids_cpu[empty_index, :num_tokens] = \ + self.token_type_ids_cpu[last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 33036600611..bf68f8fde05 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6,7 +6,7 @@ import time import weakref from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np import torch @@ -237,7 +237,7 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) - + self.token_type_ids = None # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -305,6 +305,13 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + def get_token_type_ids(self) -> Optional[torch.Tensor]: + if self.token_type_ids is None: + self.token_type_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int8, + device=self.device) + return self.token_type_ids + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: """ Update the order of requests in the batch based on the attention @@ -408,6 +415,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -627,6 +635,13 @@ def _prepare_inputs( 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) + if self.input_batch.token_type_ids_cpu_tensor is not None: + token_type_ids = torch.index_select( + self.input_batch.token_type_ids_cpu_tensor.flatten(), 0, + torch.from_numpy(token_indices)) + # Copy the tensors to the GPU. + self.get_token_type_ids()[:total_num_scheduled_tokens]\ + .copy_(token_type_ids, non_blocking=True) # Calculate the slot mapping for each KV cache group. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1299,11 +1314,17 @@ def execute_model( else: mm_embeds = [] + has_token_types = self.token_type_ids is not None + model_kwargs = {} + if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] + if has_token_types: + model_kwargs["token_type_ids"] = cast( + torch.Tensor, self.token_type_ids)[:num_scheduled_tokens] if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1319,6 +1340,9 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] + if has_token_types: + model_kwargs["token_type_ids"] = cast( + torch.Tensor, self.token_type_ids)[:num_input_tokens] inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -1352,6 +1376,7 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **model_kwargs, ) self.maybe_wait_for_kv_save() @@ -1907,7 +1932,6 @@ def _dummy_run( attn_metadata: Optional[dict[str, Any]] = None if capture_attn_cudagraph: - attn_metadata = {} query_start_loc = self.query_start_loc[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. @@ -1925,6 +1949,7 @@ def _dummy_run( max_query_len=num_tokens, ) + attn_metadata = {} for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -1972,6 +1997,7 @@ def _dummy_run( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: @@ -2482,7 +2508,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: continue # TODO: Support other attention modules, e.g., cross-attention - if attn_module.attn_type == AttentionType.DECODER: + # encoder only can also benefit from KV cache for prefix caching + if attn_module.attn_type in (AttentionType.DECODER, + AttentionType.ENCODER_ONLY): if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -2490,17 +2518,18 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=use_mla) + use_mla=use_mla, + attn_type=str(attn_module.attn_type)) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. + use_mla=use_mla, + attn_type=str(attn_module.attn_type)) + elif attn_module.attn_type == AttentionType.ENCODER: + # encoder attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 774caa1a3d9..7e7b80c0dd0 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -394,6 +394,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -494,6 +495,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=False, + attn_type=str(attn_module.attn_type), ) else: kv_cache_spec[layer_name] = FullAttentionSpec( @@ -502,6 +504,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=False, + attn_type=str(attn_module.attn_type), ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): From 7eb9d28c98502db9f43728b251d9386000cc5fbf Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 14:12:58 -0300 Subject: [PATCH 02/14] Fix CUDA graphs for BERT models The @supports_torch_compile decorator was applied only to the encoder stack leaving the embeddings out, which resulted in a dependency on dynamically allocated tensors. Signed-off-by: Max de Bayser --- vllm/config.py | 5 -- vllm/model_executor/models/bert.py | 2 +- vllm/model_executor/models/roberta.py | 86 ++++++++++++++++++++------- 3 files changed, 65 insertions(+), 28 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d112abd7156..7549c97b4fe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -716,11 +716,6 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: self.override_pooler_config = PoolerConfig( **self.override_pooler_config) - # WIP: currently cuda graphs are not working for encoder models. - logger.warning("CUDA graph is not supported for pooling yet, " - "fallback to the eager mode.") - self.enforce_eager = True - pooler_config = self.override_pooler_config or PoolerConfig() base_config = get_pooling_config(self.model, self.revision) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index cb3ac791ccb..d6bb1584118 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -95,7 +95,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output -@support_torch_compile class BertEncoder(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): @@ -313,6 +312,7 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states +@support_torch_compile class BertModel(nn.Module, SupportsQuant): packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 87251236b37..869df157aab 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -10,6 +10,7 @@ from transformers import RobertaConfig from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -59,29 +60,8 @@ def forward( input_shape = input_ids.size() inputs_embeds = self.word_embeddings(input_ids) - zero_pos = torch.where(position_ids == 0)[0] - end_pos = torch.cat((zero_pos[1:], - torch.tensor([position_ids.shape[0]], - device=zero_pos.device))) - seq_lens = end_pos - zero_pos - - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - token_list = torch.split(input_ids, seq_lens.tolist()) - - pos_list = [] - for tokens in token_list: - pos_list.append( - create_position_ids_from_input_ids(tokens, self.padding_idx)) - - corrected_positions = torch.cat(pos_list) - # Position embeddings. - position_embeddings = self.position_embeddings(corrected_positions) + position_embeddings = self.position_embeddings(position_ids) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, @@ -121,6 +101,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel): _pooler: An instance of Pooler used for pooling operations. """ + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # Fix Roberta positions here outside of the CUDA graph. + # Because we need the to extract the sequences from + # input_ids the control flow is data dependent. + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) + + return self.model(input_ids=input_ids, + position_ids=positions, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> Union[BertModel, BertWithRope]: @@ -171,6 +177,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id self.default_activation_function = \ get_cross_encoder_activation_function(config) @@ -215,6 +222,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) return self.roberta(input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds, @@ -246,6 +256,38 @@ def create_position_ids_from_input_ids(input_ids, return incremental_indices.long() + padding_idx +def replace_roberta_positions(input_ids: torch.Tensor, + position_ids: torch.Tensor, + padding_idx: int) -> None: + + seq_lens: Optional[torch.Tensor] = None + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: # can be None during warmup + if isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values())) + # TODO: remove "seq_lens_tensor" after V0 is removed + seq_lens = getattr(attn_metadata, "seq_lens_tensor", + getattr(attn_metadata, "seq_lens", None)) + + if seq_lens is not None: + assert isinstance(seq_lens, torch.Tensor) + + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + token_list = torch.split(input_ids, seq_lens.tolist()) + + offset = 0 + for tokens in token_list: + length = tokens.shape[0] + position_ids[offset:offset+length] = \ + create_position_ids_from_input_ids(tokens, padding_idx) + offset = offset + length + + def roberta_task_weights_filter( all_weights: Iterable[tuple[str, torch.Tensor]] ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, From d3099a93e19c75e8e92159ee42c9c3b5e1809b14 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 1 Jul 2025 14:46:45 -0300 Subject: [PATCH 03/14] Fix cuda graph initialization of token type ids Signed-off-by: Max de Bayser --- vllm/engine/arg_utils.py | 2 +- vllm/model_executor/models/roberta.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 54 +++++++++++++++++++++------ 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 120d7ff9679..e53ee2e8e76 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1683,7 +1683,7 @@ def _set_default_args_v1(self, usage_context: UsageContext, if (self.max_num_seqs is None and usage_context in default_max_num_seqs): self.max_num_seqs = min(default_max_num_seqs[usage_context], - self.max_num_batched_tokens) + self.max_num_batched_tokens or sys.maxsize) logger.debug("Setting max_num_seqs to %d for %s usage context.", self.max_num_seqs, use_context_value) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 869df157aab..9674c618323 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -278,7 +278,8 @@ def replace_roberta_positions(input_ids: torch.Tensor, # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - token_list = torch.split(input_ids, seq_lens.tolist()) + token_list = torch.split(input_ids[:torch.sum(seq_lens)], + seq_lens.tolist()) offset = 0 for tokens in token_list: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c9e3bb8171e..397fc201162 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,10 +3,11 @@ import copy import gc +import inspect import time import weakref from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import torch @@ -42,6 +43,8 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, get_dtype_size, @@ -247,7 +250,8 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) - self.token_type_ids = None + self.token_type_ids: Optional[torch.Tensor] = None + self.supports_token_type_ids: bool = False # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -316,13 +320,19 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} - def get_token_type_ids(self) -> Optional[torch.Tensor]: + def get_token_type_ids(self) -> torch.Tensor: if self.token_type_ids is None: self.token_type_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int8, + dtype=torch.int32, device=self.device) return self.token_type_ids + def _maybe_add_model_args(self, num_tokens: int, model_kwargs: dict[str, + Any]): + if self.supports_token_type_ids: + model_kwargs["token_type_ids"] =\ + self.get_token_type_ids()[:num_tokens] + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: """ Update the order of requests in the batch based on the attention @@ -1340,17 +1350,14 @@ def execute_model( else: mm_embeds = [] - has_token_types = self.token_type_ids is not None - model_kwargs = {} + model_kwargs: dict[str, Any] = {} if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] - if has_token_types: - model_kwargs["token_type_ids"] = cast( - torch.Tensor, self.token_type_ids)[:num_scheduled_tokens] + self._maybe_add_model_args(num_scheduled_tokens, model_kwargs) if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1366,9 +1373,7 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] - if has_token_types: - model_kwargs["token_type_ids"] = cast( - torch.Tensor, self.token_type_ids)[:num_input_tokens] + self._maybe_add_model_args(num_input_tokens, model_kwargs) inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -1772,6 +1777,14 @@ def propose_ngram_draft_token_ids( draft_token_ids.append(drafter_output.tolist()) return draft_token_ids + def _get_tokenizer(self) -> AnyTokenizer: + tokenizer_group = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + lora_config=self.lora_config) + + return tokenizer_group.get_lora_tokenizer() + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 @@ -1819,6 +1832,20 @@ def load_model(self) -> None: self.parallel_config, ) + model_supports_token_type_ids = 'token_type_ids' in \ + inspect.getfullargspec(self.model.forward).args + + tokenizer = self._get_tokenizer() + if not isinstance(tokenizer, MistralTokenizer): + tok_output = tokenizer(text="foo") + if "token_type_ids" in tok_output: + assert model_supports_token_type_ids + self.supports_token_type_ids = True + + if self.supports_token_type_ids: + # pre-allocate tensor + self.get_token_type_ids() + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", @@ -2031,6 +2058,8 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model + model_kwargs: dict[str, Any] = {} + self._maybe_add_model_args(num_tokens, model_kwargs) if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -2065,6 +2094,7 @@ def _dummy_run( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **model_kwargs, ) if self.use_aux_hidden_state_outputs: From b4f5eade87035210d8e8ab5d669e807311a10498 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 9 Jul 2025 15:36:01 -0300 Subject: [PATCH 04/14] Fix missing args Signed-off-by: Max de Bayser --- tests/tokenization/test_detokenize.py | 21 +++++++++++---------- tests/v1/core/test_kv_cache_utils.py | 7 +++++++ tests/v1/core/test_prefix_caching.py | 18 ++++++++++++++++-- tests/v1/core/test_scheduler.py | 5 +++-- tests/v1/core/test_specialized_manager.py | 4 ++++ tests/v1/engine/test_engine_core_client.py | 4 +++- tests/v1/kv_connector/unit/utils.py | 3 ++- tests/v1/worker/test_gpu_model_runner.py | 3 ++- 8 files changed, 48 insertions(+), 17 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index f8aeba8301b..4a29d736245 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -61,16 +61,17 @@ def _run_incremental_decode(tokenizer, skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) - request = EngineCoreRequest("", - prompt_token_ids, - None, - None, - None, - params, - None, - None, - 0.0, - None, + request = EngineCoreRequest(request_id="", + prompt_token_ids=prompt_token_ids, + token_type_ids=None, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, cache_salt=None, data_parallel_rank=None) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 97fff745647..4eeb9b79247 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -5,6 +5,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -61,6 +62,7 @@ def new_kv_cache_spec(block_size=16, head_size=head_size, dtype=dtype, use_mla=use_mla, + attn_type=AttentionType.DECODER, sliding_window=sliding_window) @@ -75,6 +77,7 @@ def new_sliding_window_spec(block_size=16, head_size=head_size, dtype=dtype, use_mla=use_mla, + attn_type=AttentionType.DECODER, sliding_window=sliding_window) @@ -534,6 +537,7 @@ def test_merge_kv_cache_spec(): head_size=full_spec.head_size, dtype=full_spec.dtype, use_mla=full_spec.use_mla, + attn_type=AttentionType.DECODER, sliding_window=1, ), ] @@ -603,6 +607,7 @@ def test_estimate_max_model_len(model_id, max_model_len, head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, ) # Estimate the maximum model length, 16384 model_len need 8GB estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, @@ -638,6 +643,7 @@ def test_get_max_concurrency_for_kv_cache_config(): head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, ) sliding_window_spec = SlidingWindowSpec( @@ -646,6 +652,7 @@ def test_get_max_concurrency_for_kv_cache_config(): head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, sliding_window=1024, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 7a42778831c..d518d76628f 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,6 +8,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -53,7 +54,12 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: kv_cache_groups=[ KVCacheGroupSpec( ["layer"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, + 1, + 1, + torch.float32, + False, + attn_type=AttentionType.DECODER), ) ], ) @@ -67,7 +73,12 @@ def make_kv_cache_config_hybrid_model(block_size: int, kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, + 1, + 1, + torch.float32, + False, + attn_type=AttentionType.DECODER), ), KVCacheGroupSpec( ["layer2"], @@ -76,6 +87,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, torch.float32, False, + attn_type=AttentionType.DECODER, sliding_window=2 * block_size), ), KVCacheGroupSpec( @@ -85,6 +97,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, torch.float32, False, + attn_type=AttentionType.DECODER, sliding_window=2 * block_size), ), ], @@ -1218,6 +1231,7 @@ def test_eagle_with_sliding_window(): dtype=torch.float32, sliding_window=block_size, use_mla=False, + attn_type=AttentionType.DECODER, ) manager = KVCacheManager( KVCacheConfig( diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 02d2c83ab15..88c574a7518 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -6,6 +6,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -104,7 +105,7 @@ def create_scheduler( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) cache_config.num_gpu_blocks = num_blocks @@ -1354,7 +1355,7 @@ def create_scheduler_with_priority( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) cache_config.num_gpu_blocks = num_blocks diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index a9e1898df93..ec3e725ae45 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -3,6 +3,7 @@ import torch +from vllm.attention import AttentionType from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, KVCacheBlock) @@ -26,6 +27,7 @@ def test_sliding_window_possible_cached_prefix(): dtype=torch.float32, sliding_window=4, use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -92,6 +94,7 @@ def test_sliding_window_remove_skipped_blocks(): dtype=torch.float32, sliding_window=4, use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -160,6 +163,7 @@ def test_get_num_blocks_to_allocate(): dtype=torch.float32, sliding_window=4, # Placeholder value, not related to test result use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 65f1da803fb..7718827d5f3 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -16,6 +16,7 @@ from tests.utils import multi_gpu_test from vllm import SamplingParams +from vllm.attention import AttentionType from vllm.distributed.kv_events import (BlockStored, KVEventBatch, ZmqEventPublisher) from vllm.engine.arg_utils import EngineArgs @@ -544,7 +545,8 @@ def create_mock_executor(vllm_config): num_kv_heads=1, head_size=64, dtype=torch.float16, - use_mla=False) + use_mla=False, + attn_type=AttentionType.DECODER) mock_executor.get_kv_cache_specs.return_value = [{ "default": mock_spec diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 983d900606f..b89e62efa51 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -5,6 +5,7 @@ import torch from vllm import SamplingParams +from vllm.attention import AttentionType from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) from vllm.v1.core.sched.scheduler import Scheduler @@ -99,7 +100,7 @@ def create_scheduler( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) vllm_config.cache_config.num_gpu_blocks = num_blocks diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index d13df553db6..99696c9887b 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -6,7 +6,7 @@ import pytest import torch -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) from vllm.platforms import current_platform @@ -38,6 +38,7 @@ def initialize_kv_cache(runner: GPUModelRunner): head_size=runner.model_config.get_head_size(), dtype=runner.kv_cache_dtype, use_mla=False, + attn_type=AttentionType.DECODER, ) tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( From c4060d1b7e2f64a34326e9337f4d3903fe57698b Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 9 Jul 2025 15:42:50 -0300 Subject: [PATCH 05/14] relax assertion Signed-off-by: Max de Bayser --- vllm/v1/worker/gpu_model_runner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 947ff3eb3de..7f35ed3a372 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1858,8 +1858,12 @@ def load_model(self) -> None: if not isinstance(tokenizer, MistralTokenizer): tok_output = tokenizer(text="foo") if "token_type_ids" in tok_output: - assert model_supports_token_type_ids - self.supports_token_type_ids = True + if not model_supports_token_type_ids: + logger.warning("Tokenizer returns token_type_ids but " + "but model forward() doesn't support that " + "argument") + else: + self.supports_token_type_ids = True if self.supports_token_type_ids: # pre-allocate tensor From 80930d86c521b5fd3cc52d4b3f10a2b70d5501e8 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 9 Jul 2025 16:42:54 -0300 Subject: [PATCH 06/14] fix missing arg Signed-off-by: Max de Bayser --- tests/v1/core/test_kv_cache_utils.py | 1 + tests/v1/core/test_prefix_caching.py | 1 + tests/v1/core/test_scheduler.py | 3 +++ tests/v1/kv_connector/unit/utils.py | 1 + 4 files changed, 6 insertions(+) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 4eeb9b79247..cb3fe0ca842 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -40,6 +40,7 @@ def make_request(request_id, return Request( request_id=request_id, prompt_token_ids=prompt_token_ids, + token_type_ids=None, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions, diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index d518d76628f..6e557b2e2ab 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -35,6 +35,7 @@ def make_request(request_id, return Request( request_id=request_id, prompt_token_ids=prompt_token_ids, + token_type_ids=None, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 88c574a7518..6bf959d27ca 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -138,6 +138,7 @@ def create_requests(num_requests: int, request = Request( request_id=f"{i}", prompt_token_ids=[i] * num_tokens, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=mm_inputs, @@ -1398,6 +1399,7 @@ def create_requests_with_priority( request = Request( request_id=f"{i}", prompt_token_ids=[i] * num_tokens, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=mm_inputs, @@ -1884,6 +1886,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): request = Request( request_id="0", prompt_token_ids=[0, 1], + token_type_ids=None, multi_modal_inputs=None, multi_modal_hashes=None, multi_modal_placeholders=None, diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index b89e62efa51..2615f499e67 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -149,6 +149,7 @@ def create_request( req = Request( request_id=f"id-{request_id}", prompt_token_ids=prompt_token_ids, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=None, From d881f0acd70bce7cbb8dde8424bc86a4ab3b66ef Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 9 Jul 2025 21:56:27 -0300 Subject: [PATCH 07/14] fix missing arg Signed-off-by: Max de Bayser --- tests/v1/engine/test_engine_core.py | 1 + tests/v1/engine/test_engine_core_client.py | 1 + tests/v1/engine/test_fast_incdec_prefix_err.py | 1 + tests/v1/engine/test_output_processor.py | 5 +++++ 4 files changed, 8 insertions(+) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index bbdc73e9608..f5ddbeeb4fd 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -35,6 +35,7 @@ def make_request() -> EngineCoreRequest: return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=PROMPT_TOKENS, + token_type_ids=None, mm_inputs=None, mm_hashes=None, mm_placeholders=None, diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 7718827d5f3..eb477c0e64f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -52,6 +52,7 @@ def make_request( return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=prompt_tokens_ids, + token_type_ids=None, mm_inputs=None, mm_hashes=None, mm_placeholders=None, diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index f028b4ab1d7..61a4126ff60 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -31,6 +31,7 @@ def test_fast_inc_detok_invalid_utf8_err_case(): None, None, None, + None, params, None, None, diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 949ab764e2e..c59439ed9c0 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -52,6 +52,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -401,6 +402,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -566,6 +568,7 @@ def test_stop_token(include_stop_str_in_output: bool, request = EngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -665,6 +668,7 @@ def test_stop_string(include_stop_str_in_output: bool, EngineCoreRequest( request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -781,6 +785,7 @@ def test_iteration_stats(dummy_test_vectors): EngineCoreRequest( request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, From 90a25d0b8dc1ea14216b3b60cfb1e21179f59c4f Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 9 Jul 2025 22:00:24 -0300 Subject: [PATCH 08/14] remove model from unsupported list Signed-off-by: Max de Bayser --- tests/v1/test_oracle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 7a7ba346a71..55e30ab1ee5 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -13,7 +13,6 @@ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder "state-spaces/mamba-130m-hf", # mamba1 - "BAAI/bge-m3", # embedding ] MODEL = "meta-llama/Llama-3.2-1B-Instruct" From 6686550111108d6ab73449ed94ac4abf64cf8348 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Thu, 10 Jul 2025 07:34:32 -0300 Subject: [PATCH 09/14] fix missing arg Signed-off-by: Max de Bayser --- tests/v1/worker/test_gpu_input_batch.py | 1 + tests/v1/worker/test_gpu_model_runner.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 943a13debad..abf739abc50 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -201,6 +201,7 @@ def _construct_cached_request_state(req_id_suffix: int): return CachedRequestState( req_id=f"req_id_{req_id_suffix}", prompt_token_ids=prompt_token_ids, + token_type_ids=None, sampling_params=_create_sampling_params(), pooling_params=None, mm_inputs=[], diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 99696c9887b..5b6e286af5b 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -121,6 +121,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], + token_type_ids=None, mm_inputs=[], mm_hashes=[], mm_positions=[], From 136c9b38188407c3ee38e9133d65e4b88a2310f9 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Thu, 10 Jul 2025 10:01:06 -0300 Subject: [PATCH 10/14] fix tests Signed-off-by: Max de Bayser --- tests/v1/worker/test_gpu_input_batch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index abf739abc50..fb8ad382ead 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -194,6 +194,9 @@ def _construct_cached_request_state(req_id_suffix: int): np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(0, MAX_PROMPT_SIZE)) ] + token_type_ids = [ + np.random.randint(0, 2) for _ in range(len(prompt_token_ids)) + ] output_token_ids = [ np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) @@ -201,7 +204,7 @@ def _construct_cached_request_state(req_id_suffix: int): return CachedRequestState( req_id=f"req_id_{req_id_suffix}", prompt_token_ids=prompt_token_ids, - token_type_ids=None, + token_type_ids=token_type_ids, sampling_params=_create_sampling_params(), pooling_params=None, mm_inputs=[], From e19c7381b1fd8039ac0729341343dc3310d7ec7f Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 16 Jul 2025 10:02:54 -0300 Subject: [PATCH 11/14] fix tests Signed-off-by: Max de Bayser --- tests/entrypoints/openai/test_rerank.py | 2 +- tests/v1/core/utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 4da97fe1369..912313ce133 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -124,4 +124,4 @@ def test_invocations(server: RemoteOpenAIServer): invocation_output["results"]): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( - invocations_result["relevance_score"], rel=0.01) + invocations_result["relevance_score"], rel=0.05) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 0b7d8251b64..c7ab5300989 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -4,6 +4,7 @@ import torch +from vllm.attention import AttentionType from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -102,7 +103,7 @@ def create_scheduler( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) cache_config.num_gpu_blocks = num_blocks From e255f302e430384dcfb02debd43ddd6178abf23b Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 16 Jul 2025 10:19:06 -0300 Subject: [PATCH 12/14] fix tests Signed-off-by: Max de Bayser --- tests/v1/core/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index c7ab5300989..5b3095717b1 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -142,6 +142,7 @@ def create_requests( request = Request( request_id=f"{i}", prompt_token_ids=prompt_token_ids, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=mm_inputs, From ee5950c96a075388a8074b30a01b64d0e3d94084 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 16 Jul 2025 12:28:51 -0300 Subject: [PATCH 13/14] add missing arg Signed-off-by: Max de Bayser --- vllm/model_executor/models/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index cb07fe7d9e1..31a37c94fea 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import vllm.envs as envs +from vllm.attention import AttentionType from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv @@ -251,7 +252,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, - use_mla=model_config.use_mla).page_size_bytes + use_mla=model_config.use_mla, + attn_type=AttentionType.DECODER).page_size_bytes model_cls = ModelRegistry.resolve_model_cls( model_config._model_info.architecture)[0] From a5cfc846725bc4b16d925897ef81a61610c491fa Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 16 Jul 2025 17:17:07 -0300 Subject: [PATCH 14/14] add missing arg Signed-off-by: Max de Bayser --- tests/v1/tpu/worker/test_tpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 40db0b2afe0..46d4dea894a 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -68,6 +68,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], + token_type_ids=None, mm_inputs=[], mm_hashes=[], mm_positions=[],