1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """
4
+ Benchmark and validate Jina Embeddings V4 against HuggingFace implementation.
5
+
6
+ This script compares embeddings generated by vLLM vs HuggingFace to ensure
7
+ accuracy and measure performance differences.
8
+ """
9
+
10
+ import argparse
11
+ import time
12
+ from typing import List , Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image
17
+ from transformers import AutoModel , AutoProcessor
18
+
19
+ from vllm import LLM
20
+ from vllm .config import PoolerConfig
21
+ from vllm .inputs .data import TextPrompt
22
+
23
+
24
+ def create_test_cases () -> List [Tuple [str , str , any ]]:
25
+ """Create comprehensive test cases for validation."""
26
+ test_cases = []
27
+
28
+ # Text-only test cases
29
+ test_cases .extend ([
30
+ ("text" , "Query: What is artificial intelligence?" , None ),
31
+ ("text" , "Passage: AI is a field of computer science focusing on creating intelligent machines." , None ),
32
+ ("text" , "Query: 你好世界" , None ), # Chinese text
33
+ ("text" , "Passage: " + " " .join (["word" ] * 100 ), None ), # Long text
34
+ ])
35
+
36
+ # Image test cases
37
+ for color in ["red" , "green" , "blue" ]:
38
+ img = Image .new ('RGB' , (224 , 224 ), color = color )
39
+ test_cases .append (("image" , f"{ color } image" , img ))
40
+
41
+ # Complex image
42
+ complex_img = Image .new ('RGB' , (224 , 224 ))
43
+ pixels = complex_img .load ()
44
+ for i in range (224 ):
45
+ for j in range (224 ):
46
+ pixels [i , j ] = (i % 256 , j % 256 , (i + j ) % 256 )
47
+ test_cases .append (("image" , "complex pattern" , complex_img ))
48
+
49
+ return test_cases
50
+
51
+
52
+ def compute_hf_embeddings (
53
+ model_name : str ,
54
+ test_cases : List [Tuple [str , str , any ]]
55
+ ) -> List [torch .Tensor ]:
56
+ """Compute embeddings using HuggingFace implementation."""
57
+ print ("Loading HuggingFace model..." )
58
+ model = AutoModel .from_pretrained (
59
+ model_name ,
60
+ trust_remote_code = True ,
61
+ torch_dtype = torch .float16
62
+ ).cuda ().eval ()
63
+
64
+ processor = AutoProcessor .from_pretrained (
65
+ model_name ,
66
+ trust_remote_code = True
67
+ )
68
+
69
+ embeddings = []
70
+
71
+ print ("Computing HuggingFace embeddings..." )
72
+ start_time = time .time ()
73
+
74
+ for case_type , text , image in test_cases :
75
+ if case_type == "text" :
76
+ inputs = processor (text = text , return_tensors = "pt" ).to ("cuda" )
77
+ else : # image
78
+ inputs = processor (
79
+ text = "<|im_start|>user\n <|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n " ,
80
+ images = image ,
81
+ return_tensors = "pt"
82
+ ).to ("cuda" )
83
+
84
+ with torch .no_grad ():
85
+ outputs = model (** inputs )
86
+ # Extract embeddings based on model output structure
87
+ if hasattr (outputs , 'embeddings' ):
88
+ embedding = outputs .embeddings [0 ]
89
+ else :
90
+ # Fallback to last hidden state with custom pooling
91
+ hidden_states = outputs .last_hidden_state [0 ]
92
+
93
+ # Apply token-type-aware pooling
94
+ input_ids = inputs ['input_ids' ][0 ]
95
+ vision_mask = (
96
+ (input_ids >= 151652 ) &
97
+ (input_ids <= 151653 )
98
+ )
99
+
100
+ if vision_mask .any ():
101
+ embedding = hidden_states [vision_mask ].mean (dim = 0 )
102
+ else :
103
+ embedding = hidden_states .mean (dim = 0 )
104
+
105
+ embedding = torch .nn .functional .normalize (embedding , p = 2 , dim = - 1 )
106
+
107
+ embeddings .append (embedding .cpu ())
108
+
109
+ hf_time = time .time () - start_time
110
+ print (f"HuggingFace processing time: { hf_time :.2f} s" )
111
+
112
+ return embeddings
113
+
114
+
115
+ def compute_vllm_embeddings (
116
+ model_name : str ,
117
+ test_cases : List [Tuple [str , str , any ]]
118
+ ) -> List [torch .Tensor ]:
119
+ """Compute embeddings using vLLM implementation."""
120
+ print ("\n Loading vLLM model..." )
121
+ model = LLM (
122
+ model = model_name ,
123
+ task = "embed" ,
124
+ override_pooler_config = PoolerConfig (pooling_type = "ALL" , normalize = False ),
125
+ dtype = "float16" ,
126
+ )
127
+
128
+ embeddings = []
129
+ prompts = []
130
+
131
+ # Prepare prompts
132
+ for case_type , text , image in test_cases :
133
+ if case_type == "text" :
134
+ prompt = TextPrompt (prompt = text )
135
+ else : # image
136
+ prompt = TextPrompt (
137
+ prompt = "<|im_start|>user\n <|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n " ,
138
+ multi_modal_data = {"image" : image },
139
+ )
140
+ prompts .append (prompt )
141
+
142
+ print ("Computing vLLM embeddings..." )
143
+ start_time = time .time ()
144
+
145
+ # Process all at once for better performance
146
+ outputs = model .encode (prompts )
147
+
148
+ for output in outputs :
149
+ # Extract based on token type
150
+ if 151652 in output .prompt_token_ids : # VISION_START_TOKEN_ID
151
+ img_start = output .prompt_token_ids .index (151652 )
152
+ img_end = output .prompt_token_ids .index (151653 )
153
+ embedding_data = output .outputs .data [img_start :img_end + 1 ]
154
+ else :
155
+ embedding_data = output .outputs .data
156
+
157
+ # Pool and normalize
158
+ pooled = embedding_data .mean (dim = 0 , dtype = torch .float32 )
159
+ normalized = torch .nn .functional .normalize (pooled , p = 2 , dim = - 1 )
160
+ embeddings .append (normalized .cpu ())
161
+
162
+ vllm_time = time .time () - start_time
163
+ print (f"vLLM processing time: { vllm_time :.2f} s" )
164
+
165
+ return embeddings
166
+
167
+
168
+ def compare_embeddings (
169
+ hf_embeddings : List [torch .Tensor ],
170
+ vllm_embeddings : List [torch .Tensor ],
171
+ test_cases : List [Tuple [str , str , any ]]
172
+ ) -> None :
173
+ """Compare embeddings and report differences."""
174
+ print ("\n " + "=" * 60 )
175
+ print ("EMBEDDING COMPARISON RESULTS" )
176
+ print ("=" * 60 )
177
+
178
+ similarities = []
179
+ max_diffs = []
180
+
181
+ for i , (case_type , desc , _ ) in enumerate (test_cases ):
182
+ hf_emb = hf_embeddings [i ]
183
+ vllm_emb = vllm_embeddings [i ]
184
+
185
+ # Compute cosine similarity
186
+ similarity = torch .nn .functional .cosine_similarity (
187
+ hf_emb .unsqueeze (0 ),
188
+ vllm_emb .unsqueeze (0 )
189
+ ).item ()
190
+
191
+ # Compute max absolute difference
192
+ max_diff = torch .max (torch .abs (hf_emb - vllm_emb )).item ()
193
+
194
+ similarities .append (similarity )
195
+ max_diffs .append (max_diff )
196
+
197
+ print (f"\n Test case { i + 1 } : { case_type } - { desc [:50 ]} ..." )
198
+ print (f" Cosine similarity: { similarity :.6f} " )
199
+ print (f" Max absolute diff: { max_diff :.6f} " )
200
+ print (f" HF norm: { hf_emb .norm ():.6f} , vLLM norm: { vllm_emb .norm ():.6f} " )
201
+
202
+ # Flag significant differences
203
+ if similarity < 0.99 :
204
+ print (f" ⚠️ WARNING: Low similarity detected!" )
205
+
206
+ # Summary statistics
207
+ print ("\n " + "-" * 60 )
208
+ print ("SUMMARY STATISTICS" )
209
+ print ("-" * 60 )
210
+ print (f"Average cosine similarity: { np .mean (similarities ):.6f} " )
211
+ print (f"Min cosine similarity: { np .min (similarities ):.6f} " )
212
+ print (f"Max absolute difference: { np .max (max_diffs ):.6f} " )
213
+
214
+ # Overall assessment
215
+ if np .min (similarities ) > 0.99 :
216
+ print ("\n ✅ VALIDATION PASSED: vLLM implementation matches HuggingFace" )
217
+ else :
218
+ print ("\n ❌ VALIDATION FAILED: Significant differences detected" )
219
+
220
+
221
+ def main ():
222
+ parser = argparse .ArgumentParser (
223
+ description = "Validate Jina Embeddings V4 implementation"
224
+ )
225
+ parser .add_argument (
226
+ "--model" ,
227
+ type = str ,
228
+ default = "jinaai/jina-embeddings-v4-vllm-retrieval" ,
229
+ help = "Model name to test"
230
+ )
231
+ parser .add_argument (
232
+ "--skip-hf" ,
233
+ action = "store_true" ,
234
+ help = "Skip HuggingFace comparison (for performance testing only)"
235
+ )
236
+
237
+ args = parser .parse_args ()
238
+
239
+ # Create test cases
240
+ test_cases = create_test_cases ()
241
+ print (f"Created { len (test_cases )} test cases" )
242
+
243
+ # Compute vLLM embeddings
244
+ vllm_embeddings = compute_vllm_embeddings (args .model , test_cases )
245
+
246
+ if not args .skip_hf :
247
+ # Compute HuggingFace embeddings
248
+ hf_embeddings = compute_hf_embeddings (args .model , test_cases )
249
+
250
+ # Compare results
251
+ compare_embeddings (hf_embeddings , vllm_embeddings , test_cases )
252
+ else :
253
+ print ("\n Skipping HuggingFace comparison" )
254
+ print (f"vLLM processed { len (test_cases )} embeddings successfully" )
255
+
256
+
257
+ if __name__ == "__main__" :
258
+ main ()
0 commit comments