Skip to content

Commit 9ad8c9e

Browse files
Isotr0pyChen-zexi
authored andcommitted
[CI/Build] Ensure compatability with Transformers v4.53 (vllm-project#20541)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent eb42508 commit 9ad8c9e

File tree

13 files changed

+74
-38
lines changed

13 files changed

+74
-38
lines changed

requirements/test.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ opencv-python-headless >= 4.11.0 # required for video test
3434
datamodel_code_generator # required for minicpm3 test
3535
lm-eval[api]==0.4.8 # required for model evaluation test
3636
mteb[bm25s]>=1.38.11, <2 # required for mteb test
37-
transformers==4.52.4
37+
transformers==4.53.2
3838
tokenizers==0.21.1
3939
huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads.
4040
schemathesis>=3.39.15 # Required for openai schema test.

requirements/test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ tqdm==4.66.6
800800
# transformers
801801
tqdm-multiprocess==0.0.11
802802
# via lm-eval
803-
transformers==4.52.4
803+
transformers==4.53.2
804804
# via
805805
# -r requirements/test.in
806806
# genai-perf

tests/models/multimodal/generation/test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@
318318
num_logprobs=10,
319319
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
320320
auto_cls=AutoModelForImageTextToText,
321+
marks=[large_gpu_mark(min_gb=32)],
321322
),
322323
"glm4_1v-video": VLMTestInfo(
323324
models=["THUDM/GLM-4.1V-9B-Thinking"],
@@ -331,8 +332,7 @@
331332
inputs=custom_inputs.video_with_metadata_glm4_1v(),
332333
limit_mm_per_prompt={"video": 1},
333334
)],
334-
# This is needed to run on machine with 24GB VRAM
335-
vllm_runner_kwargs={"gpu_memory_utilization": 0.95},
335+
marks=[large_gpu_mark(min_gb=32)],
336336
),
337337
"h2ovl": VLMTestInfo(
338338
models = [

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def _test_processing_correctness(
159159
_ADD_SPECIAL_TOKENS_OVERRIDES = {
160160
"mllama": False,
161161
"ovis": False,
162+
"paligemma": False,
162163
"ultravox": False,
163164
"whisper": False,
164165
}

tests/models/test_initialization.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
3131
model_info.check_transformers_version(on_fail="skip")
3232

3333
# FIXME: Possible memory leak in the previous tests?
34-
if model_arch in ("GraniteSpeechForConditionalGeneration",
34+
if model_arch in ("Glm4vForConditionalGeneration",
35+
"GraniteSpeechForConditionalGeneration",
3536
"KimiVLForConditionalGeneration"):
3637
pytest.skip("Avoid OOM")
3738

@@ -46,16 +47,23 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
4647
n_group = getattr(text_config, 'n_group', None)
4748
num_experts = n_group * 2 if n_group is not None else 2
4849

50+
# we use three layers for Gemma-3n to check
51+
# both normal layer and kv_shared_layer
52+
num_hidden_layers = (3 if model_arch
53+
== "Gemma3nForConditionalGeneration" else 1)
54+
4955
text_config.update({
5056
"num_layers": 1,
51-
"num_hidden_layers": 1,
57+
"num_hidden_layers": num_hidden_layers,
5258
"num_experts": num_experts,
5359
"num_experts_per_tok": 2,
5460
"num_local_experts": num_experts,
5561
# Otherwise there will not be any expert layers
5662
"first_k_dense_replace": 0,
5763
# To avoid OOM on DeepSeek-V3
5864
"n_routed_experts": num_experts,
65+
# For Gemma-3n
66+
"num_kv_shared_layers": 1,
5967
})
6068

6169
if hasattr(hf_config, "vision_config"):

vllm/inputs/registry.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
66

77
import torch
8-
from packaging.version import Version
98
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
10-
from transformers import __version__ as TRANSFORMERS_VERSION
119
from typing_extensions import TypeVar
1210

1311
from vllm.jsontree import JSONTree, json_map_leaves
@@ -137,13 +135,9 @@ def get_hf_processor(
137135
/,
138136
**kwargs: object,
139137
) -> _P:
140-
# Transformers 4.53.0 has issue with passing tokenizer to
141-
# initialize processor. We disable it for this version.
142-
# See: https://github.com/vllm-project/vllm/issues/20224
143-
if Version(TRANSFORMERS_VERSION) != Version("4.53.0"):
144-
kwargs["tokenizer"] = self.tokenizer
145138
return super().get_hf_processor(
146139
typ,
140+
tokenizer=self.tokenizer,
147141
**kwargs,
148142
)
149143

vllm/model_executor/models/commandr.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,13 @@ def __init__(
189189

190190
layer_idx = extract_layer_index(prefix)
191191
layer_has_sliding_window = (
192-
getattr(config, "sliding_window_pattern", False)
193-
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)
192+
getattr(config, "sliding_window_pattern", False) and
193+
(layer_idx + 1) % self.config.sliding_window_pattern
194+
!= 0) or (getattr(config, "layer_types", False)
195+
and config.layer_types[layer_idx] == "sliding_attention")
194196

195197
self.sliding_window = (interleaved_sliding_window
198+
or config.sliding_window
196199
if layer_has_sliding_window else None)
197200

198201
self.attn = Attention(self.num_heads,

vllm/model_executor/models/fuyu.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,21 @@ def _call_hf_processor(
175175

176176
# Original output: (1, num_images, Pn, Px * Py * C)
177177
# New output: (num_images, Pn, Px * Py * C)
178-
assert (isinstance(image_patches, list)
179-
and len(image_patches) == 1)
180-
assert (isinstance(image_patches[0], torch.Tensor)
181-
and len(image_patches[0]) == len(images))
182-
183-
processed_outputs["image_patches"] = image_patches[0]
178+
# image_patches is a list with shape:
179+
# (1, num_images, Pn, Px * Py * C)
180+
# before Transformers 4.53
181+
if isinstance(image_patches, list):
182+
assert len(image_patches) == 1
183+
assert (isinstance(image_patches[0], torch.Tensor)
184+
and len(image_patches[0]) == len(images))
185+
processed_outputs["image_patches"] = image_patches[0]
186+
# image_patches is a tensor with shape:
187+
# (num_images, Pn, Px * Py * C)
188+
# after Transformers 4.53
189+
elif isinstance(image_patches, torch.Tensor):
190+
assert len(image_patches) == len(images)
191+
else:
192+
raise AssertionError("This line should be unreachable.")
184193

185194
return processed_outputs
186195

@@ -193,8 +202,10 @@ def _apply_hf_processor_tokens_only(
193202
vocab = tokenizer.get_vocab()
194203

195204
boa_token_id = vocab["<0x04>"]
205+
if prompt_tokens[-1] != boa_token_id:
206+
prompt_tokens.append(boa_token_id)
196207

197-
return prompt_tokens + [boa_token_id]
208+
return prompt_tokens
198209

199210
def _get_mm_fields_config(
200211
self,

vllm/model_executor/models/gemma3.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,17 @@ def __init__(self,
149149
# TODO(woosuk): Add reference to the original HF implementation.
150150
layer_idx = extract_layer_index(prefix)
151151
self.is_sliding = (getattr(
152-
config, "interleaved_sliding_window", None) is not None and bool(
153-
(layer_idx + 1) % config.sliding_window_pattern))
152+
config, "interleaved_sliding_window", None) is not None and (bool(
153+
(layer_idx + 1) % config.sliding_window_pattern))) or (
154+
getattr(config, "layer_types", None) is not None
155+
and config.layer_types[layer_idx] == "sliding_attention")
154156
# Initialize the rotary embedding.
155157
if self.is_sliding:
156158
# Local attention. Override the values in config.json.
157159
self.rope_theta = config.rope_local_base_freq
158160
self.rope_scaling = {"rope_type": "default"}
159-
self.sliding_window = config.interleaved_sliding_window
161+
self.sliding_window = (config.interleaved_sliding_window
162+
or config.sliding_window)
160163
else:
161164
# Global attention. Use the values in config.json.
162165
self.rope_theta = config.rope_theta

vllm/model_executor/models/minicpmo.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
from torch import nn
3131
from transformers import BatchFeature, PretrainedConfig
3232
from transformers.modeling_outputs import BaseModelOutputWithPast
33-
from transformers.models.whisper.modeling_whisper import (
34-
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
33+
from transformers.models.whisper.modeling_whisper import (ACT2FN,
34+
WhisperAttention,
35+
WhisperConfig,
36+
WhisperEncoder)
3537

3638
from vllm.config import VllmConfig
3739
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -378,14 +380,13 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
378380
def __init__(self, config: WhisperConfig, layer_idx: int):
379381
super().__init__()
380382
self.embed_dim = config.d_model
381-
self.self_attn = WHISPER_ATTENTION_CLASSES[
382-
config._attn_implementation](
383-
embed_dim=self.embed_dim,
384-
num_heads=config.encoder_attention_heads,
385-
dropout=config.attention_dropout,
386-
config=config,
387-
layer_idx=layer_idx,
388-
)
383+
self.self_attn = WhisperAttention(
384+
embed_dim=self.embed_dim,
385+
num_heads=config.encoder_attention_heads,
386+
dropout=config.attention_dropout,
387+
config=config,
388+
layer_idx=layer_idx,
389+
)
389390
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
390391
self.dropout = config.dropout
391392
self.activation_fn = ACT2FN[config.activation_function]

0 commit comments

Comments
 (0)