1
1
"""
2
- Copyright (c) 2024 Intel Corporation
2
+ Copyright (c) 2024-2025 Intel Corporation
3
3
4
4
Licensed under the Apache License, Version 2.0 (the "License");
5
5
you may not use this file except in compliance with the License.
15
15
"""
16
16
import os
17
17
import numpy as np
18
-
18
+ from PIL import Image
19
19
from .base_custom_evaluator import BaseCustomEvaluator
20
20
from .base_models import BaseCascadeModel
21
21
from ...config import ConfigError
30
30
31
31
try :
32
32
import open_clip
33
- except ImportError as error :
34
- open_clip = UnsupportedPackage ('open_clip' , error .msg )
33
+ except ImportError as clip_error :
34
+ open_clip = UnsupportedPackage ('open_clip' , clip_error .msg )
35
+
36
+ try :
37
+ from transformers import AutoModel , AutoTokenizer
38
+ except ImportError as transformers_error :
39
+ AutoModel = UnsupportedPackage ('AutoModel' , transformers_error .msg )
40
+ AutoTokenizer = UnsupportedPackage ('AutoTokenizer' , transformers_error .msg )
35
41
42
+ try :
43
+ import torch
44
+ import torch .nn .functional as F
45
+ except ImportError as torch_error :
46
+ torch = UnsupportedPackage ("torch" , torch_error .msg )
36
47
37
48
class OpenVinoClipEvaluator (BaseCustomEvaluator ):
38
49
def __init__ (self , dataset_config , launcher , model , orig_config ):
@@ -42,8 +53,7 @@ def __init__(self, dataset_config, launcher, model, orig_config):
42
53
@classmethod
43
54
def from_configs (cls , config , delayed_model_loading = False , orig_config = None ):
44
55
dataset_config , launcher , _ = cls .get_dataset_and_launcher_info (config )
45
-
46
- model = OpenVinoClipModel (
56
+ model = OpenVinoClipVitModel (
47
57
config .get ('network_info' , {}), launcher , config .get ('_models' , []),
48
58
config .get ('_model_is_blob' ),
49
59
delayed_model_loading , config
@@ -69,43 +79,52 @@ def _process(self, output_callback, calculate_metrics, progress_reporter, metric
69
79
self ._update_progress (progress_reporter , metric_config , batch_id , len (batch_prediction ), csv_file )
70
80
71
81
72
- class OpenVinoClipModel (BaseCascadeModel ):
82
+ class OpenVinoJinaClipEvaluator (OpenVinoClipEvaluator ):
83
+ @classmethod
84
+ def from_configs (cls , config , delayed_model_loading = False , orig_config = None ):
85
+ if config ['launchers' ][0 ]['framework' ] == 'pytorch' :
86
+ dataset_config , launcher = config ["datasets" ], None
87
+ delayed_model_loading = False
88
+ else :
89
+ dataset_config , launcher , _ = cls .get_dataset_and_launcher_info (config )
90
+
91
+ model = OpenVinoJinaClipModel (
92
+ config .get ('network_info' , {}), launcher , config .get ('_models' , []),
93
+ config .get ('_model_is_blob' ),
94
+ delayed_model_loading , config
95
+ )
96
+ return cls (dataset_config , launcher , model , orig_config )
97
+
98
+
99
+ class BaseOpenVinoClipModel (BaseCascadeModel ):
73
100
def __init__ (self , network_info , launcher , models_args , is_blob , delayed_model_loading = False , config = None ):
74
101
super ().__init__ (network_info , launcher , delayed_model_loading )
75
102
self .network_info = network_info
76
103
self .launcher = launcher
77
104
self .config = config or {}
78
- parts = ['text_encoder' , 'image_encoder' ]
79
- network_info = self .fill_part_with_model (network_info , parts , models_args , False , delayed_model_loading )
80
- if not contains_all (network_info , parts ) and not delayed_model_loading :
81
- raise ConfigError ('configuration for text_encoder/image_encoder does not exist' )
105
+ self .templates_file = None
106
+ self .parameters_file = None
107
+ self .templates = ["a photo of a {classname}" ]
108
+ self .parts = network_info .keys ()
109
+ if launcher :
110
+ network_info = self .fill_part_with_model (network_info , self .parts ,
111
+ models_args , False , delayed_model_loading )
112
+ if not contains_all (network_info , self .parts ) and not delayed_model_loading :
113
+ raise ConfigError ('configuration for text_encoder/image_encoder does not exist' )
82
114
if not delayed_model_loading :
83
115
self .create_pipeline (launcher , network_info )
84
116
85
117
def create_pipeline (self , launcher , network_info ):
86
- orig_model_name = self .config .get ("orig_model_name" , "ViT-B-16-plus-240" )
87
- self .load_models (network_info , launcher , True )
88
-
89
- self .text_encoder = launcher .ie_core .compile_model (self .text_encoder_model , launcher .device )
90
- self .image_encoder = launcher .ie_core .compile_model (self .image_encoder_model , launcher .device )
91
-
92
- unet_shapes = [inp .get_partial_shape () for inp in self .text_encoder_model .inputs ]
93
- if unet_shapes [0 ][0 ].is_dynamic :
94
- self .templates_file = self .config .get ("templates" , "zeroshot_classification_templates.json" )
95
- else :
96
- self .templates_file = None
118
+ raise NotImplementedError ("Subclasses should implement this method" )
97
119
98
- self .classnames_file = self .config .get ("classnames" , "classnames.json" )
99
- self .parameters_file = self .config .get ("pretrained_model_params" , None )
100
- self .tokenizer = open_clip .get_tokenizer (orig_model_name )
120
+ def get_logits (self , image_features , zeroshot_weights ):
121
+ raise NotImplementedError ("Subclasses should implement this method" )
101
122
102
123
def predict (self , identifiers , input_data , zeroshot_weights ):
103
124
preds = []
104
125
for idx , image_data in zip (identifiers , input_data ):
105
- image = np .expand_dims (image_data , axis = 0 )
106
- image_features = self .encode_image (image )
107
- image_features = self .normalize (image_features , axis = - 1 )
108
- logits = 100. * image_features @ zeroshot_weights
126
+ image_features = self .encode_image (image_data )
127
+ logits = self .get_logits (image_features , zeroshot_weights )
109
128
preds .append (ClassificationPrediction (idx , np .squeeze (logits , axis = 0 )))
110
129
return None , preds
111
130
@@ -116,23 +135,11 @@ def get_network(self):
116
135
model_list .append ({"name" : model_part_name , "model" : model })
117
136
return model_list
118
137
119
- def encode_image (self , image ):
120
- features = self .image_encoder (image )
121
- return features [self .image_encoder .output ()]
138
+ def encode_image (self , image_data ):
139
+ raise NotImplementedError ("Subclasses should implement this method" )
122
140
123
141
def encode_text (self , texts , params ):
124
- text = self .tokenizer (texts ).to ('cpu' )
125
- indices = text .detach ().cpu ().numpy ()
126
-
127
- x = params ['token_embedding' ][indices ]
128
- x = x + params ['positional_embedding' ]
129
- x = x .transpose (1 , 0 , 2 )
130
- x = self .text_encoder ((x , params ['attn_mask' ]))
131
- x = x [self .text_encoder .output ()]
132
- x = x .transpose (1 , 0 , 2 )
133
- x = self .layer_norm (x , params ['gamma' ], params ['beta' ])
134
- x = x [np .arange (x .shape [0 ]), np .argmax (indices , axis = - 1 )] @ params ['text_projection' ]
135
- return x
142
+ raise NotImplementedError ("Subclasses should implement this method" )
136
143
137
144
@staticmethod
138
145
def get_pretrained_model_params (path ):
@@ -147,14 +154,20 @@ def get_pretrained_model_params(path):
147
154
params ['beta' ] = open_clip_params ['beta' ]
148
155
return params
149
156
157
+ def get_class_embeddings (self , texts , params ):
158
+ raise NotImplementedError ("Subclasses should implement this method" )
159
+
150
160
def zero_shot_classifier (self , data_source ):
151
161
classnames = read_json (os .path .join (data_source , self .classnames_file ))
152
162
if self .templates_file :
153
163
templates = read_json (os .path .join (data_source , self .templates_file ))
154
164
else :
155
- templates = ["a photo of a {c}" ]
165
+ templates = self .templates
166
+
167
+ params = None
168
+ if self .parameters_file :
169
+ params = self .get_pretrained_model_params (os .path .join (data_source , self .parameters_file ))
156
170
157
- params = self .get_pretrained_model_params (os .path .join (data_source , self .parameters_file ))
158
171
print_info ('Encoding zeroshot weights for {} imagenet classes' .format (len (classnames )))
159
172
160
173
zeroshot_weights = []
@@ -163,12 +176,9 @@ def zero_shot_classifier(self, data_source):
163
176
iterator = tqdm (classnames , mininterval = 2 )
164
177
165
178
for classname in iterator :
166
- texts = [template .format (c = classname ) for template in templates ]
167
- class_embeddings = self .encode_text (texts , params )
168
- class_embedding = self .normalize (class_embeddings , axis = - 1 )
169
- class_embedding = np .mean (class_embedding , axis = 0 )
170
- class_embedding /= np .linalg .norm (class_embedding , ord = 2 )
171
- zeroshot_weights .append (class_embedding )
179
+ texts = [template .format (classname = classname ) for template in templates ]
180
+ class_embeddings = self .get_class_embeddings (texts , params )
181
+ zeroshot_weights .append (class_embeddings )
172
182
return np .stack (zeroshot_weights , axis = 1 )
173
183
174
184
def load_models (self , network_info , launcher , log = False ):
@@ -192,7 +202,7 @@ def load_model(self, network_list, launcher):
192
202
setattr (self , "{}_model" .format (network_list ["name" ]), network )
193
203
194
204
def print_input_output_info (self ):
195
- model_parts = ( "text_encoder" , "image_encoder" )
205
+ model_parts = self . parts
196
206
for part in model_parts :
197
207
part_model_id = "{}_model" .format (part )
198
208
model = getattr (self , part_model_id , None )
@@ -218,3 +228,114 @@ def normalize(input_array, p=2, axis=-1, epsilon=1e-12):
218
228
norm = np .maximum (norm , epsilon )
219
229
normalized = input_array / norm
220
230
return normalized
231
+
232
+
233
+ class OpenVinoClipVitModel (BaseOpenVinoClipModel ):
234
+ def create_pipeline (self , launcher , network_info ):
235
+ orig_model_name = self .config .get ("orig_model_name" , "ViT-B-16-plus-240" )
236
+ self .load_models (network_info , launcher , True )
237
+ self .text_encoder = launcher .ie_core .compile_model (self .text_encoder_model , launcher .device )
238
+ self .image_encoder = launcher .ie_core .compile_model (self .image_encoder_model , launcher .device )
239
+ unet_shapes = [inp .get_partial_shape () for inp in self .text_encoder_model .inputs ]
240
+ if unet_shapes [0 ][0 ].is_dynamic :
241
+ self .templates_file = self .config .get ("templates" , "zeroshot_classification_templates.json" )
242
+
243
+ self .classnames_file = self .config .get ("classnames" , "classnames.json" )
244
+ self .parameters_file = self .config .get ("pretrained_model_params" , None )
245
+ self .tokenizer = open_clip .get_tokenizer (orig_model_name )
246
+
247
+ def get_logits (self , image_features , zeroshot_weights ):
248
+ image_features = self .normalize (image_features , axis = - 1 )
249
+ logits = 100. * image_features @ zeroshot_weights
250
+ return logits
251
+
252
+ def encode_image (self , image_data ):
253
+ image = np .expand_dims (image_data , axis = 0 )
254
+ features = self .image_encoder (image )
255
+ return features [self .image_encoder .output ()]
256
+
257
+ def encode_text (self , texts , params ):
258
+ text = self .tokenizer (texts ).to ('cpu' )
259
+ indices = text .detach ().cpu ().numpy ()
260
+
261
+ x = params ['token_embedding' ][indices ]
262
+ x = x + params ['positional_embedding' ]
263
+ x = x .transpose (1 , 0 , 2 )
264
+ x = self .text_encoder ((x , params ['attn_mask' ]))
265
+ x = x [self .text_encoder .output ()]
266
+ x = x .transpose (1 , 0 , 2 )
267
+ x = self .layer_norm (x , params ['gamma' ], params ['beta' ])
268
+ x = x [np .arange (x .shape [0 ]), np .argmax (indices , axis = - 1 )] @ params ['text_projection' ]
269
+ return x
270
+
271
+ def get_class_embeddings (self , texts , params ):
272
+ class_embeddings = self .encode_text (texts , params )
273
+ class_embedding = self .normalize (class_embeddings , axis = - 1 )
274
+ class_embedding = np .mean (class_embedding , axis = 0 )
275
+ class_embedding /= np .linalg .norm (class_embedding , ord = 2 )
276
+ return class_embedding
277
+
278
+
279
+ class OpenVinoJinaClipModel (BaseOpenVinoClipModel ):
280
+ def create_pipeline (self , launcher , network_info ):
281
+ if isinstance (AutoTokenizer , UnsupportedPackage ):
282
+ AutoTokenizer .raise_error (self .__class__ .__name__ )
283
+ if isinstance (AutoModel , UnsupportedPackage ):
284
+ AutoModel .raise_error (self .__class__ .__name__ )
285
+ if isinstance (torch , UnsupportedPackage ):
286
+ torch .raise_error (self .__class__ .__name__ )
287
+
288
+ orig_model_name = self .config .get ("orig_model_name" , "jinaai/jina-clip-v1" )
289
+
290
+ model = AutoModel .from_pretrained (orig_model_name , trust_remote_code = True )
291
+ if launcher :
292
+ self .load_models (network_info , launcher , True )
293
+ self .text_encoder = launcher .ie_core .compile_model (self .text_model , launcher .device )
294
+ self .vision_encoder = launcher .ie_core .compile_model (self .vision_model , launcher .device )
295
+ else :
296
+ self .text_encoder = model .text_model
297
+ self .vision_encoder = model .vision_model
298
+
299
+ self .templates = ["{classname}" ]
300
+ self .classnames_file = self .config .get ("classnames" , "classnames.json" )
301
+ self .tokenizer = AutoTokenizer .from_pretrained (orig_model_name , trust_remote_code = True )
302
+ self .processor = model .get_preprocess ()
303
+
304
+ def encode_image (self , image_data ):
305
+ image = Image .fromarray (image_data )
306
+ vision_input = self .processor (images = [image ], return_tensors = "pt" )
307
+ image_embeddings = self .vision_encoder (vision_input ["pixel_values" ])
308
+
309
+ if isinstance (image_embeddings , torch .Tensor ):
310
+ image_embeddings = image_embeddings .detach ().numpy ()
311
+ else :
312
+ image_embeddings = image_embeddings [0 ]
313
+
314
+ return image_embeddings
315
+
316
+ def encode_text (self , text_input ):
317
+ text_embeddings = self .text_encoder (text_input ["input_ids" ])
318
+
319
+ if isinstance (text_embeddings , torch .Tensor ):
320
+ text_embeddings = text_embeddings .detach ().numpy ()
321
+ else :
322
+ text_embeddings = text_embeddings [0 ]
323
+ return text_embeddings
324
+
325
+ def get_logits (self , image_features , zeroshot_weights ):
326
+ text_embeddings = np .squeeze (zeroshot_weights )
327
+ similarity = []
328
+ for emb1 in image_features :
329
+ temp_similarity = []
330
+ for emb2 in text_embeddings :
331
+ temp_similarity .append (emb1 @ emb2 )
332
+ similarity .append (temp_similarity )
333
+
334
+ similarity_tensor = torch .tensor (similarity )
335
+ logits = 100. * F .softmax (similarity_tensor , dim = - 1 ).numpy ()
336
+ return logits
337
+
338
+ def get_class_embeddings (self , texts , params ):
339
+ text_input = self .tokenizer (texts , return_tensors = "pt" , padding = "max_length" ,
340
+ max_length = 512 , truncation = True ).to ("cpu" )
341
+ return self .encode_text (text_input )
0 commit comments