Skip to content

Commit cabab0e

Browse files
committed
self conditioning step needs to also receive variance normalized noised image
1 parent ccb4703 commit cabab0e

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,8 @@ def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
498498

499499
# converting gamma to alpha, sigma or logsnr
500500

501-
def gamma_to_alpha_sigma(gamma):
502-
return torch.sqrt(gamma), torch.sqrt(1 - gamma)
501+
def gamma_to_alpha_sigma(gamma, scale = 1):
502+
return torch.sqrt(gamma) * scale, torch.sqrt(1 - gamma)
503503

504504
def gamma_to_log_snr(gamma, eps = 1e-5):
505505
return -log(gamma ** -1. - 1, eps = eps)
@@ -543,7 +543,7 @@ def __init__(
543543
# the main finding presented in Ting Chen's paper - that higher resolution images requires more noise for better training
544544

545545
assert scale <= 1, 'scale must be less than or equal to 1'
546-
self.scale = scale #
546+
self.scale = scale
547547
self.normalize_img_variance = normalize_img_variance if scale < 1 else identity
548548

549549
# gamma schedules
@@ -607,8 +607,8 @@ def ddpm_sample(self, shape, time_difference = None):
607607

608608
# get alpha sigma of time and next time
609609

610-
alpha, sigma = gamma_to_alpha_sigma(gamma)
611-
alpha_next, sigma_next = gamma_to_alpha_sigma(gamma_next)
610+
alpha, sigma = gamma_to_alpha_sigma(gamma, self.scale)
611+
alpha_next, sigma_next = gamma_to_alpha_sigma(gamma_next, self.scale)
612612

613613
# calculate x0 and noise
614614

@@ -666,8 +666,8 @@ def ddim_sample(self, shape, time_difference = None):
666666

667667
padded_gamma, padded_gamma_next = map(partial(right_pad_dims_to, img), (gamma, gamma_next))
668668

669-
alpha, sigma = gamma_to_alpha_sigma(padded_gamma)
670-
alpha_next, sigma_next = gamma_to_alpha_sigma(padded_gamma_next)
669+
alpha, sigma = gamma_to_alpha_sigma(padded_gamma, self.scale)
670+
alpha_next, sigma_next = gamma_to_alpha_sigma(padded_gamma_next, self.scale)
671671

672672
# add the time delay
673673

@@ -728,10 +728,12 @@ def forward(self, img, *args, **kwargs):
728728

729729
gamma = self.gamma_schedule(times)
730730
padded_gamma = right_pad_dims_to(img, gamma)
731-
alpha, sigma = gamma_to_alpha_sigma(padded_gamma)
731+
alpha, sigma = gamma_to_alpha_sigma(padded_gamma, self.scale)
732732

733733
noised_img = alpha * img + sigma * noise
734734

735+
noised_img = self.normalize_img_variance(noised_img)
736+
735737
# in the paper, they had to use a really high probability of latent self conditioning, up to 90% of the time
736738
# slight drawback
737739

@@ -745,7 +747,6 @@ def forward(self, img, *args, **kwargs):
745747

746748
# predict and take gradient step
747749

748-
noised_img = self.normalize_img_variance(noised_img)
749750
pred = self.model(noised_img, times, self_cond, self_latents)
750751

751752
if self.objective == 'x0':

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.4.0',
6+
version = '0.4.2',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)