diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 4da97fe1369..912313ce133 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -124,4 +124,4 @@ def test_invocations(server: RemoteOpenAIServer): invocation_output["results"]): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( - invocations_result["relevance_score"], rel=0.01) + invocations_result["relevance_score"], rel=0.05) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index cc9e4102d5b..ba42e389fc1 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -39,17 +39,9 @@ def v1(run_with_both_engines): pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), # [Encoder-only] - pytest.param( - "BAAI/bge-base-en-v1.5", - marks=[ - # CPU only supports V1 - pytest.mark.core_model, - pytest.mark.skip_v1 - ]), - pytest.param("sentence-transformers/all-MiniLM-L12-v2", - marks=[pytest.mark.skip_v1]), - pytest.param("intfloat/multilingual-e5-small", - marks=[pytest.mark.skip_v1]), + pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), + pytest.param("sentence-transformers/all-MiniLM-L12-v2"), + pytest.param("intfloat/multilingual-e5-small"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", marks=[pytest.mark.skip_v1]), # [Cross-Encoder] diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 9bfe7411e16..ca3dc45c32b 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -23,6 +23,14 @@ ] +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index c75ff144561..1cf2cdc0132 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -23,6 +23,15 @@ "The capital of Germany is Berlin.", ] + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + DTYPE = "half" diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index f8aeba8301b..4a29d736245 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -61,16 +61,17 @@ def _run_incremental_decode(tokenizer, skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) - request = EngineCoreRequest("", - prompt_token_ids, - None, - None, - None, - params, - None, - None, - 0.0, - None, + request = EngineCoreRequest(request_id="", + prompt_token_ids=prompt_token_ids, + token_type_ids=None, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, cache_salt=None, data_parallel_rank=None) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 0676cb3eb65..5fc74460808 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -5,6 +5,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -40,6 +41,7 @@ def make_request(request_id, return Request( request_id=request_id, prompt_token_ids=prompt_token_ids, + token_type_ids=None, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions, @@ -62,6 +64,7 @@ def new_kv_cache_spec(block_size=16, head_size=head_size, dtype=dtype, use_mla=use_mla, + attn_type=AttentionType.DECODER, sliding_window=sliding_window) @@ -76,6 +79,7 @@ def new_sliding_window_spec(block_size=16, head_size=head_size, dtype=dtype, use_mla=use_mla, + attn_type=AttentionType.DECODER, sliding_window=sliding_window) @@ -544,6 +548,7 @@ def test_merge_kv_cache_spec(): head_size=full_spec.head_size, dtype=full_spec.dtype, use_mla=full_spec.use_mla, + attn_type=AttentionType.DECODER, sliding_window=1, ), ] @@ -613,6 +618,7 @@ def test_estimate_max_model_len(model_id, max_model_len, head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, ) # Estimate the maximum model length, 16384 model_len need 8GB estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, @@ -648,6 +654,7 @@ def test_get_max_concurrency_for_kv_cache_config(): head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, ) sliding_window_spec = SlidingWindowSpec( @@ -656,6 +663,7 @@ def test_get_max_concurrency_for_kv_cache_config(): head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, sliding_window=1024, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index f31bdf74f4a..e4dea8127e0 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,6 +8,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -35,6 +36,7 @@ def make_request(request_id, return Request( request_id=request_id, prompt_token_ids=prompt_token_ids, + token_type_ids=None, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions, @@ -54,7 +56,12 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: kv_cache_groups=[ KVCacheGroupSpec( ["layer"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, + 1, + 1, + torch.float32, + False, + attn_type=AttentionType.DECODER), ) ], ) @@ -68,7 +75,12 @@ def make_kv_cache_config_hybrid_model(block_size: int, kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, + 1, + 1, + torch.float32, + False, + attn_type=AttentionType.DECODER), ), KVCacheGroupSpec( ["layer2"], @@ -77,6 +89,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, torch.float32, False, + attn_type=AttentionType.DECODER, sliding_window=2 * block_size), ), KVCacheGroupSpec( @@ -86,6 +99,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, torch.float32, False, + attn_type=AttentionType.DECODER, sliding_window=2 * block_size), ), ], @@ -1222,6 +1236,7 @@ def test_eagle_with_sliding_window(): dtype=torch.float32, sliding_window=block_size, use_mla=False, + attn_type=AttentionType.DECODER, ) manager = KVCacheManager( KVCacheConfig( diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index a858a4d8c82..3ada907fe1b 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -6,6 +6,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -1290,7 +1291,7 @@ def create_scheduler_with_priority( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) cache_config.num_gpu_blocks = num_blocks @@ -1333,6 +1334,7 @@ def create_requests_with_priority( request = Request( request_id=f"{i}", prompt_token_ids=[i] * num_tokens, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=mm_inputs, @@ -1819,6 +1821,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): request = Request( request_id="0", prompt_token_ids=[0, 1], + token_type_ids=None, multi_modal_inputs=None, multi_modal_hashes=None, multi_modal_placeholders=None, diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index a9e1898df93..ec3e725ae45 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -3,6 +3,7 @@ import torch +from vllm.attention import AttentionType from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, KVCacheBlock) @@ -26,6 +27,7 @@ def test_sliding_window_possible_cached_prefix(): dtype=torch.float32, sliding_window=4, use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -92,6 +94,7 @@ def test_sliding_window_remove_skipped_blocks(): dtype=torch.float32, sliding_window=4, use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -160,6 +163,7 @@ def test_get_num_blocks_to_allocate(): dtype=torch.float32, sliding_window=4, # Placeholder value, not related to test result use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 0b7d8251b64..5b3095717b1 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -4,6 +4,7 @@ import torch +from vllm.attention import AttentionType from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -102,7 +103,7 @@ def create_scheduler( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) cache_config.num_gpu_blocks = num_blocks @@ -141,6 +142,7 @@ def create_requests( request = Request( request_id=f"{i}", prompt_token_ids=prompt_token_ids, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=mm_inputs, diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index bbdc73e9608..f5ddbeeb4fd 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -35,6 +35,7 @@ def make_request() -> EngineCoreRequest: return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=PROMPT_TOKENS, + token_type_ids=None, mm_inputs=None, mm_hashes=None, mm_placeholders=None, diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 65f1da803fb..eb477c0e64f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -16,6 +16,7 @@ from tests.utils import multi_gpu_test from vllm import SamplingParams +from vllm.attention import AttentionType from vllm.distributed.kv_events import (BlockStored, KVEventBatch, ZmqEventPublisher) from vllm.engine.arg_utils import EngineArgs @@ -51,6 +52,7 @@ def make_request( return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=prompt_tokens_ids, + token_type_ids=None, mm_inputs=None, mm_hashes=None, mm_placeholders=None, @@ -544,7 +546,8 @@ def create_mock_executor(vllm_config): num_kv_heads=1, head_size=64, dtype=torch.float16, - use_mla=False) + use_mla=False, + attn_type=AttentionType.DECODER) mock_executor.get_kv_cache_specs.return_value = [{ "default": mock_spec diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index f028b4ab1d7..61a4126ff60 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -31,6 +31,7 @@ def test_fast_inc_detok_invalid_utf8_err_case(): None, None, None, + None, params, None, None, diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 949ab764e2e..c59439ed9c0 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -52,6 +52,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -401,6 +402,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, requests = [ EngineCoreRequest(request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -566,6 +568,7 @@ def test_stop_token(include_stop_str_in_output: bool, request = EngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -665,6 +668,7 @@ def test_stop_string(include_stop_str_in_output: bool, EngineCoreRequest( request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, @@ -781,6 +785,7 @@ def test_iteration_stats(dummy_test_vectors): EngineCoreRequest( request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, + token_type_ids=None, arrival_time=0, mm_inputs=None, mm_hashes=None, diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py index e84b5e3095d..d2da53c6856 100644 --- a/tests/v1/entrypoints/openai/test_multi_api_servers.py +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import os -import re import openai # use the official client for correctness check import pytest import pytest_asyncio +import regex as re import requests from tests.utils import RemoteOpenAIServer diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index cf20d44fbaa..d944afb0ffd 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -7,6 +7,7 @@ import torch from vllm import SamplingParams +from vllm.attention import AttentionType from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) from vllm.distributed.kv_transfer.kv_connector.factory import ( @@ -106,7 +107,7 @@ def create_scheduler( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) vllm_config.cache_config.num_gpu_blocks = num_blocks @@ -155,6 +156,7 @@ def create_request( req = Request( request_id=f"id-{request_id}", prompt_token_ids=prompt_token_ids, + token_type_ids=None, sampling_params=sampling_params, pooling_params=None, multi_modal_inputs=None, diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 7a7ba346a71..55e30ab1ee5 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -13,7 +13,6 @@ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder "state-spaces/mamba-130m-hf", # mamba1 - "BAAI/bge-m3", # embedding ] MODEL = "meta-llama/Llama-3.2-1B-Instruct" diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 40db0b2afe0..46d4dea894a 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -68,6 +68,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], + token_type_ids=None, mm_inputs=[], mm_hashes=[], mm_positions=[], diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 943a13debad..fb8ad382ead 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -194,6 +194,9 @@ def _construct_cached_request_state(req_id_suffix: int): np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(0, MAX_PROMPT_SIZE)) ] + token_type_ids = [ + np.random.randint(0, 2) for _ in range(len(prompt_token_ids)) + ] output_token_ids = [ np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) @@ -201,6 +204,7 @@ def _construct_cached_request_state(req_id_suffix: int): return CachedRequestState( req_id=f"req_id_{req_id_suffix}", prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, sampling_params=_create_sampling_params(), pooling_params=None, mm_inputs=[], diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 0bdf1f9820d..9d7cb8607ea 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -6,7 +6,7 @@ import pytest import torch -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) from vllm.platforms import current_platform @@ -38,6 +38,7 @@ def initialize_kv_cache(runner: GPUModelRunner): head_size=runner.model_config.get_head_size(), dtype=runner.kv_cache_dtype, use_mla=False, + attn_type=AttentionType.DECODER, ) tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( @@ -120,6 +121,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], + token_type_ids=None, mm_inputs=[], mm_hashes=[], mm_positions=[], diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ae5eb46fa96..eda6714921f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1692,7 +1692,8 @@ def _set_default_args_v1(self, usage_context: UsageContext, if (self.max_num_seqs is None and usage_context in default_max_num_seqs): - self.max_num_seqs = default_max_num_seqs[usage_context] + self.max_num_seqs = min(default_max_num_seqs[usage_context], + self.max_num_batched_tokens or sys.maxsize) logger.debug("Setting max_num_seqs to %d for %s usage context.", self.max_num_seqs, use_context_value) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 65e6428f491..c91933ff3b5 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -12,7 +12,6 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -25,7 +24,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .interfaces import SupportsCrossEncoding, SupportsQuant from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -56,7 +55,6 @@ def __init__(self, config: BertConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -100,7 +98,6 @@ def forward( return pooled_output -@support_torch_compile class BertEncoder(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): @@ -318,6 +315,7 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states +@support_torch_compile class BertModel(nn.Module, SupportsQuant): packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} @@ -345,13 +343,9 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - attn_metadata = get_forward_context().attn_metadata - assert hasattr(attn_metadata, "seq_lens_tensor") - hidden_states = self.embeddings( - input_ids=input_ids, - seq_lens=attn_metadata.seq_lens_tensor, - position_ids=position_ids, - token_type_ids=token_type_ids) + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states) def load_weights(self, weights: Iterable[tuple[str, @@ -392,7 +386,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): +class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -414,11 +408,13 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, position_ids=positions, + token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) @@ -454,8 +450,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: softmax=False) -class BertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding, SupportsQuant): +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index cb07fe7d9e1..31a37c94fea 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import vllm.envs as envs +from vllm.attention import AttentionType from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv @@ -251,7 +252,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, - use_mla=model_config.use_mla).page_size_bytes + use_mla=model_config.use_mla, + attn_type=AttentionType.DECODER).page_size_bytes model_cls = ModelRegistry.resolve_model_cls( model_config._model_info.architecture)[0] diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 55ebb6e9e2a..e3aa556e141 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,6 +9,7 @@ from transformers import RobertaConfig from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -19,7 +20,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import SupportsCrossEncoding class RobertaEmbedding(nn.Module): @@ -51,39 +52,12 @@ def __init__(self, config: RobertaConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() inputs_embeds = self.word_embeddings(input_ids) - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - pos_list = [] - token_list = [] - offset = 0 - for seq_len in seq_lens: - pos_list.append(position_ids[offset:offset + seq_len]) - token_list.append(input_ids[offset:offset + seq_len]) - offset += seq_len - - new_pos_list = [] - for positions, tokens in zip(pos_list, token_list): - # Verify assumption that incoming position are - # always a sequence from 0 to N. - expected_pos = torch.arange(positions.size()[0], - dtype=torch.long, - device=inputs_embeds.device) - assert torch.equal(positions, expected_pos) - new_pos_list.append( - create_position_ids_from_input_ids(tokens, self.padding_idx)) - position_ids = torch.cat(new_pos_list) - # Position embeddings. position_embeddings = self.position_embeddings(position_ids) if token_type_ids is None: @@ -125,6 +99,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel): _pooler: An instance of Pooler used for pooling operations. """ + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # Fix Roberta positions here outside of the CUDA graph. + # Because we need the to extract the sequences from + # input_ids the control flow is data dependent. + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) + + return self.model(input_ids=input_ids, + position_ids=positions, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> Union[BertModel, BertWithRope]: @@ -153,8 +153,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): return loader.load_weights(weights_list, mapper=mapper) -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsV0Only): +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -180,6 +179,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, @@ -213,6 +213,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) return self.roberta(input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds, @@ -242,3 +245,36 @@ def create_position_ids_from_input_ids(input_ids, past_key_values_length) * mask return incremental_indices.long() + padding_idx + + +def replace_roberta_positions(input_ids: torch.Tensor, + position_ids: torch.Tensor, + padding_idx: int) -> None: + + seq_lens: Optional[torch.Tensor] = None + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: # can be None during warmup + if isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values())) + # TODO: remove "seq_lens_tensor" after V0 is removed + seq_lens = getattr(attn_metadata, "seq_lens_tensor", + getattr(attn_metadata, "seq_lens", None)) + + if seq_lens is not None: + assert isinstance(seq_lens, torch.Tensor) + + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + token_list = torch.split(input_ids[:torch.sum(seq_lens)], + seq_lens.tolist()) + + offset = 0 + for tokens in token_list: + length = tokens.shape[0] + position_ids[offset:offset+length] = \ + create_position_ids_from_input_ids(tokens, padding_idx) + offset = offset + length diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 552c2caf2fa..e18ce8ee218 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -433,11 +433,13 @@ def __init__( FlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " + if attn_type not in [ + AttentionType.DECODER, AttentionType.ENCODER_ONLY + ]: + raise NotImplementedError("Encoder/decoder cross-attention " + "is not implemented for " "FlashAttentionImpl") + self.attn_type = attn_type self.use_irope = use_irope self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -556,7 +558,7 @@ def forward( seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, - causal=True, + causal=_get_causal_option(self.attn_type), alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, @@ -759,3 +761,21 @@ def cascade_attention( # Merge prefix and suffix outputs, and store the result in output. merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) + + +def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6067a127e97..9b9e8f7dd7a 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -944,6 +944,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: dtype=spec.dtype, use_mla=spec.use_mla, sliding_window=spec.sliding_window, + attn_type=str(spec.attn_type), ) if is_hybrid(kv_cache_spec): diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index d34f3932780..e2ebef46522 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -24,6 +24,7 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_hashes: list[str] mm_positions: list[PlaceholderRange] @@ -42,6 +43,7 @@ def from_request( return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, + token_type_ids=request.token_type_ids, mm_inputs=request.mm_inputs, mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 921ccd708cd..dcad52959b3 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -49,6 +49,7 @@ class EngineCoreRequest( request_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f5c59bef478..c18f347ea57 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -36,7 +36,7 @@ from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -156,6 +156,23 @@ def _initialize_kv_caches( zip(kv_cache_specs, available_gpu_memory) ] + for kv_cache_spec_one_worker in kv_cache_specs: + for _, spec in kv_cache_spec_one_worker.items(): + if isinstance(spec, AttentionSpec) and \ + spec.attn_type != "decoder": + + logger.info("Found non-decoder layer. Disabling " + "prefix cache and chunked prefill") + self.vllm_config.cache_config.\ + enable_prefix_caching = False + self.vllm_config.scheduler_config.\ + enable_chunked_prefill = False + self.vllm_config.scheduler_config.\ + chunked_prefill_enabled = False + self.vllm_config.scheduler_config.\ + long_prefill_token_threshold = 0 + break + # Since we use a shared centralized controller, we need the # `kv_cache_config` to be consistent across all workers to make sure # all the memory operators can be applied to all workers. diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7af4ed54a22..c2da1b47866 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -345,6 +345,7 @@ def process_inputs( return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], + token_type_ids=decoder_inputs.get("token_type_ids"), mm_inputs=sorted_mm_inputs, mm_hashes=sorted_mm_hashes, mm_placeholders=sorted_mm_positions, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 43456a987de..fd97feff449 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -75,6 +75,7 @@ class AttentionSpec(KVCacheSpec): head_size: int dtype: torch.dtype use_mla: bool + attn_type: str @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 85f5dcb92eb..ce49f5054ef 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -24,6 +24,7 @@ def __init__( self, request_id: str, prompt_token_ids: list[int], + token_type_ids: Optional[list[int]], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], multi_modal_placeholders: Optional[list[PlaceholderRange]], @@ -74,6 +75,7 @@ def __init__( "sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids + self.token_type_ids = token_type_ids self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: list[int] = [] self._all_token_ids: list[int] = self.prompt_token_ids.copy() @@ -119,6 +121,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, + token_type_ids=request.token_type_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a..e8ef9ecdcd7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Datastructures defining a GPU input batch +# Datastructures defining an input batch from dataclasses import dataclass from typing import Optional, cast @@ -29,6 +29,7 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] + token_type_ids: Optional[list[int]] mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] @@ -96,6 +97,8 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.token_type_ids_cpu_tensor = None + self._token_type_ids_cpu = None self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) @@ -240,6 +243,22 @@ def __init__( self.pooling_params: dict[str, PoolingParams] = {} + @property + def token_type_ids_cpu(self) -> np.ndarray: + if self._token_type_ids_cpu is None: + self.token_type_ids_cpu_tensor = torch.zeros( + self.token_ids_cpu_tensor.shape, + device="cpu", + dtype=torch.int8, + pin_memory=False, + ) + self._token_type_ids_cpu = cast( + torch.Tensor, self.token_type_ids_cpu_tensor).numpy() + return self._token_type_ids_cpu + + def has_token_types(self) -> bool: + return self._token_type_ids_cpu is not None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -284,6 +303,9 @@ def add_request( self.num_prompt_tokens[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids + if request.token_type_ids is not None: + self.token_type_ids_cpu[ + req_index, :num_prompt_tokens] = request.token_type_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) self.token_ids_cpu[req_index, @@ -472,6 +494,10 @@ def swap_states(self, i1: int, i2: int) -> None: tmp = self.token_ids_cpu[i1, ...].copy() self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + if self.has_token_types(): + tmp2 = self.token_type_ids_cpu[i1, ...].copy() + self.token_type_ids_cpu[i1, ...] = self.token_type_ids_cpu[i2, ...] + self.token_type_ids_cpu[i2, ...] = tmp2 swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) @@ -542,6 +568,9 @@ def condense(self) -> None: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] + if self.has_token_types(): + self.token_type_ids_cpu[empty_index, :num_tokens] = \ + self.token_type_ids_cpu[last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af216539c90..d3eb2c25a7a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +import inspect import time import weakref from contextlib import contextmanager @@ -41,6 +42,8 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, async_tensor_h2d, check_use_alibi, get_dtype_size, @@ -247,7 +250,8 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) - + self.token_type_ids: Optional[torch.Tensor] = None + self.supports_token_type_ids: bool = False # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -316,6 +320,19 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + def get_token_type_ids(self) -> torch.Tensor: + if self.token_type_ids is None: + self.token_type_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + return self.token_type_ids + + def _maybe_add_model_args(self, num_tokens: int, model_kwargs: dict[str, + Any]): + if self.supports_token_type_ids: + model_kwargs["token_type_ids"] =\ + self.get_token_type_ids()[:num_tokens] + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -414,6 +431,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -636,6 +654,13 @@ def _prepare_inputs( 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) + if self.input_batch.token_type_ids_cpu_tensor is not None: + token_type_ids = torch.index_select( + self.input_batch.token_type_ids_cpu_tensor.flatten(), 0, + torch.from_numpy(token_indices)) + # Copy the tensors to the GPU. + self.get_token_type_ids()[:total_num_scheduled_tokens]\ + .copy_(token_type_ids, non_blocking=True) # Calculate the slot mapping for each KV cache group. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1319,11 +1344,14 @@ def execute_model( else: mm_embeds = [] + model_kwargs: dict[str, Any] = {} + if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] + self._maybe_add_model_args(num_scheduled_tokens, model_kwargs) if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1339,6 +1367,7 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] + self._maybe_add_model_args(num_input_tokens, model_kwargs) inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -1372,6 +1401,7 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **model_kwargs, ) self.maybe_wait_for_kv_save() @@ -1738,6 +1768,14 @@ def update_config(self, overrides: dict[str, Any]) -> None: new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) + def _get_tokenizer(self) -> AnyTokenizer: + tokenizer_group = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + lora_config=self.lora_config) + + return tokenizer_group.get_lora_tokenizer() + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 @@ -1785,6 +1823,24 @@ def load_model(self) -> None: self.parallel_config, ) + model_supports_token_type_ids = 'token_type_ids' in \ + inspect.getfullargspec(self.model.forward).args + + tokenizer = self._get_tokenizer() + if not isinstance(tokenizer, MistralTokenizer): + tok_output = tokenizer(text="foo") + if "token_type_ids" in tok_output: + if not model_supports_token_type_ids: + logger.warning("Tokenizer returns token_type_ids but " + "but model forward() doesn't support that " + "argument") + else: + self.supports_token_type_ids = True + + if self.supports_token_type_ids: + # pre-allocate tensor + self.get_token_type_ids() + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", @@ -1998,6 +2054,8 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model + model_kwargs: dict[str, Any] = {} + self._maybe_add_model_args(num_tokens, model_kwargs) if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -2032,6 +2090,7 @@ def _dummy_run( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **model_kwargs, ) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs @@ -2609,7 +2668,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: continue # TODO: Support other attention modules, e.g., cross-attention - if attn_module.attn_type == AttentionType.DECODER: + if attn_module.attn_type in (AttentionType.DECODER, + AttentionType.ENCODER_ONLY): if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -2617,17 +2677,18 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=use_mla) + use_mla=use_mla, + attn_type=str(attn_module.attn_type)) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. + use_mla=use_mla, + attn_type=str(attn_module.attn_type)) + elif attn_module.attn_type == AttentionType.ENCODER: + # encoder attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ad62d204381..d5a54c132ef 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -415,6 +415,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + token_type_ids=new_req_data.token_type_ids, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -517,6 +518,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=False, + attn_type=str(attn_module.attn_type), ) else: kv_cache_spec[layer_name] = FullAttentionSpec( @@ -525,6 +527,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=False, + attn_type=str(attn_module.attn_type), ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):