Skip to content

Commit ad667ce

Browse files
Support passing raw multimodal data to model
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent dbdb7db commit ad667ce

File tree

3 files changed

+88
-2
lines changed

3 files changed

+88
-2
lines changed

vllm/model_executor/models/interfaces.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,42 @@ def supports_multimodal(
129129

130130
return isinstance(model, SupportsMultiModal)
131131

132+
@runtime_checkable
133+
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
134+
"""The interface required for all multi-modal models."""
135+
136+
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
137+
"""
138+
A flag that indicates this model supports multi-modal inputs and processes
139+
them in their raw form and not embeddings.
140+
141+
Note:
142+
There is no need to redefine this flag if this class is in the
143+
MRO of your model class.
144+
"""
145+
146+
@runtime_checkable
147+
class _SupportsMultiModalWithRawInput(Protocol):
148+
supports_multimodal_raw_input: ClassVar[Literal[True]]
149+
150+
151+
@overload
152+
def supports_multimodal_raw_input(model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
153+
...
154+
155+
156+
@overload
157+
def supports_multimodal_raw_input(model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
158+
...
159+
160+
161+
def supports_multimodal_raw_input(
162+
model: Union[type[object], object]
163+
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], TypeIs[SupportsMultiModalWithRawInput]]:
164+
if isinstance(model, type):
165+
return isinstance(model, _SupportsMultiModalWithRawInput)
166+
167+
return isinstance(model, SupportsMultiModalWithRawInput)
132168

133169
@runtime_checkable
134170
class SupportsLoRA(Protocol):

vllm/model_executor/models/registry.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323

2424
from .interfaces import (has_inner_state, has_noops, is_attention_free,
2525
is_hybrid, supports_cross_encoding,
26-
supports_multimodal, supports_pp,
27-
supports_transcription, supports_v0_only)
26+
supports_multimodal, supports_multimodal_raw_input,
27+
supports_pp, supports_transcription,
28+
supports_v0_only)
2829
from .interfaces_base import is_text_generation_model
2930

3031
logger = init_logger(__name__)
@@ -275,6 +276,7 @@ class _ModelInfo:
275276
is_pooling_model: bool
276277
supports_cross_encoding: bool
277278
supports_multimodal: bool
279+
supports_multimodal_raw_input: bool
278280
supports_pp: bool
279281
has_inner_state: bool
280282
is_attention_free: bool
@@ -291,6 +293,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
291293
is_pooling_model=True, # Can convert any model into a pooling model
292294
supports_cross_encoding=supports_cross_encoding(model),
293295
supports_multimodal=supports_multimodal(model),
296+
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
294297
supports_pp=supports_pp(model),
295298
has_inner_state=has_inner_state(model),
296299
is_attention_free=is_attention_free(model),
@@ -527,6 +530,13 @@ def is_multimodal_model(
527530
) -> bool:
528531
model_cls, _ = self.inspect_model_cls(architectures)
529532
return model_cls.supports_multimodal
533+
534+
def supports_multimodal_raw_input(
535+
self,
536+
architectures: Union[str, list[str]],
537+
) -> bool:
538+
model_cls, _ = self.inspect_model_cls(architectures)
539+
return model_cls.supports_multimodal_raw_input
530540

531541
def is_pp_supported_model(
532542
self,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,46 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
556556
# Refresh batch metadata with any pending updates.
557557
self.input_batch.refresh_metadata()
558558

559+
def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any],
560+
scheduler_output: "SchedulerOutput"):
561+
# Multi-modal data.
562+
if scheduler_output:
563+
multi_modal_kwargs_list = []
564+
for req in scheduler_output.scheduled_new_reqs:
565+
req_mm_inputs = req.mm_inputs
566+
if not isinstance(req_mm_inputs, list):
567+
req_mm_inputs = list(req_mm_inputs)
568+
multi_modal_kwargs_list.extend(req_mm_inputs)
569+
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
570+
else:
571+
# The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data.
572+
dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1)
573+
multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data])
574+
575+
model_kwargs.update(multi_modal_kwargs)
576+
577+
def _maybe_add_model_args(self, num_tokens: int,
578+
model_kwargs: dict[str,Any],
579+
scheduler_output: "SchedulerOutput"=None):
580+
581+
if self.supports_token_type_ids:
582+
model_kwargs["token_type_ids"] =\
583+
self.get_token_type_ids()[:num_tokens]
584+
585+
if self.model_supports_multimodal_raw_input:
586+
self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output)
587+
588+
def _maybe_compute_attn_prefix(
589+
self,
590+
scheduler_output: "SchedulerOutput",
591+
) -> list[int]:
592+
return [0] * len(self.kv_cache_config.kv_cache_groups)
593+
594+
def _maybe_prepare_additional_inputs(self,
595+
scheduler_output: "SchedulerOutput",
596+
token_indices: torch.Tensor):
597+
pass
598+
559599
def _get_cumsum_and_arange(
560600
self,
561601
num_tokens: np.ndarray,

0 commit comments

Comments
 (0)