Skip to content

Commit 9499e26

Browse files
zucchini-nlpIsotr0pyDarkLight1337
authored
[Model] Support VLMs with transformers backend (#20543)
Signed-off-by: raushan <raushan@huggingface.co> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 51ba839 commit 9499e26

File tree

7 files changed

+625
-87
lines changed

7 files changed

+625
-87
lines changed

docs/models/supported_models.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ These models are what we list in [supported-text-models][supported-text-models]
1818

1919
### Transformers
2020

21-
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned!
21+
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs, and require setting `--disable_mm_preprocessor_cache` when running. Support for video inputs and caching of multi-modal preprocessors will be added in future releases.
2222

2323
To check if the modeling backend is Transformers, you can simply do this:
2424

@@ -28,14 +28,17 @@ llm = LLM(model=..., task="generate") # Name or path of your model
2828
llm.apply_model(lambda model: print(type(model)))
2929
```
3030

31-
If it is `TransformersForCausalLM` then it means it's based on Transformers!
31+
If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers!
3232

3333
!!! tip
3434
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md).
3535

3636
!!! note
3737
vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM.
3838

39+
!!! note
40+
In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.
41+
3942
#### Custom models
4043

4144
If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM!
@@ -99,7 +102,7 @@ Here is what happens in the background when this model is loaded:
99102

100103
1. The config is loaded.
101104
2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`.
102-
3. `MyModel` is loaded into `TransformersForCausalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
105+
3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
103106

104107
That's it!
105108

tests/models/multimodal/generation/test_common.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
REQUIRES_V0_MODELS = [
3636
# V1 Test: not enough KV cache space in C1.
3737
"fuyu",
38+
# V1 Test: Deadlock issue when processing mm_inputs
39+
"llava-onevision-transformers",
3840
]
3941

4042
# yapf: disable
@@ -170,6 +172,79 @@
170172
hf_output_post_proc=model_utils.ultravox_trunc_hf_output,
171173
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
172174
),
175+
#### Transformers fallback to test
176+
## To reduce test burden, we only test batching arbitrary image size
177+
# Dynamic image length and number of patches
178+
"llava-onevision-transformers": VLMTestInfo(
179+
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
180+
test_type=VLMTestType.IMAGE,
181+
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
182+
max_model_len=16384,
183+
hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
184+
auto_cls=AutoModelForImageTextToText,
185+
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
186+
image_size_factors=[(0.25, 0.5, 1.0)],
187+
vllm_runner_kwargs={
188+
"model_impl": "transformers",
189+
"disable_mm_preprocessor_cache": True,
190+
"enable_prefix_caching": False,
191+
},
192+
marks=[pytest.mark.core_model],
193+
),
194+
# FIXME(Isotr0py): Enable this test after
195+
# https://github.com/huggingface/transformers/pull/39470 released
196+
# "idefics3-transformers": VLMTestInfo(
197+
# models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
198+
# test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
199+
# prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
200+
# img_idx_to_prompt=lambda idx: "<image>",
201+
# max_model_len=8192,
202+
# max_num_seqs=2,
203+
# auto_cls=AutoModelForImageTextToText,
204+
# hf_output_post_proc=model_utils.idefics3_trunc_hf_output,
205+
# image_size_factors=[(0.25, 0.5, 1.0)],
206+
# vllm_runner_kwargs={
207+
# "model_impl": "transformers",
208+
# "disable_mm_preprocessor_cache": True,
209+
# "enable_prefix_caching": False,
210+
# },
211+
# marks=[pytest.mark.core_model],
212+
# ),
213+
# Pixel values from processor are not 4D or 5D arrays
214+
"qwen2_5_vl-transformers": VLMTestInfo(
215+
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
216+
test_type=VLMTestType.IMAGE,
217+
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
218+
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
219+
max_model_len=4096,
220+
max_num_seqs=2,
221+
auto_cls=AutoModelForImageTextToText,
222+
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
223+
image_size_factors=[(0.25, 0.2, 0.15)],
224+
vllm_runner_kwargs={
225+
"model_impl": "transformers",
226+
"disable_mm_preprocessor_cache": True,
227+
"enable_prefix_caching": False,
228+
},
229+
marks=[large_gpu_mark(min_gb=32)],
230+
),
231+
# Check "auto" with fallback to transformers
232+
"internvl-transformers": VLMTestInfo(
233+
models=["OpenGVLab/InternVL3-1B-hf"],
234+
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
235+
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
236+
img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>",
237+
max_model_len=4096,
238+
use_tokenizer_eos=True,
239+
image_size_factors=[(0.25, 0.5, 1.0)],
240+
vllm_runner_kwargs={
241+
"model_impl": "auto",
242+
"disable_mm_preprocessor_cache": True,
243+
"enable_prefix_caching": False,
244+
},
245+
auto_cls=AutoModelForImageTextToText,
246+
marks=[pytest.mark.core_model],
247+
),
173248
#### Extended model tests
174249
"aria": VLMTestInfo(
175250
models=["rhymes-ai/Aria"],

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ def check_available_online(
499499

500500
_TRANSFORMERS_MODELS = {
501501
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
502+
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
502503
}
503504

504505
_EXAMPLE_MODELS = {

vllm/config.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,10 @@ def __post_init__(self) -> None:
562562

563563
self.task = "embed"
564564

565+
model_info, arch = self.registry.inspect_model_cls(self.architectures)
566+
self._model_info = model_info
567+
self._architecture = arch
568+
565569
all_supported_tasks = self._get_supported_tasks(self.task)
566570
logger.debug("Tasks supported by runner type: %s", all_supported_tasks)
567571
supported_runner_types = self._get_supported_runner_types(
@@ -587,10 +591,6 @@ def __post_init__(self) -> None:
587591
else:
588592
self.truncation_side = "right"
589593

590-
model_info, arch = self.registry.inspect_model_cls(self.architectures)
591-
self._model_info = model_info
592-
self._architecture = arch
593-
594594
self.pooler_config = self._init_pooler_config()
595595

596596
self.dtype = _get_and_verify_dtype(
@@ -674,14 +674,36 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
674674
"max_model_len must be an integer after __post_init__.")
675675
return self
676676

677+
def _get_transformers_backend_cls(self) -> str:
678+
"""Determine which Transformers backend class will be used if
679+
`model_impl` is set to `transformers` or `auto`."""
680+
if self.hf_config != self.hf_text_config:
681+
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
682+
# probably a composite config, i.e. multimodal
683+
return "TransformersForMultimodalLM"
684+
else:
685+
return "TransformersForCausalLM"
686+
677687
@property
678688
def registry(self):
679689
return me_models.ModelRegistry
680690

681691
@property
682692
def architectures(self) -> list[str]:
683693
# architectures in the model config.
684-
return getattr(self.hf_config, "architectures", [])
694+
architectures = getattr(self.hf_config, "architectures", [])
695+
# The registry assumes that it can always inspect the vLLM model class
696+
# for a given architecture. This assumption breaks down for the
697+
# Transformers backend, which may use a different class depending on
698+
# the model type. To work around this, we add the correct Transformers
699+
# backend class to the architectures list. We must do this here because
700+
# we need access to the `hf_config` to determine the backend class.
701+
transformers_backend_cls = self._get_transformers_backend_cls()
702+
if (self.model_impl != ModelImpl.VLLM.value
703+
and all(arch != transformers_backend_cls
704+
for arch in architectures)):
705+
architectures.append(transformers_backend_cls)
706+
return architectures
685707

686708
@property
687709
def architecture(self) -> str:
@@ -827,10 +849,9 @@ def _get_preferred_pooling_task(
827849
("EmbeddingModel", "embed"),
828850
("RewardModel", "reward"),
829851
]
830-
_, arch = self.registry.inspect_model_cls(architectures)
831852

832853
for suffix, pref_task in suffix_to_preferred_task:
833-
if arch.endswith(suffix):
854+
if self.architecture.endswith(suffix):
834855
return pref_task
835856

836857
return "embed"
@@ -944,10 +965,10 @@ def _resolve_runner(
944965
("EmbeddingModel", "pooling"),
945966
("RewardModel", "pooling"),
946967
]
947-
_, arch = self.registry.inspect_model_cls(self.architectures)
948968

949969
for suffix, pref_runner in suffix_to_preferred_runner:
950-
if arch.endswith(suffix) and pref_runner in supported_runner_types:
970+
if self.architecture.endswith(
971+
suffix) and pref_runner in supported_runner_types:
951972
return pref_runner
952973

953974
if "generate" in supported_runner_types:

vllm/model_executor/model_loader/utils.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
as_reward_model,
2626
as_seq_cls_model)
2727
from vllm.model_executor.models.interfaces import SupportsQuant
28+
from vllm.model_executor.models.registry import _TRANSFORMERS_MODELS
2829
from vllm.utils import is_pin_memory_available
2930

3031
logger = init_logger(__name__)
@@ -169,9 +170,22 @@ def device_loading_context(module: torch.nn.Module,
169170

170171
def resolve_transformers_arch(model_config: ModelConfig,
171172
architectures: list[str]):
173+
if model_config.model_impl == ModelImpl.VLLM:
174+
raise ValueError(
175+
"Attempting to resolve architecture from the Transformers library "
176+
"but the model implementation is set to vLLM. This should never "
177+
"happen.")
178+
172179
for i, arch in enumerate(architectures):
173-
if arch == "TransformersForCausalLM":
180+
if arch in _TRANSFORMERS_MODELS:
174181
continue
182+
183+
if model_config.model_impl == ModelImpl.AUTO:
184+
logger.warning(
185+
"%s has no vLLM implementation, falling back to Transformers "
186+
"implementation. Some features may not be supported and "
187+
"performance may not be optimal.", arch)
188+
175189
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
176190
None) or dict()
177191
# Make sure that config class is always initialized before model class,
@@ -199,25 +213,13 @@ def resolve_transformers_arch(model_config: ModelConfig,
199213
"not present in the model config's 'auto_map' (relevant "
200214
"if the model is custom).")
201215
model_module = auto_modules["AutoModel"]
202-
# TODO(Isotr0py): Further clean up these raises.
203-
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
204-
if model_config.model_impl == ModelImpl.TRANSFORMERS:
205-
if not model_module.is_backend_compatible():
206-
raise ValueError(
207-
f"The Transformers implementation of {arch} is not "
208-
"compatible with vLLM.")
209-
architectures[i] = "TransformersForCausalLM"
210-
if model_config.model_impl == ModelImpl.AUTO:
211-
if not model_module.is_backend_compatible():
212-
raise ValueError(
213-
f"{arch} has no vLLM implementation and the Transformers "
214-
"implementation is not compatible with vLLM. Try setting "
215-
"VLLM_USE_V1=0.")
216-
logger.warning(
217-
"%s has no vLLM implementation, falling back to Transformers "
218-
"implementation. Some features may not be supported and "
219-
"performance may not be optimal.", arch)
220-
architectures[i] = "TransformersForCausalLM"
216+
217+
if not model_module.is_backend_compatible():
218+
raise ValueError(
219+
f"The Transformers implementation of '{arch}' is not "
220+
"compatible with vLLM.")
221+
222+
architectures[i] = model_config._get_transformers_backend_cls()
221223
return architectures
222224

223225

@@ -237,8 +239,9 @@ def get_model_architecture(
237239
]
238240

239241
vllm_supported_archs = ModelRegistry.get_supported_archs()
240-
vllm_not_supported = not any(arch in vllm_supported_archs
241-
for arch in architectures)
242+
is_supported = lambda arch: (arch in vllm_supported_archs and arch not in
243+
_TRANSFORMERS_MODELS)
244+
vllm_not_supported = not any(is_supported(arch) for arch in architectures)
242245

243246
if vllm_not_supported:
244247
# try automatic conversion in adapters.py
@@ -259,7 +262,7 @@ def get_model_architecture(
259262
break
260263

261264
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
262-
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
265+
model_config.model_impl == ModelImpl.AUTO and vllm_not_supported):
263266
architectures = resolve_transformers_arch(model_config, architectures)
264267
logger.debug_once("Resolve transformers arch %s", str(architectures))
265268
elif (model_config.quantization is not None

vllm/model_executor/models/registry.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@
253253
}
254254

255255
_TRANSFORMERS_MODELS = {
256+
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
256257
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
257258
}
258259
# yapf: enable
@@ -504,9 +505,14 @@ def _normalize_archs(
504505
if causal_lm_arch in self.models:
505506
normalized_arch.append(arch)
506507

507-
# make sure Transformers backend is put at the last as a fallback
508-
if len(normalized_arch) != len(architectures):
509-
normalized_arch.append("TransformersForCausalLM")
508+
# NOTE(Isotr0py): Be careful of architectures' order!
509+
# Make sure Transformers backend architecture is at the end of the
510+
# list, otherwise pooling models automatic conversion will fail!
511+
for arch in normalized_arch:
512+
if arch.startswith("TransformersFor"):
513+
normalized_arch.remove(arch)
514+
normalized_arch.append(arch)
515+
510516
return normalized_arch
511517

512518
def inspect_model_cls(

0 commit comments

Comments
 (0)