Skip to content

Commit 01e9bd5

Browse files
DarkLight1337py-andy-c
authored andcommitted
[CI/Build] Fix OOM issue in Jina-VL test (vllm-project#20907)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent fdae463 commit 01e9bd5

File tree

1 file changed

+85
-58
lines changed

1 file changed

+85
-58
lines changed
Lines changed: 85 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Union
34

45
import pytest
56
from transformers import AutoModel
67

8+
from vllm.entrypoints.chat_utils import ChatCompletionContentPartImageParam
9+
from vllm.entrypoints.score_utils import ScoreMultiModalParam
10+
11+
from ....conftest import HfRunner, VllmRunner
12+
713
model_name = "jinaai/jina-reranker-m0"
814

915
mm_processor_kwargs = {
@@ -14,82 +20,99 @@
1420
limit_mm_per_prompt = {"image": 2}
1521

1622

17-
def vllm_reranker(model_name,
18-
query,
19-
documents,
20-
query_type="text",
21-
doc_type="text"):
22-
from vllm import LLM
23-
24-
model = LLM(
25-
model=model_name,
26-
task="score",
27-
max_model_len=32768,
28-
mm_processor_kwargs=mm_processor_kwargs,
29-
limit_mm_per_prompt=limit_mm_per_prompt,
30-
)
23+
def vllm_reranker(
24+
vllm_runner: type[VllmRunner],
25+
model_name: str,
26+
dtype: str,
27+
query_strs: list[str],
28+
document_strs: list[str],
29+
query_type: str = "text",
30+
doc_type: str = "text",
31+
):
3132

32-
def create_image_param(url: str):
33+
def create_image_param(url: str) -> ChatCompletionContentPartImageParam:
3334
return {"type": "image_url", "image_url": {"url": f"{url}"}}
3435

35-
if query_type == "image":
36-
query = {"content": [create_image_param(url) for url in query]}
37-
38-
if doc_type == "image":
39-
documents = {"content": [create_image_param(url) for url in documents]}
40-
41-
outputs = model.score(query, documents)
36+
query: Union[list[str], ScoreMultiModalParam]
37+
if query_type == "text":
38+
query = query_strs
39+
elif query_type == "image":
40+
query = ScoreMultiModalParam(
41+
content=[create_image_param(url) for url in query_strs])
42+
43+
documents: Union[list[str], ScoreMultiModalParam]
44+
if doc_type == "text":
45+
documents = document_strs
46+
elif doc_type == "image":
47+
documents = ScoreMultiModalParam(
48+
content=[create_image_param(url) for url in document_strs])
49+
50+
with vllm_runner(
51+
model_name,
52+
task="score",
53+
dtype=dtype,
54+
max_num_seqs=2,
55+
max_model_len=2048,
56+
mm_processor_kwargs=mm_processor_kwargs,
57+
limit_mm_per_prompt=limit_mm_per_prompt,
58+
) as vllm_model:
59+
outputs = vllm_model.model.score(query, documents)
4260

4361
return [output.outputs.score for output in outputs]
4462

4563

46-
def hf_reranker(model_name,
47-
query,
48-
documents,
49-
query_type="text",
50-
doc_type="text"):
51-
64+
def hf_reranker(
65+
hf_runner: type[HfRunner],
66+
model_name: str,
67+
dtype: str,
68+
query_strs: list[str],
69+
document_strs: list[str],
70+
query_type: str = "text",
71+
doc_type: str = "text",
72+
):
5273
checkpoint_to_hf_mapper = {
5374
"visual.": "model.visual.",
5475
"model.": "model.language_model.",
5576
}
5677

57-
model = AutoModel.from_pretrained(
58-
model_name,
59-
torch_dtype="auto",
60-
trust_remote_code=True,
61-
key_mapping=checkpoint_to_hf_mapper).to("cuda").eval()
78+
data_pairs = [[query_strs[0], d] for d in document_strs]
6279

63-
data_pairs = [[query[0], d] for d in documents]
64-
65-
scores = model.compute_score(data_pairs,
66-
max_length=2048,
67-
query_type=query_type,
68-
doc_type=doc_type)
69-
return scores
80+
with hf_runner(
81+
model_name,
82+
dtype=dtype,
83+
trust_remote_code=True,
84+
auto_cls=AutoModel,
85+
model_kwargs={"key_mapping": checkpoint_to_hf_mapper},
86+
) as hf_model:
87+
return hf_model.model.compute_score(data_pairs,
88+
max_length=2048,
89+
query_type=query_type,
90+
doc_type=doc_type)
7091

7192

7293
# Visual Documents Reranking
7394
@pytest.mark.parametrize("model_name", [model_name])
74-
def test_model_text_image(model_name):
75-
95+
@pytest.mark.parametrize("dtype", ["half"])
96+
def test_model_text_image(hf_runner, vllm_runner, model_name, dtype):
7697
query = ["slm markdown"]
7798
documents = [
7899
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
79100
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
80101
]
81102

82-
hf_outputs = hf_reranker(model_name, query, documents, "text", "image")
83-
vllm_outputs = vllm_reranker(model_name, query, documents, "text", "image")
103+
hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
104+
"text", "image")
105+
vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
106+
documents, "text", "image")
84107

85108
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
86109
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)
87110

88111

89112
# Textual Documents Reranking
90113
@pytest.mark.parametrize("model_name", [model_name])
91-
def test_model_text_text(model_name):
92-
114+
@pytest.mark.parametrize("dtype", ["half"])
115+
def test_model_text_text(hf_runner, vllm_runner, model_name, dtype):
93116
query = ["slm markdown"]
94117
documents = [
95118
"""We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient
@@ -104,18 +127,19 @@ def test_model_text_text(model_name):
104127
lower computational requirements.""", # noqa: E501
105128
"数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
106129
]
107-
108-
hf_outputs = hf_reranker(model_name, query, documents, "text", "text")
109-
vllm_outputs = vllm_reranker(model_name, query, documents, "text", "text")
130+
hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
131+
"text", "text")
132+
vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
133+
documents, "text", "text")
110134

111135
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
112136
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)
113137

114138

115139
# Image Querying for Textual Documents
116140
@pytest.mark.parametrize("model_name", [model_name])
117-
def test_model_image_text(model_name):
118-
141+
@pytest.mark.parametrize("dtype", ["half"])
142+
def test_model_image_text(hf_runner, vllm_runner, model_name, dtype):
119143
query = [
120144
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
121145
]
@@ -133,17 +157,19 @@ def test_model_image_text(model_name):
133157
"数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
134158
]
135159

136-
hf_outputs = hf_reranker(model_name, query, documents, "image", "text")
137-
vllm_outputs = vllm_reranker(model_name, query, documents, "image", "text")
160+
hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
161+
"image", "text")
162+
vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
163+
documents, "image", "text")
138164

139165
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
140166
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)
141167

142168

143169
# Image Querying for Image Documents
144170
@pytest.mark.parametrize("model_name", [model_name])
145-
def test_model_image_image(model_name):
146-
171+
@pytest.mark.parametrize("dtype", ["half"])
172+
def test_model_image_image(hf_runner, vllm_runner, model_name, dtype):
147173
query = [
148174
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
149175
]
@@ -152,9 +178,10 @@ def test_model_image_image(model_name):
152178
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
153179
]
154180

155-
hf_outputs = hf_reranker(model_name, query, documents, "image", "image")
156-
vllm_outputs = vllm_reranker(model_name, query, documents, "image",
157-
"image")
181+
hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
182+
"image", "image")
183+
vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
184+
documents, "image", "image")
158185

159186
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
160187
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)

0 commit comments

Comments
 (0)