Skip to content

Commit 15c89fc

Browse files
authored
Refactor of stable_diffusion_evaluator.py
1 parent c506dd9 commit 15c89fc

File tree

1 file changed

+76
-137
lines changed

1 file changed

+76
-137
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/stable_diffusion_evaluator.py

Lines changed: 76 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,15 @@
1414
limitations under the License.
1515
"""
1616

17-
from pathlib import Path
1817
import inspect
1918
from typing import Union, List, Optional, Dict
2019
import numpy as np
2120
import cv2
2221
import PIL
2322
from .base_custom_evaluator import BaseCustomEvaluator
2423
from .base_models import BaseCascadeModel
25-
from ...config import ConfigError
26-
from ...utils import UnsupportedPackage, extract_image_representations, get_path
24+
from ...utils import UnsupportedPackage, extract_image_representations
2725
from ...representation import Text2ImageGenerationPrediction
28-
from ...logging import print_info
29-
3026

3127
try:
3228
from diffusers import DiffusionPipeline
@@ -68,45 +64,83 @@ def create_pipeline(self, launcher, netowrk_info=None):
6864
scheduler_config = self.config.get("scheduler_config", {})
6965
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
7066
netowrk_info = netowrk_info or self.network_info
67+
68+
self.load_models(netowrk_info, launcher, True)
69+
compiled_models = self.get_compiled_models(launcher)
70+
7171
self.pipe = OVStableDiffusionPipeline(
72-
launcher, tokenizer, scheduler, self.network_info,
72+
launcher, tokenizer, scheduler,
73+
models_dict = compiled_models,
7374
seed=self.seed, num_inference_steps=self.num_steps)
7475

76+
def release(self):
77+
del self.pipe
78+
79+
def load_models(self, model_info, launcher, log=False):
80+
if isinstance(model_info, dict):
81+
for model_name, model_dict in model_info.items():
82+
model_dict["name"] = model_name
83+
self.load_model(model_dict, launcher)
84+
else:
85+
for model_dict in model_info:
86+
self.load_model(model_dict, launcher)
87+
88+
if log:
89+
self.print_input_output_info()
90+
91+
92+
def load_model(self, network_info, launcher):
93+
model, weights = self.automatic_model_search(network_info)
94+
if weights:
95+
network = launcher.read_network(str(model), str(weights))
96+
else:
97+
network = launcher.read_network(str(model), None)
98+
setattr(self, f"{network_info['name']}_model", network)
99+
100+
def print_input_output_info(self):
101+
model_parts = ("text_encoder", "unet", "vae_decoder", "vae_encoder")
102+
for part in model_parts:
103+
part_model_id = f"{part}_model"
104+
model = getattr(self, part_model_id, None)
105+
if model is not None:
106+
self.launcher.print_input_output_info(model, part)
107+
108+
def get_compiled_models(self, launcher):
109+
unet_shapes = [inp.get_partial_shape() for inp in self.unet_model.inputs]
110+
if not unet_shapes[0][0].is_dynamic:
111+
unet_shapes = [inp.get_partial_shape() for inp in self.unet_model.inputs]
112+
unet_shapes[0][0] = -1
113+
unet_shapes[2][0] = -1
114+
self.unet_model.reshape(dict(zip(self.unet_model.inputs, unet_shapes)))
115+
height = unet_shapes[0][2].get_length() * 8 if not unet_shapes[0][2].is_dynamic else 512
116+
width = unet_shapes[0][3].get_length() * 8 if not unet_shapes[0][3].is_dynamic else 512
117+
unet = launcher.ie_core.compile_model(self.unet_model, launcher.device)
118+
text_encoder = launcher.ie_core.compile_model(self.text_encoder_model, launcher.device)
119+
vae_decoder = launcher.ie_core.compile_model(self.vae_decoder_model, launcher.device)
120+
vae_encoder = None
121+
if self.vae_encoder_model is not None:
122+
vae_encoder = launcher.ie_core.compile_model(self.vae_encoder_model, launcher.device)
123+
124+
return { "unet": unet,
125+
"unet_shape" : (height, width),
126+
"text_encoder": text_encoder,
127+
"vae_decoder": vae_decoder,
128+
"vae_encoder": vae_encoder }
129+
130+
def reset_compiled_models(self):
131+
self.unet = None
132+
self.text_encoder = None
133+
self.vae_decoder = None
134+
self.vae_encoder = None
135+
136+
75137
def predict(self, identifiers, input_data, input_meta):
76138
preds = []
77139
for idx, prompt in zip(identifiers, input_data):
78140
pred = self.pipe(prompt, output_type="np")["sample"][0]
79141
preds.append(Text2ImageGenerationPrediction(idx, pred))
80142
return None, preds
81143

82-
def release(self):
83-
del self.pipe
84-
85-
def load_network(self, network_list, launcher):
86-
if self.pipe is None:
87-
self.create_pipeline(launcher, network_list)
88-
return
89-
self.pipe.reset_compiled_models()
90-
for network_dict in network_list:
91-
self.pipe.load_network(network_dict["model"], network_dict["name"])
92-
self.pipe.compile(launcher)
93-
94-
def load_model(self, network_list, launcher):
95-
if self.pipe is None:
96-
self.create_pipeline(launcher, network_list)
97-
return
98-
self.pipe.reset_compiled_models()
99-
for network_dict in network_list:
100-
self.pipe.load_model(network_dict, launcher)
101-
self.pipe.compile(launcher)
102-
103-
def get_network(self):
104-
models = self.pipe.get_models()
105-
model_list = []
106-
for model_part_name, model in models.items():
107-
model_list.append({"name": model_part_name, "model": model})
108-
return model_list
109-
110144

111145
class StableDiffusionEvaluator(BaseCustomEvaluator):
112146
def __init__(self, dataset_config, launcher, model, orig_config):
@@ -165,56 +199,30 @@ def __init__(
165199
launcher: "BaseLauncher", # noqa: F821
166200
tokenizer: "CLIPTokenizer", # noqa: F821
167201
scheduler: Union["LMSDiscreteScheduler"], # noqa: F821
168-
model_info: Dict,
202+
models_dict: Dict,
169203
seed = None,
170204
num_inference_steps = 50
171205
):
172206
super().__init__()
173207
self.scheduler = scheduler
174208
self.launcher = launcher
175209
self.tokenizer = tokenizer
176-
# self.height = height
177-
# self.width = width
178-
self.load_models(model_info, launcher, True)
179-
self.compile(launcher)
210+
self.unet = models_dict.get('unet')
211+
self.text_encoder = models_dict.get('text_encoder')
212+
self.vae_decoder = models_dict.get('vae_decoder')
213+
self.vae_encoder = models_dict.get('vae_encoder')
214+
(self.height, self.width) = models_dict.get('unet_shape')
215+
self.set_models_outputs()
180216
if seed is not None:
181217
np.random.seed(seed)
182218
self.num_inference_steps = num_inference_steps
183219

184-
def compile(self, launcher):
185-
unet_shapes = [inp.get_partial_shape() for inp in self.unet_model.inputs]
186-
if not unet_shapes[0][0].is_dynamic:
187-
unet_shapes = [inp.get_partial_shape() for inp in self.unet_model.inputs]
188-
unet_shapes[0][0] = -1
189-
unet_shapes[2][0] = -1
190-
self.unet_model.reshape(dict(zip(self.unet_model.inputs, unet_shapes)))
191-
self.unet = launcher.ie_core.compile_model(self.unet_model, launcher.device)
192-
self.text_encoder = launcher.ie_core.compile_model(self.text_encoder_model, launcher.device)
193-
self.vae_decoder = launcher.ie_core.compile_model(self.vae_decoder_model, launcher.device)
194-
if self.vae_encoder_model is not None:
195-
self.vae_encoder = launcher.ie_core.compile_model(self.vae_encoder_model, launcher.device)
220+
def set_models_outputs(self):
196221
self._text_encoder_output = self.text_encoder.output(0)
197222
self._unet_output = self.unet.output(0)
198223
self._vae_d_output = self.vae_decoder.output(0)
199224
self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder is not None else None
200-
self.height = unet_shapes[0][2].get_length() * 8 if not unet_shapes[0][2].is_dynamic else 512
201-
self.width = unet_shapes[0][3].get_length() * 8 if not unet_shapes[0][3].is_dynamic else 512
202-
203-
def get_models(self):
204-
model_dict = {"text_encoder": self.text_encoder_model, "unet": self.unet_model, "vae_decoder": self.vae_decoder}
205-
if self.vae_encoder_model is not None:
206-
model_dict["vae_encoder"] = self.vae_encoder_model
207-
return model_dict
208225

209-
def reset_compiled_models(self):
210-
self._text_encoder_output = None
211-
self._unet_output = None
212-
self._vae_d_output = None
213-
self._vae_e_output = None
214-
self.unet = None
215-
self.text_encoder = None
216-
self.vae_decoder = None
217-
self.vae_encoder = None
218226

219227
def __call__(
220228
self,
@@ -397,75 +405,6 @@ def get_timesteps(self, num_inference_steps: int, strength: float):
397405

398406
return timesteps, num_inference_steps - t_start
399407

400-
def load_models(self, model_info, launcher, log=False):
401-
if isinstance(model_info, dict):
402-
for model_name, model_dict in model_info.items():
403-
model_dict["name"] = model_name
404-
self.load_model(model_dict, launcher)
405-
else:
406-
for model_dict in model_info:
407-
self.load_model(model_dict, launcher)
408-
409-
if log:
410-
self.print_input_output_info()
411-
412-
def load_network(self, model, model_name):
413-
setattr(self, "{}_model".format(model_name), model)
414-
415-
def load_model(self, network_info, launcher):
416-
model, weights = self.automatic_model_search(network_info)
417-
if weights:
418-
network = launcher.read_network(str(model), str(weights))
419-
else:
420-
network = launcher.read_network(str(model), None)
421-
self.load_network(network, network_info["name"])
422-
423-
@staticmethod
424-
def automatic_model_search(network_info):
425-
model = Path(network_info['model'])
426-
model_name = network_info["name"]
427-
if model.is_dir():
428-
is_blob = network_info.get('_model_is_blob')
429-
if is_blob:
430-
model_list = list(model.glob('*{}.blob'.format(model_name)))
431-
if not model_list:
432-
model_list = list(model.glob('*.blob'))
433-
else:
434-
model_list = list(model.glob('*{}*.xml'.format(model_name)))
435-
blob_list = list(model.glob('*{}*.blob'.format(model_name)))
436-
onnx_list = list(model.glob('*{}*.onnx'.format(model_name)))
437-
if not model_list and not blob_list and not onnx_list:
438-
model_list = list(model.glob('*.xml'))
439-
blob_list = list(model.glob('*.blob'))
440-
onnx_list = list(model.glob('*.onnx'))
441-
if not model_list:
442-
model_list = blob_list if blob_list else onnx_list
443-
if not model_list:
444-
raise ConfigError('Suitable model for {} not found'.format(model_name))
445-
if len(model_list) > 1:
446-
raise ConfigError('Several suitable models for {} found'.format(model_name))
447-
model = model_list[0]
448-
accepted_suffixes = ['.xml', '.onnx']
449-
if model.suffix not in accepted_suffixes:
450-
raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes))
451-
print_info('{} - Found model: {}'.format(model_name, model))
452-
if model.suffix in ['.blob', '.onnx']:
453-
return model, None
454-
weights = get_path(network_info.get('weights', model.parent / model.name.replace('xml', 'bin')))
455-
accepted_weights_suffixes = ['.bin']
456-
if weights.suffix not in accepted_weights_suffixes:
457-
raise ConfigError('Weights with following suffixes are allowed: {}'.format(accepted_weights_suffixes))
458-
print_info('{} - Found weights: {}'.format(model_name, weights))
459-
return model, weights
460-
461-
def print_input_output_info(self):
462-
model_parts = ("text_encoder", "unet", "vae_decoder", "vae_encoder")
463-
for part in model_parts:
464-
part_model_id = "{}_model".format(part)
465-
model = getattr(self, part_model_id, None)
466-
if model is not None:
467-
self.launcher.print_input_output_info(model, part)
468-
469408
@staticmethod
470409
def get_w_embedding(w, embedding_dim=512, dtype=torch.float32):
471410
"""

0 commit comments

Comments
 (0)