1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ from typing import Union
3
4
4
5
import pytest
5
6
from transformers import AutoModel
6
7
8
+ from vllm .entrypoints .chat_utils import ChatCompletionContentPartImageParam
9
+ from vllm .entrypoints .score_utils import ScoreMultiModalParam
10
+
11
+ from ....conftest import HfRunner , VllmRunner
12
+
7
13
model_name = "jinaai/jina-reranker-m0"
8
14
9
15
mm_processor_kwargs = {
14
20
limit_mm_per_prompt = {"image" : 2 }
15
21
16
22
17
- def vllm_reranker (model_name ,
18
- query ,
19
- documents ,
20
- query_type = "text" ,
21
- doc_type = "text" ):
22
- from vllm import LLM
23
-
24
- model = LLM (
25
- model = model_name ,
26
- task = "score" ,
27
- max_model_len = 32768 ,
28
- mm_processor_kwargs = mm_processor_kwargs ,
29
- limit_mm_per_prompt = limit_mm_per_prompt ,
30
- )
23
+ def vllm_reranker (
24
+ vllm_runner : type [VllmRunner ],
25
+ model_name : str ,
26
+ dtype : str ,
27
+ query_strs : list [str ],
28
+ document_strs : list [str ],
29
+ query_type : str = "text" ,
30
+ doc_type : str = "text" ,
31
+ ):
31
32
32
- def create_image_param (url : str ):
33
+ def create_image_param (url : str ) -> ChatCompletionContentPartImageParam :
33
34
return {"type" : "image_url" , "image_url" : {"url" : f"{ url } " }}
34
35
35
- if query_type == "image" :
36
- query = {"content" : [create_image_param (url ) for url in query ]}
37
-
38
- if doc_type == "image" :
39
- documents = {"content" : [create_image_param (url ) for url in documents ]}
40
-
41
- outputs = model .score (query , documents )
36
+ query : Union [list [str ], ScoreMultiModalParam ]
37
+ if query_type == "text" :
38
+ query = query_strs
39
+ elif query_type == "image" :
40
+ query = ScoreMultiModalParam (
41
+ content = [create_image_param (url ) for url in query_strs ])
42
+
43
+ documents : Union [list [str ], ScoreMultiModalParam ]
44
+ if doc_type == "text" :
45
+ documents = document_strs
46
+ elif doc_type == "image" :
47
+ documents = ScoreMultiModalParam (
48
+ content = [create_image_param (url ) for url in document_strs ])
49
+
50
+ with vllm_runner (
51
+ model_name ,
52
+ task = "score" ,
53
+ dtype = dtype ,
54
+ max_num_seqs = 2 ,
55
+ max_model_len = 2048 ,
56
+ mm_processor_kwargs = mm_processor_kwargs ,
57
+ limit_mm_per_prompt = limit_mm_per_prompt ,
58
+ ) as vllm_model :
59
+ outputs = vllm_model .model .score (query , documents )
42
60
43
61
return [output .outputs .score for output in outputs ]
44
62
45
63
46
- def hf_reranker (model_name ,
47
- query ,
48
- documents ,
49
- query_type = "text" ,
50
- doc_type = "text" ):
51
-
64
+ def hf_reranker (
65
+ hf_runner : type [HfRunner ],
66
+ model_name : str ,
67
+ dtype : str ,
68
+ query_strs : list [str ],
69
+ document_strs : list [str ],
70
+ query_type : str = "text" ,
71
+ doc_type : str = "text" ,
72
+ ):
52
73
checkpoint_to_hf_mapper = {
53
74
"visual." : "model.visual." ,
54
75
"model." : "model.language_model." ,
55
76
}
56
77
57
- model = AutoModel .from_pretrained (
58
- model_name ,
59
- torch_dtype = "auto" ,
60
- trust_remote_code = True ,
61
- key_mapping = checkpoint_to_hf_mapper ).to ("cuda" ).eval ()
78
+ data_pairs = [[query_strs [0 ], d ] for d in document_strs ]
62
79
63
- data_pairs = [[query [0 ], d ] for d in documents ]
64
-
65
- scores = model .compute_score (data_pairs ,
66
- max_length = 2048 ,
67
- query_type = query_type ,
68
- doc_type = doc_type )
69
- return scores
80
+ with hf_runner (
81
+ model_name ,
82
+ dtype = dtype ,
83
+ trust_remote_code = True ,
84
+ auto_cls = AutoModel ,
85
+ model_kwargs = {"key_mapping" : checkpoint_to_hf_mapper },
86
+ ) as hf_model :
87
+ return hf_model .model .compute_score (data_pairs ,
88
+ max_length = 2048 ,
89
+ query_type = query_type ,
90
+ doc_type = doc_type )
70
91
71
92
72
93
# Visual Documents Reranking
73
94
@pytest .mark .parametrize ("model_name" , [model_name ])
74
- def test_model_text_image ( model_name ):
75
-
95
+ @ pytest . mark . parametrize ( "dtype" , [ "half" ])
96
+ def test_model_text_image ( hf_runner , vllm_runner , model_name , dtype ):
76
97
query = ["slm markdown" ]
77
98
documents = [
78
99
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" ,
79
100
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" ,
80
101
]
81
102
82
- hf_outputs = hf_reranker (model_name , query , documents , "text" , "image" )
83
- vllm_outputs = vllm_reranker (model_name , query , documents , "text" , "image" )
103
+ hf_outputs = hf_reranker (hf_runner , model_name , dtype , query , documents ,
104
+ "text" , "image" )
105
+ vllm_outputs = vllm_reranker (vllm_runner , model_name , dtype , query ,
106
+ documents , "text" , "image" )
84
107
85
108
assert hf_outputs [0 ] == pytest .approx (vllm_outputs [0 ], rel = 0.02 )
86
109
assert hf_outputs [1 ] == pytest .approx (vllm_outputs [1 ], rel = 0.02 )
87
110
88
111
89
112
# Textual Documents Reranking
90
113
@pytest .mark .parametrize ("model_name" , [model_name ])
91
- def test_model_text_text ( model_name ):
92
-
114
+ @ pytest . mark . parametrize ( "dtype" , [ "half" ])
115
+ def test_model_text_text ( hf_runner , vllm_runner , model_name , dtype ):
93
116
query = ["slm markdown" ]
94
117
documents = [
95
118
"""We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient
@@ -104,18 +127,19 @@ def test_model_text_text(model_name):
104
127
lower computational requirements.""" , # noqa: E501
105
128
"数据提取么?为什么不用正则啊,你用正则不就全解决了么?" ,
106
129
]
107
-
108
- hf_outputs = hf_reranker (model_name , query , documents , "text" , "text" )
109
- vllm_outputs = vllm_reranker (model_name , query , documents , "text" , "text" )
130
+ hf_outputs = hf_reranker (hf_runner , model_name , dtype , query , documents ,
131
+ "text" , "text" )
132
+ vllm_outputs = vllm_reranker (vllm_runner , model_name , dtype , query ,
133
+ documents , "text" , "text" )
110
134
111
135
assert hf_outputs [0 ] == pytest .approx (vllm_outputs [0 ], rel = 0.02 )
112
136
assert hf_outputs [1 ] == pytest .approx (vllm_outputs [1 ], rel = 0.02 )
113
137
114
138
115
139
# Image Querying for Textual Documents
116
140
@pytest .mark .parametrize ("model_name" , [model_name ])
117
- def test_model_image_text ( model_name ):
118
-
141
+ @ pytest . mark . parametrize ( "dtype" , [ "half" ])
142
+ def test_model_image_text ( hf_runner , vllm_runner , model_name , dtype ):
119
143
query = [
120
144
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
121
145
]
@@ -133,17 +157,19 @@ def test_model_image_text(model_name):
133
157
"数据提取么?为什么不用正则啊,你用正则不就全解决了么?" ,
134
158
]
135
159
136
- hf_outputs = hf_reranker (model_name , query , documents , "image" , "text" )
137
- vllm_outputs = vllm_reranker (model_name , query , documents , "image" , "text" )
160
+ hf_outputs = hf_reranker (hf_runner , model_name , dtype , query , documents ,
161
+ "image" , "text" )
162
+ vllm_outputs = vllm_reranker (vllm_runner , model_name , dtype , query ,
163
+ documents , "image" , "text" )
138
164
139
165
assert hf_outputs [0 ] == pytest .approx (vllm_outputs [0 ], rel = 0.02 )
140
166
assert hf_outputs [1 ] == pytest .approx (vllm_outputs [1 ], rel = 0.02 )
141
167
142
168
143
169
# Image Querying for Image Documents
144
170
@pytest .mark .parametrize ("model_name" , [model_name ])
145
- def test_model_image_image ( model_name ):
146
-
171
+ @ pytest . mark . parametrize ( "dtype" , [ "half" ])
172
+ def test_model_image_image ( hf_runner , vllm_runner , model_name , dtype ):
147
173
query = [
148
174
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
149
175
]
@@ -152,9 +178,10 @@ def test_model_image_image(model_name):
152
178
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" ,
153
179
]
154
180
155
- hf_outputs = hf_reranker (model_name , query , documents , "image" , "image" )
156
- vllm_outputs = vllm_reranker (model_name , query , documents , "image" ,
157
- "image" )
181
+ hf_outputs = hf_reranker (hf_runner , model_name , dtype , query , documents ,
182
+ "image" , "image" )
183
+ vllm_outputs = vllm_reranker (vllm_runner , model_name , dtype , query ,
184
+ documents , "image" , "image" )
158
185
159
186
assert hf_outputs [0 ] == pytest .approx (vllm_outputs [0 ], rel = 0.02 )
160
187
assert hf_outputs [1 ] == pytest .approx (vllm_outputs [1 ], rel = 0.02 )
0 commit comments