Skip to content

Commit 13b4341

Browse files
authored
fix no eval bug (#790)
1 parent 14890de commit 13b4341

File tree

6 files changed

+70
-95
lines changed

6 files changed

+70
-95
lines changed

ppfleetx/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def build_dataset(config, mode):
5858

5959
def build_dataloader(config, mode):
6060
dataset = build_dataset(config, mode)
61+
if dataset is None:
62+
return None
6163

6264
batch_sampler = None
6365
# build sampler

ppfleetx/data/dataset/multimodal_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ def data_augmentation_for_imagen(img, resolution):
7878
arr = deepcopy(img)
7979
while min(*arr.size) >= 2 * resolution:
8080
arr = arr.resize(
81-
tuple(x // 2 for x in arr.size), resample=Image.Resampling.BOX)
81+
tuple(x // 2 for x in arr.size), resample=Image.BOX)
8282
scale = resolution / min(*arr.size)
8383
arr = arr.resize(
8484
tuple(round(x * scale) for x in arr.size),
85-
resample=Image.Resampling.BICUBIC)
85+
resample=Image.BICUBIC)
8686

8787
arr = np.array(arr.convert("RGB"))
8888
crop_y = (arr.shape[0] - resolution) // 2

ppfleetx/models/multimodal_model/imagen/modeling.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import paddle.vision.transforms as T
2222

2323
from .unet import Unet
24-
from .utils import (GaussianDiffusionContinuousTimes, default, exists,
25-
cast_tuple, first, maybe, eval_decorator, identity,
24+
from .utils import (GaussianDiffusionContinuousTimes, default, cast_tuple,
25+
first, maybe, eval_decorator, identity,
2626
pad_tuple_to_length, right_pad_dims_to, resize_image_to,
2727
normalize_neg_one_to_one, rearrange, repeat, reduce,
2828
unnormalize_zero_to_one, cast_uint8_images_to_float)
@@ -195,9 +195,8 @@ def __init__(self,
195195
# randomly cropping for upsampler training
196196

197197
self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
198-
assert not exists(
199-
first(self.random_crop_sizes)
200-
), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
198+
assert first(
199+
self.random_crop_sizes) is None, 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
201200
# lowres augmentation noise schedule
202201

203202
self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(
@@ -284,22 +283,17 @@ def get_unet(self, unet_number):
284283
assert 0 < unet_number <= len(self.unets)
285284
index = unet_number - 1
286285

287-
if isinstance(self.unets, nn.LayerList):
288-
unets_list = [unet for unet in self.unets]
289-
delattr(self, 'unets')
290-
self.unets = unets_list
291286
self.unet_being_trained_index = index
292287
return self.unets[index]
293288

294289
def reset_unets(self, ):
295-
self.unets = nn.LayerList([*self.unets])
296290
self.unet_being_trained_index = -1
297291

298292
@contextmanager
299293
def one_unet_in_gpu(self, unet_number=None, unet=None):
300-
assert exists(unet_number) ^ exists(unet)
294+
assert (unet_number is not None) ^ (unet is not None)
301295

302-
if exists(unet_number):
296+
if unet_number is not None:
303297
unet = self.unets[unet_number - 1]
304298

305299
yield
@@ -320,7 +314,6 @@ def p_mean_variance(self,
320314
unet,
321315
x,
322316
t,
323-
*,
324317
noise_scheduler,
325318
text_embeds=None,
326319
text_mask=None,
@@ -370,7 +363,6 @@ def p_sample(self,
370363
unet,
371364
x,
372365
t,
373-
*,
374366
noise_scheduler,
375367
t_next=None,
376368
text_embeds=None,
@@ -412,7 +404,6 @@ def p_sample(self,
412404
def p_sample_loop(self,
413405
unet,
414406
shape,
415-
*,
416407
noise_scheduler,
417408
lowres_cond_img=None,
418409
lowres_noise_times=None,
@@ -433,7 +424,7 @@ def p_sample_loop(self,
433424

434425
# prepare inpainting
435426

436-
has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
427+
has_inpainting = inpaint_images is not None and inpaint_masks is not None
437428
resample_times = inpaint_resample_times if has_inpainting else 1
438429

439430
if has_inpainting:
@@ -532,18 +523,18 @@ def sample(
532523
batch_size = text_embeds.shape[0]
533524

534525
assert not (
535-
self.condition_on_text and not exists(text_embeds)
526+
self.condition_on_text and text_embeds is None
536527
), 'text or text encodings must be passed into imagen if specified'
537528
assert not (
538-
not self.condition_on_text and exists(text_embeds)
529+
not self.condition_on_text and text_embeds is not None
539530
), 'imagen specified not to be conditioned on text, yet it is presented'
540531
assert not (
541-
exists(text_embeds) and
532+
text_embeds is not None and
542533
text_embeds.shape[-1] != self.text_embed_dim
543534
), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
544535

545536
assert not (
546-
exists(inpaint_images) ^ exists(inpaint_masks)
537+
(inpaint_images is not None) ^ (inpaint_masks is not None)
547538
), 'inpaint images and masks must be both passed in to do inpainting'
548539

549540
outputs = []
@@ -609,8 +600,7 @@ def sample(
609600

610601
outputs.append(img)
611602

612-
if exists(stop_at_unet_number
613-
) and stop_at_unet_number == unet_number:
603+
if stop_at_unet_number is not None and stop_at_unet_number == unet_number:
614604
break
615605

616606
output_index = -1 if not return_all_unet_outputs else slice(
@@ -633,7 +623,6 @@ def p_losses(self,
633623
unet,
634624
x_start,
635625
times,
636-
*,
637626
noise_scheduler,
638627
lowres_cond_img=None,
639628
lowres_aug_times=None,
@@ -655,7 +644,7 @@ def p_losses(self,
655644
# random cropping during training
656645
# for upsamplers
657646

658-
if exists(random_crop_size):
647+
if random_crop_size is not None:
659648
aug = K.RandomCrop((random_crop_size, random_crop_size), p=1.)
660649
# make sure low res conditioner and image both get augmented the same way
661650
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
@@ -672,7 +661,7 @@ def p_losses(self,
672661
# at sample time, they then fix the noise level of 0.1 - 0.3
673662

674663
lowres_cond_img_noisy = None
675-
if exists(lowres_cond_img):
664+
if lowres_cond_img is not None:
676665
lowres_aug_times = default(lowres_aug_times, times)
677666
lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(
678667
x_start=lowres_cond_img,
@@ -715,11 +704,10 @@ def forward(self,
715704
assert images.shape[-1] == images.shape[
716705
-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
717706
assert not (
718-
len(self.unets) > 1 and not exists(unet_number)
707+
len(self.unets) > 1 and unet_number is None
719708
), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
720709
unet_number = default(unet_number, 1)
721-
assert not exists(
722-
self.only_train_unet_number
710+
assert (self.only_train_unet_number is None
723711
) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
724712

725713
images = cast_uint8_images_to_float(images)
@@ -748,19 +736,19 @@ def forward(self,
748736
text_masks, lambda: paddle.any(text_embeds != 0., axis=-1))
749737

750738
assert not (
751-
self.condition_on_text and not exists(text_embeds)
739+
self.condition_on_text and text_embeds is None
752740
), 'text or text encodings must be passed into decoder if specified'
753741
assert not (
754-
not self.condition_on_text and exists(text_embeds)
742+
not self.condition_on_text and text_embeds is not None
755743
), 'decoder specified not to be conditioned on text, yet it is presented'
756744

757745
assert not (
758-
exists(text_embeds) and
746+
(text_embeds is not None) and
759747
text_embeds.shape[-1] != self.text_embed_dim
760748
), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
761749

762750
lowres_cond_img = lowres_aug_times = None
763-
if exists(prev_image_size):
751+
if prev_image_size is not None:
764752
lowres_cond_img = resize_image_to(
765753
images, prev_image_size, clamp_range=self.input_image_range)
766754
lowres_cond_img = resize_image_to(

0 commit comments

Comments
 (0)