Skip to content

Commit 34f3e7f

Browse files
committed
feat: jina support
1 parent 77f77a9 commit 34f3e7f

File tree

5 files changed

+1074
-0
lines changed

5 files changed

+1074
-0
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Benchmark and validate Jina Embeddings V4 against HuggingFace implementation.
5+
6+
This script compares embeddings generated by vLLM vs HuggingFace to ensure
7+
accuracy and measure performance differences.
8+
"""
9+
10+
import argparse
11+
import time
12+
from typing import List, Tuple
13+
14+
import numpy as np
15+
import torch
16+
from PIL import Image
17+
from transformers import AutoModel, AutoProcessor
18+
19+
from vllm import LLM
20+
from vllm.config import PoolerConfig
21+
from vllm.inputs.data import TextPrompt
22+
23+
24+
def create_test_cases() -> List[Tuple[str, str, any]]:
25+
"""Create comprehensive test cases for validation."""
26+
test_cases = []
27+
28+
# Text-only test cases
29+
test_cases.extend([
30+
("text", "Query: What is artificial intelligence?", None),
31+
("text", "Passage: AI is a field of computer science focusing on creating intelligent machines.", None),
32+
("text", "Query: 你好世界", None), # Chinese text
33+
("text", "Passage: " + " ".join(["word"] * 100), None), # Long text
34+
])
35+
36+
# Image test cases
37+
for color in ["red", "green", "blue"]:
38+
img = Image.new('RGB', (224, 224), color=color)
39+
test_cases.append(("image", f"{color} image", img))
40+
41+
# Complex image
42+
complex_img = Image.new('RGB', (224, 224))
43+
pixels = complex_img.load()
44+
for i in range(224):
45+
for j in range(224):
46+
pixels[i, j] = (i % 256, j % 256, (i+j) % 256)
47+
test_cases.append(("image", "complex pattern", complex_img))
48+
49+
return test_cases
50+
51+
52+
def compute_hf_embeddings(
53+
model_name: str,
54+
test_cases: List[Tuple[str, str, any]]
55+
) -> List[torch.Tensor]:
56+
"""Compute embeddings using HuggingFace implementation."""
57+
print("Loading HuggingFace model...")
58+
model = AutoModel.from_pretrained(
59+
model_name,
60+
trust_remote_code=True,
61+
torch_dtype=torch.float16
62+
).cuda().eval()
63+
64+
processor = AutoProcessor.from_pretrained(
65+
model_name,
66+
trust_remote_code=True
67+
)
68+
69+
embeddings = []
70+
71+
print("Computing HuggingFace embeddings...")
72+
start_time = time.time()
73+
74+
for case_type, text, image in test_cases:
75+
if case_type == "text":
76+
inputs = processor(text=text, return_tensors="pt").to("cuda")
77+
else: # image
78+
inputs = processor(
79+
text="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
80+
images=image,
81+
return_tensors="pt"
82+
).to("cuda")
83+
84+
with torch.no_grad():
85+
outputs = model(**inputs)
86+
# Extract embeddings based on model output structure
87+
if hasattr(outputs, 'embeddings'):
88+
embedding = outputs.embeddings[0]
89+
else:
90+
# Fallback to last hidden state with custom pooling
91+
hidden_states = outputs.last_hidden_state[0]
92+
93+
# Apply token-type-aware pooling
94+
input_ids = inputs['input_ids'][0]
95+
vision_mask = (
96+
(input_ids >= 151652) &
97+
(input_ids <= 151653)
98+
)
99+
100+
if vision_mask.any():
101+
embedding = hidden_states[vision_mask].mean(dim=0)
102+
else:
103+
embedding = hidden_states.mean(dim=0)
104+
105+
embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
106+
107+
embeddings.append(embedding.cpu())
108+
109+
hf_time = time.time() - start_time
110+
print(f"HuggingFace processing time: {hf_time:.2f}s")
111+
112+
return embeddings
113+
114+
115+
def compute_vllm_embeddings(
116+
model_name: str,
117+
test_cases: List[Tuple[str, str, any]]
118+
) -> List[torch.Tensor]:
119+
"""Compute embeddings using vLLM implementation."""
120+
print("\nLoading vLLM model...")
121+
model = LLM(
122+
model=model_name,
123+
task="embed",
124+
override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
125+
dtype="float16",
126+
)
127+
128+
embeddings = []
129+
prompts = []
130+
131+
# Prepare prompts
132+
for case_type, text, image in test_cases:
133+
if case_type == "text":
134+
prompt = TextPrompt(prompt=text)
135+
else: # image
136+
prompt = TextPrompt(
137+
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
138+
multi_modal_data={"image": image},
139+
)
140+
prompts.append(prompt)
141+
142+
print("Computing vLLM embeddings...")
143+
start_time = time.time()
144+
145+
# Process all at once for better performance
146+
outputs = model.encode(prompts)
147+
148+
for output in outputs:
149+
# Extract based on token type
150+
if 151652 in output.prompt_token_ids: # VISION_START_TOKEN_ID
151+
img_start = output.prompt_token_ids.index(151652)
152+
img_end = output.prompt_token_ids.index(151653)
153+
embedding_data = output.outputs.data[img_start:img_end + 1]
154+
else:
155+
embedding_data = output.outputs.data
156+
157+
# Pool and normalize
158+
pooled = embedding_data.mean(dim=0, dtype=torch.float32)
159+
normalized = torch.nn.functional.normalize(pooled, p=2, dim=-1)
160+
embeddings.append(normalized.cpu())
161+
162+
vllm_time = time.time() - start_time
163+
print(f"vLLM processing time: {vllm_time:.2f}s")
164+
165+
return embeddings
166+
167+
168+
def compare_embeddings(
169+
hf_embeddings: List[torch.Tensor],
170+
vllm_embeddings: List[torch.Tensor],
171+
test_cases: List[Tuple[str, str, any]]
172+
) -> None:
173+
"""Compare embeddings and report differences."""
174+
print("\n" + "="*60)
175+
print("EMBEDDING COMPARISON RESULTS")
176+
print("="*60)
177+
178+
similarities = []
179+
max_diffs = []
180+
181+
for i, (case_type, desc, _) in enumerate(test_cases):
182+
hf_emb = hf_embeddings[i]
183+
vllm_emb = vllm_embeddings[i]
184+
185+
# Compute cosine similarity
186+
similarity = torch.nn.functional.cosine_similarity(
187+
hf_emb.unsqueeze(0),
188+
vllm_emb.unsqueeze(0)
189+
).item()
190+
191+
# Compute max absolute difference
192+
max_diff = torch.max(torch.abs(hf_emb - vllm_emb)).item()
193+
194+
similarities.append(similarity)
195+
max_diffs.append(max_diff)
196+
197+
print(f"\nTest case {i+1}: {case_type} - {desc[:50]}...")
198+
print(f" Cosine similarity: {similarity:.6f}")
199+
print(f" Max absolute diff: {max_diff:.6f}")
200+
print(f" HF norm: {hf_emb.norm():.6f}, vLLM norm: {vllm_emb.norm():.6f}")
201+
202+
# Flag significant differences
203+
if similarity < 0.99:
204+
print(f" ⚠️ WARNING: Low similarity detected!")
205+
206+
# Summary statistics
207+
print("\n" + "-"*60)
208+
print("SUMMARY STATISTICS")
209+
print("-"*60)
210+
print(f"Average cosine similarity: {np.mean(similarities):.6f}")
211+
print(f"Min cosine similarity: {np.min(similarities):.6f}")
212+
print(f"Max absolute difference: {np.max(max_diffs):.6f}")
213+
214+
# Overall assessment
215+
if np.min(similarities) > 0.99:
216+
print("\n✅ VALIDATION PASSED: vLLM implementation matches HuggingFace")
217+
else:
218+
print("\n❌ VALIDATION FAILED: Significant differences detected")
219+
220+
221+
def main():
222+
parser = argparse.ArgumentParser(
223+
description="Validate Jina Embeddings V4 implementation"
224+
)
225+
parser.add_argument(
226+
"--model",
227+
type=str,
228+
default="jinaai/jina-embeddings-v4-vllm-retrieval",
229+
help="Model name to test"
230+
)
231+
parser.add_argument(
232+
"--skip-hf",
233+
action="store_true",
234+
help="Skip HuggingFace comparison (for performance testing only)"
235+
)
236+
237+
args = parser.parse_args()
238+
239+
# Create test cases
240+
test_cases = create_test_cases()
241+
print(f"Created {len(test_cases)} test cases")
242+
243+
# Compute vLLM embeddings
244+
vllm_embeddings = compute_vllm_embeddings(args.model, test_cases)
245+
246+
if not args.skip_hf:
247+
# Compute HuggingFace embeddings
248+
hf_embeddings = compute_hf_embeddings(args.model, test_cases)
249+
250+
# Compare results
251+
compare_embeddings(hf_embeddings, vllm_embeddings, test_cases)
252+
else:
253+
print("\nSkipping HuggingFace comparison")
254+
print(f"vLLM processed {len(test_cases)} embeddings successfully")
255+
256+
257+
if __name__ == "__main__":
258+
main()
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Example of using Jina Embeddings V4 with vLLM for multimodal embeddings.
5+
6+
This example demonstrates:
7+
1. Text-only embeddings
8+
2. Image-only embeddings
9+
3. Mixed text and image embeddings
10+
"""
11+
12+
import torch
13+
from PIL import Image
14+
15+
from vllm import LLM
16+
from vllm.config import PoolerConfig
17+
from vllm.inputs.data import TextPrompt
18+
from vllm.multimodal.utils import fetch_image
19+
20+
21+
def get_embeddings(outputs):
22+
"""Extract and normalize embeddings from model outputs."""
23+
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653
24+
25+
embeddings = []
26+
for output in outputs:
27+
if VISION_START_TOKEN_ID in output.prompt_token_ids:
28+
# For vision inputs, extract only vision token embeddings
29+
img_start_pos = torch.where(
30+
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
31+
)[0][0]
32+
img_end_pos = torch.where(
33+
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
34+
)[0][0]
35+
embeddings_tensor = output.outputs.data.detach().clone()[
36+
img_start_pos : img_end_pos + 1
37+
]
38+
else:
39+
# For text-only inputs, use all token embeddings
40+
embeddings_tensor = output.outputs.data.detach().clone()
41+
42+
# Pool and normalize embeddings
43+
pooled_output = (
44+
embeddings_tensor.sum(dim=0, dtype=torch.float32)
45+
/ embeddings_tensor.shape[0]
46+
)
47+
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
48+
return embeddings
49+
50+
51+
def main():
52+
# Initialize the model
53+
model = LLM(
54+
model="jinaai/jina-embeddings-v4-vllm-retrieval",
55+
task="embed",
56+
override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
57+
dtype="float16",
58+
)
59+
60+
# Example 1: Text-only embeddings
61+
print("=== Text Embeddings ===")
62+
query = "Overview of climate change impacts on coastal cities"
63+
query_prompt = TextPrompt(prompt=f"Query: {query}")
64+
65+
passage = """The impacts of climate change on coastal cities are significant
66+
and multifaceted. Rising sea levels threaten infrastructure, while increased
67+
storm intensity poses risks to populations and economies."""
68+
passage_prompt = TextPrompt(prompt=f"Passage: {passage}")
69+
70+
# Generate embeddings
71+
text_outputs = model.encode([query_prompt, passage_prompt])
72+
text_embeddings = get_embeddings(text_outputs)
73+
74+
# Calculate similarity
75+
similarity = torch.dot(text_embeddings[0], text_embeddings[1]).item()
76+
print(f"Query: {query[:50]}...")
77+
print(f"Passage: {passage[:50]}...")
78+
print(f"Similarity: {similarity:.4f}\n")
79+
80+
# Example 2: Image embeddings
81+
print("=== Image Embeddings ===")
82+
# Fetch sample images
83+
image1_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
84+
image2_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
85+
86+
image1 = fetch_image(image1_url)
87+
image2 = fetch_image(image2_url)
88+
89+
# Create image prompts with the required format
90+
image1_prompt = TextPrompt(
91+
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
92+
multi_modal_data={"image": image1},
93+
)
94+
95+
image2_prompt = TextPrompt(
96+
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
97+
multi_modal_data={"image": image2},
98+
)
99+
100+
# Generate embeddings
101+
image_outputs = model.encode([image1_prompt, image2_prompt])
102+
image_embeddings = get_embeddings(image_outputs)
103+
104+
# Calculate similarity
105+
similarity = torch.dot(image_embeddings[0], image_embeddings[1]).item()
106+
print(f"Image 1: {image1_url.split('/')[-1]}")
107+
print(f"Image 2: {image2_url.split('/')[-1]}")
108+
print(f"Similarity: {similarity:.4f}\n")
109+
110+
# Example 3: Cross-modal similarity (text vs image)
111+
print("=== Cross-modal Similarity ===")
112+
query = "scientific paper with markdown formatting"
113+
query_prompt = TextPrompt(prompt=f"Query: {query}")
114+
115+
# Generate embeddings for text query and second image
116+
cross_outputs = model.encode([query_prompt, image2_prompt])
117+
cross_embeddings = get_embeddings(cross_outputs)
118+
119+
# Calculate cross-modal similarity
120+
similarity = torch.dot(cross_embeddings[0], cross_embeddings[1]).item()
121+
print(f"Text query: {query}")
122+
print(f"Image: {image2_url.split('/')[-1]}")
123+
print(f"Cross-modal similarity: {similarity:.4f}")
124+
125+
126+
if __name__ == "__main__":
127+
main()

0 commit comments

Comments
 (0)