diff --git a/requirements/test.in b/requirements/test.in index 907d90201a2..1c725df7e60 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -34,7 +34,7 @@ opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test mteb[bm25s]>=1.38.11, <2 # required for mteb test -transformers==4.52.4 +transformers==4.53.2 tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads. schemathesis>=3.39.15 # Required for openai schema test. diff --git a/requirements/test.txt b/requirements/test.txt index 2f3ccc4f61d..6f500992bb5 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -800,7 +800,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.52.4 +transformers==4.53.2 # via # -r requirements/test.in # genai-perf diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index ce449489965..98461676aa4 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -318,6 +318,7 @@ num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], auto_cls=AutoModelForImageTextToText, + marks=[large_gpu_mark(min_gb=32)], ), "glm4_1v-video": VLMTestInfo( models=["THUDM/GLM-4.1V-9B-Thinking"], @@ -331,8 +332,7 @@ inputs=custom_inputs.video_with_metadata_glm4_1v(), limit_mm_per_prompt={"video": 1}, )], - # This is needed to run on machine with 24GB VRAM - vllm_runner_kwargs={"gpu_memory_utilization": 0.95}, + marks=[large_gpu_mark(min_gb=32)], ), "h2ovl": VLMTestInfo( models = [ diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0f33225eda2..ab21941fae9 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -159,6 +159,7 @@ def _test_processing_correctness( _ADD_SPECIAL_TOKENS_OVERRIDES = { "mllama": False, "ovis": False, + "paligemma": False, "ultravox": False, "whisper": False, } diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 76726c0c820..07ded1e5880 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -31,7 +31,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): model_info.check_transformers_version(on_fail="skip") # FIXME: Possible memory leak in the previous tests? - if model_arch in ("GraniteSpeechForConditionalGeneration", + if model_arch in ("Glm4vForConditionalGeneration", + "GraniteSpeechForConditionalGeneration", "KimiVLForConditionalGeneration"): pytest.skip("Avoid OOM") @@ -46,9 +47,14 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: n_group = getattr(text_config, 'n_group', None) num_experts = n_group * 2 if n_group is not None else 2 + # we use three layers for Gemma-3n to check + # both normal layer and kv_shared_layer + num_hidden_layers = (3 if model_arch + == "Gemma3nForConditionalGeneration" else 1) + text_config.update({ "num_layers": 1, - "num_hidden_layers": 1, + "num_hidden_layers": num_hidden_layers, "num_experts": num_experts, "num_experts_per_tok": 2, "num_local_experts": num_experts, @@ -56,6 +62,8 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: "first_k_dense_replace": 0, # To avoid OOM on DeepSeek-V3 "n_routed_experts": num_experts, + # For Gemma-3n + "num_kv_shared_layers": 1, }) if hasattr(hf_config, "vision_config"): diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index fc6e190e548..66e78833f52 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -5,9 +5,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import torch -from packaging.version import Version from transformers import BatchFeature, PretrainedConfig, ProcessorMixin -from transformers import __version__ as TRANSFORMERS_VERSION from typing_extensions import TypeVar from vllm.jsontree import JSONTree, json_map_leaves @@ -130,13 +128,9 @@ def get_hf_processor( /, **kwargs: object, ) -> _P: - # Transformers 4.53.0 has issue with passing tokenizer to - # initialize processor. We disable it for this version. - # See: https://github.com/vllm-project/vllm/issues/20224 - if Version(TRANSFORMERS_VERSION) != Version("4.53.0"): - kwargs["tokenizer"] = self.tokenizer return super().get_hf_processor( typ, + tokenizer=self.tokenizer, **kwargs, ) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 817c6bb9a7f..c4f6144ed91 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -189,10 +189,13 @@ def __init__( layer_idx = extract_layer_index(prefix) layer_has_sliding_window = ( - getattr(config, "sliding_window_pattern", False) - and (layer_idx + 1) % self.config.sliding_window_pattern != 0) + getattr(config, "sliding_window_pattern", False) and + (layer_idx + 1) % self.config.sliding_window_pattern + != 0) or (getattr(config, "layer_types", False) + and config.layer_types[layer_idx] == "sliding_attention") self.sliding_window = (interleaved_sliding_window + or config.sliding_window if layer_has_sliding_window else None) self.attn = Attention(self.num_heads, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 26c8f80d5a0..558d4fbb4de 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -175,12 +175,21 @@ def _call_hf_processor( # Original output: (1, num_images, Pn, Px * Py * C) # New output: (num_images, Pn, Px * Py * C) - assert (isinstance(image_patches, list) - and len(image_patches) == 1) - assert (isinstance(image_patches[0], torch.Tensor) - and len(image_patches[0]) == len(images)) - - processed_outputs["image_patches"] = image_patches[0] + # image_patches is a list with shape: + # (1, num_images, Pn, Px * Py * C) + # before Transformers 4.53 + if isinstance(image_patches, list): + assert len(image_patches) == 1 + assert (isinstance(image_patches[0], torch.Tensor) + and len(image_patches[0]) == len(images)) + processed_outputs["image_patches"] = image_patches[0] + # image_patches is a tensor with shape: + # (num_images, Pn, Px * Py * C) + # after Transformers 4.53 + elif isinstance(image_patches, torch.Tensor): + assert len(image_patches) == len(images) + else: + raise AssertionError("This line should be unreachable.") return processed_outputs @@ -193,8 +202,10 @@ def _apply_hf_processor_tokens_only( vocab = tokenizer.get_vocab() boa_token_id = vocab["<0x04>"] + if prompt_tokens[-1] != boa_token_id: + prompt_tokens.append(boa_token_id) - return prompt_tokens + [boa_token_id] + return prompt_tokens def _get_mm_fields_config( self, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 954e48d25f6..1a2ce65d1e4 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -149,14 +149,17 @@ def __init__(self, # TODO(woosuk): Add reference to the original HF implementation. layer_idx = extract_layer_index(prefix) self.is_sliding = (getattr( - config, "interleaved_sliding_window", None) is not None and bool( - (layer_idx + 1) % config.sliding_window_pattern)) + config, "interleaved_sliding_window", None) is not None and (bool( + (layer_idx + 1) % config.sliding_window_pattern))) or ( + getattr(config, "layer_types", None) is not None + and config.layer_types[layer_idx] == "sliding_attention") # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. self.rope_theta = config.rope_local_base_freq self.rope_scaling = {"rope_type": "default"} - self.sliding_window = config.interleaved_sliding_window + self.sliding_window = (config.interleaved_sliding_window + or config.sliding_window) else: # Global attention. Use the values in config.json. self.rope_theta = config.rope_theta diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 71593d4bb89..4e4fc3d5c76 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -30,8 +30,10 @@ from torch import nn from transformers import BatchFeature, PretrainedConfig from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.whisper.modeling_whisper import ( - ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) +from transformers.models.whisper.modeling_whisper import (ACT2FN, + WhisperAttention, + WhisperConfig, + WhisperEncoder) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig @@ -378,14 +380,13 @@ class MiniCPMWhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig, layer_idx: int): super().__init__() self.embed_dim = config.d_model - self.self_attn = WHISPER_ATTENTION_CLASSES[ - config._attn_implementation]( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - config=config, - layer_idx=layer_idx, - ) + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + layer_idx=layer_idx, + ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 77197abe571..b1f2e53b0c7 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -125,7 +125,7 @@ def _call_hf_processor( ) -> BatchFeature: tokenizer = self.info.get_tokenizer() if not mm_data: - prompt_ids = tokenizer.encode(prompt) + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return super()._call_hf_processor( diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 377a34f2088..c5a5c10d950 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -144,8 +144,16 @@ def get_hf_processor( ) -> Qwen2_5OmniProcessor: if fps is not None: kwargs["fps"] = fps + + # Monkey patch for Transformers v4.53 + processor_class = Qwen2_5OmniProcessor + if processor_class.image_processor_class != "AutoImageProcessor": + processor_class.image_processor_class = "AutoImageProcessor" + if processor_class.video_processor_class != "AutoVideoProcessor": + processor_class.video_processor_class = "AutoVideoProcessor" + processor = self.ctx.get_hf_processor( - Qwen2_5OmniProcessor, + processor_class, image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 344d6fc8f45..ee1cfd7d713 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -634,7 +634,14 @@ def get_hf_config(self) -> WhisperConfig: def get_hf_processor(self, sampling_rate: Optional[int] = None ) -> WhisperProcessor: - return self.ctx.get_hf_processor(WhisperProcessor) + # HACK: Transformers 4.53.0 has issue with whisper tokenizer to + # initialize processor. We use a monkeypatch to fix it here. + # See: https://github.com/vllm-project/vllm/issues/20224 + processor_class = WhisperProcessor + tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast") + if processor_class.tokenizer_class != tokenizer_class: + processor_class.tokenizer_class = tokenizer_class + return self.ctx.get_hf_processor(processor_class) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": 1}