Skip to content

Commit 3b430d0

Browse files
Support passing raw multimodal data to model
1 parent 4bf0214 commit 3b430d0

File tree

3 files changed

+75
-5
lines changed

3 files changed

+75
-5
lines changed

vllm/model_executor/models/interfaces.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,42 @@ def supports_multimodal(
120120

121121
return isinstance(model, SupportsMultiModal)
122122

123+
@runtime_checkable
124+
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
125+
"""The interface required for all multi-modal models."""
126+
127+
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
128+
"""
129+
A flag that indicates this model supports multi-modal inputs and processes
130+
them in their raw form and not embeddings.
131+
132+
Note:
133+
There is no need to redefine this flag if this class is in the
134+
MRO of your model class.
135+
"""
136+
137+
@runtime_checkable
138+
class _SupportsMultiModalWithRawInput(Protocol):
139+
supports_multimodal_raw_input: ClassVar[Literal[True]]
140+
141+
142+
@overload
143+
def supports_multimodal_raw_input(model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
144+
...
145+
146+
147+
@overload
148+
def supports_multimodal_raw_input(model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
149+
...
150+
151+
152+
def supports_multimodal_raw_input(
153+
model: Union[type[object], object]
154+
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], TypeIs[SupportsMultiModalWithRawInput]]:
155+
if isinstance(model, type):
156+
return isinstance(model, _SupportsMultiModalWithRawInput)
157+
158+
return isinstance(model, SupportsMultiModalWithRawInput)
123159

124160
@runtime_checkable
125161
class SupportsLoRA(Protocol):

vllm/model_executor/models/registry.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323

2424
from .interfaces import (has_inner_state, has_noops, is_attention_free,
2525
is_hybrid, supports_cross_encoding,
26-
supports_multimodal, supports_pp,
27-
supports_transcription, supports_v0_only)
26+
supports_multimodal, supports_multimodal_raw_input,
27+
supports_pp, supports_transcription,
28+
supports_v0_only)
2829
from .interfaces_base import is_text_generation_model
2930

3031
logger = init_logger(__name__)
@@ -267,6 +268,7 @@ class _ModelInfo:
267268
is_pooling_model: bool
268269
supports_cross_encoding: bool
269270
supports_multimodal: bool
271+
supports_multimodal_raw_input: bool
270272
supports_pp: bool
271273
has_inner_state: bool
272274
is_attention_free: bool
@@ -283,6 +285,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
283285
is_pooling_model=True, # Can convert any model into a pooling model
284286
supports_cross_encoding=supports_cross_encoding(model),
285287
supports_multimodal=supports_multimodal(model),
288+
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
286289
supports_pp=supports_pp(model),
287290
has_inner_state=has_inner_state(model),
288291
is_attention_free=is_attention_free(model),
@@ -519,6 +522,13 @@ def is_multimodal_model(
519522
) -> bool:
520523
model_cls, _ = self.inspect_model_cls(architectures)
521524
return model_cls.supports_multimodal
525+
526+
def supports_multimodal_raw_input(
527+
self,
528+
architectures: Union[str, list[str]],
529+
) -> bool:
530+
model_cls, _ = self.inspect_model_cls(architectures)
531+
return model_cls.supports_multimodal_raw_input
522532

523533
def is_pp_supported_model(
524534
self,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,10 +552,34 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
552552
if batch_changed or batch_reordered:
553553
self.input_batch.refresh()
554554

555-
def _maybe_add_model_args(self, num_tokens: int,
556-
model_kwargs: dict[str, Any],
555+
def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any],
556+
scheduler_output: "SchedulerOutput"):
557+
# Multi-modal data.
558+
if scheduler_output:
559+
multi_modal_kwargs_list = []
560+
for req in scheduler_output.scheduled_new_reqs:
561+
req_mm_inputs = req.mm_inputs
562+
if not isinstance(req_mm_inputs, list):
563+
req_mm_inputs = list(req_mm_inputs)
564+
multi_modal_kwargs_list.extend(req_mm_inputs)
565+
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
566+
else:
567+
# The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data.
568+
dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1)
569+
multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data])
570+
571+
model_kwargs.update(multi_modal_kwargs)
572+
573+
def _maybe_add_model_args(self, num_tokens: int,
574+
model_kwargs: dict[str,Any],
557575
scheduler_output: "SchedulerOutput"=None):
558-
pass
576+
577+
if self.supports_token_type_ids:
578+
model_kwargs["token_type_ids"] =\
579+
self.get_token_type_ids()[:num_tokens]
580+
581+
if self.model_supports_multimodal_raw_input:
582+
self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output)
559583

560584
def _maybe_compute_attn_prefix(
561585
self,

0 commit comments

Comments
 (0)