Skip to content

Commit ebf0225

Browse files
committed
make it clear that the normalize function may not be in effect
1 parent 87faae9 commit ebf0225

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def __init__(
544544

545545
assert scale <= 1, 'scale must be less than or equal to 1'
546546
self.scale = scale
547-
self.normalize_img_variance = normalize_img_variance if scale < 1 else identity
547+
self.maybe_normalize_img_variance = normalize_img_variance if scale < 1 else identity
548548

549549
# gamma schedules
550550

@@ -596,8 +596,8 @@ def ddpm_sample(self, shape, time_difference = None):
596596

597597
# get predicted x0
598598

599-
img = self.normalize_img_variance(img)
600-
model_output, last_latents = self.model(img, noise_cond, x_start, last_latents, return_latents = True)
599+
maybe_normalized_img = self.maybe_normalize_img_variance(img)
600+
model_output, last_latents = self.model(maybe_normalized_img, noise_cond, x_start, last_latents, return_latents = True)
601601

602602
# get log(snr)
603603

@@ -675,8 +675,8 @@ def ddim_sample(self, shape, time_difference = None):
675675

676676
# predict x0
677677

678-
img = self.normalize_img_variance(img)
679-
model_output, last_latents = self.model(img, times, x_start, last_latents, return_latents = True)
678+
maybe_normalized_img = self.maybe_normalize_img_variance(img)
679+
model_output, last_latents = self.model(maybe_normalized_img, times, x_start, last_latents, return_latents = True)
680680

681681
# calculate x0 and noise
682682

@@ -732,7 +732,7 @@ def forward(self, img, *args, **kwargs):
732732

733733
noised_img = alpha * img + sigma * noise
734734

735-
noised_img = self.normalize_img_variance(noised_img)
735+
noised_img = self.maybe_normalize_img_variance(noised_img)
736736

737737
# in the paper, they had to use a really high probability of latent self conditioning, up to 90% of the time
738738
# slight drawback

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.4.4',
6+
version = '0.4.5',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)