Skip to content

Commit ccb4703

Browse files
committed
incorporate the findings of Ting Chens new paper, ability to noise at higher levels (and normalize variance of noised image automatically). they show good results using RIN at higher resolution images using this technique. also redo all the schedules to be gamma-centric
1 parent ee85704 commit ccb4703

File tree

3 files changed

+80
-34
lines changed

3 files changed

+80
-34
lines changed

README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Additionally, we will try adding an extra linear attention on the main branch as
1212

1313
The insight of being able to self-condition on any hidden state of the network as well as the newly proposed sigmoid noise schedule are the two main findings.
1414

15+
This repository also contains the ability to <a href="https://arxiv.org/abs/2301.10972">noise higher resolution images more</a>, using the `scale` keyword argument on the `GaussianDiffusion` class. It also contains the simple linear gamma schedule proposed in that paper.
16+
1517
## Appreciation
1618

1719
- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work on cutting edge artificial intelligence research
@@ -42,7 +44,8 @@ diffusion = GaussianDiffusion(
4244
image_size = 128,
4345
use_ddim = False,
4446
timesteps = 400,
45-
train_prob_self_cond = 0.9 # how often to self condition on latents
47+
train_prob_self_cond = 0.9, # how often to self condition on latents
48+
scale = 1. # this will be set to < 1. for more noising and leads to better convergence when training on higher resolution images (512, 1024) - input noised images will be auto variance normalized
4649
).cuda()
4750

4851
trainer = Trainer(
@@ -81,7 +84,8 @@ diffusion = GaussianDiffusion(
8184
model,
8285
image_size = 128,
8386
timesteps = 1000,
84-
train_prob_self_cond = 0.9
87+
train_prob_self_cond = 0.9,
88+
scale = 1.
8589
)
8690

8791
training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1
@@ -110,3 +114,11 @@ sampled_images.shape # (4, 3, 128, 128)
110114
primaryClass = {cs.LG}
111115
}
112116
```
117+
118+
```bibtex
119+
@inproceedings{Chen2023OnTI,
120+
title = {On the Importance of Noise Scheduling for Diffusion Models},
121+
author = {Ting Chen},
122+
year = {2023}
123+
}
124+
```

rin_pytorch/rin_pytorch.py

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
def exists(x):
3030
return x is not None
3131

32+
def identity(x):
33+
return x
34+
3235
def default(val, d):
3336
if exists(val):
3437
return val
@@ -457,6 +460,12 @@ def normalize_img(x):
457460
def unnormalize_img(x):
458461
return (x + 1) * 0.5
459462

463+
# normalize variance of noised image, if scale is not 1
464+
465+
def normalize_img_variance(x, eps = 1e-5):
466+
std = reduce(x, 'b c h w -> b 1 1 1', partial(torch.std, unbiased = False))
467+
return x / std.clamp(min = eps)
468+
460469
# helper functions
461470

462471
def log(t, eps = 1e-20):
@@ -468,21 +477,32 @@ def right_pad_dims_to(x, t):
468477
return t
469478
return t.view(*t.shape, *((1,) * padding_dims))
470479

471-
def beta_linear_log_snr(t):
472-
return -log(expm1(1e-4 + 10 * (t ** 2)))
480+
# noise schedules
473481

474-
def alpha_cosine_log_snr(t, s = 0.008):
475-
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5)
482+
def simple_linear_schedule(t, clip_min = 1e-9):
483+
return (1 - t).clamp(min = clip_min)
476484

477-
def gamma_sigmoid_log_snr(t, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
485+
def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
486+
power = 2 * tau
487+
v_start = math.cos(start * math.pi / 2) ** power
488+
v_end = math.cos(end * math.pi / 2) ** power
489+
output = math.cos((t * (end - start) + start) * math.pi / 2) ** power
490+
output = (v_end - output) / (v_end - v_start)
491+
return output.clamp(min = clip_min)
492+
493+
def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
478494
v_start = torch.tensor(start / tau).sigmoid()
479495
v_end = torch.tensor(end / tau).sigmoid()
480496
gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
481-
gamma.clamp_(min = clamp_min, max = 1.)
482-
return -log(gamma ** -1. - 1, eps = 1e-5)
497+
return gamma.clamp_(min = clamp_min, max = 1.)
498+
499+
# converting gamma to alpha, sigma or logsnr
483500

484-
def log_snr_to_alpha_sigma(log_snr):
485-
return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
501+
def gamma_to_alpha_sigma(gamma):
502+
return torch.sqrt(gamma), torch.sqrt(1 - gamma)
503+
504+
def gamma_to_log_snr(gamma, eps = 1e-5):
505+
return -log(gamma ** -1. - 1, eps = eps)
486506

487507
# gaussian diffusion
488508

@@ -499,7 +519,8 @@ def __init__(
499519
objective = 'eps',
500520
schedule_kwargs: dict = dict(),
501521
time_difference = 0.,
502-
train_prob_self_cond = 0.9
522+
train_prob_self_cond = 0.9,
523+
scale = 1. # this will be set to < 1. for better convergence when training on higher resolution images
503524
):
504525
super().__init__()
505526
self.model = model
@@ -511,15 +532,23 @@ def __init__(
511532
self.image_size = image_size
512533

513534
if noise_schedule == "linear":
514-
self.log_snr = beta_linear_log_snr
535+
self.gamma_schedule = simple_linear_schedule
515536
elif noise_schedule == "cosine":
516-
self.log_snr = alpha_cosine_log_snr
537+
self.gamma_schedule = cosine_schedule
517538
elif noise_schedule == "sigmoid":
518-
self.log_snr = gamma_sigmoid_log_snr
539+
self.gamma_schedule = sigmoid_schedule
519540
else:
520541
raise ValueError(f'invalid noise schedule {noise_schedule}')
521542

522-
self.log_snr = partial(self.log_snr, **schedule_kwargs)
543+
# the main finding presented in Ting Chen's paper - that higher resolution images requires more noise for better training
544+
545+
assert scale <= 1, 'scale must be less than or equal to 1'
546+
self.scale = scale #
547+
self.normalize_img_variance = normalize_img_variance if scale < 1 else identity
548+
549+
# gamma schedules
550+
551+
self.gamma_schedule = partial(self.gamma_schedule, **schedule_kwargs)
523552

524553
self.timesteps = timesteps
525554
self.use_ddim = use_ddim
@@ -563,22 +592,23 @@ def ddpm_sample(self, shape, time_difference = None):
563592

564593
time_next = (time_next - self.time_difference).clamp(min = 0.)
565594

566-
noise_cond = self.log_snr(time)
595+
noise_cond = time
567596

568597
# get predicted x0
569598

599+
img = self.normalize_img_variance(img)
570600
model_output, last_latents = self.model(img, noise_cond, x_start, last_latents, return_latents = True)
571601

572602
# get log(snr)
573603

574-
log_snr = self.log_snr(time)
575-
log_snr_next = self.log_snr(time_next)
576-
log_snr, log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
604+
gamma = self.gamma_schedule(time)
605+
gamma_next = self.gamma_schedule(time_next)
606+
gamma, gamma_next = map(partial(right_pad_dims_to, img), (gamma, gamma_next))
577607

578608
# get alpha sigma of time and next time
579609

580-
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
581-
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
610+
alpha, sigma = gamma_to_alpha_sigma(gamma)
611+
alpha_next, sigma_next = gamma_to_alpha_sigma(gamma_next)
582612

583613
# calculate x0 and noise
584614

@@ -594,6 +624,8 @@ def ddpm_sample(self, shape, time_difference = None):
594624

595625
# derive posterior mean and variance
596626

627+
log_snr, log_snr_next = map(gamma_to_log_snr, (gamma, gamma_next))
628+
597629
c = -expm1(log_snr - log_snr_next)
598630

599631
mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
@@ -629,21 +661,22 @@ def ddim_sample(self, shape, time_difference = None):
629661

630662
# get times and noise levels
631663

632-
log_snr = self.log_snr(times)
633-
log_snr_next = self.log_snr(times_next)
664+
gamma = self.gamma_schedule(times)
665+
gamma_next = self.gamma_schedule(times_next)
634666

635-
padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
667+
padded_gamma, padded_gamma_next = map(partial(right_pad_dims_to, img), (gamma, gamma_next))
636668

637-
alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr)
638-
alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next)
669+
alpha, sigma = gamma_to_alpha_sigma(padded_gamma)
670+
alpha_next, sigma_next = gamma_to_alpha_sigma(padded_gamma_next)
639671

640672
# add the time delay
641673

642674
times_next = (times_next - time_difference).clamp(min = 0.)
643675

644676
# predict x0
645677

646-
model_output, last_latents = self.model(img, log_snr, x_start, last_latents, return_latents = True)
678+
img = self.normalize_img_variance(img)
679+
model_output, last_latents = self.model(img, times, x_start, last_latents, return_latents = True)
647680

648681
# calculate x0 and noise
649682

@@ -693,9 +726,9 @@ def forward(self, img, *args, **kwargs):
693726

694727
noise = torch.randn_like(img)
695728

696-
noise_level = self.log_snr(times)
697-
padded_noise_level = right_pad_dims_to(img, noise_level)
698-
alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level)
729+
gamma = self.gamma_schedule(times)
730+
padded_gamma = right_pad_dims_to(img, gamma)
731+
alpha, sigma = gamma_to_alpha_sigma(padded_gamma)
699732

700733
noised_img = alpha * img + sigma * noise
701734

@@ -706,13 +739,14 @@ def forward(self, img, *args, **kwargs):
706739

707740
if random() < self.train_prob_self_cond:
708741
with torch.no_grad():
709-
self_cond, self_latents = self.model(noised_img, noise_level, return_latents = True)
742+
self_cond, self_latents = self.model(noised_img, times, return_latents = True)
710743
self_cond = self_cond.detach()
711744
self_latents = self_latents.detach()
712745

713746
# predict and take gradient step
714747

715-
pred = self.model(noised_img, noise_level, self_cond, self_latents)
748+
noised_img = self.normalize_img_variance(noised_img)
749+
pred = self.model(noised_img, times, self_cond, self_latents)
716750

717751
if self.objective == 'x0':
718752
target = img

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.3.0',
6+
version = '0.4.0',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)