@@ -498,8 +498,8 @@ def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
498
498
499
499
# converting gamma to alpha, sigma or logsnr
500
500
501
- def gamma_to_alpha_sigma (gamma ):
502
- return torch .sqrt (gamma ), torch .sqrt (1 - gamma )
501
+ def gamma_to_alpha_sigma (gamma , scale = 1 ):
502
+ return torch .sqrt (gamma ) * scale , torch .sqrt (1 - gamma )
503
503
504
504
def gamma_to_log_snr (gamma , eps = 1e-5 ):
505
505
return - log (gamma ** - 1. - 1 , eps = eps )
@@ -543,7 +543,7 @@ def __init__(
543
543
# the main finding presented in Ting Chen's paper - that higher resolution images requires more noise for better training
544
544
545
545
assert scale <= 1 , 'scale must be less than or equal to 1'
546
- self .scale = scale #
546
+ self .scale = scale
547
547
self .normalize_img_variance = normalize_img_variance if scale < 1 else identity
548
548
549
549
# gamma schedules
@@ -607,8 +607,8 @@ def ddpm_sample(self, shape, time_difference = None):
607
607
608
608
# get alpha sigma of time and next time
609
609
610
- alpha , sigma = gamma_to_alpha_sigma (gamma )
611
- alpha_next , sigma_next = gamma_to_alpha_sigma (gamma_next )
610
+ alpha , sigma = gamma_to_alpha_sigma (gamma , self . scale )
611
+ alpha_next , sigma_next = gamma_to_alpha_sigma (gamma_next , self . scale )
612
612
613
613
# calculate x0 and noise
614
614
@@ -666,8 +666,8 @@ def ddim_sample(self, shape, time_difference = None):
666
666
667
667
padded_gamma , padded_gamma_next = map (partial (right_pad_dims_to , img ), (gamma , gamma_next ))
668
668
669
- alpha , sigma = gamma_to_alpha_sigma (padded_gamma )
670
- alpha_next , sigma_next = gamma_to_alpha_sigma (padded_gamma_next )
669
+ alpha , sigma = gamma_to_alpha_sigma (padded_gamma , self . scale )
670
+ alpha_next , sigma_next = gamma_to_alpha_sigma (padded_gamma_next , self . scale )
671
671
672
672
# add the time delay
673
673
@@ -728,10 +728,12 @@ def forward(self, img, *args, **kwargs):
728
728
729
729
gamma = self .gamma_schedule (times )
730
730
padded_gamma = right_pad_dims_to (img , gamma )
731
- alpha , sigma = gamma_to_alpha_sigma (padded_gamma )
731
+ alpha , sigma = gamma_to_alpha_sigma (padded_gamma , self . scale )
732
732
733
733
noised_img = alpha * img + sigma * noise
734
734
735
+ noised_img = self .normalize_img_variance (noised_img )
736
+
735
737
# in the paper, they had to use a really high probability of latent self conditioning, up to 90% of the time
736
738
# slight drawback
737
739
@@ -745,7 +747,6 @@ def forward(self, img, *args, **kwargs):
745
747
746
748
# predict and take gradient step
747
749
748
- noised_img = self .normalize_img_variance (noised_img )
749
750
pred = self .model (noised_img , times , self_cond , self_latents )
750
751
751
752
if self .objective == 'x0' :
0 commit comments