Skip to content

[Model] Support VLMs with transformers backend #13754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
26a9f1b
tmp
zucchini-nlp Feb 19, 2025
a502988
dump
zucchini-nlp Feb 21, 2025
e0b534b
clean up
zucchini-nlp Feb 24, 2025
7e8f0d8
clean up 2
zucchini-nlp Feb 24, 2025
57c2d85
use arbitrary high resolution in dummy inputs
zucchini-nlp Feb 24, 2025
de54bbf
tmp
zucchini-nlp Mar 27, 2025
739216d
Merge remote-tracking branch 'upstream/main' into vlm-transformers
zucchini-nlp Apr 8, 2025
4b4f8b7
still ugly but works with latest processor update
zucchini-nlp Apr 9, 2025
c5aac3e
update
zucchini-nlp May 21, 2025
d26c81b
Merge remote-tracking branch 'upstream/main' into vlm-transformers
zucchini-nlp May 21, 2025
60300c4
fix issues
zucchini-nlp May 21, 2025
0c69ade
update
zucchini-nlp May 29, 2025
66a1a10
Merge remote-tracking branch 'upstream/main' into vlm-transformers
zucchini-nlp May 29, 2025
d36ab67
style
zucchini-nlp May 29, 2025
bf08a9e
need to update dummy builder after rebase
zucchini-nlp May 29, 2025
ba1143a
delet meta to device
zucchini-nlp May 29, 2025
267a57f
add tests
zucchini-nlp May 30, 2025
2c73f88
style
zucchini-nlp Jun 2, 2025
8c1f220
i dont get the style guidelines
zucchini-nlp Jun 2, 2025
8d5d67e
Update vllm/model_executor/models/transformers.py
zucchini-nlp Jun 3, 2025
be850dc
address some comments
zucchini-nlp Jun 3, 2025
e730323
forgot to add `@support_torch_compile` decorator
zucchini-nlp Jun 3, 2025
cfa1998
cant compile yet + clean up commented code
zucchini-nlp Jun 4, 2025
52bda05
fix param dtype
Isotr0py Jun 16, 2025
9aec5ac
Merge remote-tracking branch 'upstream/main' into vlm-transformers
zucchini-nlp Jun 17, 2025
6ef7b35
mention VLMs in the docs
zucchini-nlp Jun 17, 2025
d1e6d95
v0 backward compatibility
Isotr0py Jun 18, 2025
81fccb0
Merge remote-tracking branch upstream/main into vlm-transformers
zucchini-nlp Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ def chat(
]

tokenizer = self.get_tokenizer()

model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
chat_template,
Expand Down
268 changes: 245 additions & 23 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers import AutoModel, PreTrainedModel, LlavaConfig
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from vllm.attention import Attention, AttentionMetadata
Expand All @@ -37,11 +37,16 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import cached_get_processor
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry, MultiModalKwargs
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalInputs, PlaceholderRange

from .interfaces import SupportsQuant
from .interfaces import SupportsQuant, SupportsMultiModal
from .utils import maybe_prefix

Check failure on line 47 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:47:81: E501 Line too long (82 > 80)

logger = init_logger(__name__)

Check failure on line 49 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:49:81: E501 Line too long (92 > 80)


def vllm_flash_attention_forward(
Expand Down Expand Up @@ -119,10 +124,180 @@
)


class TransformersModel(nn.Module, SupportsQuant):
class MultiModalProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
# NOTE: this means we don't check if return config type is same as requested
# VLLM on contrary always checks. In whcih cases we can have different config types tho?
return self.ctx.model_config.hf_config

def get_supported_mm_limits(self):
return {"image": None, "video": None}

def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
return {"image": self.get_max_image_tokens(), "video": 100}

def get_max_image_tokens(self) -> int:
# Is already an attribute in some VLMs and now reason to make it a required attribute
# TODO: @raushan add it for all VLM configs
return self.get_hf_config().image_seq_length

def get_hf_processor(self):
processor = cached_get_processor(self.ctx.model_config.model)
return processor


class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder):
def get_dummy_processor_inputs(
self,
seq_len,
mm_counts,
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
num_frames = 8

processor = self.info.get_hf_processor()
image_token = getattr(processor, "image_token", None)
video_token = getattr(processor, "video_token", None)

# TODO: raushan, we can have processor attr for `processor.max_output_size` which will infer
# max features for model in HF side. But imo we can just set a veru high resolution
# and the processor will return us pixels with correct max shape. Resolution 3kx3k is high enough

Check failure on line 165 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:165:81: E501 Line too long (84 > 80)
target_width = target_height = 3000

Check failure on line 166 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:166:81: E501 Line too long (96 > 80)

# NOTE: we can pass videos/images/audio to any processor With the new API used in MLLMs,
# HF processor will take the modality needed for model and ignore all others
mm_data = {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images
),
"video": self._get_dummy_videos(

Check failure on line 176 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:176:81: E501 Line too long (93 > 80)
width=target_width,
height=target_height,
num_frames=num_frames,
num_videos=num_videos,
)
}

prompt_text = video_token*num_videos if video_token is not None else image_token*num_images
return ProcessorInputs(
prompt_text=prompt_text,
mm_data=mm_data,
)


class MultiModalProcessor(BaseMultiModalProcessor):
def _get_prompt_replacements(
self,
mm_items,
hf_processor_mm_kwargs,
out_mm_kwargs: MultiModalKwargs,
):
return

def _get_mm_fields_config(

Check failure on line 200 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:200:81: E501 Line too long (100 > 80)
self,

Check failure on line 201 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:201:81: E501 Line too long (91 > 80)
hf_inputs,

Check failure on line 202 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:202:81: E501 Line too long (105 > 80)
hf_processor_mm_kwargs,
):
return dict(

Check failure on line 205 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:205:81: E501 Line too long (96 > 80)
pixel_values=MultiModalFieldConfig.batched("image"),

Check failure on line 206 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:206:81: E501 Line too long (84 > 80)
mm_token_type_ids=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.batched("video"),
image_embeds=MultiModalFieldConfig.batched("image"),
video_embeds=MultiModalFieldConfig.batched("video"),
)

def _apply_hf_processor_text_mm(
self,
prompt_text,
mm_items,
hf_processor_mm_kwargs,
):
"""
Apply the HF processor on the prompt text and multi-modal data
together.

In addition, return whether prompt replacements have been applied.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True

processed_data = self._call_hf_processor(
prompt=prompt_text,
mm_data=processor_data,
mm_kwargs=hf_processor_mm_kwargs,
)
processed_data.update(passthrough_data)

prompt_ids, = processed_data.pop("input_ids").tolist()
mm_token_type_ids = processed_data.pop("mm_token_type_ids")

mm_kwargs = MultiModalKwargs.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)

return prompt_ids, mm_kwargs, mm_token_type_ids

def apply(
self,
prompt,
mm_data,
hf_processor_mm_kwargs,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.

Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
mm_items = self._to_mm_items(mm_data)
prompt_ids, mm_kwargs, mm_token_type_ids = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)

# HF processor will return `mm_token_type_ids` from which
# we can infer mm_placeholders. Until then hardcode to make code run
# Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
mm_positions = torch.where(mm_token_type_ids == 1)[1]
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
mm_tokens_per_modality = hf_processor._get_num_mm_tokens(
image_inputs=mm_kwargs.get_hf_inputs("image"),
video_inputs=mm_kwargs.get_hf_inputs("video"),
)

mm_placeholders = {}
for modality in mm_tokens_per_modality:
split_sizes = mm_tokens_per_modality[modality]
if split_sizes != 0:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
ranges = [
PlaceholderRange(offset=positions[0].item(), length=positions.shape[0])
for positions in chunked_mm_positions
]
mm_placeholders = {modality: ranges}

return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=None,
mm_placeholders=mm_placeholders,
)


@MULTIMODAL_REGISTRY.register_processor(MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder)
class TransformersModel(nn.Module, SupportsQuant, SupportsMultiModal):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
Expand All @@ -132,12 +307,13 @@
cache_config = vllm_config.cache_config

self.config = config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
self.text_config = config.get_text_config()
self.vocab_size = self.text_config.vocab_size
self.unpadded_vocab_size = self.text_config.vocab_size

self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
attn_implementation={"text_config": "vllm", "vision_config": "eager"},
torch_dtype=vllm_config.model_config.dtype,
trust_remote_code=vllm_config.model_config.trust_remote_code,
)
Expand All @@ -150,47 +326,47 @@
tp_size = get_tensor_model_parallel_world_size()
self.attention_instances = [
Attention(
num_heads=divide(config.num_attention_heads, tp_size),
head_size=config.head_dim,
num_heads=divide(self.text_config.num_attention_heads, tp_size),
head_size=self.text_config.head_dim,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward
scale=config.head_dim**-0.5,
num_kv_heads=divide(config.num_key_value_heads, tp_size),
scale=self.text_config.head_dim**-0.5,
num_kv_heads=divide(self.text_config.num_key_value_heads, tp_size),
cache_config=cache_config,
quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
prefix=f"{i}.attn") for i in range(self.text_config.num_hidden_layers)
]

# Model modifications
self.replace_vocab_embed_class(self.model)

# ForCausalLM modifications
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
self.lm_head = ParallelLMHead(self.text_config.vocab_size,
self.text_config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if config.tie_word_embeddings:
if self.text_config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.vocab_size, logit_scale)
self.sampler = get_sampler()

def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
"""
Apply the base model tensor parallelization plan to a module.
Currently only supports linear layers.
"""
if (self.config.base_model_tp_plan is None
if (self.text_config.base_model_tp_plan is None
and get_tensor_model_parallel_world_size() > 1):
raise ValueError(
"Trying to run tensor parallelization but the model does not "
"support it yet!")

for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in self.config.base_model_tp_plan.items():
for pattern, style in self.text_config.base_model_tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_linear_class(child_module, style,
Expand All @@ -204,8 +380,8 @@
# Use native set input embeddings
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
self.text_config.hidden_size,
org_num_embeddings=self.vocab_size,
quant_config=None,
)
log_replacement("input embedding", self.model.get_input_embeddings(),
Expand All @@ -222,7 +398,8 @@
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(
input_ids[None, ...],
input_ids[None, ...] if input_ids is not None else None,
inputs_embeds=inputs_embeds[None, ...] if inputs_embeds is not None else None,
use_cache=False,
position_ids=positions[None, ...],
attn_metadata=attn_metadata,
Expand Down Expand Up @@ -252,11 +429,56 @@
loaded_params = set[str]()
for name, loaded_weight in weights:
if name not in params_dict:
name = f"{self.model.base_model_prefix}.{name}"
# In MLLM the head is usually part of the LM so we might want to strip it
# Very bad workaround, needs smth better
if "lm_head" in name:
name = name.replace("language_model.", "")
else:
name = f"{self.model.base_model_prefix}.{name}"
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

def get_multimodal_embeddings(self, **kwargs):
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None and image_embeds is None:
return None

if pixel_values is not None:
vision_embeddings = self.model.get_image_features(
# Thing about pixels being batched again, adding extra dim
# TODO: find out do we really need that extra dim
pixel_values.flatten(0, 1),
vision_feature_layer=self.config.vision_feature_layer,
vision_feature_select_strategy=self.config.vision_feature_select_strategy,
)
return vision_embeddings

if image_embeds is not None:
return image_embeds

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if multimodal_embeddings is not None:
# most supported VLMs merge like this, otherwise we can add a special
# `merge_multimodal_embeddings` method on HF side
mask = (input_ids == self.config.image_token_index)
mask = mask.unsqueeze(-1).expand_as(inputs_embeds)
multimodal_embeddings = torch.cat(multimodal_embeddings)

# FIXME: The returned multimodal_embeddings must be either a 3D torch.Tensor of shape
# (num_items, feature_size, hidden_size), or a list / tuple of 2D torch.Tensor’s of shape
# (feature_size, hidden_size), so that multimodal_embeddings[i] retrieves the embeddings generated
# from the i-th multimodal data item (e.g, image) of the request.
inputs_embeds = inputs_embeds.masked_scatter(mask, multimodal_embeddings)
return inputs_embeds
9 changes: 9 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,15 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
self._validate_modality("get_items", modality)
return self._items_by_modality[modality]

def get_hf_inputs(self, modality: str) -> dict[str, NestedTensors]:
modality_items = self._items_by_modality.get(modality, None)
hf_inputs = defaultdict[str, list[NestedTensors]](list)
if modality_items is not None:
for mm_kwargs_item in modality_items:
for key, value in mm_kwargs_item.items():
hf_inputs[key].append(value.data)
hf_inputs = {key: torch.stack(value) for key, value in hf_inputs.items()}
return hf_inputs

MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
Expand Down