Skip to content

Commit 2c73f88

Browse files
committed
style
Signed-off-by: raushan <raushan@huggingface.co>
1 parent 267a57f commit 2c73f88

File tree

4 files changed

+50
-38
lines changed

4 files changed

+50
-38
lines changed

tests/models/test_transformers.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from typing import Any, Optional, Union
44

55
import pytest
6+
from transformers import AutoModelForImageTextToText
67

78
from vllm.platforms import current_platform
89

910
from ..conftest import HfRunner, VllmRunner
1011
from ..core.block.e2e.test_correctness_sliding_window import prep_prompts
1112
from ..utils import multi_gpu_test
1213
from .utils import check_logprobs_close
13-
from transformers import AutoModelForImageTextToText
1414

1515

1616
def check_implementation(
@@ -75,23 +75,30 @@ def test_models(
7575
@pytest.mark.parametrize(
7676
"model,model_impl",
7777
[
78-
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "transformers"), # dynamic image length and number of patches
79-
("HuggingFaceTB/SmolVLM-256M-Instruct", "transformers"), # has col/row special token between patches
80-
("Qwen/Qwen2.5-VL-3B-Instruct", "transformers"), # pixel values from processor are not 4D or 5D arraya
81-
]) # no custom code support because custom models don't follow the standard yet!
78+
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
79+
"transformers"), # dynamic image length and number of patches
80+
("HuggingFaceTB/SmolVLM-256M-Instruct",
81+
"transformers"), # has col/row special token between patches
82+
("Qwen/Qwen2.5-VL-3B-Instruct", "transformers"
83+
), # pixel values from processor are not 4D or 5D arraya
84+
]
85+
) # no custom code support because custom models don't follow the standard yet!
8286
def test_models_multimodal(
8387
hf_runner: type[HfRunner],
8488
vllm_runner: type[VllmRunner],
8589
example_prompts: list[str],
8690
model: str,
8791
model_impl: str,
8892
) -> None:
89-
check_implementation(hf_runner,
90-
vllm_runner,
91-
example_prompts,
92-
model,
93-
model_impl=model_impl,
94-
kwargs_ref={"auto_cls": AutoModelForImageTextToText},)
93+
check_implementation(
94+
hf_runner,
95+
vllm_runner,
96+
example_prompts,
97+
model,
98+
model_impl=model_impl,
99+
kwargs_ref={"auto_cls": AutoModelForImageTextToText},
100+
)
101+
95102

96103
def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None:
97104
prompts, _, _ = prep_prompts(4, (800, 801))

vllm/model_executor/models/registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@
231231
}
232232

233233
_TRANSFORMERS_MODELS = {
234-
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"),
234+
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
235235
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
236236
}
237237
# yapf: enable
@@ -457,8 +457,8 @@ def _normalize_archs(
457457

458458
# make sure Transformers backend is put at the last as a fallback
459459
if len(normalized_arch) != len(architectures):
460-
# The order matters. If causal comes first, checks on MM model fails because it is not registered in MultimodalRegistry
461-
# TODO: needs help from vLLM team
460+
# The order matters. If the CausalLM comes first, then checks for
461+
# registered model in MultimodalRegistry fail
462462
normalized_arch.extend(
463463
["TransformersForMultimodalLM", "TransformersForCausalLM"])
464464
return normalized_arch

vllm/model_executor/models/transformers.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@
4141
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
4242
from vllm.model_executor.sampling_metadata import SamplingMetadata
4343
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
44-
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalInputs,
45-
PlaceholderRange, MultiModalDataDict)
44+
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
45+
MultiModalInputs, PlaceholderRange)
4646
from vllm.multimodal.parse import ImageProcessorItems
4747
from vllm.multimodal.processing import (BaseMultiModalProcessor,
4848
BaseProcessingInfo)
49-
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
49+
from vllm.multimodal.profiling import BaseDummyInputsBuilder
5050
from vllm.sequence import IntermediateTensors
5151
from vllm.transformers_utils.processor import cached_get_processor
5252

@@ -124,8 +124,9 @@ def replace_linear_class(
124124
@contextmanager
125125
def init_on_device_without_buffers(device: torch.device):
126126
"""
127-
A context manager under which models are initialized with all parameters on the specified device.
128-
However buffers are not initialized on specified device.
127+
A context manager under which models are initialized with all
128+
parameters on the specified device. However buffers are not
129+
initialized on specified device.
129130
130131
Args:
131132
device (`torch.device`):
@@ -162,8 +163,7 @@ def wrapper(*args, **kwargs):
162163
yield
163164
finally:
164165
nn.Module.register_parameter = old_register_parameter
165-
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(
166-
):
166+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
167167
setattr(torch, torch_function_name, old_torch_function)
168168

169169

@@ -216,7 +216,7 @@ def get_dummy_mm_data(
216216

217217
target_width, target_height = self.info.get_max_image_size()
218218

219-
return {
219+
return {
220220
"image":
221221
self._get_dummy_images(width=target_width,
222222
height=target_height,
@@ -253,13 +253,11 @@ def _get_mm_fields_config(
253253
hf_processor_mm_kwargs,
254254
num_image_patches: torch.Tensor = None,
255255
):
256-
hf_inputs.pop(
257-
"attention_mask",
258-
None) # processors always return a mask but vLLM doesn't need it
256+
# HF Processors always return a mask but vLLM doesn't need it
257+
hf_inputs.pop("attention_mask", None)
259258
mm_fields = {
260-
key: MultiModalFieldConfig.flat_from_sizes("image",
261-
num_image_patches)
262-
for key in hf_inputs.keys()
259+
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
260+
for key in hf_inputs
263261
}
264262
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
265263
"image", num_image_patches)
@@ -311,13 +309,17 @@ def apply(
311309
"""
312310
if return_mm_hashes:
313311
raise ValueError(
314-
"TransformersMultimodalLM doesn't support mm hashing yet! Probably you did not set "
315-
"`disable_mm_preprocessor_cache=True`.")
312+
"TransformersMultimodalLM doesn't support mm hashing yet! "
313+
"Probably you did not set `disable_mm_preprocessor_cache=True`")
316314

317315
mm_items = self._to_mm_items(mm_data)
318316
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
319317

320-
prompt_ids, processed_data, mm_token_type_ids = self._apply_hf_processor_text_mm(
318+
(
319+
prompt_ids,
320+
processed_data,
321+
mm_token_type_ids
322+
) = self._apply_hf_processor_text_mm(
321323
prompt_text=prompt,
322324
mm_items=mm_items,
323325
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@@ -435,7 +437,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
435437
config_override = ConfigOverride(
436438
config, sliding_window=config.interleaved_sliding_window)
437439

438-
# Set correct attn impl and init on "meta" to delay allocating GPU tensors
440+
# Set correct attn and init on "meta" to delay allocating GPU tensors
439441
self.text_config._attn_implementation = "vllm"
440442
with init_on_device_without_buffers("meta"):
441443
# FIXME(Isotr0py): We need to refactor this part in the future to
@@ -870,9 +872,9 @@ def get_multimodal_embeddings(self, **kwargs):
870872
if vision_embeddings.ndim == 2:
871873
vision_embeddings = vision_embeddings.unsqueeze(0)
872874

873-
# Embeddings have to be 2D tensors of length `num_images` but transformers
874-
# returns concat tensors if each patch is of different size. We split it back
875-
# to make vLLM assertions happy
875+
# Embeddings have to be 2D tensors of length `num_images`
876+
# but transformers returns concat tensors if each patch
877+
# is of different size. We split it back to make vLLM happy
876878
vision_embeddings = torch.split(vision_embeddings,
877879
num_image_patches.tolist())
878880
vision_embeddings = [

vllm/v1/engine/mm_input_cache.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ class MirroredProcessingCache:
3434

3535
def __init__(self, model_config):
3636
mm_config = model_config.multimodal_config
37-
disable_mm_preprocessor_cache = mm_config is not None and mm_config.disable_mm_preprocessor_cache
37+
disable_mm_preprocessor_cache = (
38+
mm_config is not None and mm_config.disable_mm_preprocessor_cache
39+
)
3840
self.use_cache = not disable_mm_preprocessor_cache
39-
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
40-
MultiModalKwargs)
41+
self.mm_cache = ProcessingCache.get_lru_cache(
42+
VLLM_MM_INPUT_CACHE_GIB, MultiModalKwargs
43+
)
4144

4245
def get_and_update_p0(
4346
self,

0 commit comments

Comments
 (0)