diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index cc9e4102d5b..cc07179f0cd 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -39,22 +39,13 @@ def v1(run_with_both_engines): pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), # [Encoder-only] - pytest.param( - "BAAI/bge-base-en-v1.5", - marks=[ - # CPU only supports V1 - pytest.mark.core_model, - pytest.mark.skip_v1 - ]), - pytest.param("sentence-transformers/all-MiniLM-L12-v2", - marks=[pytest.mark.skip_v1]), - pytest.param("intfloat/multilingual-e5-small", - marks=[pytest.mark.skip_v1]), + pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), + pytest.param("sentence-transformers/all-MiniLM-L12-v2"), + pytest.param("intfloat/multilingual-e5-small"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - marks=[pytest.mark.skip_v1]), + marks=[pytest.mark.skip_v0]), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2", - marks=[pytest.mark.skip_v1]), + pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) def test_models( diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 9bfe7411e16..ca3dc45c32b 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -23,6 +23,14 @@ ] +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index c75ff144561..1cf2cdc0132 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -23,6 +23,15 @@ "The capital of Germany is Berlin.", ] + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + DTYPE = "half" diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 01b2260abe8..66c03ff872c 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -51,9 +51,9 @@ def test_registry_imports(model_arch): ("LlamaForCausalLM", False, False, False), ("MllamaForConditionalGeneration", True, False, False), ("LlavaForConditionalGeneration", True, True, False), - ("BertForSequenceClassification", False, False, True), - ("RobertaForSequenceClassification", False, False, True), - ("XLMRobertaForSequenceClassification", False, False, True), + ("BertForSequenceClassification", True, False, True), + ("RobertaForSequenceClassification", True, False, True), + ("XLMRobertaForSequenceClassification", True, False, True), ]) def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): assert ModelRegistry.is_multimodal_model(model_arch) is is_mm diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e80ad8a6815..4eeb9b79247 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -5,6 +5,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -61,6 +62,7 @@ def new_kv_cache_spec(block_size=16, head_size=head_size, dtype=dtype, use_mla=use_mla, + attn_type=AttentionType.DECODER, sliding_window=sliding_window) @@ -75,6 +77,7 @@ def new_sliding_window_spec(block_size=16, head_size=head_size, dtype=dtype, use_mla=use_mla, + attn_type=AttentionType.DECODER, sliding_window=sliding_window) @@ -534,6 +537,7 @@ def test_merge_kv_cache_spec(): head_size=full_spec.head_size, dtype=full_spec.dtype, use_mla=full_spec.use_mla, + attn_type=AttentionType.DECODER, sliding_window=1, ), ] @@ -603,6 +607,7 @@ def test_estimate_max_model_len(model_id, max_model_len, head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, ) # Estimate the maximum model length, 16384 model_len need 8GB estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, @@ -638,6 +643,7 @@ def test_get_max_concurrency_for_kv_cache_config(): head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, ) sliding_window_spec = SlidingWindowSpec( @@ -646,6 +652,7 @@ def test_get_max_concurrency_for_kv_cache_config(): head_size=128, dtype=torch.float16, use_mla=False, + attn_type=AttentionType.DECODER, sliding_window=1024, ) @@ -916,4 +923,4 @@ def test_get_kv_cache_config(): ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) \ No newline at end of file + ]) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 7a42778831c..d518d76628f 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,6 +8,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -53,7 +54,12 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: kv_cache_groups=[ KVCacheGroupSpec( ["layer"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, + 1, + 1, + torch.float32, + False, + attn_type=AttentionType.DECODER), ) ], ) @@ -67,7 +73,12 @@ def make_kv_cache_config_hybrid_model(block_size: int, kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, + 1, + 1, + torch.float32, + False, + attn_type=AttentionType.DECODER), ), KVCacheGroupSpec( ["layer2"], @@ -76,6 +87,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, torch.float32, False, + attn_type=AttentionType.DECODER, sliding_window=2 * block_size), ), KVCacheGroupSpec( @@ -85,6 +97,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, torch.float32, False, + attn_type=AttentionType.DECODER, sliding_window=2 * block_size), ), ], @@ -1218,6 +1231,7 @@ def test_eagle_with_sliding_window(): dtype=torch.float32, sliding_window=block_size, use_mla=False, + attn_type=AttentionType.DECODER, ) manager = KVCacheManager( KVCacheConfig( diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 02d2c83ab15..88c574a7518 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -6,6 +6,7 @@ import pytest import torch +from vllm.attention import AttentionType from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -104,7 +105,7 @@ def create_scheduler( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) cache_config.num_gpu_blocks = num_blocks @@ -1354,7 +1355,7 @@ def create_scheduler_with_priority( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) cache_config.num_gpu_blocks = num_blocks diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index a9e1898df93..ec3e725ae45 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -3,6 +3,7 @@ import torch +from vllm.attention import AttentionType from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, KVCacheBlock) @@ -26,6 +27,7 @@ def test_sliding_window_possible_cached_prefix(): dtype=torch.float32, sliding_window=4, use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -92,6 +94,7 @@ def test_sliding_window_remove_skipped_blocks(): dtype=torch.float32, sliding_window=4, use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -160,6 +163,7 @@ def test_get_num_blocks_to_allocate(): dtype=torch.float32, sliding_window=4, # Placeholder value, not related to test result use_mla=False, + attn_type=AttentionType.DECODER, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 65f1da803fb..7718827d5f3 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -16,6 +16,7 @@ from tests.utils import multi_gpu_test from vllm import SamplingParams +from vllm.attention import AttentionType from vllm.distributed.kv_events import (BlockStored, KVEventBatch, ZmqEventPublisher) from vllm.engine.arg_utils import EngineArgs @@ -544,7 +545,8 @@ def create_mock_executor(vllm_config): num_kv_heads=1, head_size=64, dtype=torch.float16, - use_mla=False) + use_mla=False, + attn_type=AttentionType.DECODER) mock_executor.get_kv_cache_specs.return_value = [{ "default": mock_spec diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py index e84b5e3095d..d2da53c6856 100644 --- a/tests/v1/entrypoints/openai/test_multi_api_servers.py +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import os -import re import openai # use the official client for correctness check import pytest import pytest_asyncio +import regex as re import requests from tests.utils import RemoteOpenAIServer diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index cf20d44fbaa..e7206daa31f 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -7,6 +7,7 @@ import torch from vllm import SamplingParams +from vllm.attention import AttentionType from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) from vllm.distributed.kv_transfer.kv_connector.factory import ( @@ -106,7 +107,7 @@ def create_scheduler( kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + False, AttentionType.DECODER)) ], ) vllm_config.cache_config.num_gpu_blocks = num_blocks diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 7a7ba346a71..55e30ab1ee5 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -13,7 +13,6 @@ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder "state-spaces/mamba-130m-hf", # mamba1 - "BAAI/bge-m3", # embedding ] MODEL = "meta-llama/Llama-3.2-1B-Instruct" diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index d13df553db6..99696c9887b 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -6,7 +6,7 @@ import pytest import torch -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) from vllm.platforms import current_platform @@ -38,6 +38,7 @@ def initialize_kv_cache(runner: GPUModelRunner): head_size=runner.model_config.get_head_size(), dtype=runner.kv_cache_dtype, use_mla=False, + attn_type=AttentionType.DECODER, ) tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f47499309d8..a2af06a6f32 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1744,7 +1744,8 @@ def _set_default_args_v1(self, usage_context: UsageContext, if (self.max_num_seqs is None and usage_context in default_max_num_seqs): - self.max_num_seqs = default_max_num_seqs[usage_context] + self.max_num_seqs = min(default_max_num_seqs[usage_context], + self.max_num_batched_tokens or sys.maxsize) logger.debug("Setting max_num_seqs to %d for %s usage context.", self.max_num_seqs, use_context_value) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c60a566f585..12d987124bd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1295,39 +1295,18 @@ def _cross_encoding_score( input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - if self.llm_engine.model_config.is_multimodal_model: - - model_config = self.llm_engine.model_config - - for q, d in input_pairs: - _, engine_prompt = get_score_prompt( - model_config=model_config, - data_1=q, - data_2=d, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - ) - - parsed_prompts.append(engine_prompt) - - else: + model_config = self.llm_engine.model_config + + for q, d in input_pairs: + _, engine_prompt = get_score_prompt( + model_config=model_config, + data_1=q, + data_2=d, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + ) - for q, t in input_pairs: - if self.llm_engine.model_config.use_pad_token: - # cross_encoder models defaults to using pad_token. - prompt_inputs = tokenizer( - text=q, # type: ignore[arg-type] - text_pair=t, # type: ignore[arg-type] - **tokenization_kwargs) - else: - # `llm as reranker` models defaults to not using pad_token. - prompt_inputs = tokenizer( - text=q + t, # type: ignore[operator] - **tokenization_kwargs) - engine_prompt = TokensPrompt( - prompt_token_ids=prompt_inputs["input_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - parsed_prompts.append(engine_prompt) + parsed_prompts.append(engine_prompt) self._validate_and_add_requests( prompts=parsed_prompts, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 8d47a417f9c..67a555d2558 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -191,56 +191,19 @@ async def _cross_encoding_score( input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - if self.model_config.is_multimodal_model: - - preprocess_async = make_async(self._preprocess_score, - executor=self._tokenizer_executor) - - preprocessed_prompts = await asyncio.gather( - *(preprocess_async(request=request, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - data_1=t1, - data_2=t2) for t1, t2 in input_pairs)) - - for full_prompt, engine_prompt in preprocessed_prompts: - request_prompts.append(full_prompt) - engine_prompts.append(engine_prompt) - - else: - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) - use_pad_token = self.model_config.use_pad_token - - if use_pad_token: - # cross_encoder models defaults to using pad_token. - tokenized_prompts = await asyncio.gather(*( - tokenize_async( - text=t1, # type: ignore[arg-type] - text_pair=t2, # type: ignore[arg-type] - **tokenization_kwargs) for t1, t2 in input_pairs)) - else: - # `llm as reranker` models defaults to not using pad_token. - tokenized_prompts = await asyncio.gather(*( - tokenize_async( - text=t1 + # type: ignore[operator] - t2, - **tokenization_kwargs) for t1, t2 in input_pairs)) - - for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): - sep_token = tokenizer.sep_token if (tokenizer.sep_token - and use_pad_token) else '' - request_prompt = f"{t1}{sep_token}{t2}" - - input_ids = prompt_inputs["input_ids"] - text_token_prompt = \ - self._validate_input(request, input_ids, request_prompt) - engine_prompt = TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - - request_prompts.append(request_prompt) - engine_prompts.append(engine_prompt) + preprocess_async = make_async(self._preprocess_score, + executor=self._tokenizer_executor) + + preprocessed_prompts = await asyncio.gather( + *(preprocess_async(request=request, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + data_1=t1, + data_2=t2) for t1, t2 in input_pairs)) + + for full_prompt, engine_prompt in preprocessed_prompts: + request_prompts.append(full_prompt) + engine_prompts.append(engine_prompt) # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index f3f042355c9..229fc10391d 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -5,6 +5,7 @@ from torch.nn import CosineSimilarity from typing_extensions import Required, TypeAlias, TypedDict +import vllm.envs as envs from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ( BaseMultiModalItemTracker, ChatCompletionContentPartImageEmbedsParam, @@ -184,13 +185,30 @@ def get_score_prompt( model_config, tokenizer, ) + from vllm.model_executor.model_loader import get_model_cls - full_prompt = apply_score_template(model_config, prompt_1, prompt_2) - - prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) + model = get_model_cls(model_config) + if supports_score_template(model): + full_prompt = apply_score_template(model_config, prompt_1, prompt_2) + prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) + elif model_config.use_pad_token: + # cross_encoder models defaults to using pad_token. + prompt_inputs = tokenizer(text=prompt_1, + text_pair=prompt_2, + **tokenization_kwargs) + else: + # `llm as reranker` models defaults to not using pad_token. + prompt_inputs = tokenizer(text=prompt_1 + prompt_2, + **tokenization_kwargs) engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"]) + if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None: + if envs.VLLM_USE_V1: + mm_data = {"token_type_ids": token_type_ids, **(mm_data or {})} + else: + engine_prompt["token_type_ids"] = token_type_ids + post_process_tokens(model_config, engine_prompt) if mm_data is not None: diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6e955e1c512..c6f01224bff 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -1,18 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable -from typing import Optional +from collections.abc import Iterable, Mapping, Sequence +from typing import Optional, Union import torch from torch import nn -from transformers import BertConfig +from transformers import BatchFeature, BertConfig from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -24,9 +23,19 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalFlatField, MultiModalInputs, + MultiModalKwargs, MultiModalKwargsItem, + PlaceholderRange) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .interfaces import (MultiModalEmbeddings, SupportsCrossEncoding, + SupportsMultiModal, SupportsQuant) from .utils import WeightsMapper, maybe_prefix @@ -54,29 +63,41 @@ def __init__(self, config: BertConfig): def forward( self, - input_ids: torch.Tensor, - seq_lens: torch.Tensor, - position_ids: torch.Tensor, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + apply_layer_norm: bool = True, ) -> torch.Tensor: - input_shape = input_ids.size() - # Input embeddings. - inputs_embeds = self.word_embeddings(input_ids) + # forward was called directly without going + # throught the multi-modal flow + if input_ids is not None and position_ids is not None \ + and inputs_embeds is None and token_type_ids is None: + token_type_ids = torch.zeros(input_ids.size(), + dtype=torch.long, + device=input_ids.device) + + tensors_to_add: list[torch.Tensor] = [] - # Position embeddings. - position_embeddings = self.position_embeddings(position_ids) + if inputs_embeds is not None: + tensors_to_add.append(inputs_embeds) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device) + if token_type_ids is not None: + tensors_to_add.append(self.token_type_embeddings(token_type_ids)) - token_type_embeddings = self.token_type_embeddings(token_type_ids) + if position_ids is not None: + tensors_to_add.append(self.position_embeddings(position_ids)) - embeddings = inputs_embeds + token_type_embeddings + position_embeddings - embeddings = self.LayerNorm(embeddings) - return embeddings + if input_ids is not None: + tensors_to_add.append(self.word_embeddings(input_ids)) + + embeds = torch.stack(tensors_to_add, dim=0).sum(dim=0) + + if apply_layer_norm: + return self.LayerNorm(embeds) + else: + return embeds class BertPooler(nn.Module): @@ -95,7 +116,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output -@support_torch_compile class BertEncoder(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): @@ -313,6 +333,7 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states +@support_torch_compile class BertModel(nn.Module, SupportsQuant): packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} @@ -337,16 +358,11 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - attn_metadata = get_forward_context().attn_metadata - assert hasattr(attn_metadata, "seq_lens_tensor") - hidden_states = self.embeddings( - input_ids=input_ids, - seq_lens=attn_metadata.seq_lens_tensor, - position_ids=position_ids, - token_type_ids=token_type_ids) + + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds) return self.encoder(hidden_states) def load_weights(self, weights: Iterable[tuple[str, @@ -386,7 +402,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): +class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -409,11 +425,13 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, position_ids=positions, + token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) @@ -444,8 +462,147 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: softmax=False) -class BertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding, SupportsQuant): +TOKEN_TYPES = "token_type_ids" + + +class TokenTypeMultiModalProcessor(BaseMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + raise NotImplementedError + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + raise NotImplementedError + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + + assert isinstance(prompt, list) + + mm_data_item = mm_data[TOKEN_TYPES] + if isinstance(mm_data_item, list): + mm_data_item = torch.tensor(mm_data_item) + + prompt_len = len(prompt) + + mm_placeholders = { + TOKEN_TYPES: [PlaceholderRange( + offset=0, + length=prompt_len, + )] + } + + field = MultiModalFlatField([slice(0, prompt_len)]) + mm_item = MultiModalKwargsItem.from_elems( + field.build_elems(modality=TOKEN_TYPES, + key=TOKEN_TYPES, + data=mm_data_item)) + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=prompt, + mm_kwargs=MultiModalKwargs.from_items([mm_item]), + mm_hashes=None, + mm_placeholders=mm_placeholders, + ) + + +class TokenTypeProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {TOKEN_TYPES: 1} + + +class TokenTypeInputBuilder(BaseDummyInputsBuilder[TokenTypeProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + raise NotImplementedError + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + raise NotImplementedError + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + + dummy_prompt = [0] * seq_len + dummy_mm_data = { + TOKEN_TYPES: torch.zeros(seq_len, dtype=torch.int32), + } + return ProcessorInputs(prompt=dummy_prompt, mm_data=dummy_mm_data) + + +class BertMMTokenIdsMixin(SupportsMultiModal): + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + token_type_ids = kwargs.pop(TOKEN_TYPES, None) + + if token_type_ids is None: + return [] + + if not isinstance(token_type_ids, torch.Tensor): + raise ValueError("Incorrect type token_type_ids. " + f"Got type: {type(token_type_ids)}") + + return self.get_language_model().embeddings( + token_type_ids=token_type_ids, apply_layer_norm=False) + + def maybe_store_input_ids(self, input_ids: torch.Tensor): + pass + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + token_type_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + # save for forward() + self.maybe_store_input_ids(input_ids) + + token_type_ids: Optional[torch.Tensor] = None + + if token_type_embeddings is not None: + assert isinstance(token_type_embeddings, list) + token_type_embeddings = torch.cat(token_type_embeddings) + else: + token_type_ids = torch.zeros(input_ids.size(), + dtype=torch.long, + device=input_ids.device) + + return self.get_language_model().embeddings( + input_ids=input_ids, + inputs_embeds=token_type_embeddings, + token_type_ids=token_type_ids, + apply_layer_norm=False) + + +@MULTIMODAL_REGISTRY.register_processor(TokenTypeMultiModalProcessor, + info=TokenTypeProcessingInfo, + dummy_inputs=TokenTypeInputBuilder) +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, + BertMMTokenIdsMixin, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -469,6 +626,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._pooler = ClassifierPooler(vllm_config.model_config, self.classifier, self.bert.pooler) + def get_language_model(self) -> torch.nn.Module: + return self.bert + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): self_weights = [] diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 048fa827fb2..b757aa2d6c3 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -10,17 +10,23 @@ from transformers import RobertaConfig from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel +from vllm.model_executor.models.bert import (BertEmbeddingModel, + BertMMTokenIdsMixin, BertModel, + TokenTypeInputBuilder, + TokenTypeMultiModalProcessor, + TokenTypeProcessingInfo) from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, PoolerOutput from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import SupportsCrossEncoding class RobertaEmbedding(nn.Module): @@ -49,51 +55,41 @@ def __init__(self, config: RobertaConfig): def forward( self, - input_ids: torch.Tensor, - seq_lens: torch.Tensor, - position_ids: torch.Tensor, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + apply_layer_norm: bool = True, ) -> torch.Tensor: - input_shape = input_ids.size() - inputs_embeds = self.word_embeddings(input_ids) - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - pos_list = [] - token_list = [] - offset = 0 - for seq_len in seq_lens: - pos_list.append(position_ids[offset:offset + seq_len]) - token_list.append(input_ids[offset:offset + seq_len]) - offset += seq_len - - new_pos_list = [] - for positions, tokens in zip(pos_list, token_list): - # Verify assumption that incoming position are - # always a sequence from 0 to N. - expected_pos = torch.arange(positions.size()[0], - dtype=torch.long, - device=inputs_embeds.device) - assert torch.equal(positions, expected_pos) - new_pos_list.append( - create_position_ids_from_input_ids(tokens, self.padding_idx)) - position_ids = torch.cat(new_pos_list) - - # Position embeddings. - position_embeddings = self.position_embeddings(position_ids) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, + # forward was called directly without going + # throught the multi-modal flow + if input_ids is not None and position_ids is not None \ + and inputs_embeds is None and token_type_ids is None: + token_type_ids = torch.zeros(input_ids.size(), dtype=torch.long, - device=inputs_embeds.device) + device=input_ids.device) + + tensors_to_add: list[torch.Tensor] = [] + + if inputs_embeds is not None: + tensors_to_add.append(inputs_embeds) + + if token_type_ids is not None: + tensors_to_add.append(self.token_type_embeddings(token_type_ids)) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings + position_embeddings - embeddings = self.LayerNorm(embeddings) - return embeddings + if position_ids is not None: + tensors_to_add.append(self.position_embeddings(position_ids)) + + if input_ids is not None: + tensors_to_add.append(self.word_embeddings(input_ids)) + + embeds = torch.stack(tensors_to_add, dim=0).sum(dim=0) + + if apply_layer_norm: + return self.LayerNorm(embeds) + else: + return embeds # Adapted from transformers @@ -124,6 +120,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel): _pooler: An instance of Pooler used for pooling operations. """ + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # Fix Roberta positions here outside of the CUDA graph. + # Because we need the to extract the sequences from + # input_ids the control flow is data dependent. + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) + + return self.model(input_ids=input_ids, + position_ids=positions, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> Union[BertModel, BertWithRope]: @@ -148,8 +170,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): assert len(loaded), "Unable to load RobertaEmbeddingModel" -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsV0Only): +@MULTIMODAL_REGISTRY.register_processor(TokenTypeMultiModalProcessor, + info=TokenTypeProcessingInfo, + dummy_inputs=TokenTypeInputBuilder) +class RobertaForSequenceClassification(nn.Module, BertMMTokenIdsMixin, + SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -175,6 +200,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, @@ -185,6 +211,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._pooler = ClassifierPooler(vllm_config.model_config, self.classifier) + self.input_ids: Optional[torch.Tensor] = None + + def maybe_store_input_ids(self, input_ids: torch.Tensor): + self.input_ids = input_ids + + def get_language_model(self) -> torch.nn.Module: + return self.roberta def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): bert_weights, task_weights = roberta_task_weights_filter(weights) @@ -216,6 +249,10 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + _input_ids = input_ids if input_ids is not None else self.input_ids + replace_roberta_positions(input_ids=_input_ids, + position_ids=positions, + padding_idx=self.padding_idx) return self.roberta(input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds, @@ -247,6 +284,39 @@ def create_position_ids_from_input_ids(input_ids, return incremental_indices.long() + padding_idx +def replace_roberta_positions(input_ids: torch.Tensor, + position_ids: torch.Tensor, + padding_idx: int) -> None: + + seq_lens: Optional[torch.Tensor] = None + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: # can be None during warmup + if isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values())) + # TODO: remove "seq_lens_tensor" after V0 is removed + seq_lens = getattr(attn_metadata, "seq_lens_tensor", + getattr(attn_metadata, "seq_lens", None)) + + if seq_lens is not None: + assert isinstance(seq_lens, torch.Tensor) + + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + token_list = torch.split(input_ids[:torch.sum(seq_lens)], + seq_lens.tolist()) + + offset = 0 + for tokens in token_list: + length = tokens.shape[0] + position_ids[offset:offset+length] = \ + create_position_ids_from_input_ids(tokens, padding_idx) + offset = offset + length + + def roberta_task_weights_filter( all_weights: Iterable[tuple[str, torch.Tensor]] ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py index 598a0e97e51..fb29d51eae8 100644 --- a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import re from collections.abc import Sequence from typing import Optional, Union +import regex as re from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fbc13c06c65..4748f522af4 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -429,11 +429,13 @@ def __init__( FlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " + if attn_type not in [ + AttentionType.DECODER, AttentionType.ENCODER_ONLY + ]: + raise NotImplementedError("Encoder/decoder cross-attention " + "is not implemented for " "FlashAttentionImpl") + self.attn_type = attn_type self.use_irope = use_irope self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -552,7 +554,7 @@ def forward( seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, - causal=True, + causal=_get_causal_option(self.attn_type), alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, @@ -755,3 +757,21 @@ def cascade_attention( # Merge prefix and suffix outputs, and store the result in output. merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) + + +def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2fbcb569e3d..45f95fbabee 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -917,6 +917,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: dtype=spec.dtype, use_mla=spec.use_mla, sliding_window=spec.sliding_window, + attn_type=str(spec.attn_type), ) if is_hybrid(kv_cache_spec): diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e2fdf6f8a11..bb367a641ab 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -36,7 +36,7 @@ from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -150,6 +150,23 @@ def _initialize_kv_caches( zip(kv_cache_specs, available_gpu_memory) ] + for kv_cache_spec_one_worker in kv_cache_specs: + for _, spec in kv_cache_spec_one_worker.items(): + if isinstance(spec, AttentionSpec) and \ + spec.attn_type != "decoder": + + logger.info("Found non-decoder layer. Disabling " + "prefix cache and chunked prefill") + self.vllm_config.cache_config.\ + enable_prefix_caching = False + self.vllm_config.scheduler_config.\ + enable_chunked_prefill = False + self.vllm_config.scheduler_config.\ + chunked_prefill_enabled = False + self.vllm_config.scheduler_config.\ + long_prefill_token_threshold = 0 + break + # Since we use a shared centralized controller, we need the # `kv_cache_config` to be consistent across all workers to make sure # all the memory operators can be applied to all workers. diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 43456a987de..fd97feff449 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -75,6 +75,7 @@ class AttentionSpec(KVCacheSpec): head_size: int dtype: torch.dtype use_mla: bool + attn_type: str @property def page_size_bytes(self) -> int: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a..fe3bd87177f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Datastructures defining a GPU input batch +# Datastructures defining an input batch from dataclasses import dataclass from typing import Optional, cast diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f3279fa5fa8..1c3fcf2cad7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -247,7 +247,6 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) - # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -2597,7 +2596,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: continue # TODO: Support other attention modules, e.g., cross-attention - if attn_module.attn_type == AttentionType.DECODER: + if attn_module.attn_type in (AttentionType.DECODER, + AttentionType.ENCODER_ONLY): if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -2605,17 +2605,18 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=use_mla) + use_mla=use_mla, + attn_type=str(attn_module.attn_type)) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. + use_mla=use_mla, + attn_type=str(attn_module.attn_type)) + elif attn_module.attn_type == AttentionType.ENCODER: + # encoder attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5af052e6851..f947a9c86ef 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -506,6 +506,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=False, + attn_type=str(attn_module.attn_type), ) else: kv_cache_spec[layer_name] = FullAttentionSpec( @@ -514,6 +515,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=False, + attn_type=str(attn_module.attn_type), ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):