Skip to content

Commit c506dd9

Browse files
authored
Create ldm_super_resolution_evaluator.py
1 parent e7df86d commit c506dd9

File tree

1 file changed

+265
-0
lines changed

1 file changed

+265
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
"""
2+
Copyright (c) 2024 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import inspect
18+
from typing import Union, List, Optional
19+
import numpy as np
20+
import PIL
21+
from diffusers.utils import torch_utils
22+
from .base_custom_evaluator import BaseCustomEvaluator
23+
from .base_models import BaseCascadeModel
24+
from ...utils import UnsupportedPackage, extract_image_representations
25+
from ...representation import SuperResolutionPrediction
26+
27+
try:
28+
from diffusers import DiffusionPipeline
29+
except ImportError as err_diff:
30+
DiffusionPipeline = UnsupportedPackage("diffusers", err_diff)
31+
32+
try:
33+
from diffusers import LMSDiscreteScheduler
34+
except ImportError as err_diff:
35+
LMSDiscreteScheduler = UnsupportedPackage("diffusers", err_diff)
36+
37+
try:
38+
from diffusers import DDIMScheduler
39+
except ImportError as err_diff:
40+
DDIMScheduler = UnsupportedPackage("diffusers", err_diff)
41+
42+
try:
43+
import torch
44+
except ImportError as err_torch:
45+
torch = UnsupportedPackage("torch", err_torch)
46+
47+
48+
class PipelinedModel(BaseCascadeModel):
49+
def __init__(self, network_info, launcher, models_args, delayed_model_loading=False, config=None):
50+
super().__init__(network_info, launcher, delayed_model_loading)
51+
self.network_info = network_info
52+
self.launcher = launcher
53+
self.pipe = None
54+
self.config = config or {}
55+
self.seed = self.config.get("seed", 42)
56+
self.num_steps = self.config.get("num_inference_steps", 100)
57+
parts = network_info.keys()
58+
network_info = self.fill_part_with_model(
59+
network_info, parts, models_args, False, delayed_model_loading
60+
)
61+
if not delayed_model_loading:
62+
self.create_pipeline(launcher)
63+
64+
def create_pipeline(self, launcher, netowrk_info=None):
65+
netowrk_info = netowrk_info or self.network_info
66+
scheduler_config = self.config.get("scheduler_config", {})
67+
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
68+
69+
self.load_models(netowrk_info, launcher, True)
70+
unet = launcher.ie_core.compile_model(self.unet_model, launcher.device)
71+
vqvae = launcher.ie_core.compile_model(self.vqvae_model, launcher.device)
72+
73+
self.pipe = OVLdmSuperResolutionPipeline(
74+
launcher, scheduler, unet, vqvae,
75+
seed=self.seed, num_inference_steps=self.num_steps
76+
)
77+
78+
def release(self):
79+
del self.pipe
80+
81+
def load_models(self, model_info, launcher, log=False):
82+
if isinstance(model_info, dict):
83+
for model_name, model_dict in model_info.items():
84+
model_dict["name"] = model_name
85+
self.load_model(model_dict, launcher)
86+
else:
87+
for model_dict in model_info:
88+
self.load_model(model_dict, launcher)
89+
90+
if log:
91+
self.print_input_output_info()
92+
93+
def load_model(self, network_info, launcher):
94+
model, weights = self.automatic_model_search(network_info)
95+
if weights:
96+
network = launcher.read_network(str(model), str(weights))
97+
else:
98+
network = launcher.read_network(str(model), None)
99+
setattr(self, f"{network_info['name']}_model", network)
100+
101+
def print_input_output_info(self):
102+
model_parts = ("unet", "vqvae")
103+
for part in model_parts:
104+
part_model_id = f"{part}_model"
105+
model = getattr(self, part_model_id, None)
106+
if model is not None:
107+
self.launcher.print_input_output_info(model, part)
108+
109+
def predict(self, identifiers, input_data, input_meta):
110+
preds = []
111+
for idx, image in zip(identifiers, input_data):
112+
pred = self.pipe(image, eta=1, output_type="np")["hr_sample"][0]
113+
preds.append(SuperResolutionPrediction(idx, pred))
114+
return preds
115+
116+
117+
class LdmSuperResolutionEvaluator(BaseCustomEvaluator):
118+
def __init__(self, model, dataset_config, launcher, preprocessor, postprocessor, orig_config):
119+
super().__init__(dataset_config, launcher, orig_config, preprocessor, postprocessor)
120+
self.model = model
121+
122+
@classmethod
123+
def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
124+
dataset_config, launcher, preprocessor, postprocessor = (
125+
BaseCustomEvaluator.get_evaluator_init_info(
126+
config, delayed_annotation_loading=False
127+
)
128+
)
129+
130+
model = PipelinedModel(
131+
config.get('network_info', {}), launcher, config.get('_models', []),
132+
delayed_model_loading, config
133+
)
134+
135+
return cls(
136+
model, dataset_config, launcher, preprocessor, postprocessor, orig_config
137+
)
138+
139+
def _process(self, output_callback, calculate_metrics, progress_reporter, metric_config, csv_file):
140+
for batch_id, (batch_input_ids, batch_annotation, batch_input, batch_identifiers) in enumerate(self.dataset):
141+
batch_input = self.preprocessor.process(batch_input, batch_annotation)
142+
143+
batch_data, batch_meta = extract_image_representations(batch_input)
144+
batch_prediction = self.model.predict(
145+
batch_identifiers, batch_data, batch_meta
146+
)
147+
batch_annotation, batch_prediction = self.postprocessor.process_batch(
148+
batch_annotation, batch_prediction, batch_meta
149+
)
150+
151+
metrics_result = self._get_metrics_result(
152+
batch_input_ids, batch_annotation, batch_prediction, calculate_metrics
153+
)
154+
155+
if output_callback:
156+
output_callback(
157+
batch_raw_prediction=None, metrics_result=metrics_result,
158+
element_identifiers=batch_identifiers, dataset_indices=batch_input_ids
159+
)
160+
self._update_progress(
161+
progress_reporter, metric_config, batch_id, len(batch_prediction), csv_file
162+
)
163+
164+
165+
class OVLdmSuperResolutionPipeline(DiffusionPipeline):
166+
def __init__(
167+
self,
168+
launcher: "BaseLauncher", # noqa: F821
169+
scheduler: Union[DDIMScheduler, LMSDiscreteScheduler],
170+
unet,
171+
vqvae,
172+
seed=None,
173+
num_inference_steps=100
174+
):
175+
super().__init__()
176+
self.launcher = launcher
177+
self.scheduler = scheduler
178+
self.unet = unet
179+
self.vqvae = vqvae
180+
self._unet_output = self.unet.output(0)
181+
self._vqvae_output = self.vqvae.output(0)
182+
if seed is not None:
183+
np.random.seed(seed)
184+
self.num_inference_steps = num_inference_steps
185+
186+
def __call__(
187+
self,
188+
image: Union[torch.Tensor, np.ndarray, PIL.Image.Image] = None,
189+
batch_size: Optional[int] = 1,
190+
guidance_scale: Optional[float] = 7.5,
191+
eta: Optional[float] = 0.0,
192+
output_type: Optional[str] = "pil",
193+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
194+
return_dict: bool = True
195+
):
196+
197+
batch_size, image = self.preprocess_image(image)
198+
height, width = image.shape[-2:]
199+
200+
# in_channels should be 6: 3 for latents, 3 for low resolution image
201+
latents_shape = (batch_size, 3, height, width)
202+
latents = torch_utils.randn_tensor(latents_shape, generator=generator)
203+
# set timesteps and move to the correct device
204+
self.scheduler.set_timesteps(self.num_inference_steps)
205+
timesteps_tensor = self.scheduler.timesteps
206+
# scale the initial noise by the standard deviation required by the scheduler
207+
latents = latents * self.scheduler.init_noise_sigma
208+
latents = latents.numpy()
209+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
210+
extra_kwargs = {}
211+
if accepts_eta:
212+
extra_kwargs["eta"] = eta
213+
214+
for t in self.progress_bar(timesteps_tensor):
215+
# concat latents and low resolution image in the channel dimension.
216+
latents_input = np.concatenate([latents, image], axis=1)
217+
latents_input = self.scheduler.scale_model_input(latents_input, t)
218+
# predict the noise residual
219+
noise_pred = self.unet([latents_input, t])[self._unet_output]
220+
# compute the previous noisy sample x_t -> x_t-1
221+
latents = self.scheduler.step(
222+
torch.from_numpy(noise_pred), t, torch.from_numpy(latents)
223+
)["prev_sample"].numpy()
224+
225+
# decode the image latents with the VQVAE
226+
image = self.vqvae(latents)[self._vqvae_output]
227+
228+
image = self.postprocess_image(image, std=255, mean=0)
229+
230+
return {"hr_sample": image}
231+
232+
@staticmethod
233+
def preprocess_image(image):
234+
if isinstance(image, PIL.Image.Image):
235+
w, h = image.size
236+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
237+
image = image.resize((w, h), resample=PIL.Image.Resampling.LANCZOS)
238+
image = np.array(image)
239+
batch_size = 1
240+
elif isinstance(image, torch.Tensor):
241+
image = np.array(image)
242+
batch_size = image.shape[0]
243+
elif isinstance(image, np.ndarray):
244+
batch_size = 1
245+
else:
246+
raise ValueError(
247+
f"`image` has to be of type `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` but is {type(image)}"
248+
)
249+
250+
image = image.astype(np.float32) / 255.0
251+
image = image[None].transpose(0, 3, 1, 2)
252+
image = torch.from_numpy(image)
253+
timage = 2.0 * image - 1.0
254+
255+
return batch_size, timage
256+
257+
@staticmethod
258+
def postprocess_image(image: np.ndarray, std=255, mean=0):
259+
image = image / 2 + 0.5
260+
image *= np.array(std, dtype=image.dtype)
261+
image += np.array(mean, dtype=image.dtype)
262+
263+
image = np.clip(image, 0., 255.)
264+
image = image.astype(np.uint8)
265+
return image

0 commit comments

Comments
 (0)