Skip to content

Commit b616f6a

Browse files
authored
[Misc] Small: Fix video loader return type annotations. (#20389)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
1 parent 2e25bb1 commit b616f6a

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

tests/multimodal/test_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,10 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
172172
"num_frames": num_frames,
173173
}})
174174

175-
video_sync = connector.fetch_video(video_url)
176-
video_async = await connector.fetch_video_async(video_url)
177-
assert np.array_equal(video_sync[0], video_async[0])
175+
video_sync, metadata_sync = connector.fetch_video(video_url)
176+
video_async, metadata_async = await connector.fetch_video_async(video_url)
177+
assert np.array_equal(video_sync, video_async)
178+
assert metadata_sync == metadata_async
178179

179180

180181
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.

vllm/multimodal/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def fetch_video(
228228
video_url: str,
229229
*,
230230
image_mode: str = "RGB",
231-
) -> npt.NDArray:
231+
) -> tuple[npt.NDArray, dict[str, Any]]:
232232
"""
233233
Load video from a HTTP or base64 data URL.
234234
"""
@@ -248,7 +248,7 @@ async def fetch_video_async(
248248
video_url: str,
249249
*,
250250
image_mode: str = "RGB",
251-
) -> npt.NDArray:
251+
) -> tuple[npt.NDArray, dict[str, Any]]:
252252
"""
253253
Asynchronously load video from a HTTP or base64 data URL.
254254

vllm/multimodal/video.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from functools import partial
77
from io import BytesIO
88
from pathlib import Path
9+
from typing import Any
910

1011
import numpy as np
1112
import numpy.typing as npt
@@ -57,7 +58,7 @@ class VideoLoader:
5758
def load_bytes(cls,
5859
data: bytes,
5960
num_frames: int = -1,
60-
**kwargs) -> npt.NDArray:
61+
**kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
6162
raise NotImplementedError
6263

6364

@@ -106,7 +107,7 @@ def get_cv2_video_api(self):
106107
def load_bytes(cls,
107108
data: bytes,
108109
num_frames: int = -1,
109-
**kwargs) -> npt.NDArray:
110+
**kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
110111
import cv2
111112

112113
backend = cls().get_cv2_video_api()
@@ -179,12 +180,13 @@ def __init__(
179180
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
180181
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
181182

182-
def load_bytes(self, data: bytes) -> npt.NDArray:
183+
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
183184
return self.video_loader.load_bytes(data,
184185
num_frames=self.num_frames,
185186
**self.kwargs)
186187

187-
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
188+
def load_base64(self, media_type: str,
189+
data: str) -> tuple[npt.NDArray, dict[str, Any]]:
188190
if media_type.lower() == "video/jpeg":
189191
load_frame = partial(
190192
self.image_io.load_base64,
@@ -194,11 +196,11 @@ def load_base64(self, media_type: str, data: str) -> npt.NDArray:
194196
return np.stack([
195197
np.asarray(load_frame(frame_data))
196198
for frame_data in data.split(",")
197-
])
199+
]), {}
198200

199201
return self.load_bytes(base64.b64decode(data))
200202

201-
def load_file(self, filepath: Path) -> npt.NDArray:
203+
def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
202204
with filepath.open("rb") as f:
203205
data = f.read()
204206

0 commit comments

Comments
 (0)