Skip to content

Commit 4bed167

Browse files
authored
[Model][VLM] Support JinaVL Reranker (#20260)
Signed-off-by: shineran96 <shinewang96@gmail.com>
1 parent b140416 commit 4bed167

File tree

15 files changed

+993
-133
lines changed

15 files changed

+993
-133
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ steps:
282282
- python3 offline_inference/llm_engine_example.py
283283
- python3 offline_inference/audio_language.py --seed 0
284284
- python3 offline_inference/vision_language.py --seed 0
285-
- python3 offline_inference/vision_language_embedding.py --seed 0
285+
- python3 offline_inference/vision_language_pooling.py --seed 0
286286
- python3 offline_inference/vision_language_multi_image.py --seed 0
287287
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
288288
- python3 offline_inference/encoder_decoder.py

docs/models/supported_models.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,14 @@ The following table lists those that are tested in vLLM.
712712

713713
---
714714

715+
#### Scoring
716+
717+
Specified using `--task score`.
718+
719+
| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
720+
|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------|
721+
| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | | | ✅︎ |
722+
715723
## Model Support Policy
716724

717725
At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Here’s how we manage third-party model support:

docs/serving/openai_compatible_server.md

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ The following extra parameters are supported:
537537

538538
### Score API
539539

540-
Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair.
540+
Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence or multimodal pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair.
541541
Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1.
542542

543543
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
@@ -676,6 +676,55 @@ The total number of pairs is `len(text_2)`.
676676
}
677677
```
678678

679+
#### Multi-modal inputs
680+
681+
You can pass multi-modal inputs to scoring models by passing `content` including a list of multi-modal input (image, etc.) in the request. Refer to the examples below for illustration.
682+
683+
=== "JinaVL-Reranker"
684+
685+
To serve the model:
686+
687+
```bash
688+
vllm serve jinaai/jina-reranker-m0
689+
```
690+
691+
Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library:
692+
693+
??? Code
694+
695+
```python
696+
import requests
697+
698+
response = requests.post(
699+
"http://localhost:8000/v1/score",
700+
json={
701+
"model": "jinaai/jina-reranker-m0",
702+
"text_1": "slm markdown",
703+
"text_2": {
704+
"content": [
705+
{
706+
"type": "image_url",
707+
"image_url": {
708+
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
709+
},
710+
},
711+
{
712+
"type": "image_url",
713+
"image_url": {
714+
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
715+
},
716+
},
717+
]
718+
}
719+
},
720+
)
721+
response.raise_for_status()
722+
response_json = response.json()
723+
print("Scoring output:", response_json["data"][0]["score"])
724+
print("Scoring output:", response_json["data"][1]["score"])
725+
```
726+
Full example: <gh-file:examples/online_serving/openai_cross_encoder_score_for_multimodal.py>
727+
679728
#### Extra parameters
680729

681730
The following [pooling parameters][pooling-params] are supported.
@@ -695,8 +744,7 @@ The following extra parameters are supported:
695744
### Re-rank API
696745

697746
Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and
698-
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
699-
a scale of 0 to 1.
747+
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences or multi-modal inputs (image, etc.), on a scale of 0 to 1.
700748

701749
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
702750

examples/offline_inference/vision_language_embedding.py renamed to examples/offline_inference/vision_language_pooling.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""
44
This example shows how to use vLLM for running offline inference with
5-
the correct prompt format on vision language models for multimodal embedding.
5+
the correct prompt format on vision language models for multimodal pooling.
66
77
For most models, the prompt format should follow corresponding examples
88
on HuggingFace model repository.
@@ -15,6 +15,7 @@
1515
from PIL.Image import Image
1616

1717
from vllm import LLM, EngineArgs
18+
from vllm.entrypoints.score_utils import ScoreMultiModalParam
1819
from vllm.multimodal.utils import fetch_image
1920
from vllm.utils import FlexibleArgumentParser
2021

@@ -35,14 +36,22 @@ class TextImageQuery(TypedDict):
3536
image: Image
3637

3738

38-
QueryModality = Literal["text", "image", "text+image"]
39-
Query = Union[TextQuery, ImageQuery, TextImageQuery]
39+
class TextImagesQuery(TypedDict):
40+
modality: Literal["text+images"]
41+
text: str
42+
image: ScoreMultiModalParam
43+
44+
45+
QueryModality = Literal["text", "image", "text+image", "text+images"]
46+
Query = Union[TextQuery, ImageQuery, TextImageQuery, TextImagesQuery]
4047

4148

4249
class ModelRequestData(NamedTuple):
4350
engine_args: EngineArgs
44-
prompt: str
45-
image: Optional[Image]
51+
prompt: Optional[str] = None
52+
image: Optional[Image] = None
53+
query: Optional[str] = None
54+
documents: Optional[ScoreMultiModalParam] = None
4655

4756

4857
def run_e5_v(query: Query) -> ModelRequestData:
@@ -107,6 +116,29 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
107116
)
108117

109118

119+
def run_jinavl_reranker(query: Query) -> ModelRequestData:
120+
if query["modality"] != "text+images":
121+
raise ValueError(f"Unsupported query modality: '{query['modality']}'")
122+
123+
engine_args = EngineArgs(
124+
model="jinaai/jina-reranker-m0",
125+
task="score",
126+
max_model_len=32768,
127+
trust_remote_code=True,
128+
mm_processor_kwargs={
129+
"min_pixels": 3136,
130+
"max_pixels": 602112,
131+
},
132+
limit_mm_per_prompt={"image": 1},
133+
)
134+
135+
return ModelRequestData(
136+
engine_args=engine_args,
137+
query=query["text"],
138+
documents=query["image"],
139+
)
140+
141+
110142
def get_query(modality: QueryModality):
111143
if modality == "text":
112144
return TextQuery(modality="text", text="A dog sitting in the grass")
@@ -128,6 +160,28 @@ def get_query(modality: QueryModality):
128160
),
129161
)
130162

163+
if modality == "text+images":
164+
return TextImagesQuery(
165+
modality="text+images",
166+
text="slm markdown",
167+
image={
168+
"content": [
169+
{
170+
"type": "image_url",
171+
"image_url": {
172+
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
173+
},
174+
},
175+
{
176+
"type": "image_url",
177+
"image_url": {
178+
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
179+
},
180+
},
181+
]
182+
},
183+
)
184+
131185
msg = f"Modality {modality} is not supported."
132186
raise ValueError(msg)
133187

@@ -162,16 +216,31 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
162216
print("-" * 50)
163217

164218

219+
def run_score(model: str, modality: QueryModality, seed: Optional[int]):
220+
query = get_query(modality)
221+
req_data = model_example_map[model](query)
222+
223+
engine_args = asdict(req_data.engine_args) | {"seed": seed}
224+
llm = LLM(**engine_args)
225+
226+
outputs = llm.score(req_data.query, req_data.documents)
227+
228+
print("-" * 30)
229+
print([output.outputs.score for output in outputs])
230+
print("-" * 30)
231+
232+
165233
model_example_map = {
166234
"e5_v": run_e5_v,
167235
"vlm2vec": run_vlm2vec,
236+
"jinavl_reranker": run_jinavl_reranker,
168237
}
169238

170239

171240
def parse_args():
172241
parser = FlexibleArgumentParser(
173242
description="Demo on using vLLM for offline inference with "
174-
"vision language models for multimodal embedding"
243+
"vision language models for multimodal pooling tasks."
175244
)
176245
parser.add_argument(
177246
"--model-name",
@@ -181,6 +250,14 @@ def parse_args():
181250
choices=model_example_map.keys(),
182251
help="The name of the embedding model.",
183252
)
253+
parser.add_argument(
254+
"--task",
255+
"-t",
256+
type=str,
257+
default="embedding",
258+
choices=["embedding", "scoring"],
259+
help="The task type.",
260+
)
184261
parser.add_argument(
185262
"--modality",
186263
type=str,
@@ -198,7 +275,12 @@ def parse_args():
198275

199276

200277
def main(args: Namespace):
201-
run_encode(args.model_name, args.modality, args.seed)
278+
if args.task == "embedding":
279+
run_encode(args.model_name, args.modality, args.seed)
280+
elif args.task == "scoring":
281+
run_score(args.model_name, args.modality, args.seed)
282+
else:
283+
raise ValueError(f"Unsupported task: {args.task}")
202284

203285

204286
if __name__ == "__main__":
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Example online usage of Score API.
5+
6+
Run `vllm serve <model> --task score` to start up the server in vLLM.
7+
"""
8+
9+
import argparse
10+
import pprint
11+
12+
import requests
13+
14+
15+
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
16+
headers = {"User-Agent": "Test Client"}
17+
response = requests.post(api_url, headers=headers, json=prompt)
18+
return response
19+
20+
21+
def parse_args():
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument("--host", type=str, default="localhost")
24+
parser.add_argument("--port", type=int, default=8000)
25+
parser.add_argument("--model", type=str, default="jinaai/jina-reranker-m0")
26+
return parser.parse_args()
27+
28+
29+
def main(args):
30+
api_url = f"http://{args.host}:{args.port}/score"
31+
model_name = args.model
32+
33+
text_1 = "slm markdown"
34+
text_2 = {
35+
"content": [
36+
{
37+
"type": "image_url",
38+
"image_url": {
39+
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
40+
},
41+
},
42+
{
43+
"type": "image_url",
44+
"image_url": {
45+
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
46+
},
47+
},
48+
]
49+
}
50+
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
51+
score_response = post_http_request(prompt=prompt, api_url=api_url)
52+
print("\nPrompt when text_1 is string and text_2 is a image list:")
53+
pprint.pprint(prompt)
54+
print("\nScore Response:")
55+
pprint.pprint(score_response.json())
56+
57+
58+
if __name__ == "__main__":
59+
args = parse_args()
60+
main(args)

0 commit comments

Comments
 (0)