Skip to content

[Model] Add support for Jina Embeddings V4 #20802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
265 changes: 265 additions & 0 deletions benchmarks/jina_embeddings_v4_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark and validate Jina Embeddings V4 against HuggingFace implementation.
This script compares embeddings generated by vLLM vs HuggingFace to ensure
accuracy and measure performance differences.
"""

import argparse
import time

import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor

from vllm import LLM
from vllm.config import PoolerConfig
from vllm.inputs.data import TextPrompt

# Vision token IDs
VISION_START_TOKEN_ID = 151652
VISION_END_TOKEN_ID = 151653


def create_test_cases() -> list[tuple[str, str, any]]:
"""Create comprehensive test cases for validation."""
test_cases = []

# Text-only test cases
test_cases.extend(
[
("text", "Query: What is artificial intelligence?", None),
(
"text",
"Passage: AI is a field of computer science focusing on "
"creating intelligent machines.",
None,
),
("text", "Query: 你好世界", None), # Chinese text
("text", "Passage: " + " ".join(["word"] * 100), None), # Long text
]
)

# Image test cases
for color in ["red", "green", "blue"]:
img = Image.new("RGB", (224, 224), color=color)
test_cases.append(("image", f"{color} image", img))

# Complex image
complex_img = Image.new("RGB", (224, 224))
pixels = complex_img.load()
for i in range(224):
for j in range(224):
pixels[i, j] = (i % 256, j % 256, (i + j) % 256)
test_cases.append(("image", "complex pattern", complex_img))

return test_cases


def compute_hf_embeddings(
model_name: str, test_cases: list[tuple[str, str, any]]
) -> list[torch.Tensor]:
"""Compute embeddings using HuggingFace implementation."""
print("Loading HuggingFace model...")
model = (
AutoModel.from_pretrained(
model_name, trust_remote_code=True, torch_dtype=torch.float16
)
.cuda()
.eval()
)

processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

embeddings = []

print("Computing HuggingFace embeddings...")
start_time = time.time()

for case_type, text, image in test_cases:
if case_type == "text":
inputs = processor(text=text, return_tensors="pt").to("cuda")
else: # image
inputs = processor(
text="<|im_start|>user\n<|vision_start|><|image_pad|>"
"<|vision_end|>Describe the image.<|im_end|>\n",
images=image,
return_tensors="pt",
).to("cuda")

with torch.no_grad():
outputs = model(**inputs)
# Extract embeddings based on model output structure
if hasattr(outputs, "embeddings"):
embedding = outputs.embeddings[0]
else:
# Fallback to last hidden state with custom pooling
hidden_states = outputs.last_hidden_state[0]

# Apply token-type-aware pooling
input_ids = inputs["input_ids"][0]
vision_mask = (input_ids >= VISION_START_TOKEN_ID) & (
input_ids <= VISION_END_TOKEN_ID
)

if vision_mask.any():
embedding = hidden_states[vision_mask].mean(dim=0)
else:
embedding = hidden_states.mean(dim=0)

embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)

embeddings.append(embedding.cpu())

hf_time = time.time() - start_time
print(f"HuggingFace processing time: {hf_time:.2f}s")

return embeddings


def compute_vllm_embeddings(
model_name: str, test_cases: list[tuple[str, str, any]]
) -> list[torch.Tensor]:
"""Compute embeddings using vLLM implementation."""
print("\nLoading vLLM model...")
model = LLM(
model=model_name,
task="embed",
override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
dtype="float16",
)

embeddings = []
prompts = []

# Prepare prompts
for case_type, text, image in test_cases:
if case_type == "text":
prompt = TextPrompt(prompt=text)
else: # image
prompt = TextPrompt(
prompt="<|im_start|>user\n<|vision_start|><|image_pad|>"
"<|vision_end|>Describe the image.<|im_end|>\n",
multi_modal_data={"image": image},
)
prompts.append(prompt)

print("Computing vLLM embeddings...")
start_time = time.time()

# Process all at once for better performance
outputs = model.encode(prompts)

for output in outputs:
# Extract based on token type
if 151652 in output.prompt_token_ids: # VISION_START_TOKEN_ID
img_start = output.prompt_token_ids.index(151652)
img_end = output.prompt_token_ids.index(151653)
Comment on lines +158 to +160
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace magic numbers with constants

embedding_data = output.outputs.data[img_start : img_end + 1]
else:
embedding_data = output.outputs.data

# Pool and normalize
pooled = embedding_data.mean(dim=0, dtype=torch.float32)
normalized = torch.nn.functional.normalize(pooled, p=2, dim=-1)
embeddings.append(normalized.cpu())

vllm_time = time.time() - start_time
print(f"vLLM processing time: {vllm_time:.2f}s")

return embeddings


def compare_embeddings(
hf_embeddings: list[torch.Tensor],
vllm_embeddings: list[torch.Tensor],
test_cases: list[tuple[str, str, any]],
) -> None:
"""Compare embeddings and report differences."""
print("\n" + "=" * 60)
print("EMBEDDING COMPARISON RESULTS")
print("=" * 60)

similarities = []
max_diffs = []

for i, (case_type, desc, _) in enumerate(test_cases):
hf_emb = hf_embeddings[i]
vllm_emb = vllm_embeddings[i]

# Compute cosine similarity
similarity = torch.nn.functional.cosine_similarity(
hf_emb.unsqueeze(0), vllm_emb.unsqueeze(0)
).item()

# Compute max absolute difference
max_diff = torch.max(torch.abs(hf_emb - vllm_emb)).item()

similarities.append(similarity)
max_diffs.append(max_diff)

print(f"\nTest case {i + 1}: {case_type} - {desc[:50]}...")
print(f" Cosine similarity: {similarity:.6f}")
print(f" Max absolute diff: {max_diff:.6f}")
print(f" HF norm: {hf_emb.norm():.6f}, vLLM norm: {vllm_emb.norm():.6f}")

# Flag significant differences
if similarity < 0.99:
print(" ⚠️ WARNING: Low similarity detected!")

# Summary statistics
print("\n" + "-" * 60)
print("SUMMARY STATISTICS")
print("-" * 60)
print(f"Average cosine similarity: {np.mean(similarities):.6f}")
print(f"Min cosine similarity: {np.min(similarities):.6f}")
print(f"Max absolute difference: {np.max(max_diffs):.6f}")

# Overall assessment
if np.min(similarities) > 0.99:
print("\n✅ VALIDATION PASSED: vLLM implementation matches HuggingFace")
else:
print("\n❌ VALIDATION FAILED: Significant differences detected")


def main():
parser = argparse.ArgumentParser(
description="Validate Jina Embeddings V4 implementation"
)
parser.add_argument(
"--model",
type=str,
default="jinaai/jina-embeddings-v4-vllm-retrieval",
help="Model name to test",
)
parser.add_argument(
"--skip-hf",
action="store_true",
help="Skip HuggingFace comparison (for performance testing only)",
)

args = parser.parse_args()

# Create test cases
test_cases = create_test_cases()
print(f"Created {len(test_cases)} test cases")

# Compute vLLM embeddings
vllm_embeddings = compute_vllm_embeddings(args.model, test_cases)

if not args.skip_hf:
# Compute HuggingFace embeddings
hf_embeddings = compute_hf_embeddings(args.model, test_cases)

# Compare results
compare_embeddings(hf_embeddings, vllm_embeddings, test_cases)
else:
print("\nSkipping HuggingFace comparison")
print(f"vLLM processed {len(test_cases)} embeddings successfully")


if __name__ == "__main__":
main()
121 changes: 121 additions & 0 deletions examples/offline_inference/jina_embeddings_v4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using Jina Embeddings V4 with vLLM for multimodal embeddings.
This example demonstrates:
1. Text-only embeddings
2. Image-only embeddings
3. Mixed text and image embeddings
"""

import torch

from vllm import LLM
from vllm.config import PoolerConfig
from vllm.inputs.data import TextPrompt
from vllm.multimodal.utils import fetch_image


def get_embeddings(outputs):
"""Extract and normalize embeddings from model outputs."""
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653

embeddings = []
for output in outputs:
if VISION_START_TOKEN_ID in output.prompt_token_ids:
# For vision inputs, extract only vision token embeddings
img_start_pos = output.prompt_token_ids.index(VISION_START_TOKEN_ID)
img_end_pos = output.prompt_token_ids.index(VISION_END_TOKEN_ID)
embeddings_tensor = output.outputs.data.detach().clone()[
img_start_pos : img_end_pos + 1
]
else:
# For text-only inputs, use all token embeddings
embeddings_tensor = output.outputs.data.detach().clone()

# Pool and normalize embeddings
pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32)
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
return embeddings


def main():
# Initialize the model
model = LLM(
model="jinaai/jina-embeddings-v4-vllm-retrieval",
task="embed",
override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
dtype="float16",
)

# Example 1: Text-only embeddings
print("=== Text Embeddings ===")
query = "Overview of climate change impacts on coastal cities"
query_prompt = TextPrompt(prompt=f"Query: {query}")

passage = """The impacts of climate change on coastal cities are significant
and multifaceted. Rising sea levels threaten infrastructure, while increased
storm intensity poses risks to populations and economies."""
passage_prompt = TextPrompt(prompt=f"Passage: {passage}")

# Generate embeddings
text_outputs = model.encode([query_prompt, passage_prompt])
text_embeddings = get_embeddings(text_outputs)

# Calculate similarity
similarity = torch.dot(text_embeddings[0], text_embeddings[1]).item()
print(f"Query: {query[:50]}...")
print(f"Passage: {passage[:50]}...")
print(f"Similarity: {similarity:.4f}\n")

# Example 2: Image embeddings
print("=== Image Embeddings ===")
# Fetch sample images
image1_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
image2_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"

image1 = fetch_image(image1_url)
image2 = fetch_image(image2_url)

# Create image prompts with the required format
image1_prompt = TextPrompt(
prompt="<|im_start|>user\n<|vision_start|><|image_pad|>"
"<|vision_end|>Describe the image.<|im_end|>\n",
multi_modal_data={"image": image1},
)

image2_prompt = TextPrompt(
prompt="<|im_start|>user\n<|vision_start|><|image_pad|>"
"<|vision_end|>Describe the image.<|im_end|>\n",
multi_modal_data={"image": image2},
)

# Generate embeddings
image_outputs = model.encode([image1_prompt, image2_prompt])
image_embeddings = get_embeddings(image_outputs)

# Calculate similarity
similarity = torch.dot(image_embeddings[0], image_embeddings[1]).item()
print(f"Image 1: {image1_url.split('/')[-1]}")
print(f"Image 2: {image2_url.split('/')[-1]}")
print(f"Similarity: {similarity:.4f}\n")

# Example 3: Cross-modal similarity (text vs image)
print("=== Cross-modal Similarity ===")
query = "scientific paper with markdown formatting"
query_prompt = TextPrompt(prompt=f"Query: {query}")

# Generate embeddings for text query and second image
cross_outputs = model.encode([query_prompt, image2_prompt])
cross_embeddings = get_embeddings(cross_outputs)

# Calculate cross-modal similarity
similarity = torch.dot(cross_embeddings[0], cross_embeddings[1]).item()
print(f"Text query: {query}")
print(f"Image: {image2_url.split('/')[-1]}")
print(f"Cross-modal similarity: {similarity:.4f}")


if __name__ == "__main__":
main()
Loading