1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """
4
+ Benchmark and validate Jina Embeddings V4 against HuggingFace implementation.
5
+
6
+ This script compares embeddings generated by vLLM vs HuggingFace to ensure
7
+ accuracy and measure performance differences.
8
+ """
9
+
10
+ import argparse
11
+ import time
12
+ from typing import List , Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image
17
+ from transformers import AutoModel , AutoProcessor
18
+
19
+ from vllm import LLM
20
+ from vllm .config import PoolerConfig
21
+
22
+ # Vision token IDs
23
+ VISION_START_TOKEN_ID = 151652
24
+ VISION_END_TOKEN_ID = 151653
25
+ from vllm .inputs .data import TextPrompt
26
+
27
+
28
+ def create_test_cases () -> List [Tuple [str , str , any ]]:
29
+ """Create comprehensive test cases for validation."""
30
+ test_cases = []
31
+
32
+ # Text-only test cases
33
+ test_cases .extend ([
34
+ ("text" , "Query: What is artificial intelligence?" , None ),
35
+ ("text" , "Passage: AI is a field of computer science focusing on creating intelligent machines." , None ),
36
+ ("text" , "Query: 你好世界" , None ), # Chinese text
37
+ ("text" , "Passage: " + " " .join (["word" ] * 100 ), None ), # Long text
38
+ ])
39
+
40
+ # Image test cases
41
+ for color in ["red" , "green" , "blue" ]:
42
+ img = Image .new ('RGB' , (224 , 224 ), color = color )
43
+ test_cases .append (("image" , f"{ color } image" , img ))
44
+
45
+ # Complex image
46
+ complex_img = Image .new ('RGB' , (224 , 224 ))
47
+ pixels = complex_img .load ()
48
+ for i in range (224 ):
49
+ for j in range (224 ):
50
+ pixels [i , j ] = (i % 256 , j % 256 , (i + j ) % 256 )
51
+ test_cases .append (("image" , "complex pattern" , complex_img ))
52
+
53
+ return test_cases
54
+
55
+
56
+ def compute_hf_embeddings (
57
+ model_name : str ,
58
+ test_cases : List [Tuple [str , str , any ]]
59
+ ) -> List [torch .Tensor ]:
60
+ """Compute embeddings using HuggingFace implementation."""
61
+ print ("Loading HuggingFace model..." )
62
+ model = AutoModel .from_pretrained (
63
+ model_name ,
64
+ trust_remote_code = True ,
65
+ torch_dtype = torch .float16
66
+ ).cuda ().eval ()
67
+
68
+ processor = AutoProcessor .from_pretrained (
69
+ model_name ,
70
+ trust_remote_code = True
71
+ )
72
+
73
+ embeddings = []
74
+
75
+ print ("Computing HuggingFace embeddings..." )
76
+ start_time = time .time ()
77
+
78
+ for case_type , text , image in test_cases :
79
+ if case_type == "text" :
80
+ inputs = processor (text = text , return_tensors = "pt" ).to ("cuda" )
81
+ else : # image
82
+ inputs = processor (
83
+ text = "<|im_start|>user\n <|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n " ,
84
+ images = image ,
85
+ return_tensors = "pt"
86
+ ).to ("cuda" )
87
+
88
+ with torch .no_grad ():
89
+ outputs = model (** inputs )
90
+ # Extract embeddings based on model output structure
91
+ if hasattr (outputs , 'embeddings' ):
92
+ embedding = outputs .embeddings [0 ]
93
+ else :
94
+ # Fallback to last hidden state with custom pooling
95
+ hidden_states = outputs .last_hidden_state [0 ]
96
+
97
+ # Apply token-type-aware pooling
98
+ input_ids = inputs ['input_ids' ][0 ]
99
+ vision_mask = (
100
+ (input_ids >= VISION_START_TOKEN_ID ) &
101
+ (input_ids <= VISION_END_TOKEN_ID )
102
+ )
103
+
104
+ if vision_mask .any ():
105
+ embedding = hidden_states [vision_mask ].mean (dim = 0 )
106
+ else :
107
+ embedding = hidden_states .mean (dim = 0 )
108
+
109
+ embedding = torch .nn .functional .normalize (embedding , p = 2 , dim = - 1 )
110
+
111
+ embeddings .append (embedding .cpu ())
112
+
113
+ hf_time = time .time () - start_time
114
+ print (f"HuggingFace processing time: { hf_time :.2f} s" )
115
+
116
+ return embeddings
117
+
118
+
119
+ def compute_vllm_embeddings (
120
+ model_name : str ,
121
+ test_cases : List [Tuple [str , str , any ]]
122
+ ) -> List [torch .Tensor ]:
123
+ """Compute embeddings using vLLM implementation."""
124
+ print ("\n Loading vLLM model..." )
125
+ model = LLM (
126
+ model = model_name ,
127
+ task = "embed" ,
128
+ override_pooler_config = PoolerConfig (pooling_type = "ALL" , normalize = False ),
129
+ dtype = "float16" ,
130
+ )
131
+
132
+ embeddings = []
133
+ prompts = []
134
+
135
+ # Prepare prompts
136
+ for case_type , text , image in test_cases :
137
+ if case_type == "text" :
138
+ prompt = TextPrompt (prompt = text )
139
+ else : # image
140
+ prompt = TextPrompt (
141
+ prompt = "<|im_start|>user\n <|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n " ,
142
+ multi_modal_data = {"image" : image },
143
+ )
144
+ prompts .append (prompt )
145
+
146
+ print ("Computing vLLM embeddings..." )
147
+ start_time = time .time ()
148
+
149
+ # Process all at once for better performance
150
+ outputs = model .encode (prompts )
151
+
152
+ for output in outputs :
153
+ # Extract based on token type
154
+ if 151652 in output .prompt_token_ids : # VISION_START_TOKEN_ID
155
+ img_start = output .prompt_token_ids .index (151652 )
156
+ img_end = output .prompt_token_ids .index (151653 )
157
+ embedding_data = output .outputs .data [img_start :img_end + 1 ]
158
+ else :
159
+ embedding_data = output .outputs .data
160
+
161
+ # Pool and normalize
162
+ pooled = embedding_data .mean (dim = 0 , dtype = torch .float32 )
163
+ normalized = torch .nn .functional .normalize (pooled , p = 2 , dim = - 1 )
164
+ embeddings .append (normalized .cpu ())
165
+
166
+ vllm_time = time .time () - start_time
167
+ print (f"vLLM processing time: { vllm_time :.2f} s" )
168
+
169
+ return embeddings
170
+
171
+
172
+ def compare_embeddings (
173
+ hf_embeddings : List [torch .Tensor ],
174
+ vllm_embeddings : List [torch .Tensor ],
175
+ test_cases : List [Tuple [str , str , any ]]
176
+ ) -> None :
177
+ """Compare embeddings and report differences."""
178
+ print ("\n " + "=" * 60 )
179
+ print ("EMBEDDING COMPARISON RESULTS" )
180
+ print ("=" * 60 )
181
+
182
+ similarities = []
183
+ max_diffs = []
184
+
185
+ for i , (case_type , desc , _ ) in enumerate (test_cases ):
186
+ hf_emb = hf_embeddings [i ]
187
+ vllm_emb = vllm_embeddings [i ]
188
+
189
+ # Compute cosine similarity
190
+ similarity = torch .nn .functional .cosine_similarity (
191
+ hf_emb .unsqueeze (0 ),
192
+ vllm_emb .unsqueeze (0 )
193
+ ).item ()
194
+
195
+ # Compute max absolute difference
196
+ max_diff = torch .max (torch .abs (hf_emb - vllm_emb )).item ()
197
+
198
+ similarities .append (similarity )
199
+ max_diffs .append (max_diff )
200
+
201
+ print (f"\n Test case { i + 1 } : { case_type } - { desc [:50 ]} ..." )
202
+ print (f" Cosine similarity: { similarity :.6f} " )
203
+ print (f" Max absolute diff: { max_diff :.6f} " )
204
+ print (f" HF norm: { hf_emb .norm ():.6f} , vLLM norm: { vllm_emb .norm ():.6f} " )
205
+
206
+ # Flag significant differences
207
+ if similarity < 0.99 :
208
+ print (f" ⚠️ WARNING: Low similarity detected!" )
209
+
210
+ # Summary statistics
211
+ print ("\n " + "-" * 60 )
212
+ print ("SUMMARY STATISTICS" )
213
+ print ("-" * 60 )
214
+ print (f"Average cosine similarity: { np .mean (similarities ):.6f} " )
215
+ print (f"Min cosine similarity: { np .min (similarities ):.6f} " )
216
+ print (f"Max absolute difference: { np .max (max_diffs ):.6f} " )
217
+
218
+ # Overall assessment
219
+ if np .min (similarities ) > 0.99 :
220
+ print ("\n ✅ VALIDATION PASSED: vLLM implementation matches HuggingFace" )
221
+ else :
222
+ print ("\n ❌ VALIDATION FAILED: Significant differences detected" )
223
+
224
+
225
+ def main ():
226
+ parser = argparse .ArgumentParser (
227
+ description = "Validate Jina Embeddings V4 implementation"
228
+ )
229
+ parser .add_argument (
230
+ "--model" ,
231
+ type = str ,
232
+ default = "jinaai/jina-embeddings-v4-vllm-retrieval" ,
233
+ help = "Model name to test"
234
+ )
235
+ parser .add_argument (
236
+ "--skip-hf" ,
237
+ action = "store_true" ,
238
+ help = "Skip HuggingFace comparison (for performance testing only)"
239
+ )
240
+
241
+ args = parser .parse_args ()
242
+
243
+ # Create test cases
244
+ test_cases = create_test_cases ()
245
+ print (f"Created { len (test_cases )} test cases" )
246
+
247
+ # Compute vLLM embeddings
248
+ vllm_embeddings = compute_vllm_embeddings (args .model , test_cases )
249
+
250
+ if not args .skip_hf :
251
+ # Compute HuggingFace embeddings
252
+ hf_embeddings = compute_hf_embeddings (args .model , test_cases )
253
+
254
+ # Compare results
255
+ compare_embeddings (hf_embeddings , vllm_embeddings , test_cases )
256
+ else :
257
+ print ("\n Skipping HuggingFace comparison" )
258
+ print (f"vLLM processed { len (test_cases )} embeddings successfully" )
259
+
260
+
261
+ if __name__ == "__main__" :
262
+ main ()
0 commit comments