Skip to content

[Model] Support VLMs with transformers backend #13754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
26a9f1b
tmp
zucchini-nlp Feb 19, 2025
a502988
dump
zucchini-nlp Feb 21, 2025
e0b534b
clean up
zucchini-nlp Feb 24, 2025
7e8f0d8
clean up 2
zucchini-nlp Feb 24, 2025
57c2d85
use arbitrary high resolution in dummy inputs
zucchini-nlp Feb 24, 2025
de54bbf
tmp
zucchini-nlp Mar 27, 2025
739216d
Merge remote-tracking branch 'upstream/main' into vlm-transformers
zucchini-nlp Apr 8, 2025
4b4f8b7
still ugly but works with latest processor update
zucchini-nlp Apr 9, 2025
c5aac3e
update
zucchini-nlp May 21, 2025
d26c81b
Merge remote-tracking branch 'upstream/main' into vlm-transformers
zucchini-nlp May 21, 2025
60300c4
fix issues
zucchini-nlp May 21, 2025
0c69ade
update
zucchini-nlp May 29, 2025
66a1a10
Merge remote-tracking branch 'upstream/main' into vlm-transformers
zucchini-nlp May 29, 2025
d36ab67
style
zucchini-nlp May 29, 2025
bf08a9e
need to update dummy builder after rebase
zucchini-nlp May 29, 2025
ba1143a
delet meta to device
zucchini-nlp May 29, 2025
267a57f
add tests
zucchini-nlp May 30, 2025
2c73f88
style
zucchini-nlp Jun 2, 2025
8c1f220
i dont get the style guidelines
zucchini-nlp Jun 2, 2025
8d5d67e
Update vllm/model_executor/models/transformers.py
zucchini-nlp Jun 3, 2025
be850dc
address some comments
zucchini-nlp Jun 3, 2025
e730323
forgot to add `@support_torch_compile` decorator
zucchini-nlp Jun 3, 2025
cfa1998
cant compile yet + clean up commented code
zucchini-nlp Jun 4, 2025
52bda05
fix param dtype
Isotr0py Jun 16, 2025
9aec5ac
Merge remote-tracking branch 'upstream/main' into vlm-transformers
zucchini-nlp Jun 17, 2025
6ef7b35
mention VLMs in the docs
zucchini-nlp Jun 17, 2025
d1e6d95
v0 backward compatibility
Isotr0py Jun 18, 2025
81fccb0
Merge remote-tracking branch upstream/main into vlm-transformers
zucchini-nlp Jul 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ These models are what we list in [supported-text-models][supported-text-models]

### Transformers

vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned!
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported!

To check if the modeling backend is Transformers, you can simply do this:

Expand All @@ -31,14 +31,17 @@ llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(type(model)))
```

If it is `TransformersForCausalLM` then it means it's based on Transformers!
If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers!

!!! tip
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference][offline-inference] or `--model-impl transformers` for the [openai-compatible-server][serving-openai-compatible-server].
You can force the use of `Transformers` model by setting `model_impl="transformers"` for [offline-inference][offline-inference] or `--model-impl transformers` for the [openai-compatible-server][serving-openai-compatible-server].

!!! note
vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM.

!!! note
In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Trasnformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.

#### Custom models

If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM!
Expand Down Expand Up @@ -102,7 +105,7 @@ Here is what happens in the background when this model is loaded:

1. The config is loaded.
2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`.
3. `MyModel` is loaded into `TransformersForCausalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.

That's it!

Expand Down
22 changes: 20 additions & 2 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ argcomplete==3.5.1
# via datamodel-code-generator
arrow==1.3.0
# via isoduration
async-timeout==5.0.1
# via
# aiohttp
# redis
attrs==24.2.0
# via
# aiohttp
Expand Down Expand Up @@ -141,6 +145,11 @@ eval-type-backport==0.2.2
# via mteb
evaluate==0.4.3
# via lm-eval
exceptiongroup==1.3.0
# via
# anyio
# hypothesis
# pytest
fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
Expand Down Expand Up @@ -690,7 +699,6 @@ setuptools==77.0.3
# via
# mamba-ssm
# pytablewriter
# torch
# triton
shellingham==1.5.4
# via typer
Expand Down Expand Up @@ -753,8 +761,13 @@ tokenizers==0.21.1
# via
# -r requirements/test.in
# transformers
toml==0.10.2
# via datamodel-code-generator
tomli==2.2.1
# via schemathesis
# via
# black
# pytest
# schemathesis
tomli-w==1.2.0
# via schemathesis
torch==2.7.0+cu128
Expand Down Expand Up @@ -828,13 +841,18 @@ types-python-dateutil==2.9.0.20241206
# via arrow
typing-extensions==4.12.2
# via
# anyio
# black
# exceptiongroup
# huggingface-hub
# librosa
# mistral-common
# mteb
# multidict
# pqdm
# pydantic
# pydantic-core
# rich
# torch
# typer
# typing-inspection
Expand Down
31 changes: 31 additions & 0 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Optional, Union

import pytest
from transformers import AutoModelForImageTextToText

from vllm.platforms import current_platform

Expand Down Expand Up @@ -72,6 +73,36 @@ def test_models(
model_impl=model_impl)


@pytest.mark.parametrize(
"model,model_impl",
[
# Dynamic image length and number of patches
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "transformers"),
# Has col/row special token between patches
("HuggingFaceTB/SmolVLM-256M-Instruct", "transformers"),
# Pixel values from processor are not 4D or 5D arrays
("Qwen/Qwen2.5-VL-3B-Instruct", "transformers"),
# Check "auto" with fallback to transformers
("BAAI/Emu3-Chat-hf", "auto"),
]
) # no custom code support because custom models don't follow the standard yet!
def test_models_multimodal(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
example_prompts: list[str],
model: str,
model_impl: str,
) -> None:
check_implementation(
hf_runner,
vllm_runner,
example_prompts,
model,
model_impl=model_impl,
kwargs_ref={"auto_cls": AutoModelForImageTextToText},
)


def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None:
prompts, _, _ = prep_prompts(4, (800, 801))
kwargs_ref = {"max_model_len": 8192, "enforce_eager": True}
Expand Down
18 changes: 15 additions & 3 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def device_loading_context(module: torch.nn.Module,
def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersForCausalLM":
if arch in ["TransformersForCausalLM", "TransformersForMultimodalLM"]:
continue
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict()
Expand Down Expand Up @@ -206,7 +206,13 @@ def resolve_transformers_arch(model_config: ModelConfig,
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersForCausalLM"
# Check if text-config is `self`. If not most probably it is
# a composite config, i.e. mutlimodal
if model_config.hf_config.get_text_config(
) != model_config.hf_config:
architectures[i] = "TransformersForMultimodalLM"
else:
architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO:
if not model_module.is_backend_compatible():
raise ValueError(
Expand All @@ -217,7 +223,13 @@ def resolve_transformers_arch(model_config: ModelConfig,
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersForCausalLM"
# Check if text-config is `self`. If not most probably it is
# a composite config, i.e. mutlimodal
if model_config.hf_config.get_text_config(
) != model_config.hf_config:
architectures[i] = "TransformersForMultimodalLM"
else:
architectures[i] = "TransformersForCausalLM"
return architectures


Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@
}

_TRANSFORMERS_MODELS = {
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
}
# yapf: enable
Expand Down Expand Up @@ -469,15 +470,17 @@ def _normalize_archs(

# make sure Transformers backend is put at the last as a fallback
if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersForCausalLM")
# The order matters. If the CausalLM comes first, then checks for
# registered model in MultimodalRegistry fail
normalized_arch.extend(
["TransformersForMultimodalLM", "TransformersForCausalLM"])
return normalized_arch

def inspect_model_cls(
self,
architectures: Union[str, list[str]],
) -> tuple[_ModelInfo, str]:
architectures = self._normalize_archs(architectures)

for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
Expand Down
Loading
Loading