Skip to content

Commit b1e5aa3

Browse files
huachenhelihuydhn
authored andcommitted
[Frontend] Support configurable mm placeholder strings & flexible video sampling policies via CLI flags. (vllm-project#20105)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
1 parent ada7b41 commit b1e5aa3

File tree

12 files changed

+199
-29
lines changed

12 files changed

+199
-29
lines changed

tests/async_engine/test_async_llm_engine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import uuid
77
from asyncio import CancelledError
88
from copy import copy
9-
from dataclasses import dataclass
10-
from typing import Optional
9+
from dataclasses import dataclass, field
10+
from typing import Any, Optional
1111

1212
import pytest
1313
import pytest_asyncio
@@ -32,6 +32,8 @@ class RequestOutput:
3232
@dataclass
3333
class MockModelConfig:
3434
use_async_output_proc = True
35+
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
36+
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
3537

3638

3739
class MockEngine:

tests/engine/test_arg_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,58 @@ def test_limit_mm_per_prompt_parser(arg, expected):
231231
assert args.limit_mm_per_prompt == expected
232232

233233

234+
@pytest.mark.parametrize(
235+
("arg", "expected"),
236+
[
237+
(None, dict()),
238+
('{"video": {"num_frames": 123} }', {
239+
"video": {
240+
"num_frames": 123
241+
}
242+
}),
243+
(
244+
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
245+
{
246+
"video": {
247+
"num_frames": 123,
248+
"fps": 1.0,
249+
"foo": "bar"
250+
},
251+
"image": {
252+
"foo": "bar"
253+
}
254+
}),
255+
])
256+
def test_media_io_kwargs_parser(arg, expected):
257+
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
258+
if arg is None:
259+
args = parser.parse_args([])
260+
else:
261+
args = parser.parse_args(["--media-io-kwargs", arg])
262+
263+
assert args.media_io_kwargs == expected
264+
265+
266+
@pytest.mark.parametrize(("arg", "expected"), [
267+
(None, dict()),
268+
('{"video":"<|video_placeholder|>"}', {
269+
"video": "<|video_placeholder|>"
270+
}),
271+
('{"video":"<|video_placeholder|>", "image": "<|image_placeholder|>"}', {
272+
"video": "<|video_placeholder|>",
273+
"image": "<|image_placeholder|>"
274+
}),
275+
])
276+
def test_mm_placeholder_str_override_parser(arg, expected):
277+
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
278+
if arg is None:
279+
args = parser.parse_args([])
280+
else:
281+
args = parser.parse_args(["--mm-placeholder-str-override", arg])
282+
283+
assert args.mm_placeholder_str_override == expected
284+
285+
234286
def test_compilation_config():
235287
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
236288

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import asyncio
55
from contextlib import suppress
6-
from dataclasses import dataclass
7-
from typing import Optional
6+
from dataclasses import dataclass, field
7+
from typing import Any, Optional
88
from unittest.mock import MagicMock
99

1010
from vllm.config import MultiModalConfig
@@ -40,6 +40,8 @@ class MockModelConfig:
4040
allowed_local_media_path: str = ""
4141
encoder_config = None
4242
generation_config: str = "auto"
43+
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
44+
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
4345

4446
def get_diff_sampling_param(self):
4547
return self.diff_sampling_param or {}

tests/multimodal/test_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,14 @@ async def test_fetch_image_error_conversion():
167167
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
168168
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
169169
async def test_fetch_video_http(video_url: str, num_frames: int):
170-
connector = MediaConnector()
170+
connector = MediaConnector(
171+
media_io_kwargs={"video": {
172+
"num_frames": num_frames,
173+
}})
171174

172-
video_sync = connector.fetch_video(video_url, num_frames=num_frames)
173-
video_async = await connector.fetch_video_async(video_url,
174-
num_frames=num_frames)
175-
# Check that the video frames are equal and metadata are same
175+
video_sync = connector.fetch_video(video_url)
176+
video_async = await connector.fetch_video_async(video_url)
176177
assert np.array_equal(video_sync[0], video_async[0])
177-
assert video_sync[1] == video_async[1]
178178

179179

180180
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.

tests/multimodal/test_video.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import numpy.typing as npt
55
import pytest
66

7-
from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader
7+
from vllm import envs
8+
from vllm.multimodal.image import ImageMediaIO
9+
from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader,
10+
VideoMediaIO)
811

912
NUM_FRAMES = 10
1013
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
@@ -40,3 +43,46 @@ def test_video_loader_registry():
4043
def test_video_loader_type_doesnt_exist():
4144
with pytest.raises(AssertionError):
4245
VIDEO_LOADER_REGISTRY.load("non_existing_video_loader")
46+
47+
48+
@VIDEO_LOADER_REGISTRY.register("assert_10_frames_1_fps")
49+
class Assert10Frames1FPSVideoLoader(VideoLoader):
50+
51+
@classmethod
52+
def load_bytes(cls,
53+
data: bytes,
54+
num_frames: int = -1,
55+
fps: float = -1.0,
56+
**kwargs) -> npt.NDArray:
57+
assert num_frames == 10, "bad num_frames"
58+
assert fps == 1.0, "bad fps"
59+
return FAKE_OUTPUT_2
60+
61+
62+
def test_video_media_io_kwargs():
63+
envs.VLLM_VIDEO_LOADER_BACKEND = "assert_10_frames_1_fps"
64+
imageio = ImageMediaIO()
65+
66+
# Verify that different args pass/fail assertions as expected.
67+
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
68+
_ = videoio.load_bytes(b"test")
69+
70+
videoio = VideoMediaIO(
71+
imageio, **{
72+
"num_frames": 10,
73+
"fps": 1.0,
74+
"not_used": "not_used"
75+
})
76+
_ = videoio.load_bytes(b"test")
77+
78+
with pytest.raises(AssertionError, match="bad num_frames"):
79+
videoio = VideoMediaIO(imageio, **{})
80+
_ = videoio.load_bytes(b"test")
81+
82+
with pytest.raises(AssertionError, match="bad num_frames"):
83+
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
84+
_ = videoio.load_bytes(b"test")
85+
86+
with pytest.raises(AssertionError, match="bad fps"):
87+
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
88+
_ = videoio.load_bytes(b"test")

vllm/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,12 @@ class ModelConfig:
346346
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
347347
"""Maximum number of data items per modality per prompt. Only applicable
348348
for multimodal models."""
349+
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
350+
"""Additional args passed to process media inputs, keyed by modalities.
351+
For example, to set num_frames for video, set
352+
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
353+
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
354+
"""Optionally override placeholder string for given modalities."""
349355
use_async_output_proc: bool = True
350356
"""Whether to use async output processor."""
351357
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
@@ -694,6 +700,8 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
694700
if self.registry.is_multimodal_model(self.architectures):
695701
return MultiModalConfig(
696702
limit_per_prompt=self.limit_mm_per_prompt,
703+
media_io_kwargs=self.media_io_kwargs,
704+
mm_placeholder_str_override=self.mm_placeholder_str_override,
697705
mm_processor_kwargs=self.mm_processor_kwargs,
698706
disable_mm_preprocessor_cache=self.
699707
disable_mm_preprocessor_cache)
@@ -3063,6 +3071,14 @@ class MultiModalConfig:
30633071
`{"images": 16, "videos": 2}`
30643072
"""
30653073

3074+
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
3075+
"""Additional args passed to process media inputs, keyed by modalities.
3076+
For example, to set num_frames for video, set
3077+
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
3078+
3079+
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
3080+
"""Optionally override placeholder string for given modalities."""
3081+
30663082
mm_processor_kwargs: Optional[dict[str, object]] = None
30673083
"""
30683084
Overrides for the multi-modal processor obtained from

vllm/engine/arg_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,11 @@ class EngineArgs:
369369
get_field(TokenizerPoolConfig, "extra_config")
370370
limit_mm_per_prompt: dict[str, int] = \
371371
get_field(MultiModalConfig, "limit_per_prompt")
372+
media_io_kwargs: dict[str, dict[str,
373+
Any]] = get_field(MultiModalConfig,
374+
"media_io_kwargs")
375+
mm_placeholder_str_override: dict[str, str] = \
376+
get_field(MultiModalConfig, "mm_placeholder_str_override")
372377
mm_processor_kwargs: Optional[Dict[str, Any]] = \
373378
MultiModalConfig.mm_processor_kwargs
374379
disable_mm_preprocessor_cache: bool = \
@@ -745,6 +750,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
745750
)
746751
multimodal_group.add_argument("--limit-mm-per-prompt",
747752
**multimodal_kwargs["limit_per_prompt"])
753+
multimodal_group.add_argument("--media-io-kwargs",
754+
**multimodal_kwargs["media_io_kwargs"])
755+
multimodal_group.add_argument(
756+
"--mm-placeholder-str-override",
757+
**multimodal_kwargs["mm_placeholder_str_override"])
748758
multimodal_group.add_argument(
749759
"--mm-processor-kwargs",
750760
**multimodal_kwargs["mm_processor_kwargs"])
@@ -969,6 +979,8 @@ def create_model_config(self) -> ModelConfig:
969979
enable_prompt_embeds=self.enable_prompt_embeds,
970980
served_model_name=self.served_model_name,
971981
limit_mm_per_prompt=self.limit_mm_per_prompt,
982+
media_io_kwargs=self.media_io_kwargs,
983+
mm_placeholder_str_override=self.mm_placeholder_str_override,
972984
use_async_output_proc=not self.disable_async_output_proc,
973985
config_format=self.config_format,
974986
mm_processor_kwargs=self.mm_processor_kwargs,

vllm/entrypoints/chat_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,9 @@ def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
507507

508508
def _placeholder_str(self, modality: ModalityStr,
509509
current_count: int) -> Optional[str]:
510+
if modality in self._model_config.mm_placeholder_str_override:
511+
return self._model_config.mm_placeholder_str_override[modality]
512+
510513
# TODO: Let user specify how to insert image tokens into prompt
511514
# (similar to chat template)
512515
hf_config = self._model_config.hf_config
@@ -725,6 +728,7 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
725728
self._tracker = tracker
726729

727730
self._connector = MediaConnector(
731+
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
728732
allowed_local_media_path=tracker.allowed_local_media_path,
729733
)
730734

@@ -763,7 +767,7 @@ def parse_input_audio(self, input_audio: InputAudio) -> None:
763767
return self.parse_audio(audio_url)
764768

765769
def parse_video(self, video_url: str) -> None:
766-
video = self._connector.fetch_video(video_url)
770+
video = self._connector.fetch_video(video_url=video_url)
767771

768772
placeholder = self._tracker.add("video", video)
769773
self._add_placeholder(placeholder)
@@ -776,7 +780,8 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
776780

777781
self._tracker = tracker
778782
self._connector = MediaConnector(
779-
allowed_local_media_path=tracker.allowed_local_media_path,
783+
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
784+
allowed_local_media_path=tracker.allowed_local_media_path
780785
)
781786

782787
def parse_image(self, image_url: str) -> None:
@@ -818,7 +823,7 @@ def parse_input_audio(self, input_audio: InputAudio) -> None:
818823
return self.parse_audio(audio_url)
819824

820825
def parse_video(self, video_url: str) -> None:
821-
video = self._connector.fetch_video_async(video_url)
826+
video = self._connector.fetch_video_async(video_url=video_url)
822827

823828
placeholder = self._tracker.add("video", video)
824829
self._add_placeholder(placeholder)

vllm/multimodal/audio.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ def resample(
8383

8484
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
8585

86+
def __init__(self, **kwargs) -> None:
87+
super().__init__()
88+
89+
# `kwargs` contains custom arguments from
90+
# --media-io-kwargs for this modality.
91+
# They can be passed to the underlying
92+
# media loaders (e.g. custom implementations)
93+
# for flexible control.
94+
self.kwargs = kwargs
95+
8696
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
8797
return librosa.load(BytesIO(data), sr=None)
8898

vllm/multimodal/image.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,16 @@ def convert_image_mode(image: Image.Image, to_mode: str):
4444

4545
class ImageMediaIO(MediaIO[Image.Image]):
4646

47-
def __init__(self, *, image_mode: str = "RGB") -> None:
47+
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
4848
super().__init__()
4949

5050
self.image_mode = image_mode
51+
# `kwargs` contains custom arguments from
52+
# --media-io-kwargs for this modality.
53+
# They can be passed to the underlying
54+
# media loaders (e.g. custom implementations)
55+
# for flexible control.
56+
self.kwargs = kwargs
5157

5258
def load_bytes(self, data: bytes) -> Image.Image:
5359
image = Image.open(BytesIO(data))

0 commit comments

Comments
 (0)