From aa2ebecbc1b4138aadc037c6e15d52f45d3fac3a Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 7 Jul 2025 07:05:37 +0200 Subject: [PATCH 1/8] add vision LLMs Signed-off-by: raushan --- docs/models/supported_models.md | 11 +- requirements/test.txt | 22 +- tests/models/test_transformers.py | 31 ++ vllm/model_executor/model_loader/utils.py | 18 +- vllm/model_executor/models/registry.py | 7 +- vllm/model_executor/models/transformers.py | 495 +++++++++++++++++++-- vllm/multimodal/inputs.py | 14 + 7 files changed, 539 insertions(+), 59 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 4c279b795ad..1caac8f8a1d 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -21,7 +21,7 @@ These models are what we list in [supported-text-models][supported-text-models] ### Transformers -vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned! +vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! To check if the modeling backend is Transformers, you can simply do this: @@ -31,14 +31,17 @@ llm = LLM(model=..., task="generate") # Name or path of your model llm.apply_model(lambda model: print(type(model))) ``` -If it is `TransformersForCausalLM` then it means it's based on Transformers! +If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers! !!! tip - You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference][offline-inference] or `--model-impl transformers` for the [openai-compatible-server][serving-openai-compatible-server]. + You can force the use of `Transformers` model by setting `model_impl="transformers"` for [offline-inference][offline-inference] or `--model-impl transformers` for the [openai-compatible-server][serving-openai-compatible-server]. !!! note vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. +!!! note + In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Trasnformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. + #### Custom models If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! @@ -102,7 +105,7 @@ Here is what happens in the background when this model is loaded: 1. The config is loaded. 2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. -3. `MyModel` is loaded into `TransformersForCausalLM` (see ) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. +3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see ) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. That's it! diff --git a/requirements/test.txt b/requirements/test.txt index 16d8ee54adc..e9e7f24e611 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -31,6 +31,10 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration +async-timeout==5.0.1 + # via + # aiohttp + # redis attrs==24.2.0 # via # aiohttp @@ -141,6 +145,11 @@ eval-type-backport==0.2.2 # via mteb evaluate==0.4.3 # via lm-eval +exceptiongroup==1.3.0 + # via + # anyio + # hypothesis + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -690,7 +699,6 @@ setuptools==77.0.3 # via # mamba-ssm # pytablewriter - # torch # triton shellingham==1.5.4 # via typer @@ -753,8 +761,13 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers +toml==0.10.2 + # via datamodel-code-generator tomli==2.2.1 - # via schemathesis + # via + # black + # pytest + # schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0+cu128 @@ -828,13 +841,18 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via + # anyio + # black + # exceptiongroup # huggingface-hub # librosa # mistral-common # mteb + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer # typing-inspection diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b7b99ce41cb..0513975df4c 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -4,6 +4,7 @@ from typing import Any, Optional, Union import pytest +from transformers import AutoModelForImageTextToText from vllm.platforms import current_platform @@ -72,6 +73,36 @@ def test_models( model_impl=model_impl) +@pytest.mark.parametrize( + "model,model_impl", + [ + # Dynamic image length and number of patches + ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "transformers"), + # Has col/row special token between patches + ("HuggingFaceTB/SmolVLM-256M-Instruct", "transformers"), + # Pixel values from processor are not 4D or 5D arrays + ("Qwen/Qwen2.5-VL-3B-Instruct", "transformers"), + # Check "auto" with fallback to transformers + ("BAAI/Emu3-Chat-hf", "auto"), + ] +) # no custom code support because custom models don't follow the standard yet! +def test_models_multimodal( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + example_prompts: list[str], + model: str, + model_impl: str, +) -> None: + check_implementation( + hf_runner, + vllm_runner, + example_prompts, + model, + model_impl=model_impl, + kwargs_ref={"auto_cls": AutoModelForImageTextToText}, + ) + + def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None: prompts, _, _ = prep_prompts(4, (800, 801)) kwargs_ref = {"max_model_len": 8192, "enforce_eager": True} diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 159e7b1e6b0..60250f9edb3 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -170,7 +170,7 @@ def device_loading_context(module: torch.nn.Module, def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]): for i, arch in enumerate(architectures): - if arch == "TransformersForCausalLM": + if arch in ["TransformersForCausalLM", "TransformersForMultimodalLM"]: continue auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", None) or dict() @@ -206,7 +206,13 @@ def resolve_transformers_arch(model_config: ModelConfig, raise ValueError( f"The Transformers implementation of {arch} is not " "compatible with vLLM.") - architectures[i] = "TransformersForCausalLM" + # Check if text-config is `self`. If not most probably it is + # a composite config, i.e. mutlimodal + if model_config.hf_config.get_text_config( + ) != model_config.hf_config: + architectures[i] = "TransformersForMultimodalLM" + else: + architectures[i] = "TransformersForCausalLM" if model_config.model_impl == ModelImpl.AUTO: if not model_module.is_backend_compatible(): raise ValueError( @@ -217,7 +223,13 @@ def resolve_transformers_arch(model_config: ModelConfig, "%s has no vLLM implementation, falling back to Transformers " "implementation. Some features may not be supported and " "performance may not be optimal.", arch) - architectures[i] = "TransformersForCausalLM" + # Check if text-config is `self`. If not most probably it is + # a composite config, i.e. mutlimodal + if model_config.hf_config.get_text_config( + ) != model_config.hf_config: + architectures[i] = "TransformersForMultimodalLM" + else: + architectures[i] = "TransformersForCausalLM" return architectures diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8aefbc206d8..974d115d155 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -244,6 +244,7 @@ } _TRANSFORMERS_MODELS = { + "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), } # yapf: enable @@ -469,7 +470,10 @@ def _normalize_archs( # make sure Transformers backend is put at the last as a fallback if len(normalized_arch) != len(architectures): - normalized_arch.append("TransformersForCausalLM") + # The order matters. If the CausalLM comes first, then checks for + # registered model in MultimodalRegistry fail + normalized_arch.extend( + ["TransformersForMultimodalLM", "TransformersForCausalLM"]) return normalized_arch def inspect_model_cls( @@ -477,7 +481,6 @@ def inspect_model_cls( architectures: Union[str, list[str]], ) -> tuple[_ModelInfo, str]: architectures = self._normalize_archs(architectures) - for arch in architectures: model_info = self._try_inspect_model_cls(arch) if model_info is not None: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 04ee3a454f9..af862dc1abe 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -15,8 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` models""" -from collections.abc import Iterable -from contextlib import nullcontext +from collections.abc import Iterable, Mapping +from contextlib import contextmanager, nullcontext from typing import Literal, Optional, Union import regex as re @@ -41,9 +41,18 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, PlaceholderRange) +from vllm.multimodal.parse import ImageProcessorItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo) +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processor import cached_get_processor -from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, + SupportsQuant) from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, maybe_prefix) @@ -112,6 +121,260 @@ def replace_linear_class( ) +# Copied from `accelerate` +@contextmanager +def init_on_device_without_buffers(device: torch.device): + """ + A context manager under which models are initialized with all + parameters on the specified device. However buffers are not + initialized on specified device. + + Args: + device (`torch.device`): + Device to initialize all parameters on. + """ + + old_register_parameter = nn.Module.register_parameter + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs) + + tensor_constructors_to_patch = {} + + def patch_tensor_constructor(fn): + + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + try: + nn.Module.register_parameter = register_empty_parameter + for torch_function_name in tensor_constructors_to_patch: + setattr( + torch, torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + nn.Module.register_parameter = old_register_parameter + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items( + ): + setattr(torch, torch_function_name, old_torch_function) + + +class MultiModalProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_supported_mm_limits(self): + return {"image": None} + + def get_mm_max_tokens_per_item(self, seq_len, mm_counts): + return {"image": self.get_max_image_tokens()} + + def get_max_image_tokens(self) -> int: + width, height = self.get_max_image_size() + processor = self.get_hf_processor() + mm_processor_kwargs = self.ctx.model_config.mm_processor_kwargs or {} + mm_tokens = processor._get_num_multimodal_tokens( + image_sizes=([height, width], ), **mm_processor_kwargs) + image_tokens = mm_tokens["num_image_tokens"][0] + return image_tokens + + def get_hf_processor(self): + processor = cached_get_processor(self.ctx.model_config.model) + return processor + + def get_max_image_size(self): + return 10_000, 10_000 # hardcode for arbitrary very large size + + +class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + if "gemma3" in processor.__class__.__name__.lower(): + image_token = processor.boi_token + else: + image_token = getattr(processor, "image_token", "") + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_max_image_size() + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class MultiModalProcessor(BaseMultiModalProcessor): + + def _get_prompt_updates( + self, + mm_items, + hf_processor_mm_kwargs, + out_mm_kwargs, + ): + """ + Given the original multi-modal items for this modality + and HF-processed data, output the updates to perform. + + The information returned by this method is used to update token inputs + which bypass the HF processor. It is also used to update the output of + HF processor if the HF process does not apply prompt updates to text + inputs. + + Moreover, this information is critical to determine the token positions + in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` + for each multi-modal item. + """ + return None + + def _get_mm_fields_config( + self, + hf_inputs, + hf_processor_mm_kwargs, + num_image_patches: torch.Tensor = None, + ): + # HF Processors always return a mask but vLLM doesn't need it + hf_inputs.pop("attention_mask", None) + mm_fields = { + key: MultiModalFieldConfig.flat_from_sizes("image", + num_image_patches) + for key in hf_inputs + } + mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( + "image", num_image_patches) + mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") + return mm_fields + + def _apply_hf_processor_text_mm( + self, + prompt_text, + mm_items, + hf_processor_mm_kwargs, + ): + """ + Apply the HF processor on the prompt text and multi-modal data + together. + + In addition, return whether prompt replacements have been applied. + """ + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + processor_data["return_mm_token_type_ids"] = True + + processed_data = self._call_hf_processor( + prompt=prompt_text, + mm_data=processor_data, + mm_kwargs=hf_processor_mm_kwargs, + ) + processed_data.update(passthrough_data) + + prompt_ids, = processed_data.pop("input_ids").tolist() + mm_token_type_ids = processed_data.pop( + "mm_token_type_ids" + ) if "mm_token_type_ids" in processed_data else processed_data.pop( + "token_type_ids") # for gemma3 only + + return prompt_ids, processed_data, mm_token_type_ids + + def apply( + self, + prompt, + mm_data, + hf_processor_mm_kwargs, + return_mm_hashes=False, + ) -> MultiModalInputs: + """ + Process multi-modal inputs to be used in vLLM. + + Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + """ + if return_mm_hashes: + raise ValueError( + "TransformersForMultimodalLM doesn't support mm hashing yet! " + "Probably you didn't set `disable_mm_preprocessor_cache=True`") + + mm_items = self._to_mm_items(mm_data) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + (prompt_ids, processed_data, + mm_token_type_ids) = self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + # HF processor will return `mm_token_type_ids` from which + # we can infer mm_placeholders. Until then hardcode to make code run + # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1 + mm_positions = torch.where(mm_token_type_ids == 1)[1] + images = mm_items.get_items("image", ImageProcessorItems) + mm_processor_kwargs = self.info.ctx.model_config.mm_processor_kwargs or {} + image_sizes = [] + for item_idx in range(len(images)): + image_size = images.get_image_size(item_idx) + image_sizes.append((image_size.height, image_size.width)) + + mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( + image_sizes=image_sizes, **mm_processor_kwargs) + + mm_placeholders = {} + split_sizes = mm_tokens_per_modality["num_image_tokens"] + if split_sizes: + chunked_mm_positions = torch.split(mm_positions, split_sizes) + mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] + chunked_mm_tokens = torch.split(mm_tokens, split_sizes) + ranges = [ + PlaceholderRange( + offset=positions[0].item(), + length=positions.shape[0], + is_embed=(mm_tokens == hf_processor.image_token_id).bool()) + for positions, mm_tokens in zip(chunked_mm_positions, + chunked_mm_tokens) + ] + mm_placeholders = {"image": ranges} + + num_image_patches = torch.tensor( + mm_tokens_per_modality["num_image_patches"] + ) if "num_image_patches" in mm_tokens_per_modality else None + processed_data['num_image_patches'] = num_image_patches + mm_kwargs = MultiModalKwargs.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, + num_image_patches), + ) + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_hashes=None, + mm_placeholders=mm_placeholders, + ) + + class ConfigOverride: """Context manager to temporarily override config attributes.""" @@ -153,6 +416,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config: QuantizationConfig = vllm_config.quant_config self.config = config + self.text_config = config.get_text_config() self.cache_config = cache_config self.device_config = device_config self.model_config = model_config @@ -173,14 +437,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config_override = ConfigOverride( config, sliding_window=config.interleaved_sliding_window) - # Use meta device to delay allocating GPU tensors - with torch.device("meta"), config_override: + # Set correct attn and init on "meta" to delay allocating GPU tensors + self.text_config._attn_implementation = "vllm" + with init_on_device_without_buffers("meta"): # FIXME(Isotr0py): We need to refactor this part in the future to # avoid registering an extra model layer, otherwise we will need a # weights mapper to rename weights. self.model: PreTrainedModel = AutoModel.from_config( config, - attn_implementation="vllm", torch_dtype=model_config.dtype, trust_remote_code=model_config.trust_remote_code, ) @@ -189,27 +453,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.tensor_parallel() # Input embeddings + text_config = config.get_text_config() if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): self.model.set_input_embeddings( VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, + text_config.vocab_size, + text_config.hidden_size, + org_num_embeddings=text_config.vocab_size, quant_config=quant_config, )) # Attention layers self.attention_instances = self.create_attention_instances() - # Initialize buffers (e.g. rotary embedding inverse frequency) - self.init_buffers(self.model) - # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + text_config.hidden_size)) def pipeline_parallel(self): """ @@ -240,14 +502,15 @@ def pipeline_parallel(self): # Layers before module list for name in pp_plan[:module_list_idx]: - if self.pp_group.is_first_rank or (self.config.tie_word_embeddings - and self.pp_group.is_last_rank): + if self.pp_group.is_first_rank or ( + self.text_config.tie_word_embeddings + and self.pp_group.is_last_rank): continue setattr(self.model, name, PPMissingLayer()) # Module list - start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers, - self.pp_rank, self.pp_size) + start_layer, end_layer = get_pp_indices( + self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) layers_name = pp_plan[module_list_idx] layers = getattr(self.model, layers_name) for i in range(len(layers)): @@ -298,7 +561,7 @@ def create_attention_instances(self) -> dict[int, Attention]: self.parallel_config) head_size = self.model_config.get_head_size() num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - start, end = get_pp_indices(self.config.num_hidden_layers, + start, end = get_pp_indices(self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) attention_instances = {} @@ -323,35 +586,6 @@ def create_attention_instances(self) -> dict[int, Attention]: prefix=f"{i}.attn") return attention_instances - def init_buffers(self, module: nn.Module): - """ - If a `buffer` is on the `meta` device, then its parent - `module` is the original module created by: - - ```python - with torch.device("meta"): - self.model: PreTrainedModel = AutoModel.from_config(...) - ``` - - This means that: - - `type(module)` is a class from `transformers` - - This class is constructed using a `PretrainedConfig` - """ - for name, buffer in module.named_buffers(recurse=False): - if buffer.device == torch.device("meta"): - if module == self.model: - logger.warning( - "To initialize buffers correctly, we instantiate the " - "parent module and and extract the value of the " - "buffer from it. In this case, the parent module is " - "the base model. Instantiating the entire model here " - "risks GPU OOM. Could this buffer be moved to a child " - "module?") - new_buffer = getattr(type(module)(self.config), name) - setattr(module, name, new_buffer) - for child in module.children(): - self.init_buffers(child) - def init_parameters(self, module: nn.Module): """ If a `parameter` is on the `meta` device, then its parent @@ -391,11 +625,16 @@ def forward( if inputs_embeds is not None: inputs_embeds = inputs_embeds[None, ...] + if self.model_config.uses_mrope: + positions = positions[:, None] + else: + positions = positions[None, ...] + hidden_states = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, use_cache=False, - position_ids=positions[None, ...], + position_ids=positions, attention_instances=self.attention_instances, return_dict=False)[0][0, ...] # we remove batch dimension for now @@ -507,3 +746,163 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder) +class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, + SupportsPP, SupportsMultiModal): + embedding_padding_modules = ["lm_head"] + embedding_modules = ["embed_tokens"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: PretrainedConfig = vllm_config.model_config.hf_config + quant_config: QuantizationConfig = vllm_config.quant_config + + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix) + text_config = config.get_text_config() + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = text_config.vocab_size + self.lm_head = ParallelLMHead( + text_config.vocab_size, + text_config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if text_config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.get_input_embeddings()) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + text_config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + @property + def hf_to_vllm_mapper(self): + # Backwards compatibility for prev released models + # State dicts back then had different formats + # and cannot be loaded with `AutoModel` mapping + # as is + prefix_mapper = { + "language_model.model": "model.language_model", + "text_model.model": "model.text_model", + "vision_tower": "model.vision_tower", + "vqmodel": "model.vqmodel", + "vision_model": "model.vision_model", + "vision_embed_tokens": "model.vision_embed_tokens", + "image_newline": "model.image_newline", + "multi_modal_projector": "model.multi_modal_projector", + "text_model.lm_head": "lm_head", + "language_model.lm_head": "lm_head", + } + # Don't change the order for QwenVL + if 'Qwen2' in self.config.__class__.__name__: + prefix_mapper["model"] = "model.language_model" + prefix_mapper["visual"] = "model.visual" + + return WeightsMapper(orig_to_new_prefix=prefix_mapper, ) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=([ + "lm_head." + ] if self.config.get_text_config().tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_multimodal_embeddings(self, **kwargs): + pixel_values = kwargs.pop("pixel_values", None) + pixel_values = pixel_values if pixel_values is not None else kwargs.pop( + "image_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + num_image_patches = kwargs.pop("num_image_patches") + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.flatten(0, 1).to(self.dtype) + else: + pixel_values = torch.cat(pixel_values).to(self.dtype) + + if isinstance(num_image_patches, list): + num_image_patches = torch.cat(num_image_patches) + + vision_embeddings = self.model.model.get_image_features( + pixel_values, + **{ + k: v.flatten(0, 1) + for k, v in kwargs.items() + }, + ) + + if isinstance(vision_embeddings, torch.Tensor): + if vision_embeddings.ndim == 2: + vision_embeddings = vision_embeddings.unsqueeze(0) + + # Embeddings have to be 2D tensors of length `num_images` + # but transformers returns concat tensors if each patch + # is of different size. We split it back to make vLLM happy + vision_embeddings = torch.split( + vision_embeddings, + num_image_patches.flatten().tolist()) + vision_embeddings = [ + embed.flatten(start_dim=0, end_dim=-2) + for embed in vision_embeddings + ] + + return vision_embeddings + + if image_embeds is not None: + return image_embeds + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + ) -> torch.Tensor: + inputs_embeds = self.model.model.get_input_embeddings()(input_ids) + if multimodal_embeddings is not None: + mask = (input_ids == self.config.image_token_id) + mask = mask.unsqueeze(-1).expand_as(inputs_embeds) + multimodal_embeddings = torch.cat(multimodal_embeddings) + + inputs_embeds = inputs_embeds.masked_scatter( + mask, multimodal_embeddings) + return inputs_embeds diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 18aae35c6fd..d31844c54b0 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -817,6 +817,20 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: self._validate_modality("get_items", modality) return self._items_by_modality[modality] + def get_hf_inputs(self, modality: str) -> dict[str, NestedTensors]: + modality_items = self._items_by_modality.get(modality, None) + hf_inputs = defaultdict[str, list[NestedTensors]](list) + if modality_items is not None: + for mm_kwargs_item in modality_items: + for key, value in mm_kwargs_item.items(): + hf_inputs[key].append(value.data) + + hf_inputs_as_tensors = { + key: torch.stack(value) + for key, value in hf_inputs.items() + } + return hf_inputs_as_tensors + MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]] """ From 75d8ca7f5e3fc540706c093bd00bee1e9ec2f377 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 16 Jun 2025 18:00:49 +0800 Subject: [PATCH 2/8] fix param dtype Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/transformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index af862dc1abe..1d74a8d291e 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -600,6 +600,7 @@ def init_parameters(self, module: nn.Module): if param.device == torch.device("meta"): new_param = nn.Parameter( torch.empty_like(param.data, + dtype=self.model_config.dtype, device=self.device_config.device)) setattr(module, name, new_param) for child in module.children(): From 5da435898dff53e96c57cf4d69aaa5528fcf911b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 18 Jun 2025 16:22:36 +0800 Subject: [PATCH 3/8] v0 backward compatibility Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/transformers.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 1d74a8d291e..32886ccaebc 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -822,7 +822,17 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + if inputs_embeds is None: + multimodal_embeds = self.get_multimodal_embeddings(**kwargs) + if multimodal_embeds is not None: + inputs_embeds = self.get_input_embeddings(input_ids, multimodal_embeds) + input_ids = None + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output @@ -851,11 +861,11 @@ def get_multimodal_embeddings(self, **kwargs): pixel_values = pixel_values if pixel_values is not None else kwargs.pop( "image_patches", None) image_embeds = kwargs.pop("image_embeds", None) - num_image_patches = kwargs.pop("num_image_patches") if pixel_values is None and image_embeds is None: return None + num_image_patches = kwargs.pop("num_image_patches") if pixel_values is not None: if isinstance(pixel_values, torch.Tensor): pixel_values = pixel_values.flatten(0, 1).to(self.dtype) From f6458bc1a31e58132ef81a37d16866d9974bb2c0 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 15 Jul 2025 15:04:57 +0800 Subject: [PATCH 4/8] fix out-of-date signature Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/transformers.py | 57 ++++++++++++++++------ 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 32886ccaebc..71f539c93ca 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -44,9 +44,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, PlaceholderRange) -from vllm.multimodal.parse import ImageProcessorItems +from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo) + BaseProcessingInfo, ProcessingCache) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processor import cached_get_processor @@ -228,11 +228,31 @@ def get_dummy_mm_data( class MultiModalProcessor(BaseMultiModalProcessor): + def __init__(self, + info: MultiModalProcessingInfo, + dummy_inputs: "BaseDummyInputsBuilder[MultiModalProcessingInfo]", + *, + cache: Optional[ProcessingCache] = None, + ) -> None: + super().__init__( + info=info, + dummy_inputs=dummy_inputs, + cache=cache, + ) + + if self.cache is not None: + logger.warning_once( + "TransformersForMultimodalLM doesn't support mm cache yet! " + "But mm_preprocessor_cache is enabled. Disable it due to the " + "compatibility issue for now." + ) + self.cache = None + def _get_prompt_updates( self, - mm_items, - hf_processor_mm_kwargs, - out_mm_kwargs, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ): """ Given the original multi-modal items for this modality @@ -269,9 +289,10 @@ def _get_mm_fields_config( def _apply_hf_processor_text_mm( self, - prompt_text, - mm_items, - hf_processor_mm_kwargs, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], ): """ Apply the HF processor on the prompt text and multi-modal data @@ -286,6 +307,7 @@ def _apply_hf_processor_text_mm( prompt=prompt_text, mm_data=processor_data, mm_kwargs=hf_processor_mm_kwargs, + tok_kwargs=tokenization_kwargs, ) processed_data.update(passthrough_data) @@ -299,10 +321,11 @@ def _apply_hf_processor_text_mm( def apply( self, - prompt, - mm_data, - hf_processor_mm_kwargs, - return_mm_hashes=False, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + return_mm_hashes: bool=False, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -311,9 +334,14 @@ def apply( outputting token IDs and processed tensors. """ if return_mm_hashes: - raise ValueError( + logger.warning_once( "TransformersForMultimodalLM doesn't support mm hashing yet! " - "Probably you didn't set `disable_mm_preprocessor_cache=True`") + "But mm_preprocessor_cache is enabled. Disable it due to the " + "compatibility issue for now.") + return_mm_hashes = False + + if tokenization_kwargs is None: + tokenization_kwargs = {} mm_items = self._to_mm_items(mm_data) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -323,6 +351,7 @@ def apply( prompt_text=prompt, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, ) # HF processor will return `mm_token_type_ids` from which From cc1b2234a3c3a954da87e7b9669436c493994eee Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 15 Jul 2025 15:10:51 +0800 Subject: [PATCH 5/8] revert auto mm preprocessor Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/transformers.py | 33 ++++------------------ 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 71f539c93ca..5d80afc8f55 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -46,7 +46,7 @@ MultiModalInputs, PlaceholderRange) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache) + BaseProcessingInfo) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processor import cached_get_processor @@ -228,26 +228,6 @@ def get_dummy_mm_data( class MultiModalProcessor(BaseMultiModalProcessor): - def __init__(self, - info: MultiModalProcessingInfo, - dummy_inputs: "BaseDummyInputsBuilder[MultiModalProcessingInfo]", - *, - cache: Optional[ProcessingCache] = None, - ) -> None: - super().__init__( - info=info, - dummy_inputs=dummy_inputs, - cache=cache, - ) - - if self.cache is not None: - logger.warning_once( - "TransformersForMultimodalLM doesn't support mm cache yet! " - "But mm_preprocessor_cache is enabled. Disable it due to the " - "compatibility issue for now." - ) - self.cache = None - def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -325,7 +305,7 @@ def apply( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool=False, + return_mm_hashes: bool = False, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -334,11 +314,9 @@ def apply( outputting token IDs and processed tensors. """ if return_mm_hashes: - logger.warning_once( + raise ValueError( "TransformersForMultimodalLM doesn't support mm hashing yet! " - "But mm_preprocessor_cache is enabled. Disable it due to the " - "compatibility issue for now.") - return_mm_hashes = False + "Probably you didn't set `disable_mm_preprocessor_cache=True`") if tokenization_kwargs is None: tokenization_kwargs = {} @@ -859,7 +837,8 @@ def forward( if inputs_embeds is None: multimodal_embeds = self.get_multimodal_embeddings(**kwargs) if multimodal_embeds is not None: - inputs_embeds = self.get_input_embeddings(input_ids, multimodal_embeds) + inputs_embeds = self.get_input_embeddings( + input_ids, multimodal_embeds) input_ids = None model_output = self.model(input_ids, positions, intermediate_tensors, From 7efbdda7d537d0ecc9e5e5eec521683447b3d34f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 15 Jul 2025 15:11:40 +0800 Subject: [PATCH 6/8] address test.txt Signed-off-by: Isotr0py <2037008807@qq.com> --- requirements/test.txt | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index a5c35549c66..3828efae381 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -31,10 +31,6 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration -async-timeout==5.0.1 - # via - # aiohttp - # redis attrs==24.2.0 # via # aiohttp @@ -145,11 +141,6 @@ eval-type-backport==0.2.2 # via mteb evaluate==0.4.3 # via lm-eval -exceptiongroup==1.3.0 - # via - # anyio - # hypothesis - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -699,6 +690,7 @@ setuptools==77.0.3 # via # mamba-ssm # pytablewriter + # torch # triton shellingham==1.5.4 # via typer @@ -761,13 +753,8 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers -toml==0.10.2 - # via datamodel-code-generator tomli==2.2.1 - # via - # black - # pytest - # schemathesis + # via schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0+cu128 @@ -841,18 +828,13 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via - # anyio - # black - # exceptiongroup # huggingface-hub # librosa # mistral-common # mteb - # multidict # pqdm # pydantic # pydantic-core - # rich # torch # typer # typing-inspection From 2918b6b3def19b542b15a02a9549074310b37ece Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 15 Jul 2025 15:23:10 +0800 Subject: [PATCH 7/8] fix typo and make pre-commiter happy Signed-off-by: Isotr0py <2037008807@qq.com> --- docs/models/supported_models.md | 2 +- vllm/model_executor/models/transformers.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9d1a1fb1574..ae5d929b86a 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -37,7 +37,7 @@ If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it mean vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. !!! note - In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Trasnformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. + In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. #### Custom models diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 5d80afc8f55..08664e91352 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -164,8 +164,8 @@ def wrapper(*args, **kwargs): yield finally: nn.Module.register_parameter = old_register_parameter - for torch_function_name, old_torch_function in tensor_constructors_to_patch.items( - ): + for torch_function_name, old_torch_function in ( + tensor_constructors_to_patch.items()): setattr(torch, torch_function_name, old_torch_function) @@ -197,7 +197,8 @@ def get_max_image_size(self): return 10_000, 10_000 # hardcode for arbitrary very large size -class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder): +class MultiModalDummyInputsBuilder( + BaseDummyInputsBuilder[MultiModalProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -226,7 +227,7 @@ def get_dummy_mm_data( } -class MultiModalProcessor(BaseMultiModalProcessor): +class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): def _get_prompt_updates( self, @@ -337,7 +338,8 @@ def apply( # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1 mm_positions = torch.where(mm_token_type_ids == 1)[1] images = mm_items.get_items("image", ImageProcessorItems) - mm_processor_kwargs = self.info.ctx.model_config.mm_processor_kwargs or {} + mm_processor_kwargs = (self.info.ctx.model_config.mm_processor_kwargs + or {}) image_sizes = [] for item_idx in range(len(images)): image_size = images.get_image_size(item_idx) @@ -446,7 +448,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Set correct attn and init on "meta" to delay allocating GPU tensors self.text_config._attn_implementation = "vllm" - with init_on_device_without_buffers("meta"): + with init_on_device_without_buffers("meta"), config_override: # FIXME(Isotr0py): We need to refactor this part in the future to # avoid registering an extra model layer, otherwise we will need a # weights mapper to rename weights. From e7e6869b52fea67439e751900c83107c47646e3e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 15 Jul 2025 17:41:25 +0800 Subject: [PATCH 8/8] add transformers fallback test with image input Signed-off-by: Isotr0py <2037008807@qq.com> --- .../multimodal/generation/test_common.py | 70 +++++++++++++++++++ tests/models/registry.py | 1 + tests/models/test_transformers.py | 31 -------- 3 files changed, 71 insertions(+), 31 deletions(-) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 98461676aa4..0faf7b698f7 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -170,6 +170,76 @@ hf_output_post_proc=model_utils.ultravox_trunc_hf_output, marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + #### Transformers fallback to test + ## To reduce test burden, we only test batching arbitrary image size + # Dynamic image length and number of patches + "llava-onevision-transformers": VLMTestInfo( + models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + max_model_len=10240, + hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + auto_cls=AutoModelForImageTextToText, + vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + "disable_mm_preprocessor_cache": True, + "enable_prefix_caching": False, + }, + marks=[pytest.mark.core_model], + ), + # Has col/row special token between patches + "idefics3-transformers": VLMTestInfo( + models=["HuggingFaceTB/SmolVLM-256M-Instruct"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501 + img_idx_to_prompt=lambda idx: "", + max_model_len=8192, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + hf_output_post_proc=model_utils.idefics3_trunc_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + "disable_mm_preprocessor_cache": True, + "enable_prefix_caching": False, + }, + marks=[pytest.mark.core_model], + ), + # Pixel values from processor are not 4D or 5D arrays + "qwen2_5_vl-transformers": VLMTestInfo( + models=["Qwen/Qwen2.5-VL-3B-Instruct"], + test_type=VLMTestType.IMAGE, + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + "disable_mm_preprocessor_cache": True, + "enable_prefix_caching": False, + }, + marks=[pytest.mark.core_model], + ), + # Check "auto" with fallback to transformers + "internvl-transformers": VLMTestInfo( + models=["OpenGVLab/InternVL3-1B-hf"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + max_model_len=4096, + use_tokenizer_eos=True, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "auto", + "disable_mm_preprocessor_cache": True, + "enable_prefix_caching": False, + }, + marks=[pytest.mark.core_model], + ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], diff --git a/tests/models/registry.py b/tests/models/registry.py index 9d3fc8a1b1c..a1777eef218 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -476,6 +476,7 @@ def check_available_online( _TRANSFORMERS_MODELS = { "TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 + "TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), } _EXAMPLE_MODELS = { diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 0513975df4c..b7b99ce41cb 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -4,7 +4,6 @@ from typing import Any, Optional, Union import pytest -from transformers import AutoModelForImageTextToText from vllm.platforms import current_platform @@ -73,36 +72,6 @@ def test_models( model_impl=model_impl) -@pytest.mark.parametrize( - "model,model_impl", - [ - # Dynamic image length and number of patches - ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "transformers"), - # Has col/row special token between patches - ("HuggingFaceTB/SmolVLM-256M-Instruct", "transformers"), - # Pixel values from processor are not 4D or 5D arrays - ("Qwen/Qwen2.5-VL-3B-Instruct", "transformers"), - # Check "auto" with fallback to transformers - ("BAAI/Emu3-Chat-hf", "auto"), - ] -) # no custom code support because custom models don't follow the standard yet! -def test_models_multimodal( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - example_prompts: list[str], - model: str, - model_impl: str, -) -> None: - check_implementation( - hf_runner, - vllm_runner, - example_prompts, - model, - model_impl=model_impl, - kwargs_ref={"auto_cls": AutoModelForImageTextToText}, - ) - - def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None: prompts, _, _ = prep_prompts(4, (800, 801)) kwargs_ref = {"max_model_len": 8192, "enforce_eager": True}