Skip to content

Commit d7d6b60

Browse files
committed
feat: jina support
1 parent 77f77a9 commit d7d6b60

File tree

5 files changed

+1078
-0
lines changed

5 files changed

+1078
-0
lines changed
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Benchmark and validate Jina Embeddings V4 against HuggingFace implementation.
5+
6+
This script compares embeddings generated by vLLM vs HuggingFace to ensure
7+
accuracy and measure performance differences.
8+
"""
9+
10+
import argparse
11+
import time
12+
from typing import List, Tuple
13+
14+
import numpy as np
15+
import torch
16+
from PIL import Image
17+
from transformers import AutoModel, AutoProcessor
18+
19+
from vllm import LLM
20+
from vllm.config import PoolerConfig
21+
22+
# Vision token IDs
23+
VISION_START_TOKEN_ID = 151652
24+
VISION_END_TOKEN_ID = 151653
25+
from vllm.inputs.data import TextPrompt
26+
27+
28+
def create_test_cases() -> List[Tuple[str, str, any]]:
29+
"""Create comprehensive test cases for validation."""
30+
test_cases = []
31+
32+
# Text-only test cases
33+
test_cases.extend([
34+
("text", "Query: What is artificial intelligence?", None),
35+
("text", "Passage: AI is a field of computer science focusing on creating intelligent machines.", None),
36+
("text", "Query: 你好世界", None), # Chinese text
37+
("text", "Passage: " + " ".join(["word"] * 100), None), # Long text
38+
])
39+
40+
# Image test cases
41+
for color in ["red", "green", "blue"]:
42+
img = Image.new('RGB', (224, 224), color=color)
43+
test_cases.append(("image", f"{color} image", img))
44+
45+
# Complex image
46+
complex_img = Image.new('RGB', (224, 224))
47+
pixels = complex_img.load()
48+
for i in range(224):
49+
for j in range(224):
50+
pixels[i, j] = (i % 256, j % 256, (i+j) % 256)
51+
test_cases.append(("image", "complex pattern", complex_img))
52+
53+
return test_cases
54+
55+
56+
def compute_hf_embeddings(
57+
model_name: str,
58+
test_cases: List[Tuple[str, str, any]]
59+
) -> List[torch.Tensor]:
60+
"""Compute embeddings using HuggingFace implementation."""
61+
print("Loading HuggingFace model...")
62+
model = AutoModel.from_pretrained(
63+
model_name,
64+
trust_remote_code=True,
65+
torch_dtype=torch.float16
66+
).cuda().eval()
67+
68+
processor = AutoProcessor.from_pretrained(
69+
model_name,
70+
trust_remote_code=True
71+
)
72+
73+
embeddings = []
74+
75+
print("Computing HuggingFace embeddings...")
76+
start_time = time.time()
77+
78+
for case_type, text, image in test_cases:
79+
if case_type == "text":
80+
inputs = processor(text=text, return_tensors="pt").to("cuda")
81+
else: # image
82+
inputs = processor(
83+
text="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
84+
images=image,
85+
return_tensors="pt"
86+
).to("cuda")
87+
88+
with torch.no_grad():
89+
outputs = model(**inputs)
90+
# Extract embeddings based on model output structure
91+
if hasattr(outputs, 'embeddings'):
92+
embedding = outputs.embeddings[0]
93+
else:
94+
# Fallback to last hidden state with custom pooling
95+
hidden_states = outputs.last_hidden_state[0]
96+
97+
# Apply token-type-aware pooling
98+
input_ids = inputs['input_ids'][0]
99+
vision_mask = (
100+
(input_ids >= VISION_START_TOKEN_ID) &
101+
(input_ids <= VISION_END_TOKEN_ID)
102+
)
103+
104+
if vision_mask.any():
105+
embedding = hidden_states[vision_mask].mean(dim=0)
106+
else:
107+
embedding = hidden_states.mean(dim=0)
108+
109+
embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
110+
111+
embeddings.append(embedding.cpu())
112+
113+
hf_time = time.time() - start_time
114+
print(f"HuggingFace processing time: {hf_time:.2f}s")
115+
116+
return embeddings
117+
118+
119+
def compute_vllm_embeddings(
120+
model_name: str,
121+
test_cases: List[Tuple[str, str, any]]
122+
) -> List[torch.Tensor]:
123+
"""Compute embeddings using vLLM implementation."""
124+
print("\nLoading vLLM model...")
125+
model = LLM(
126+
model=model_name,
127+
task="embed",
128+
override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
129+
dtype="float16",
130+
)
131+
132+
embeddings = []
133+
prompts = []
134+
135+
# Prepare prompts
136+
for case_type, text, image in test_cases:
137+
if case_type == "text":
138+
prompt = TextPrompt(prompt=text)
139+
else: # image
140+
prompt = TextPrompt(
141+
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
142+
multi_modal_data={"image": image},
143+
)
144+
prompts.append(prompt)
145+
146+
print("Computing vLLM embeddings...")
147+
start_time = time.time()
148+
149+
# Process all at once for better performance
150+
outputs = model.encode(prompts)
151+
152+
for output in outputs:
153+
# Extract based on token type
154+
if 151652 in output.prompt_token_ids: # VISION_START_TOKEN_ID
155+
img_start = output.prompt_token_ids.index(151652)
156+
img_end = output.prompt_token_ids.index(151653)
157+
embedding_data = output.outputs.data[img_start:img_end + 1]
158+
else:
159+
embedding_data = output.outputs.data
160+
161+
# Pool and normalize
162+
pooled = embedding_data.mean(dim=0, dtype=torch.float32)
163+
normalized = torch.nn.functional.normalize(pooled, p=2, dim=-1)
164+
embeddings.append(normalized.cpu())
165+
166+
vllm_time = time.time() - start_time
167+
print(f"vLLM processing time: {vllm_time:.2f}s")
168+
169+
return embeddings
170+
171+
172+
def compare_embeddings(
173+
hf_embeddings: List[torch.Tensor],
174+
vllm_embeddings: List[torch.Tensor],
175+
test_cases: List[Tuple[str, str, any]]
176+
) -> None:
177+
"""Compare embeddings and report differences."""
178+
print("\n" + "="*60)
179+
print("EMBEDDING COMPARISON RESULTS")
180+
print("="*60)
181+
182+
similarities = []
183+
max_diffs = []
184+
185+
for i, (case_type, desc, _) in enumerate(test_cases):
186+
hf_emb = hf_embeddings[i]
187+
vllm_emb = vllm_embeddings[i]
188+
189+
# Compute cosine similarity
190+
similarity = torch.nn.functional.cosine_similarity(
191+
hf_emb.unsqueeze(0),
192+
vllm_emb.unsqueeze(0)
193+
).item()
194+
195+
# Compute max absolute difference
196+
max_diff = torch.max(torch.abs(hf_emb - vllm_emb)).item()
197+
198+
similarities.append(similarity)
199+
max_diffs.append(max_diff)
200+
201+
print(f"\nTest case {i+1}: {case_type} - {desc[:50]}...")
202+
print(f" Cosine similarity: {similarity:.6f}")
203+
print(f" Max absolute diff: {max_diff:.6f}")
204+
print(f" HF norm: {hf_emb.norm():.6f}, vLLM norm: {vllm_emb.norm():.6f}")
205+
206+
# Flag significant differences
207+
if similarity < 0.99:
208+
print(f" ⚠️ WARNING: Low similarity detected!")
209+
210+
# Summary statistics
211+
print("\n" + "-"*60)
212+
print("SUMMARY STATISTICS")
213+
print("-"*60)
214+
print(f"Average cosine similarity: {np.mean(similarities):.6f}")
215+
print(f"Min cosine similarity: {np.min(similarities):.6f}")
216+
print(f"Max absolute difference: {np.max(max_diffs):.6f}")
217+
218+
# Overall assessment
219+
if np.min(similarities) > 0.99:
220+
print("\n✅ VALIDATION PASSED: vLLM implementation matches HuggingFace")
221+
else:
222+
print("\n❌ VALIDATION FAILED: Significant differences detected")
223+
224+
225+
def main():
226+
parser = argparse.ArgumentParser(
227+
description="Validate Jina Embeddings V4 implementation"
228+
)
229+
parser.add_argument(
230+
"--model",
231+
type=str,
232+
default="jinaai/jina-embeddings-v4-vllm-retrieval",
233+
help="Model name to test"
234+
)
235+
parser.add_argument(
236+
"--skip-hf",
237+
action="store_true",
238+
help="Skip HuggingFace comparison (for performance testing only)"
239+
)
240+
241+
args = parser.parse_args()
242+
243+
# Create test cases
244+
test_cases = create_test_cases()
245+
print(f"Created {len(test_cases)} test cases")
246+
247+
# Compute vLLM embeddings
248+
vllm_embeddings = compute_vllm_embeddings(args.model, test_cases)
249+
250+
if not args.skip_hf:
251+
# Compute HuggingFace embeddings
252+
hf_embeddings = compute_hf_embeddings(args.model, test_cases)
253+
254+
# Compare results
255+
compare_embeddings(hf_embeddings, vllm_embeddings, test_cases)
256+
else:
257+
print("\nSkipping HuggingFace comparison")
258+
print(f"vLLM processed {len(test_cases)} embeddings successfully")
259+
260+
261+
if __name__ == "__main__":
262+
main()
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Example of using Jina Embeddings V4 with vLLM for multimodal embeddings.
5+
6+
This example demonstrates:
7+
1. Text-only embeddings
8+
2. Image-only embeddings
9+
3. Mixed text and image embeddings
10+
"""
11+
12+
import torch
13+
from PIL import Image
14+
15+
from vllm import LLM
16+
from vllm.config import PoolerConfig
17+
from vllm.inputs.data import TextPrompt
18+
from vllm.multimodal.utils import fetch_image
19+
20+
21+
def get_embeddings(outputs):
22+
"""Extract and normalize embeddings from model outputs."""
23+
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653
24+
25+
embeddings = []
26+
for output in outputs:
27+
if VISION_START_TOKEN_ID in output.prompt_token_ids:
28+
# For vision inputs, extract only vision token embeddings
29+
img_start_pos = output.prompt_token_ids.index(VISION_START_TOKEN_ID)
30+
img_end_pos = output.prompt_token_ids.index(VISION_END_TOKEN_ID)
31+
embeddings_tensor = output.outputs.data.detach().clone()[
32+
img_start_pos : img_end_pos + 1
33+
]
34+
else:
35+
# For text-only inputs, use all token embeddings
36+
embeddings_tensor = output.outputs.data.detach().clone()
37+
38+
# Pool and normalize embeddings
39+
pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32)
40+
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
41+
return embeddings
42+
43+
44+
def main():
45+
# Initialize the model
46+
model = LLM(
47+
model="jinaai/jina-embeddings-v4-vllm-retrieval",
48+
task="embed",
49+
override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
50+
dtype="float16",
51+
)
52+
53+
# Example 1: Text-only embeddings
54+
print("=== Text Embeddings ===")
55+
query = "Overview of climate change impacts on coastal cities"
56+
query_prompt = TextPrompt(prompt=f"Query: {query}")
57+
58+
passage = """The impacts of climate change on coastal cities are significant
59+
and multifaceted. Rising sea levels threaten infrastructure, while increased
60+
storm intensity poses risks to populations and economies."""
61+
passage_prompt = TextPrompt(prompt=f"Passage: {passage}")
62+
63+
# Generate embeddings
64+
text_outputs = model.encode([query_prompt, passage_prompt])
65+
text_embeddings = get_embeddings(text_outputs)
66+
67+
# Calculate similarity
68+
similarity = torch.dot(text_embeddings[0], text_embeddings[1]).item()
69+
print(f"Query: {query[:50]}...")
70+
print(f"Passage: {passage[:50]}...")
71+
print(f"Similarity: {similarity:.4f}\n")
72+
73+
# Example 2: Image embeddings
74+
print("=== Image Embeddings ===")
75+
# Fetch sample images
76+
image1_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
77+
image2_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
78+
79+
image1 = fetch_image(image1_url)
80+
image2 = fetch_image(image2_url)
81+
82+
# Create image prompts with the required format
83+
image1_prompt = TextPrompt(
84+
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
85+
multi_modal_data={"image": image1},
86+
)
87+
88+
image2_prompt = TextPrompt(
89+
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
90+
multi_modal_data={"image": image2},
91+
)
92+
93+
# Generate embeddings
94+
image_outputs = model.encode([image1_prompt, image2_prompt])
95+
image_embeddings = get_embeddings(image_outputs)
96+
97+
# Calculate similarity
98+
similarity = torch.dot(image_embeddings[0], image_embeddings[1]).item()
99+
print(f"Image 1: {image1_url.split('/')[-1]}")
100+
print(f"Image 2: {image2_url.split('/')[-1]}")
101+
print(f"Similarity: {similarity:.4f}\n")
102+
103+
# Example 3: Cross-modal similarity (text vs image)
104+
print("=== Cross-modal Similarity ===")
105+
query = "scientific paper with markdown formatting"
106+
query_prompt = TextPrompt(prompt=f"Query: {query}")
107+
108+
# Generate embeddings for text query and second image
109+
cross_outputs = model.encode([query_prompt, image2_prompt])
110+
cross_embeddings = get_embeddings(cross_outputs)
111+
112+
# Calculate cross-modal similarity
113+
similarity = torch.dot(cross_embeddings[0], cross_embeddings[1]).item()
114+
print(f"Text query: {query}")
115+
print(f"Image: {image2_url.split('/')[-1]}")
116+
print(f"Cross-modal similarity: {similarity:.4f}")
117+
118+
119+
if __name__ == "__main__":
120+
main()

0 commit comments

Comments
 (0)