Skip to content

Commit 34c79e0

Browse files
DarkLight1337sfeng33
authored andcommitted
[Core] Move multimodal placeholder from chat utils to model definition (vllm-project#20355)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 678a35f commit 34c79e0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+396
-155
lines changed

docs/contributing/model/multimodal.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@ This document walks you through the steps to extend a basic model so that it acc
1010
It is assumed that you have already implemented the model in vLLM according to [these steps][new-model-basic].
1111
Further update the model as follows:
1212

13+
- Implement [get_placeholder_str][vllm.model_executor.models.interfaces.SupportsMultiModal.get_placeholder_str] to define the placeholder string which is used to represent the multi-modal item in the text prompt. This should be consistent with the chat template of the model.
14+
15+
??? Code
16+
17+
```python
18+
class YourModelForImage2Seq(nn.Module):
19+
...
20+
21+
@classmethod
22+
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
23+
if modality.startswith("image"):
24+
return "<image>"
25+
26+
raise ValueError("Only image modality is supported")
27+
```
28+
1329
- Reserve a keyword parameter in [forward][torch.nn.Module.forward] for each input tensor that corresponds to a multi-modal input, as shown in the following example:
1430

1531
```diff

tests/async_engine/test_async_llm_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class RequestOutput:
3333
class MockModelConfig:
3434
use_async_output_proc = True
3535
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
36-
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
3736

3837

3938
class MockEngine:

tests/engine/test_arg_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -263,26 +263,6 @@ def test_media_io_kwargs_parser(arg, expected):
263263
assert args.media_io_kwargs == expected
264264

265265

266-
@pytest.mark.parametrize(("arg", "expected"), [
267-
(None, dict()),
268-
('{"video":"<|video_placeholder|>"}', {
269-
"video": "<|video_placeholder|>"
270-
}),
271-
('{"video":"<|video_placeholder|>", "image": "<|image_placeholder|>"}', {
272-
"video": "<|video_placeholder|>",
273-
"image": "<|image_placeholder|>"
274-
}),
275-
])
276-
def test_mm_placeholder_str_override_parser(arg, expected):
277-
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
278-
if arg is None:
279-
args = parser.parse_args([])
280-
else:
281-
args = parser.parse_args(["--mm-placeholder-str-override", arg])
282-
283-
assert args.mm_placeholder_str_override == expected
284-
285-
286266
def test_compilation_config():
287267
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
288268

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class MockModelConfig:
4141
encoder_config = None
4242
generation_config: str = "auto"
4343
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
44-
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
4544

4645
def get_diff_sampling_param(self):
4746
return self.diff_sampling_param or {}

vllm/config.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,6 @@ class ModelConfig:
350350
"""Additional args passed to process media inputs, keyed by modalities.
351351
For example, to set num_frames for video, set
352352
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
353-
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
354-
"""Optionally override placeholder string for given modalities."""
355353
use_async_output_proc: bool = True
356354
"""Whether to use async output processor."""
357355
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
@@ -661,7 +659,7 @@ def architecture(self) -> str:
661659
return self._architecture
662660

663661
@property
664-
def model_info(self) -> dict[str, Any]:
662+
def model_info(self):
665663
return self._model_info
666664

667665
def maybe_pull_model_tokenizer_for_s3(self, model: str,
@@ -701,7 +699,6 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
701699
return MultiModalConfig(
702700
limit_per_prompt=self.limit_mm_per_prompt,
703701
media_io_kwargs=self.media_io_kwargs,
704-
mm_placeholder_str_override=self.mm_placeholder_str_override,
705702
mm_processor_kwargs=self.mm_processor_kwargs,
706703
disable_mm_preprocessor_cache=self.
707704
disable_mm_preprocessor_cache)
@@ -3096,9 +3093,6 @@ class MultiModalConfig:
30963093
For example, to set num_frames for video, set
30973094
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
30983095

3099-
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
3100-
"""Optionally override placeholder string for given modalities."""
3101-
31023096
mm_processor_kwargs: Optional[dict[str, object]] = None
31033097
"""
31043098
Overrides for the multi-modal processor obtained from

vllm/engine/arg_utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,6 @@ class EngineArgs:
373373
media_io_kwargs: dict[str, dict[str,
374374
Any]] = get_field(MultiModalConfig,
375375
"media_io_kwargs")
376-
mm_placeholder_str_override: dict[str, str] = \
377-
get_field(MultiModalConfig, "mm_placeholder_str_override")
378376
mm_processor_kwargs: Optional[Dict[str, Any]] = \
379377
MultiModalConfig.mm_processor_kwargs
380378
disable_mm_preprocessor_cache: bool = \
@@ -759,9 +757,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
759757
**multimodal_kwargs["limit_per_prompt"])
760758
multimodal_group.add_argument("--media-io-kwargs",
761759
**multimodal_kwargs["media_io_kwargs"])
762-
multimodal_group.add_argument(
763-
"--mm-placeholder-str-override",
764-
**multimodal_kwargs["mm_placeholder_str_override"])
765760
multimodal_group.add_argument(
766761
"--mm-processor-kwargs",
767762
**multimodal_kwargs["mm_processor_kwargs"])
@@ -987,7 +982,6 @@ def create_model_config(self) -> ModelConfig:
987982
served_model_name=self.served_model_name,
988983
limit_mm_per_prompt=self.limit_mm_per_prompt,
989984
media_io_kwargs=self.media_io_kwargs,
990-
mm_placeholder_str_override=self.mm_placeholder_str_override,
991985
use_async_output_proc=not self.disable_async_output_proc,
992986
config_format=self.config_format,
993987
mm_processor_kwargs=self.mm_processor_kwargs,

vllm/entrypoints/chat_utils.py

Lines changed: 9 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from abc import ABC, abstractmethod
77
from collections import defaultdict, deque
88
from collections.abc import Awaitable, Iterable
9-
from functools import cache, lru_cache, partial
9+
from functools import cached_property, lru_cache, partial
1010
from pathlib import Path
1111
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
1212
cast)
@@ -37,6 +37,8 @@
3737

3838
from vllm.config import ModelConfig
3939
from vllm.logger import init_logger
40+
from vllm.model_executor.model_loader import get_model_cls
41+
from vllm.model_executor.models import SupportsMultiModal
4042
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
4143
from vllm.multimodal.utils import MediaConnector
4244
# yapf: disable
@@ -492,6 +494,10 @@ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
492494
def model_config(self) -> ModelConfig:
493495
return self._model_config
494496

497+
@cached_property
498+
def model_cls(self):
499+
return get_model_cls(self.model_config)
500+
495501
@property
496502
def allowed_local_media_path(self):
497503
return self._model_config.allowed_local_media_path
@@ -500,96 +506,14 @@ def allowed_local_media_path(self):
500506
def mm_registry(self):
501507
return MULTIMODAL_REGISTRY
502508

503-
@staticmethod
504-
@cache
505-
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
506-
return tokenizer.decode(token_index)
507-
508-
def _placeholder_str(self, modality: ModalityStr,
509-
current_count: int) -> Optional[str]:
510-
if modality in self._model_config.mm_placeholder_str_override:
511-
return self._model_config.mm_placeholder_str_override[modality]
512-
513-
# TODO: Let user specify how to insert image tokens into prompt
514-
# (similar to chat template)
515-
hf_config = self._model_config.hf_config
516-
model_type = hf_config.model_type
517-
518-
if modality in ("image", "image_embeds"):
519-
if model_type == "chatglm":
520-
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
521-
if model_type == "glm4v":
522-
return "<|begin_of_image|><|image|><|end_of_image|>"
523-
if model_type in ("phi3_v", "phi4mm"):
524-
return f"<|image_{current_count}|>"
525-
if model_type in ("minicpmo", "minicpmv"):
526-
return "(<image>./</image>)"
527-
if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
528-
"pixtral", "mistral3"):
529-
# These models do not use image tokens in the prompt
530-
return None
531-
if model_type == "qwen":
532-
return f"Picture {current_count}: <img></img>"
533-
if model_type.startswith("llava"):
534-
return self._cached_token_str(self._tokenizer,
535-
hf_config.image_token_index)
536-
537-
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
538-
"internvl_chat", "ovis", "skywork_chat",
539-
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
540-
return "<image>"
541-
if model_type in ("mllama", "llama4"):
542-
return "<|image|>"
543-
if model_type in ("qwen2_vl", "qwen2_5_vl", "keye", "Keye"):
544-
return "<|vision_start|><|image_pad|><|vision_end|>"
545-
if model_type == "qwen2_5_omni":
546-
return "<|vision_start|><|IMAGE|><|vision_end|>"
547-
if model_type == "molmo":
548-
return ""
549-
if model_type == "aria":
550-
return "<|fim_prefix|><|img|><|fim_suffix|>"
551-
if model_type == "gemma3":
552-
return "<start_of_image>"
553-
if model_type == "kimi_vl":
554-
return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501
555-
556-
raise TypeError(f"Unknown {modality} model type: {model_type}")
557-
elif modality == "audio":
558-
if model_type in ("ultravox", "granite_speech"):
559-
return "<|audio|>"
560-
if model_type == "phi4mm":
561-
return f"<|audio_{current_count}|>"
562-
if model_type in ("qwen2_audio", "qwen2_5_omni"):
563-
return (f"Audio {current_count}: "
564-
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
565-
if model_type == "minicpmo":
566-
return "(<audio>./</audio>)"
567-
raise TypeError(f"Unknown model type: {model_type}")
568-
elif modality == "video":
569-
if model_type == "internvl_chat":
570-
return "<video>"
571-
if model_type == "glm4v":
572-
return "<|begin_of_video|><|video|><|end_of_video|>"
573-
if model_type in ("qwen2_vl", "qwen2_5_vl", "keye", "Keye"):
574-
return "<|vision_start|><|video_pad|><|vision_end|>"
575-
if model_type == "qwen2_5_omni":
576-
return "<|vision_start|><|VIDEO|><|vision_end|>"
577-
if model_type in ("minicpmo", "minicpmv"):
578-
return "(<video>./</video>)"
579-
if model_type.startswith("llava"):
580-
return self._cached_token_str(self._tokenizer,
581-
hf_config.video_token_index)
582-
raise TypeError(f"Unknown {modality} model type: {model_type}")
583-
else:
584-
raise TypeError(f"Unknown modality: {modality}")
585-
586509
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
587510
"""
588511
Add a multi-modal item to the current prompt and returns the
589512
placeholder string to use, if any.
590513
"""
591514
mm_registry = self.mm_registry
592515
model_config = self.model_config
516+
model_cls = cast(SupportsMultiModal, self.model_cls)
593517

594518
input_modality = modality.replace("_embeds", "")
595519

@@ -614,7 +538,7 @@ def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
614538

615539
self._items_by_modality[modality].append(item)
616540

617-
return self._placeholder_str(modality, current_count)
541+
return model_cls.get_placeholder_str(modality, current_count)
618542

619543
@abstractmethod
620544
def create_parser(self) -> "BaseMultiModalContentParser":

vllm/entrypoints/openai/speech_to_text.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import time
77
from collections.abc import AsyncGenerator
8+
from functools import cached_property
89
from math import ceil
910
from typing import Callable, Literal, Optional, TypeVar, Union, cast
1011

@@ -24,7 +25,8 @@
2425
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
2526
from vllm.inputs.data import PromptType
2627
from vllm.logger import init_logger
27-
from vllm.model_executor.model_loader.utils import get_model_architecture
28+
from vllm.model_executor.model_loader import get_model_cls
29+
from vllm.model_executor.models import SupportsTranscription
2830
from vllm.outputs import RequestOutput
2931
from vllm.transformers_utils.processor import cached_get_processor
3032
from vllm.utils import PlaceholderModule
@@ -76,24 +78,29 @@ def __init__(
7678
self.model_sr = processor.feature_extractor.sampling_rate
7779
self.hop_length = processor.feature_extractor.hop_length
7880
self.task_type = task_type
79-
self.model_cls, _ = get_model_architecture(model_config)
8081

8182
if self.default_sampling_params:
8283
logger.info(
8384
"Overwriting default completion sampling param with: %s",
8485
self.default_sampling_params)
8586

87+
@cached_property
88+
def model_cls(self):
89+
return get_model_cls(self.model_config)
90+
8691
async def _preprocess_speech_to_text(
8792
self,
8893
request: SpeechToTextRequest,
8994
audio_data: bytes,
9095
) -> tuple[list[PromptType], float]:
96+
model_cls = cast(SupportsTranscription, self.model_cls)
97+
9198
# Validate request
9299
# TODO language should be optional and can be guessed.
93100
# For now we default to en. See
94101
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
95102
lang = request.language or "en"
96-
self.model_cls.validate_language(lang) # type: ignore[attr-defined]
103+
model_cls.validate_language(lang)
97104

98105
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
99106
raise ValueError("Maximum file size exceeded.")
@@ -117,9 +124,8 @@ async def _preprocess_speech_to_text(
117124
},
118125
},
119126
"decoder_prompt":
120-
self.model_cls.
121-
get_decoder_prompt( # type: ignore[attr-defined]
122-
lang, self.task_type, request.prompt)
127+
model_cls.get_decoder_prompt(lang, self.task_type,
128+
request.prompt)
123129
}
124130
prompts.append(cast(PromptType, prompt))
125131
return prompts, duration

vllm/model_executor/model_loader/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
ShardedStateLoader)
1919
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
2020
from vllm.model_executor.model_loader.utils import (
21-
get_architecture_class_name, get_model_architecture)
21+
get_architecture_class_name, get_model_architecture, get_model_cls)
2222

2323

2424
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
@@ -65,6 +65,7 @@ def get_model(*,
6565
"get_model_loader",
6666
"get_architecture_class_name",
6767
"get_model_architecture",
68+
"get_model_cls",
6869
"BaseModelLoader",
6970
"BitsAndBytesModelLoader",
7071
"GGUFModelLoader",

vllm/model_executor/model_loader/tensorizer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from collections.abc import Generator
1414
from dataclasses import dataclass
1515
from functools import partial
16-
from typing import Any, BinaryIO, Optional, Union
16+
from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union
1717

1818
import regex as re
1919
import torch
@@ -24,12 +24,14 @@
2424
import vllm.envs as envs
2525
from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
2626
set_current_vllm_config)
27-
from vllm.engine.arg_utils import EngineArgs
2827
from vllm.logger import init_logger
2928
from vllm.model_executor.layers.vocab_parallel_embedding import (
3029
VocabParallelEmbedding)
3130
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
3231

32+
if TYPE_CHECKING:
33+
from vllm.engine.arg_utils import EngineArgs
34+
3335
try:
3436
from tensorizer import (DecryptionParams, EncryptionParams,
3537
TensorDeserializer, TensorSerializer)
@@ -503,7 +505,7 @@ def serialize_vllm_model(
503505
return model
504506

505507

506-
def tensorize_vllm_model(engine_args: EngineArgs,
508+
def tensorize_vllm_model(engine_args: "EngineArgs",
507509
tensorizer_config: TensorizerConfig,
508510
generate_keyfile: bool = True):
509511
"""Utility to load a model and then serialize it with Tensorizer

0 commit comments

Comments
 (0)