|
14 | 14 | limitations under the License.
|
15 | 15 | """
|
16 | 16 |
|
17 |
| -from pathlib import Path |
18 | 17 | import inspect
|
19 | 18 | from typing import Union, List, Optional, Dict
|
20 | 19 | import numpy as np
|
21 | 20 | import cv2
|
22 | 21 | import PIL
|
23 | 22 | from .base_custom_evaluator import BaseCustomEvaluator
|
24 | 23 | from .base_models import BaseCascadeModel
|
25 |
| -from ...config import ConfigError |
26 |
| -from ...utils import UnsupportedPackage, extract_image_representations, get_path |
| 24 | +from ...utils import UnsupportedPackage, extract_image_representations |
27 | 25 | from ...representation import Text2ImageGenerationPrediction
|
28 |
| -from ...logging import print_info |
29 |
| - |
30 | 26 |
|
31 | 27 | try:
|
32 | 28 | from diffusers import DiffusionPipeline
|
@@ -68,45 +64,83 @@ def create_pipeline(self, launcher, netowrk_info=None):
|
68 | 64 | scheduler_config = self.config.get("scheduler_config", {})
|
69 | 65 | scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
|
70 | 66 | netowrk_info = netowrk_info or self.network_info
|
| 67 | + |
| 68 | + self.load_models(netowrk_info, launcher, True) |
| 69 | + compiled_models = self.get_compiled_models(launcher) |
| 70 | + |
71 | 71 | self.pipe = OVStableDiffusionPipeline(
|
72 |
| - launcher, tokenizer, scheduler, self.network_info, |
| 72 | + launcher, tokenizer, scheduler, |
| 73 | + models_dict = compiled_models, |
73 | 74 | seed=self.seed, num_inference_steps=self.num_steps)
|
74 | 75 |
|
| 76 | + def release(self): |
| 77 | + del self.pipe |
| 78 | + |
| 79 | + def load_models(self, model_info, launcher, log=False): |
| 80 | + if isinstance(model_info, dict): |
| 81 | + for model_name, model_dict in model_info.items(): |
| 82 | + model_dict["name"] = model_name |
| 83 | + self.load_model(model_dict, launcher) |
| 84 | + else: |
| 85 | + for model_dict in model_info: |
| 86 | + self.load_model(model_dict, launcher) |
| 87 | + |
| 88 | + if log: |
| 89 | + self.print_input_output_info() |
| 90 | + |
| 91 | + |
| 92 | + def load_model(self, network_info, launcher): |
| 93 | + model, weights = self.automatic_model_search(network_info) |
| 94 | + if weights: |
| 95 | + network = launcher.read_network(str(model), str(weights)) |
| 96 | + else: |
| 97 | + network = launcher.read_network(str(model), None) |
| 98 | + setattr(self, f"{network_info['name']}_model", network) |
| 99 | + |
| 100 | + def print_input_output_info(self): |
| 101 | + model_parts = ("text_encoder", "unet", "vae_decoder", "vae_encoder") |
| 102 | + for part in model_parts: |
| 103 | + part_model_id = f"{part}_model" |
| 104 | + model = getattr(self, part_model_id, None) |
| 105 | + if model is not None: |
| 106 | + self.launcher.print_input_output_info(model, part) |
| 107 | + |
| 108 | + def get_compiled_models(self, launcher): |
| 109 | + unet_shapes = [inp.get_partial_shape() for inp in self.unet_model.inputs] |
| 110 | + if not unet_shapes[0][0].is_dynamic: |
| 111 | + unet_shapes = [inp.get_partial_shape() for inp in self.unet_model.inputs] |
| 112 | + unet_shapes[0][0] = -1 |
| 113 | + unet_shapes[2][0] = -1 |
| 114 | + self.unet_model.reshape(dict(zip(self.unet_model.inputs, unet_shapes))) |
| 115 | + height = unet_shapes[0][2].get_length() * 8 if not unet_shapes[0][2].is_dynamic else 512 |
| 116 | + width = unet_shapes[0][3].get_length() * 8 if not unet_shapes[0][3].is_dynamic else 512 |
| 117 | + unet = launcher.ie_core.compile_model(self.unet_model, launcher.device) |
| 118 | + text_encoder = launcher.ie_core.compile_model(self.text_encoder_model, launcher.device) |
| 119 | + vae_decoder = launcher.ie_core.compile_model(self.vae_decoder_model, launcher.device) |
| 120 | + vae_encoder = None |
| 121 | + if self.vae_encoder_model is not None: |
| 122 | + vae_encoder = launcher.ie_core.compile_model(self.vae_encoder_model, launcher.device) |
| 123 | + |
| 124 | + return { "unet": unet, |
| 125 | + "unet_shape" : (height, width), |
| 126 | + "text_encoder": text_encoder, |
| 127 | + "vae_decoder": vae_decoder, |
| 128 | + "vae_encoder": vae_encoder } |
| 129 | + |
| 130 | + def reset_compiled_models(self): |
| 131 | + self.unet = None |
| 132 | + self.text_encoder = None |
| 133 | + self.vae_decoder = None |
| 134 | + self.vae_encoder = None |
| 135 | + |
| 136 | + |
75 | 137 | def predict(self, identifiers, input_data, input_meta):
|
76 | 138 | preds = []
|
77 | 139 | for idx, prompt in zip(identifiers, input_data):
|
78 | 140 | pred = self.pipe(prompt, output_type="np")["sample"][0]
|
79 | 141 | preds.append(Text2ImageGenerationPrediction(idx, pred))
|
80 | 142 | return None, preds
|
81 | 143 |
|
82 |
| - def release(self): |
83 |
| - del self.pipe |
84 |
| - |
85 |
| - def load_network(self, network_list, launcher): |
86 |
| - if self.pipe is None: |
87 |
| - self.create_pipeline(launcher, network_list) |
88 |
| - return |
89 |
| - self.pipe.reset_compiled_models() |
90 |
| - for network_dict in network_list: |
91 |
| - self.pipe.load_network(network_dict["model"], network_dict["name"]) |
92 |
| - self.pipe.compile(launcher) |
93 |
| - |
94 |
| - def load_model(self, network_list, launcher): |
95 |
| - if self.pipe is None: |
96 |
| - self.create_pipeline(launcher, network_list) |
97 |
| - return |
98 |
| - self.pipe.reset_compiled_models() |
99 |
| - for network_dict in network_list: |
100 |
| - self.pipe.load_model(network_dict, launcher) |
101 |
| - self.pipe.compile(launcher) |
102 |
| - |
103 |
| - def get_network(self): |
104 |
| - models = self.pipe.get_models() |
105 |
| - model_list = [] |
106 |
| - for model_part_name, model in models.items(): |
107 |
| - model_list.append({"name": model_part_name, "model": model}) |
108 |
| - return model_list |
109 |
| - |
110 | 144 |
|
111 | 145 | class StableDiffusionEvaluator(BaseCustomEvaluator):
|
112 | 146 | def __init__(self, dataset_config, launcher, model, orig_config):
|
@@ -165,56 +199,30 @@ def __init__(
|
165 | 199 | launcher: "BaseLauncher", # noqa: F821
|
166 | 200 | tokenizer: "CLIPTokenizer", # noqa: F821
|
167 | 201 | scheduler: Union["LMSDiscreteScheduler"], # noqa: F821
|
168 |
| - model_info: Dict, |
| 202 | + models_dict: Dict, |
169 | 203 | seed = None,
|
170 | 204 | num_inference_steps = 50
|
171 | 205 | ):
|
172 | 206 | super().__init__()
|
173 | 207 | self.scheduler = scheduler
|
174 | 208 | self.launcher = launcher
|
175 | 209 | self.tokenizer = tokenizer
|
176 |
| - # self.height = height |
177 |
| - # self.width = width |
178 |
| - self.load_models(model_info, launcher, True) |
179 |
| - self.compile(launcher) |
| 210 | + self.unet = models_dict.get('unet') |
| 211 | + self.text_encoder = models_dict.get('text_encoder') |
| 212 | + self.vae_decoder = models_dict.get('vae_decoder') |
| 213 | + self.vae_encoder = models_dict.get('vae_encoder') |
| 214 | + (self.height, self.width) = models_dict.get('unet_shape') |
| 215 | + self.set_models_outputs() |
180 | 216 | if seed is not None:
|
181 | 217 | np.random.seed(seed)
|
182 | 218 | self.num_inference_steps = num_inference_steps
|
183 | 219 |
|
184 |
| - def compile(self, launcher): |
185 |
| - unet_shapes = [inp.get_partial_shape() for inp in self.unet_model.inputs] |
186 |
| - if not unet_shapes[0][0].is_dynamic: |
187 |
| - unet_shapes = [inp.get_partial_shape() for inp in self.unet_model.inputs] |
188 |
| - unet_shapes[0][0] = -1 |
189 |
| - unet_shapes[2][0] = -1 |
190 |
| - self.unet_model.reshape(dict(zip(self.unet_model.inputs, unet_shapes))) |
191 |
| - self.unet = launcher.ie_core.compile_model(self.unet_model, launcher.device) |
192 |
| - self.text_encoder = launcher.ie_core.compile_model(self.text_encoder_model, launcher.device) |
193 |
| - self.vae_decoder = launcher.ie_core.compile_model(self.vae_decoder_model, launcher.device) |
194 |
| - if self.vae_encoder_model is not None: |
195 |
| - self.vae_encoder = launcher.ie_core.compile_model(self.vae_encoder_model, launcher.device) |
| 220 | + def set_models_outputs(self): |
196 | 221 | self._text_encoder_output = self.text_encoder.output(0)
|
197 | 222 | self._unet_output = self.unet.output(0)
|
198 | 223 | self._vae_d_output = self.vae_decoder.output(0)
|
199 | 224 | self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder is not None else None
|
200 |
| - self.height = unet_shapes[0][2].get_length() * 8 if not unet_shapes[0][2].is_dynamic else 512 |
201 |
| - self.width = unet_shapes[0][3].get_length() * 8 if not unet_shapes[0][3].is_dynamic else 512 |
202 |
| - |
203 |
| - def get_models(self): |
204 |
| - model_dict = {"text_encoder": self.text_encoder_model, "unet": self.unet_model, "vae_decoder": self.vae_decoder} |
205 |
| - if self.vae_encoder_model is not None: |
206 |
| - model_dict["vae_encoder"] = self.vae_encoder_model |
207 |
| - return model_dict |
208 | 225 |
|
209 |
| - def reset_compiled_models(self): |
210 |
| - self._text_encoder_output = None |
211 |
| - self._unet_output = None |
212 |
| - self._vae_d_output = None |
213 |
| - self._vae_e_output = None |
214 |
| - self.unet = None |
215 |
| - self.text_encoder = None |
216 |
| - self.vae_decoder = None |
217 |
| - self.vae_encoder = None |
218 | 226 |
|
219 | 227 | def __call__(
|
220 | 228 | self,
|
@@ -397,75 +405,6 @@ def get_timesteps(self, num_inference_steps: int, strength: float):
|
397 | 405 |
|
398 | 406 | return timesteps, num_inference_steps - t_start
|
399 | 407 |
|
400 |
| - def load_models(self, model_info, launcher, log=False): |
401 |
| - if isinstance(model_info, dict): |
402 |
| - for model_name, model_dict in model_info.items(): |
403 |
| - model_dict["name"] = model_name |
404 |
| - self.load_model(model_dict, launcher) |
405 |
| - else: |
406 |
| - for model_dict in model_info: |
407 |
| - self.load_model(model_dict, launcher) |
408 |
| - |
409 |
| - if log: |
410 |
| - self.print_input_output_info() |
411 |
| - |
412 |
| - def load_network(self, model, model_name): |
413 |
| - setattr(self, "{}_model".format(model_name), model) |
414 |
| - |
415 |
| - def load_model(self, network_info, launcher): |
416 |
| - model, weights = self.automatic_model_search(network_info) |
417 |
| - if weights: |
418 |
| - network = launcher.read_network(str(model), str(weights)) |
419 |
| - else: |
420 |
| - network = launcher.read_network(str(model), None) |
421 |
| - self.load_network(network, network_info["name"]) |
422 |
| - |
423 |
| - @staticmethod |
424 |
| - def automatic_model_search(network_info): |
425 |
| - model = Path(network_info['model']) |
426 |
| - model_name = network_info["name"] |
427 |
| - if model.is_dir(): |
428 |
| - is_blob = network_info.get('_model_is_blob') |
429 |
| - if is_blob: |
430 |
| - model_list = list(model.glob('*{}.blob'.format(model_name))) |
431 |
| - if not model_list: |
432 |
| - model_list = list(model.glob('*.blob')) |
433 |
| - else: |
434 |
| - model_list = list(model.glob('*{}*.xml'.format(model_name))) |
435 |
| - blob_list = list(model.glob('*{}*.blob'.format(model_name))) |
436 |
| - onnx_list = list(model.glob('*{}*.onnx'.format(model_name))) |
437 |
| - if not model_list and not blob_list and not onnx_list: |
438 |
| - model_list = list(model.glob('*.xml')) |
439 |
| - blob_list = list(model.glob('*.blob')) |
440 |
| - onnx_list = list(model.glob('*.onnx')) |
441 |
| - if not model_list: |
442 |
| - model_list = blob_list if blob_list else onnx_list |
443 |
| - if not model_list: |
444 |
| - raise ConfigError('Suitable model for {} not found'.format(model_name)) |
445 |
| - if len(model_list) > 1: |
446 |
| - raise ConfigError('Several suitable models for {} found'.format(model_name)) |
447 |
| - model = model_list[0] |
448 |
| - accepted_suffixes = ['.xml', '.onnx'] |
449 |
| - if model.suffix not in accepted_suffixes: |
450 |
| - raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes)) |
451 |
| - print_info('{} - Found model: {}'.format(model_name, model)) |
452 |
| - if model.suffix in ['.blob', '.onnx']: |
453 |
| - return model, None |
454 |
| - weights = get_path(network_info.get('weights', model.parent / model.name.replace('xml', 'bin'))) |
455 |
| - accepted_weights_suffixes = ['.bin'] |
456 |
| - if weights.suffix not in accepted_weights_suffixes: |
457 |
| - raise ConfigError('Weights with following suffixes are allowed: {}'.format(accepted_weights_suffixes)) |
458 |
| - print_info('{} - Found weights: {}'.format(model_name, weights)) |
459 |
| - return model, weights |
460 |
| - |
461 |
| - def print_input_output_info(self): |
462 |
| - model_parts = ("text_encoder", "unet", "vae_decoder", "vae_encoder") |
463 |
| - for part in model_parts: |
464 |
| - part_model_id = "{}_model".format(part) |
465 |
| - model = getattr(self, part_model_id, None) |
466 |
| - if model is not None: |
467 |
| - self.launcher.print_input_output_info(model, part) |
468 |
| - |
469 | 408 | @staticmethod
|
470 | 409 | def get_w_embedding(w, embedding_dim=512, dtype=torch.float32):
|
471 | 410 | """
|
|
0 commit comments