Skip to content

Commit b5517bc

Browse files
authored
Add custom evaluator for LDM Super Resolution (#3985)
* Create ldm_super_resolution_evaluator.py * Refactor of stable_diffusion_evaluator.py * Add pre/post processing to base_custom_evaluator.py * Fix torch_utils import and setting torch seed * Update ldm_super_resolution_evaluator.py * Silence W0237 pylint warning
1 parent e7df86d commit b5517bc

File tree

3 files changed

+379
-142
lines changed

3 files changed

+379
-142
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/base_custom_evaluator.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,20 @@
1818
from ...progress_reporters import ProgressReporter
1919
from ..quantization_model_evaluator import create_dataset_attributes
2020
from ...launcher import create_launcher
21-
21+
from ...postprocessor import PostprocessingExecutor
22+
from ...preprocessor import PreprocessingExecutor
2223

2324
# base class for custom evaluators
2425
class BaseCustomEvaluator(BaseEvaluator):
25-
def __init__(self, dataset_config, launcher, orig_config):
26+
def __init__(self, dataset_config, launcher, orig_config,
27+
preprocessor=None, postprocessor=None):
2628
self.dataset_config = dataset_config
2729
self.dataset = None
30+
self.input_feeder = None
31+
self.adapter = None
2832
self.preprocessing_executor = None
29-
self.preprocessor = None
30-
self.postprocessor = None
33+
self.preprocessor = preprocessor
34+
self.postprocessor = postprocessor
3135
self.metric_executor = None
3236
self.launcher = launcher
3337
self._metrics_results = []
@@ -47,6 +51,30 @@ def get_dataset_and_launcher_info(config):
4751
launcher = create_launcher(launcher_config, delayed_model_loading=True)
4852
return dataset_config, launcher, launcher_config
4953

54+
@classmethod
55+
def get_evaluator_init_info(cls, model_config, delayed_annotation_loading=False):
56+
launcher_config = model_config['launchers'][0]
57+
datasets = model_config['datasets']
58+
dataset_config = datasets[0]
59+
dataset_name = dataset_config['name']
60+
61+
runtime_framework = launcher_config['framework']
62+
enable_runtime_preprocessing = False
63+
if runtime_framework in ['dlsdk', 'openvino']:
64+
enable_runtime_preprocessing = dataset_config.get('_ie_preprocessing', False)
65+
preprocessor = PreprocessingExecutor(
66+
dataset_config.get('preprocessing'), dataset_name,
67+
enable_runtime_preprocessing=enable_runtime_preprocessing, runtime_framework=runtime_framework
68+
)
69+
70+
if launcher_config['framework'] == 'dlsdk' and 'device' not in launcher_config:
71+
launcher_config['device'] = 'CPU'
72+
launcher = create_launcher(launcher_config, delayed_model_loading=True)
73+
dataset_metadata = {}
74+
postprocessor = PostprocessingExecutor(dataset_config.get('postprocessing'), dataset_name, dataset_metadata)
75+
76+
return datasets, launcher, preprocessor, postprocessor
77+
5078
def process_dataset(self, subset=None, num_images=None, check_progress=False, dataset_tag='',
5179
output_callback=None, allow_pairwise_subset=False, dump_prediction_to_annotation=False,
5280
calculate_metrics=True, **kwargs):
@@ -164,7 +192,7 @@ def register_metric(self, metric_config):
164192
elif isinstance(metric_config, dict):
165193
self.metric_executor.register_metric(metric_config)
166194
else:
167-
raise ValueError('Unsupported metric configuration type {}'.format(type(metric_config)))
195+
raise ValueError(f'Unsupported metric configuration type {type(metric_config)}')
168196

169197
def get_metrics_attributes(self):
170198
if not self.metric_executor:
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""
2+
Copyright (c) 2024 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import inspect
18+
from typing import Union, List, Optional
19+
import numpy as np
20+
import PIL
21+
from .base_custom_evaluator import BaseCustomEvaluator
22+
from .base_models import BaseCascadeModel
23+
from ...utils import UnsupportedPackage, extract_image_representations
24+
from ...representation import SuperResolutionPrediction
25+
26+
try:
27+
from diffusers import DiffusionPipeline
28+
except ImportError as err_diff:
29+
DiffusionPipeline = UnsupportedPackage("diffusers", err_diff)
30+
31+
try:
32+
from diffusers import LMSDiscreteScheduler
33+
except ImportError as err_diff:
34+
LMSDiscreteScheduler = UnsupportedPackage("diffusers", err_diff)
35+
36+
try:
37+
from diffusers import DDIMScheduler
38+
except ImportError as err_diff:
39+
DDIMScheduler = UnsupportedPackage("diffusers", err_diff)
40+
41+
try:
42+
from diffusers.utils import torch_utils
43+
except ImportError as err_diff:
44+
torch_utils = UnsupportedPackage("diffusers.utils", err_diff)
45+
46+
try:
47+
import torch
48+
except ImportError as err_torch:
49+
torch = UnsupportedPackage("torch", err_torch)
50+
51+
52+
class PipelinedModel(BaseCascadeModel):
53+
def __init__(self, network_info, launcher, models_args, delayed_model_loading=False, config=None):
54+
super().__init__(network_info, launcher, delayed_model_loading)
55+
self.network_info = network_info
56+
self.launcher = launcher
57+
self.pipe = None
58+
self.config = config or {}
59+
self.seed = self.config.get("seed", 42)
60+
self.num_steps = self.config.get("num_inference_steps", 100)
61+
parts = network_info.keys()
62+
network_info = self.fill_part_with_model(
63+
network_info, parts, models_args, False, delayed_model_loading
64+
)
65+
if not delayed_model_loading:
66+
self.create_pipeline(launcher)
67+
68+
def create_pipeline(self, launcher, netowrk_info=None):
69+
netowrk_info = netowrk_info or self.network_info
70+
scheduler_config = self.config.get("scheduler_config", {})
71+
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
72+
73+
self.load_models(netowrk_info, launcher, True)
74+
unet = launcher.ie_core.compile_model(self.unet_model, launcher.device)
75+
vqvae = launcher.ie_core.compile_model(self.vqvae_model, launcher.device)
76+
77+
self.pipe = OVLdmSuperResolutionPipeline(
78+
launcher, scheduler, unet, vqvae,
79+
seed=self.seed, num_inference_steps=self.num_steps
80+
)
81+
82+
def release(self):
83+
del self.pipe
84+
85+
def load_models(self, model_info, launcher, log=False):
86+
if isinstance(model_info, dict):
87+
for model_name, model_dict in model_info.items():
88+
model_dict["name"] = model_name
89+
self.load_model(model_dict, launcher)
90+
else:
91+
for model_dict in model_info:
92+
self.load_model(model_dict, launcher)
93+
94+
if log:
95+
self.print_input_output_info()
96+
97+
def load_model(self, network_list, launcher):
98+
model, weights = self.automatic_model_search(network_list)
99+
if weights:
100+
network = launcher.read_network(str(model), str(weights))
101+
else:
102+
network = launcher.read_network(str(model), None)
103+
setattr(self, f"{network_list['name']}_model", network)
104+
105+
def print_input_output_info(self):
106+
model_parts = ("unet", "vqvae")
107+
for part in model_parts:
108+
part_model_id = f"{part}_model"
109+
model = getattr(self, part_model_id, None)
110+
if model is not None:
111+
self.launcher.print_input_output_info(model, part)
112+
113+
def predict(self, identifiers, input_data, input_meta):
114+
preds = []
115+
for idx, image in zip(identifiers, input_data):
116+
pred = self.pipe(image, eta=1, output_type="np")["hr_sample"][0]
117+
preds.append(SuperResolutionPrediction(idx, pred))
118+
return preds
119+
120+
121+
class LdmSuperResolutionEvaluator(BaseCustomEvaluator):
122+
def __init__(self, model, dataset_config, launcher, preprocessor, postprocessor, orig_config):
123+
super().__init__(dataset_config, launcher, orig_config, preprocessor, postprocessor)
124+
self.model = model
125+
126+
@classmethod
127+
def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
128+
dataset_config, launcher, preprocessor, postprocessor = (
129+
BaseCustomEvaluator.get_evaluator_init_info(
130+
config, delayed_annotation_loading=False
131+
)
132+
)
133+
134+
model = PipelinedModel(
135+
config.get('network_info', {}), launcher, config.get('_models', []),
136+
delayed_model_loading, config
137+
)
138+
139+
return cls(
140+
model, dataset_config, launcher, preprocessor, postprocessor, orig_config
141+
)
142+
143+
def _process(self, output_callback, calculate_metrics, progress_reporter, metric_config, csv_file):
144+
for batch_id, (batch_input_ids, batch_annotation, batch_input, batch_identifiers) in enumerate(self.dataset):
145+
batch_input = self.preprocessor.process(batch_input, batch_annotation)
146+
147+
batch_data, batch_meta = extract_image_representations(batch_input)
148+
batch_prediction = self.model.predict(
149+
batch_identifiers, batch_data, batch_meta
150+
)
151+
batch_annotation, batch_prediction = self.postprocessor.process_batch(
152+
batch_annotation, batch_prediction, batch_meta
153+
)
154+
155+
metrics_result = self._get_metrics_result(
156+
batch_input_ids, batch_annotation, batch_prediction, calculate_metrics
157+
)
158+
159+
if output_callback:
160+
output_callback(
161+
batch_raw_prediction=None, metrics_result=metrics_result,
162+
element_identifiers=batch_identifiers, dataset_indices=batch_input_ids
163+
)
164+
self._update_progress(
165+
progress_reporter, metric_config, batch_id, len(batch_prediction), csv_file
166+
)
167+
168+
169+
class OVLdmSuperResolutionPipeline(DiffusionPipeline):
170+
def __init__(
171+
self,
172+
launcher: "BaseLauncher", # noqa: F821
173+
scheduler: Union[DDIMScheduler, LMSDiscreteScheduler],
174+
unet,
175+
vqvae,
176+
seed=None,
177+
num_inference_steps=100
178+
):
179+
super().__init__()
180+
self.launcher = launcher
181+
self.scheduler = scheduler
182+
self.unet = unet
183+
self.vqvae = vqvae
184+
self._unet_output = self.unet.output(0)
185+
self._vqvae_output = self.vqvae.output(0)
186+
if seed is not None:
187+
torch.manual_seed(seed)
188+
self.num_inference_steps = num_inference_steps
189+
190+
def __call__(
191+
self,
192+
image: Union[torch.Tensor, np.ndarray, PIL.Image.Image] = None,
193+
batch_size: Optional[int] = 1,
194+
guidance_scale: Optional[float] = 7.5,
195+
eta: Optional[float] = 0.0,
196+
output_type: Optional[str] = "pil",
197+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
198+
return_dict: bool = True
199+
):
200+
201+
batch_size, image = self.preprocess_image(image)
202+
height, width = image.shape[-2:]
203+
204+
# in_channels should be 6: 3 for latents, 3 for low resolution image
205+
latents_shape = (batch_size, 3, height, width)
206+
latents = torch_utils.randn_tensor(latents_shape, generator=generator)
207+
# set timesteps and move to the correct device
208+
self.scheduler.set_timesteps(self.num_inference_steps)
209+
timesteps_tensor = self.scheduler.timesteps
210+
# scale the initial noise by the standard deviation required by the scheduler
211+
latents = latents * self.scheduler.init_noise_sigma
212+
latents = latents.numpy()
213+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
214+
extra_kwargs = {}
215+
if accepts_eta:
216+
extra_kwargs["eta"] = eta
217+
218+
for t in self.progress_bar(timesteps_tensor):
219+
# concat latents and low resolution image in the channel dimension.
220+
latents_input = np.concatenate([latents, image], axis=1)
221+
latents_input = self.scheduler.scale_model_input(latents_input, t)
222+
# predict the noise residual
223+
noise_pred = self.unet([latents_input, t])[self._unet_output]
224+
# compute the previous noisy sample x_t -> x_t-1
225+
latents = self.scheduler.step(
226+
torch.from_numpy(noise_pred), t, torch.from_numpy(latents)
227+
)["prev_sample"].numpy()
228+
229+
# decode the image latents with the VQVAE
230+
image = self.vqvae(latents)[self._vqvae_output]
231+
232+
image = self.postprocess_image(image, std=255, mean=0)
233+
234+
return {"hr_sample": image}
235+
236+
@staticmethod
237+
def preprocess_image(image):
238+
if isinstance(image, PIL.Image.Image):
239+
w, h = image.size
240+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
241+
image = image.resize((w, h), resample=PIL.Image.Resampling.LANCZOS)
242+
image = np.array(image)
243+
batch_size = 1
244+
elif isinstance(image, torch.Tensor):
245+
image = np.array(image)
246+
batch_size = image.shape[0]
247+
elif isinstance(image, np.ndarray):
248+
batch_size = 1
249+
else:
250+
raise ValueError(
251+
f"`image` has to be of type `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` but is {type(image)}"
252+
)
253+
254+
image = image.astype(np.float32) / 255.0
255+
image = image[None].transpose(0, 3, 1, 2)
256+
image = torch.from_numpy(image)
257+
timage = 2.0 * image - 1.0
258+
259+
return batch_size, timage
260+
261+
@staticmethod
262+
def postprocess_image(image: np.ndarray, std=255, mean=0):
263+
image = image / 2 + 0.5
264+
image = image.transpose(0, 2, 3, 1)
265+
image *= np.array(std, dtype=image.dtype)
266+
image += np.array(mean, dtype=image.dtype)
267+
268+
image = np.clip(image, 0., 255.)
269+
image = image.astype(np.uint8)
270+
return image

0 commit comments

Comments
 (0)