Skip to content

Commit 9a06b55

Browse files
Support passing raw multimodal data to model
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent 7392c45 commit 9a06b55

File tree

3 files changed

+88
-2
lines changed

3 files changed

+88
-2
lines changed

vllm/model_executor/models/interfaces.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,42 @@ def supports_multimodal(
145145

146146
return isinstance(model, SupportsMultiModal)
147147

148+
@runtime_checkable
149+
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
150+
"""The interface required for all multi-modal models."""
151+
152+
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
153+
"""
154+
A flag that indicates this model supports multi-modal inputs and processes
155+
them in their raw form and not embeddings.
156+
157+
Note:
158+
There is no need to redefine this flag if this class is in the
159+
MRO of your model class.
160+
"""
161+
162+
@runtime_checkable
163+
class _SupportsMultiModalWithRawInput(Protocol):
164+
supports_multimodal_raw_input: ClassVar[Literal[True]]
165+
166+
167+
@overload
168+
def supports_multimodal_raw_input(model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
169+
...
170+
171+
172+
@overload
173+
def supports_multimodal_raw_input(model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
174+
...
175+
176+
177+
def supports_multimodal_raw_input(
178+
model: Union[type[object], object]
179+
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], TypeIs[SupportsMultiModalWithRawInput]]:
180+
if isinstance(model, type):
181+
return isinstance(model, _SupportsMultiModalWithRawInput)
182+
183+
return isinstance(model, SupportsMultiModalWithRawInput)
148184

149185
@runtime_checkable
150186
class SupportsScoreTemplate(Protocol):

vllm/model_executor/models/registry.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222

2323
from .interfaces import (has_inner_state, has_noops, is_attention_free,
2424
is_hybrid, supports_cross_encoding,
25-
supports_multimodal, supports_pp,
26-
supports_transcription, supports_v0_only)
25+
supports_multimodal, supports_multimodal_raw_input,
26+
supports_pp, supports_transcription,
27+
supports_v0_only)
2728
from .interfaces_base import is_text_generation_model
2829

2930
logger = init_logger(__name__)
@@ -281,6 +282,7 @@ class _ModelInfo:
281282
is_pooling_model: bool
282283
supports_cross_encoding: bool
283284
supports_multimodal: bool
285+
supports_multimodal_raw_input: bool
284286
supports_pp: bool
285287
has_inner_state: bool
286288
is_attention_free: bool
@@ -298,6 +300,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
298300
is_pooling_model=True, # Can convert any model into a pooling model
299301
supports_cross_encoding=supports_cross_encoding(model),
300302
supports_multimodal=supports_multimodal(model),
303+
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
301304
supports_pp=supports_pp(model),
302305
has_inner_state=has_inner_state(model),
303306
is_attention_free=is_attention_free(model),
@@ -536,6 +539,13 @@ def is_multimodal_model(
536539
) -> bool:
537540
model_cls, _ = self.inspect_model_cls(architectures)
538541
return model_cls.supports_multimodal
542+
543+
def supports_multimodal_raw_input(
544+
self,
545+
architectures: Union[str, list[str]],
546+
) -> bool:
547+
model_cls, _ = self.inspect_model_cls(architectures)
548+
return model_cls.supports_multimodal_raw_input
539549

540550
def is_pp_supported_model(
541551
self,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,46 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
559559
# Refresh batch metadata with any pending updates.
560560
self.input_batch.refresh_metadata()
561561

562+
def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any],
563+
scheduler_output: "SchedulerOutput"):
564+
# Multi-modal data.
565+
if scheduler_output:
566+
multi_modal_kwargs_list = []
567+
for req in scheduler_output.scheduled_new_reqs:
568+
req_mm_inputs = req.mm_inputs
569+
if not isinstance(req_mm_inputs, list):
570+
req_mm_inputs = list(req_mm_inputs)
571+
multi_modal_kwargs_list.extend(req_mm_inputs)
572+
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
573+
else:
574+
# The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data.
575+
dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1)
576+
multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data])
577+
578+
model_kwargs.update(multi_modal_kwargs)
579+
580+
def _maybe_add_model_args(self, num_tokens: int,
581+
model_kwargs: dict[str,Any],
582+
scheduler_output: "SchedulerOutput"=None):
583+
584+
if self.supports_token_type_ids:
585+
model_kwargs["token_type_ids"] =\
586+
self.get_token_type_ids()[:num_tokens]
587+
588+
if self.model_supports_multimodal_raw_input:
589+
self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output)
590+
591+
def _maybe_compute_attn_prefix(
592+
self,
593+
scheduler_output: "SchedulerOutput",
594+
) -> list[int]:
595+
return [0] * len(self.kv_cache_config.kv_cache_groups)
596+
597+
def _maybe_prepare_additional_inputs(self,
598+
scheduler_output: "SchedulerOutput",
599+
token_indices: torch.Tensor):
600+
pass
601+
562602
def _get_cumsum_and_arange(
563603
self,
564604
num_tokens: np.ndarray,

0 commit comments

Comments
 (0)