diff --git a/benchmarks/jina_embeddings_v4_validation.py b/benchmarks/jina_embeddings_v4_validation.py new file mode 100644 index 00000000000..f0eba1f9f45 --- /dev/null +++ b/benchmarks/jina_embeddings_v4_validation.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark and validate Jina Embeddings V4 against HuggingFace implementation. + +This script compares embeddings generated by vLLM vs HuggingFace to ensure +accuracy and measure performance differences. +""" + +import argparse +import time + +import numpy as np +import torch +from PIL import Image +from transformers import AutoModel, AutoProcessor + +from vllm import LLM +from vllm.config import PoolerConfig +from vllm.inputs.data import TextPrompt + +# Vision token IDs +VISION_START_TOKEN_ID = 151652 +VISION_END_TOKEN_ID = 151653 + + +def create_test_cases() -> list[tuple[str, str, any]]: + """Create comprehensive test cases for validation.""" + test_cases = [] + + # Text-only test cases + test_cases.extend( + [ + ("text", "Query: What is artificial intelligence?", None), + ( + "text", + "Passage: AI is a field of computer science focusing on " + "creating intelligent machines.", + None, + ), + ("text", "Query: 你好世界", None), # Chinese text + ("text", "Passage: " + " ".join(["word"] * 100), None), # Long text + ] + ) + + # Image test cases + for color in ["red", "green", "blue"]: + img = Image.new("RGB", (224, 224), color=color) + test_cases.append(("image", f"{color} image", img)) + + # Complex image + complex_img = Image.new("RGB", (224, 224)) + pixels = complex_img.load() + for i in range(224): + for j in range(224): + pixels[i, j] = (i % 256, j % 256, (i + j) % 256) + test_cases.append(("image", "complex pattern", complex_img)) + + return test_cases + + +def compute_hf_embeddings( + model_name: str, test_cases: list[tuple[str, str, any]] +) -> list[torch.Tensor]: + """Compute embeddings using HuggingFace implementation.""" + print("Loading HuggingFace model...") + model = ( + AutoModel.from_pretrained( + model_name, trust_remote_code=True, torch_dtype=torch.float16 + ) + .cuda() + .eval() + ) + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + embeddings = [] + + print("Computing HuggingFace embeddings...") + start_time = time.time() + + for case_type, text, image in test_cases: + if case_type == "text": + inputs = processor(text=text, return_tensors="pt").to("cuda") + else: # image + inputs = processor( + text="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + images=image, + return_tensors="pt", + ).to("cuda") + + with torch.no_grad(): + outputs = model(**inputs) + # Extract embeddings based on model output structure + if hasattr(outputs, "embeddings"): + embedding = outputs.embeddings[0] + else: + # Fallback to last hidden state with custom pooling + hidden_states = outputs.last_hidden_state[0] + + # Apply token-type-aware pooling + input_ids = inputs["input_ids"][0] + vision_mask = (input_ids >= VISION_START_TOKEN_ID) & ( + input_ids <= VISION_END_TOKEN_ID + ) + + if vision_mask.any(): + embedding = hidden_states[vision_mask].mean(dim=0) + else: + embedding = hidden_states.mean(dim=0) + + embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1) + + embeddings.append(embedding.cpu()) + + hf_time = time.time() - start_time + print(f"HuggingFace processing time: {hf_time:.2f}s") + + return embeddings + + +def compute_vllm_embeddings( + model_name: str, test_cases: list[tuple[str, str, any]] +) -> list[torch.Tensor]: + """Compute embeddings using vLLM implementation.""" + print("\nLoading vLLM model...") + model = LLM( + model=model_name, + task="embed", + override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False), + dtype="float16", + ) + + embeddings = [] + prompts = [] + + # Prepare prompts + for case_type, text, image in test_cases: + if case_type == "text": + prompt = TextPrompt(prompt=text) + else: # image + prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image}, + ) + prompts.append(prompt) + + print("Computing vLLM embeddings...") + start_time = time.time() + + # Process all at once for better performance + outputs = model.encode(prompts) + + for output in outputs: + # Extract based on token type + if VISION_START_TOKEN_ID in output.prompt_token_ids: + img_start = output.prompt_token_ids.index(VISION_START_TOKEN_ID) + img_end = output.prompt_token_ids.index(VISION_END_TOKEN_ID) + embedding_data = output.outputs.data[img_start : img_end + 1] + else: + embedding_data = output.outputs.data + + # Pool and normalize + pooled = embedding_data.mean(dim=0, dtype=torch.float32) + normalized = torch.nn.functional.normalize(pooled, p=2, dim=-1) + embeddings.append(normalized.cpu()) + + vllm_time = time.time() - start_time + print(f"vLLM processing time: {vllm_time:.2f}s") + + return embeddings + + +def compare_embeddings( + hf_embeddings: list[torch.Tensor], + vllm_embeddings: list[torch.Tensor], + test_cases: list[tuple[str, str, any]], +) -> None: + """Compare embeddings and report differences.""" + print("\n" + "=" * 60) + print("EMBEDDING COMPARISON RESULTS") + print("=" * 60) + + similarities = [] + max_diffs = [] + + for i, (case_type, desc, _) in enumerate(test_cases): + hf_emb = hf_embeddings[i] + vllm_emb = vllm_embeddings[i] + + # Compute cosine similarity + similarity = torch.nn.functional.cosine_similarity( + hf_emb.unsqueeze(0), vllm_emb.unsqueeze(0) + ).item() + + # Compute max absolute difference + max_diff = torch.max(torch.abs(hf_emb - vllm_emb)).item() + + similarities.append(similarity) + max_diffs.append(max_diff) + + print(f"\nTest case {i + 1}: {case_type} - {desc[:50]}...") + print(f" Cosine similarity: {similarity:.6f}") + print(f" Max absolute diff: {max_diff:.6f}") + print(f" HF norm: {hf_emb.norm():.6f}, vLLM norm: {vllm_emb.norm():.6f}") + + # Flag significant differences + if similarity < 0.99: + print(" ⚠️ WARNING: Low similarity detected!") + + # Summary statistics + print("\n" + "-" * 60) + print("SUMMARY STATISTICS") + print("-" * 60) + print(f"Average cosine similarity: {np.mean(similarities):.6f}") + print(f"Min cosine similarity: {np.min(similarities):.6f}") + print(f"Max absolute difference: {np.max(max_diffs):.6f}") + + # Overall assessment + if np.min(similarities) > 0.99: + print("\n✅ VALIDATION PASSED: vLLM implementation matches HuggingFace") + else: + print("\n❌ VALIDATION FAILED: Significant differences detected") + + +def main(): + parser = argparse.ArgumentParser( + description="Validate Jina Embeddings V4 implementation" + ) + parser.add_argument( + "--model", + type=str, + default="jinaai/jina-embeddings-v4-vllm-retrieval", + help="Model name to test", + ) + parser.add_argument( + "--skip-hf", + action="store_true", + help="Skip HuggingFace comparison (for performance testing only)", + ) + + args = parser.parse_args() + + # Create test cases + test_cases = create_test_cases() + print(f"Created {len(test_cases)} test cases") + + # Compute vLLM embeddings + vllm_embeddings = compute_vllm_embeddings(args.model, test_cases) + + if not args.skip_hf: + # Compute HuggingFace embeddings + hf_embeddings = compute_hf_embeddings(args.model, test_cases) + + # Compare results + compare_embeddings(hf_embeddings, vllm_embeddings, test_cases) + else: + print("\nSkipping HuggingFace comparison") + print(f"vLLM processed {len(test_cases)} embeddings successfully") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/jina_embeddings_v4.py b/examples/offline_inference/jina_embeddings_v4.py new file mode 100644 index 00000000000..c3716b5e09f --- /dev/null +++ b/examples/offline_inference/jina_embeddings_v4.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Example of using Jina Embeddings V4 with vLLM for multimodal embeddings. + +This example demonstrates: +1. Text-only embeddings +2. Image-only embeddings +3. Mixed text and image embeddings +""" + +import torch + +from vllm import LLM +from vllm.config import PoolerConfig +from vllm.inputs.data import TextPrompt +from vllm.multimodal.utils import fetch_image + + +def get_embeddings(outputs): + """Extract and normalize embeddings from model outputs.""" + VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653 + + embeddings = [] + for output in outputs: + if VISION_START_TOKEN_ID in output.prompt_token_ids: + # For vision inputs, extract only vision token embeddings + img_start_pos = output.prompt_token_ids.index(VISION_START_TOKEN_ID) + img_end_pos = output.prompt_token_ids.index(VISION_END_TOKEN_ID) + embeddings_tensor = output.outputs.data.detach().clone()[ + img_start_pos : img_end_pos + 1 + ] + else: + # For text-only inputs, use all token embeddings + embeddings_tensor = output.outputs.data.detach().clone() + + # Pool and normalize embeddings + pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32) + embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1)) + return embeddings + + +def main(): + # Initialize the model + model = LLM( + model="jinaai/jina-embeddings-v4-vllm-retrieval", + task="embed", + override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False), + dtype="float16", + ) + + # Example 1: Text-only embeddings + print("=== Text Embeddings ===") + query = "Overview of climate change impacts on coastal cities" + query_prompt = TextPrompt(prompt=f"Query: {query}") + + passage = """The impacts of climate change on coastal cities are significant + and multifaceted. Rising sea levels threaten infrastructure, while increased + storm intensity poses risks to populations and economies.""" + passage_prompt = TextPrompt(prompt=f"Passage: {passage}") + + # Generate embeddings + text_outputs = model.encode([query_prompt, passage_prompt]) + text_embeddings = get_embeddings(text_outputs) + + # Calculate similarity + similarity = torch.dot(text_embeddings[0], text_embeddings[1]).item() + print(f"Query: {query[:50]}...") + print(f"Passage: {passage[:50]}...") + print(f"Similarity: {similarity:.4f}\n") + + # Example 2: Image embeddings + print("=== Image Embeddings ===") + # Fetch sample images + image1_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" + image2_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + + image1 = fetch_image(image1_url) + image2 = fetch_image(image2_url) + + # Create image prompts with the required format + image1_prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image1}, + ) + + image2_prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image2}, + ) + + # Generate embeddings + image_outputs = model.encode([image1_prompt, image2_prompt]) + image_embeddings = get_embeddings(image_outputs) + + # Calculate similarity + similarity = torch.dot(image_embeddings[0], image_embeddings[1]).item() + print(f"Image 1: {image1_url.split('/')[-1]}") + print(f"Image 2: {image2_url.split('/')[-1]}") + print(f"Similarity: {similarity:.4f}\n") + + # Example 3: Cross-modal similarity (text vs image) + print("=== Cross-modal Similarity ===") + query = "scientific paper with markdown formatting" + query_prompt = TextPrompt(prompt=f"Query: {query}") + + # Generate embeddings for text query and second image + cross_outputs = model.encode([query_prompt, image2_prompt]) + cross_embeddings = get_embeddings(cross_outputs) + + # Calculate cross-modal similarity + similarity = torch.dot(cross_embeddings[0], cross_embeddings[1]).item() + print(f"Text query: {query}") + print(f"Image: {image2_url.split('/')[-1]}") + print(f"Cross-modal similarity: {similarity:.4f}") + + +if __name__ == "__main__": + main() diff --git a/tests/models/pooling/test_jina_embeddings_v4.py b/tests/models/pooling/test_jina_embeddings_v4.py new file mode 100644 index 00000000000..6baa8d859d7 --- /dev/null +++ b/tests/models/pooling/test_jina_embeddings_v4.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import gc +import time +from array import array +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import torch +from PIL import Image + +from vllm import LLM +from vllm.config import PoolerConfig +from vllm.inputs.data import TextPrompt +from vllm.sequence import SequenceData + +model_name = "jinaai/jina-embeddings-v4-vllm-retrieval" + +# Vision token IDs +VISION_START_TOKEN_ID = 151652 +VISION_END_TOKEN_ID = 151653 + + +@pytest.fixture(scope="module") +def model(): + """Initialize model once for all tests.""" + return LLM( + model=model_name, + task="embed", + override_pooler_config=PoolerConfig(pooling_type="ALL", + normalize=False), + dtype="float16", + max_model_len=2048, + ) + + +def extract_embeddings(output): + """Extract embeddings based on token type.""" + if VISION_START_TOKEN_ID in output.prompt_token_ids: + # Extract vision tokens only + img_start = output.prompt_token_ids.index(VISION_START_TOKEN_ID) + img_end = output.prompt_token_ids.index(VISION_END_TOKEN_ID) + embeddings = output.outputs.data[img_start:img_end + 1] + else: + # Use all tokens for text + embeddings = output.outputs.data + + # Mean pool and normalize + pooled = embeddings.mean(dim=0, dtype=torch.float32) + return torch.nn.functional.normalize(pooled, dim=-1) + + +class TestBasicFunctionality: + """Test basic embedding generation functionality.""" + + def test_text_only_embeddings(self, model): + """Test text-only embedding generation.""" + prompts = [ + TextPrompt(prompt="Query: What is machine learning?"), + TextPrompt(prompt="Passage: Machine learning is a subset of " + "artificial intelligence.") + ] + + outputs = model.encode(prompts) + embeddings = [extract_embeddings(output) for output in outputs] + + # Check embeddings are normalized + for emb in embeddings: + assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) + + # Check similarity is reasonable + similarity = torch.dot(embeddings[0], embeddings[1]).item() + assert 0.0 <= similarity <= 1.0 + + def test_image_embeddings(self, model): + """Test image embedding generation.""" + # Create a dummy image + image = Image.new('RGB', (224, 224), color='red') + + prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image}, + ) + + outputs = model.encode([prompt]) + embedding = extract_embeddings(outputs[0]) + + # Check embedding is normalized + assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) + + # Check dimension + assert embedding.shape[ + 0] == model.llm_engine.model_config.hf_config.hidden_size + + def test_mixed_batch(self, model): + """Test mixed text and image batch processing.""" + image = Image.new('RGB', (224, 224), color='blue') + + prompts = [ + TextPrompt(prompt="Query: blue color"), + TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": image}, + ), + TextPrompt(prompt="Passage: The sky is blue.") + ] + + outputs = model.encode(prompts) + embeddings = [extract_embeddings(output) for output in outputs] + + # All embeddings should be normalized + for emb in embeddings: + assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) + + # Text query about blue should have some similarity to blue image + text_image_sim = torch.dot(embeddings[0], embeddings[1]).item() + assert text_image_sim > 0.0 # Should have positive similarity + + +class TestThreadSafety: + """Test thread safety of the pooling implementation.""" + + def test_concurrent_requests(self, model): + """Test handling of concurrent embedding requests.""" + num_threads = 4 + requests_per_thread = 5 + + def process_request(thread_id): + results = [] + for i in range(requests_per_thread): + prompt = TextPrompt( + prompt=f"Query from thread {thread_id}, request {i}") + outputs = model.encode([prompt]) + embedding = extract_embeddings(outputs[0]) + results.append(embedding) + return results + + # Run concurrent requests + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(process_request, i) for i in range(num_threads) + ] + + all_results = [] + for future in as_completed(futures): + thread_results = future.result() + all_results.extend(thread_results) + + # Verify all embeddings are valid + assert len(all_results) == num_threads * requests_per_thread + for emb in all_results: + assert torch.allclose(emb.norm(), torch.tensor(1.0), atol=1e-3) + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_empty_input_handling(self, model): + """Test handling of empty inputs.""" + # This should not crash but return empty outputs + outputs = model.encode([]) + assert len(outputs) == 0 + + def test_very_long_sequence(self, model): + """Test handling of sequences near max length.""" + # Create a long text that approaches max_model_len + long_text = " ".join(["word"] * 1000) + prompt = TextPrompt(prompt=f"Query: {long_text}") + + # Should handle gracefully without crashing + outputs = model.encode([prompt]) + embedding = extract_embeddings(outputs[0]) + assert torch.allclose(embedding.norm(), torch.tensor(1.0), atol=1e-3) + + def test_invalid_image_format(self, model): + """Test handling of invalid image inputs.""" + # Create an invalid image (too small) + tiny_image = Image.new('RGB', (1, 1), color='red') + + prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe the image.<|im_end|>\n", + multi_modal_data={"image": tiny_image}, + ) + + # Should handle gracefully + try: + outputs = model.encode([prompt]) + # If it doesn't crash, check output is valid + if outputs: + embedding = extract_embeddings(outputs[0]) + assert embedding.shape[ + 0] == model.llm_engine.model_config.hf_config.hidden_size + except Exception as e: + # Should provide meaningful error message + assert "image" in str(e).lower() or "size" in str(e).lower() + + +class TestMemoryManagement: + """Test memory management and cleanup.""" + + def test_memory_cleanup(self, model): + """Test that memory is properly cleaned up after processing.""" + # Get initial memory usage + torch.cuda.empty_cache() + if torch.cuda.is_available(): + initial_memory = torch.cuda.memory_allocated() + + # Process multiple large batches + for _ in range(5): + prompts = [ + TextPrompt(prompt=f"Query: test {i}") for i in range(10) + ] + outputs = model.encode(prompts) + del outputs + gc.collect() + + # Check memory usage hasn't grown significantly + if torch.cuda.is_available(): + torch.cuda.empty_cache() + final_memory = torch.cuda.memory_allocated() + memory_growth = final_memory - initial_memory + # Allow some growth but not excessive + assert memory_growth < 100 * 1024 * 1024 # Less than 100MB growth + + +class TestPerformance: + """Test performance characteristics.""" + + def test_pooling_performance(self, model): + """Test that custom pooling is performant.""" + # Create test prompts + text_prompts = [ + TextPrompt(prompt=f"Query: test {i}") for i in range(10) + ] + + # Time text-only pooling + start_time = time.time() + text_outputs = model.encode(text_prompts) + text_time = time.time() - start_time + + # Create image prompts + image = Image.new('RGB', (224, 224), color='green') + image_prompts = [ + TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Describe.<|im_end|>\n", + multi_modal_data={"image": image}, + ) for _ in range(10) + ] + + # Time vision pooling + start_time = time.time() + image_outputs = model.encode(image_prompts) + image_time = time.time() - start_time + + # Vision pooling should not be significantly slower + # (allowing 2x slower due to additional processing) + assert image_time < text_time * 2.0 + + # Verify outputs are valid + for output in text_outputs + image_outputs: + embedding = extract_embeddings(output) + assert torch.allclose(embedding.norm(), + torch.tensor(1.0), + atol=1e-3) + + +class TestPoolingMetadataIntegration: + """Test proper integration with PoolingMetadata.""" + + def test_seq_data_access(self): + """Test that token IDs are properly accessible via seq_data.""" + # Create mock sequence data + prompt_tokens = array('l', [ + 101, 102, VISION_START_TOKEN_ID, VISION_START_TOKEN_ID, + VISION_END_TOKEN_ID, 104 + ]) + seq_data = SequenceData(prompt_tokens) + + # Verify prompt_token_ids_array property works + assert hasattr(seq_data, 'prompt_token_ids_array') + retrieved_tokens = seq_data.prompt_token_ids_array + assert list(retrieved_tokens) == list(prompt_tokens) + + # Verify vision tokens can be detected + token_tensor = torch.tensor(list(retrieved_tokens)) + vision_mask = ((token_tensor >= VISION_START_TOKEN_ID) & + (token_tensor <= VISION_END_TOKEN_ID)) + assert vision_mask.any() + assert vision_mask.sum() == 3 # Start, middle, end tokens + + +class TestAccuracyValidation: + """Test accuracy against expected behavior.""" + + @pytest.mark.parametrize("text", [ + "Short text", + "A much longer text that contains multiple sentences for testing", + "特殊字符测试 🚀 emoji test", "Numbers 12345 and symbols !@#$%" + ]) + def test_text_embedding_consistency(self, model, text): + """Test that same text produces consistent embeddings.""" + prompt = TextPrompt(prompt=f"Query: {text}") + + # Generate embeddings multiple times + embeddings = [] + for _ in range(3): + outputs = model.encode([prompt]) + emb = extract_embeddings(outputs[0]) + embeddings.append(emb) + + # All should be identical + for i in range(1, len(embeddings)): + assert torch.allclose(embeddings[0], embeddings[i], atol=1e-5) + + def test_vision_only_pooling(self, model): + """Test that vision pooling extracts only vision tokens.""" + # Create an image with known characteristics + image = Image.new('RGB', (224, 224), color='red') + + # Two prompts with same image but different text + prompt1 = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Red image<|im_end|>\n", + multi_modal_data={"image": image}, + ) + prompt2 = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|>" + "<|vision_end|>Blue sky green grass<|im_end|>\n", + multi_modal_data={"image": image}, + ) + + outputs = model.encode([prompt1, prompt2]) + emb1 = extract_embeddings(outputs[0]) + emb2 = extract_embeddings(outputs[1]) + + # Since both use the same image and vision-only pooling, + # embeddings should be very similar despite different text + similarity = torch.dot(emb1, emb2).item() + assert similarity > 0.99 # Should be nearly identical diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index b378a3db032..278243e14fe 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -15,9 +15,12 @@ PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.triton_utils import tl, triton from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata +HAS_TRITON = triton is not None + PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] @@ -658,3 +661,41 @@ def forward( ]) return build_output(scores) + + +if HAS_TRITON: + + @triton.jit + def extract_vision_tokens_kernel( + hidden_states_ptr, + token_ids_ptr, + output_ptr, + seq_start, + seq_len, + hidden_size, + vision_start_id: tl.constexpr, + vision_end_id: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + """Triton kernel to extract and pool vision tokens efficiently.""" + pid = tl.program_id(0) + + if pid >= hidden_size: + return + + # Find vision token range + vision_count = 0 + accumulator = 0.0 + + for i in range(seq_len): + token_id = tl.load(token_ids_ptr + seq_start + i) + if token_id >= vision_start_id and token_id <= vision_end_id: + hidden_val = tl.load(hidden_states_ptr + + (seq_start + i) * hidden_size + pid) + accumulator += hidden_val + vision_count += 1 + + # Store mean pooled result + result = accumulator / vision_count if vision_count > 0 else 0.0 + + tl.store(output_ptr + pid, result) diff --git a/vllm/model_executor/models/jina_embeddings_v4.py b/vllm/model_executor/models/jina_embeddings_v4.py new file mode 100644 index 00000000000..5dc47d41867 --- /dev/null +++ b/vllm/model_executor/models/jina_embeddings_v4.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +from array import array +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn.functional as F + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingType, + extract_vision_tokens_kernel) +# yapf: disable +from vllm.model_executor.pooling_metadata import ( + PoolingMetadata as V0PoolingMetadata) +from vllm.model_executor.pooling_metadata import PoolingTensors +# yapf: enable +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata + +from .interfaces import SupportsCrossEncoding, SupportsMultiModal +from .qwen2_vl import (Qwen2VLDummyInputsBuilder, + Qwen2VLForConditionalGeneration, + Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix + +logger = init_logger(__name__) + +# Vision token IDs for Jina V4 +VISION_START_TOKEN_ID = 151652 +VISION_END_TOKEN_ID = 151653 + +PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] + + +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder) +class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, + SupportsCrossEncoding, SupportsMultiModal): + # Weight mapping for HuggingFace checkpoint compatibility + weight_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "language_model.model.", + "visual.": "visual.", + "lm_head.": "language_model.lm_head.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "qwen2_vl")) + + self.hidden_size = vllm_config.model_config.hf_config.hidden_size + pooler_config = vllm_config.model_config.pooler_config + self.observability_config = vllm_config.observability_config + + # Configuration for vision pooling backend + self.pooling_backend = getattr(vllm_config.model_config, + "jina_pooling_backend", "triton") + if self.pooling_backend not in ("triton", "pytorch"): + logger.warning( + "Invalid jina_pooling_backend '%s'. " + "Must be 'triton' or 'pytorch'. Defaulting to 'triton'.", + self.pooling_backend) + self.pooling_backend = "triton" + + # Initialize base pooler for fallback + self._base_pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.MEAN, + normalize=True, + softmax=False) + + # Performance tracking + self._pooling_time_ms = 0.0 + self._pooling_count = 0 + + logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") + + def _extract_token_ids_safe( + self, pooling_metadata: PoolingMetadata + ) -> tuple[list[array], list[int]]: + """Safely extract token IDs from pooling metadata.""" + token_ids_list: list[array] = [] + try: + if isinstance(pooling_metadata, V1PoolingMetadata): + # For V1, we get token IDs and sequence indices directly + for i, num in enumerate(pooling_metadata.prompt_lens): + token_ids = pooling_metadata.prompt_token_ids[ + i, :num].tolist() + token_ids_list.append(array('l', token_ids)) + + # V1 metadata does not have explicit seq_ids, so we use indices + seq_ids = list(range(len(token_ids_list))) + return token_ids_list, seq_ids + + # For V0, we extract from seq_groups and seq_data + seq_ids = [] + for seq_group, _ in pooling_metadata.seq_groups: + for seq_id in seq_group: + if seq_id not in pooling_metadata.seq_data: + logger.warning("Sequence %s not found in seq_data", + seq_id) + continue + + seq_data = pooling_metadata.seq_data[seq_id] + + # Get prompt token IDs safely + if hasattr(seq_data, 'prompt_token_ids_array'): + token_ids = seq_data.prompt_token_ids_array + elif hasattr(seq_data, '_prompt_token_ids'): + token_ids = seq_data._prompt_token_ids + else: + logger.warning("No token IDs found for sequence %s", + seq_id) + continue + + seq_ids.append(seq_id) + token_ids_list.append(token_ids) + + return token_ids_list, seq_ids + + except Exception as e: + logger.error( + "Error extracting token IDs: %s. " + "Extracted %d sequences before failure", e, + len(token_ids_list)) + raise + + def _apply_vision_pooling_optimized( + self, + hidden_states: torch.Tensor, + token_ids_list: list[array], + prompt_lens: torch.Tensor, + ) -> list[torch.Tensor]: + """Apply optimized vision token pooling using Triton kernels.""" + if not HAS_TRITON: + logger.debug( + "Triton not available, falling back to PyTorch implementation") + return self._apply_vision_pooling_pytorch(hidden_states, + token_ids_list, + prompt_lens) + + pooled_outputs = [] + offset = 0 + device = hidden_states.device + + for i, (token_ids, + prompt_len) in enumerate(zip(token_ids_list, prompt_lens)): + prompt_len = int(prompt_len.item()) + + # Convert token IDs to tensor + token_tensor = torch.tensor(list(token_ids), + dtype=torch.long, + device=device) + + # Allocate output tensor + output = torch.zeros(self.hidden_size, + device=device, + dtype=hidden_states.dtype) + + # Check for vision tokens + has_vision = torch.any((token_tensor >= VISION_START_TOKEN_ID) + & (token_tensor <= VISION_END_TOKEN_ID)) + + if has_vision: + # Use Triton kernel for vision token extraction + grid = (self.hidden_size, ) + extract_vision_tokens_kernel[grid]( + hidden_states, + token_tensor, + output, + offset, + prompt_len, + self.hidden_size, + VISION_START_TOKEN_ID, + VISION_END_TOKEN_ID, + BLOCK_SIZE=1024, + ) + else: + # Regular mean pooling for text + seq_states = hidden_states[offset:offset + prompt_len] + output = seq_states.mean(dim=0) + + # Normalize and handle potential NaNs by replacing with zeros + output = F.normalize(output, p=2, dim=-1, eps=1e-12) + pooled_outputs.append(output) + + offset += prompt_len + + return pooled_outputs + + def _apply_vision_pooling_pytorch( + self, + hidden_states: torch.Tensor, + token_ids_list: list[array], + prompt_lens: torch.Tensor, + ) -> list[torch.Tensor]: + """PyTorch fallback for vision token pooling.""" + pooled_outputs = [] + offset = 0 + + for token_ids, prompt_len in zip(token_ids_list, prompt_lens): + prompt_len = int(prompt_len.item()) + + # Extract sequence states and tokens + seq_states = hidden_states[offset:offset + prompt_len] + + # Convert array to tensor for processing + seq_tokens = torch.tensor(list(token_ids[:prompt_len]), + dtype=torch.long, + device=hidden_states.device) + + # Check for vision tokens + vision_mask = ((seq_tokens >= VISION_START_TOKEN_ID) & + (seq_tokens <= VISION_END_TOKEN_ID)) + + if vision_mask.any(): + # Pool only vision tokens + vision_states = seq_states[vision_mask] + if vision_states.numel() == 0: + logger.warning( + "No vision states found despite vision mask") + pooled = seq_states.mean(dim=0) + else: + pooled = vision_states.mean(dim=0) + else: + # Pool all tokens for text + pooled = seq_states.mean(dim=0) + + # Normalize embeddings + pooled = F.normalize(pooled, p=2, dim=-1, eps=1e-12) + pooled_outputs.append(pooled) + + offset += prompt_len + + return pooled_outputs + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + """Thread-safe pooler with production error handling.""" + start_time = time.time() if self.observability_config else None + + # Validate inputs + if hidden_states is None or hidden_states.numel() == 0: + logger.warning("Empty hidden states received") + return PoolerOutput(outputs=[]) + + # Extract token IDs safely from metadata + token_ids_list, seq_ids = self._extract_token_ids_safe( + pooling_metadata) + + if not token_ids_list: + logger.warning("No valid sequences found for pooling") + # Fallback to base pooler + return self._base_pooler(hidden_states, pooling_metadata) + + # Get prompt lengths based on metadata type + if isinstance(pooling_metadata, V1PoolingMetadata): + prompt_lens = pooling_metadata.prompt_lens + else: + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + # Validate lengths match + assert len(token_ids_list) == len(prompt_lens), ( + f"Mismatch: {len(token_ids_list)} sequences vs " + f"{len(prompt_lens)} lengths") + + # Apply pooling based on configured backend + if self.pooling_backend == "triton": + pooled_data = self._apply_vision_pooling_optimized( + hidden_states, token_ids_list, prompt_lens) + else: # self.pooling_backend == "pytorch" + pooled_data = self._apply_vision_pooling_pytorch( + hidden_states, token_ids_list, prompt_lens) + + # Build output + pooled_outputs = [ + PoolingSequenceGroupOutput(data) for data in pooled_data + ] + + # Record metrics + if self.observability_config: + elapsed_ms = (time.time() - start_time) * 1000 + self._pooling_time_ms += elapsed_ms + self._pooling_count += 1 + + if self._pooling_count % 100 == 0: + avg_time = self._pooling_time_ms / self._pooling_count + logger.debug("Average pooling time: %.2fms", avg_time) + + return PoolerOutput(outputs=pooled_outputs) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + """Load weights with validation and error handling.""" + loader = AutoWeightsLoader(self) + loaded_weights = loader.load_weights(weights, + mapper=self.weight_mapper) + return loaded_weights diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index bc936500bdc..b26190a84f3 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -170,6 +170,8 @@ # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), + # Multimodal embedding model with token-type-aware pooling + "JinaVLForEmbedding": ("jina_embeddings_v4", "JinaVLForEmbedding"), } _CROSS_ENCODER_MODELS = {