21
21
import paddle .vision .transforms as T
22
22
23
23
from .unet import Unet
24
- from .utils import (GaussianDiffusionContinuousTimes , default , exists ,
25
- cast_tuple , first , maybe , eval_decorator , identity ,
24
+ from .utils import (GaussianDiffusionContinuousTimes , default , cast_tuple ,
25
+ first , maybe , eval_decorator , identity ,
26
26
pad_tuple_to_length , right_pad_dims_to , resize_image_to ,
27
27
normalize_neg_one_to_one , rearrange , repeat , reduce ,
28
28
unnormalize_zero_to_one , cast_uint8_images_to_float )
@@ -195,9 +195,8 @@ def __init__(self,
195
195
# randomly cropping for upsampler training
196
196
197
197
self .random_crop_sizes = cast_tuple (random_crop_sizes , num_unets )
198
- assert not exists (
199
- first (self .random_crop_sizes )
200
- ), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
198
+ assert first (
199
+ self .random_crop_sizes ) is None , 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
201
200
# lowres augmentation noise schedule
202
201
203
202
self .lowres_noise_schedule = GaussianDiffusionContinuousTimes (
@@ -284,22 +283,17 @@ def get_unet(self, unet_number):
284
283
assert 0 < unet_number <= len (self .unets )
285
284
index = unet_number - 1
286
285
287
- if isinstance (self .unets , nn .LayerList ):
288
- unets_list = [unet for unet in self .unets ]
289
- delattr (self , 'unets' )
290
- self .unets = unets_list
291
286
self .unet_being_trained_index = index
292
287
return self .unets [index ]
293
288
294
289
def reset_unets (self , ):
295
- self .unets = nn .LayerList ([* self .unets ])
296
290
self .unet_being_trained_index = - 1
297
291
298
292
@contextmanager
299
293
def one_unet_in_gpu (self , unet_number = None , unet = None ):
300
- assert exists (unet_number ) ^ exists (unet )
294
+ assert (unet_number is not None ) ^ (unet is not None )
301
295
302
- if exists ( unet_number ) :
296
+ if unet_number is not None :
303
297
unet = self .unets [unet_number - 1 ]
304
298
305
299
yield
@@ -320,7 +314,6 @@ def p_mean_variance(self,
320
314
unet ,
321
315
x ,
322
316
t ,
323
- * ,
324
317
noise_scheduler ,
325
318
text_embeds = None ,
326
319
text_mask = None ,
@@ -370,7 +363,6 @@ def p_sample(self,
370
363
unet ,
371
364
x ,
372
365
t ,
373
- * ,
374
366
noise_scheduler ,
375
367
t_next = None ,
376
368
text_embeds = None ,
@@ -412,7 +404,6 @@ def p_sample(self,
412
404
def p_sample_loop (self ,
413
405
unet ,
414
406
shape ,
415
- * ,
416
407
noise_scheduler ,
417
408
lowres_cond_img = None ,
418
409
lowres_noise_times = None ,
@@ -433,7 +424,7 @@ def p_sample_loop(self,
433
424
434
425
# prepare inpainting
435
426
436
- has_inpainting = exists ( inpaint_images ) and exists ( inpaint_masks )
427
+ has_inpainting = inpaint_images is not None and inpaint_masks is not None
437
428
resample_times = inpaint_resample_times if has_inpainting else 1
438
429
439
430
if has_inpainting :
@@ -532,18 +523,18 @@ def sample(
532
523
batch_size = text_embeds .shape [0 ]
533
524
534
525
assert not (
535
- self .condition_on_text and not exists ( text_embeds )
526
+ self .condition_on_text and text_embeds is None
536
527
), 'text or text encodings must be passed into imagen if specified'
537
528
assert not (
538
- not self .condition_on_text and exists ( text_embeds )
529
+ not self .condition_on_text and text_embeds is not None
539
530
), 'imagen specified not to be conditioned on text, yet it is presented'
540
531
assert not (
541
- exists ( text_embeds ) and
532
+ text_embeds is not None and
542
533
text_embeds .shape [- 1 ] != self .text_embed_dim
543
534
), f'invalid text embedding dimension being passed in (should be { self .text_embed_dim } )'
544
535
545
536
assert not (
546
- exists (inpaint_images ) ^ exists (inpaint_masks )
537
+ (inpaint_images is not None ) ^ (inpaint_masks is not None )
547
538
), 'inpaint images and masks must be both passed in to do inpainting'
548
539
549
540
outputs = []
@@ -609,8 +600,7 @@ def sample(
609
600
610
601
outputs .append (img )
611
602
612
- if exists (stop_at_unet_number
613
- ) and stop_at_unet_number == unet_number :
603
+ if stop_at_unet_number is not None and stop_at_unet_number == unet_number :
614
604
break
615
605
616
606
output_index = - 1 if not return_all_unet_outputs else slice (
@@ -633,7 +623,6 @@ def p_losses(self,
633
623
unet ,
634
624
x_start ,
635
625
times ,
636
- * ,
637
626
noise_scheduler ,
638
627
lowres_cond_img = None ,
639
628
lowres_aug_times = None ,
@@ -655,7 +644,7 @@ def p_losses(self,
655
644
# random cropping during training
656
645
# for upsamplers
657
646
658
- if exists ( random_crop_size ) :
647
+ if random_crop_size is not None :
659
648
aug = K .RandomCrop ((random_crop_size , random_crop_size ), p = 1. )
660
649
# make sure low res conditioner and image both get augmented the same way
661
650
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
@@ -672,7 +661,7 @@ def p_losses(self,
672
661
# at sample time, they then fix the noise level of 0.1 - 0.3
673
662
674
663
lowres_cond_img_noisy = None
675
- if exists ( lowres_cond_img ) :
664
+ if lowres_cond_img is not None :
676
665
lowres_aug_times = default (lowres_aug_times , times )
677
666
lowres_cond_img_noisy , _ = self .lowres_noise_schedule .q_sample (
678
667
x_start = lowres_cond_img ,
@@ -715,11 +704,10 @@ def forward(self,
715
704
assert images .shape [- 1 ] == images .shape [
716
705
- 2 ], f'the images you pass in must be a square, but received dimensions of { images .shape [2 ]} , { images .shape [- 1 ]} '
717
706
assert not (
718
- len (self .unets ) > 1 and not exists ( unet_number )
707
+ len (self .unets ) > 1 and unet_number is None
719
708
), f'you must specify which unet you want trained, from a range of 1 to { len (self .unets )} , if you are training cascading DDPM (multiple unets)'
720
709
unet_number = default (unet_number , 1 )
721
- assert not exists (
722
- self .only_train_unet_number
710
+ assert (self .only_train_unet_number is None
723
711
) or self .only_train_unet_number == unet_number , 'you can only train on unet #{self.only_train_unet_number}'
724
712
725
713
images = cast_uint8_images_to_float (images )
@@ -748,19 +736,19 @@ def forward(self,
748
736
text_masks , lambda : paddle .any (text_embeds != 0. , axis = - 1 ))
749
737
750
738
assert not (
751
- self .condition_on_text and not exists ( text_embeds )
739
+ self .condition_on_text and text_embeds is None
752
740
), 'text or text encodings must be passed into decoder if specified'
753
741
assert not (
754
- not self .condition_on_text and exists ( text_embeds )
742
+ not self .condition_on_text and text_embeds is not None
755
743
), 'decoder specified not to be conditioned on text, yet it is presented'
756
744
757
745
assert not (
758
- exists (text_embeds ) and
746
+ (text_embeds is not None ) and
759
747
text_embeds .shape [- 1 ] != self .text_embed_dim
760
748
), f'invalid text embedding dimension being passed in (should be { self .text_embed_dim } )'
761
749
762
750
lowres_cond_img = lowres_aug_times = None
763
- if exists ( prev_image_size ) :
751
+ if prev_image_size is not None :
764
752
lowres_cond_img = resize_image_to (
765
753
images , prev_image_size , clamp_range = self .input_image_range )
766
754
lowres_cond_img = resize_image_to (
0 commit comments