@@ -544,7 +544,7 @@ def __init__(
544
544
545
545
assert scale <= 1 , 'scale must be less than or equal to 1'
546
546
self .scale = scale
547
- self .normalize_img_variance = normalize_img_variance if scale < 1 else identity
547
+ self .maybe_normalize_img_variance = normalize_img_variance if scale < 1 else identity
548
548
549
549
# gamma schedules
550
550
@@ -596,8 +596,8 @@ def ddpm_sample(self, shape, time_difference = None):
596
596
597
597
# get predicted x0
598
598
599
- img = self .normalize_img_variance (img )
600
- model_output , last_latents = self .model (img , noise_cond , x_start , last_latents , return_latents = True )
599
+ maybe_normalized_img = self .maybe_normalize_img_variance (img )
600
+ model_output , last_latents = self .model (maybe_normalized_img , noise_cond , x_start , last_latents , return_latents = True )
601
601
602
602
# get log(snr)
603
603
@@ -675,8 +675,8 @@ def ddim_sample(self, shape, time_difference = None):
675
675
676
676
# predict x0
677
677
678
- img = self .normalize_img_variance (img )
679
- model_output , last_latents = self .model (img , times , x_start , last_latents , return_latents = True )
678
+ maybe_normalized_img = self .maybe_normalize_img_variance (img )
679
+ model_output , last_latents = self .model (maybe_normalized_img , times , x_start , last_latents , return_latents = True )
680
680
681
681
# calculate x0 and noise
682
682
@@ -732,7 +732,7 @@ def forward(self, img, *args, **kwargs):
732
732
733
733
noised_img = alpha * img + sigma * noise
734
734
735
- noised_img = self .normalize_img_variance (noised_img )
735
+ noised_img = self .maybe_normalize_img_variance (noised_img )
736
736
737
737
# in the paper, they had to use a really high probability of latent self conditioning, up to 90% of the time
738
738
# slight drawback
0 commit comments