Skip to content

Commit 77e4ced

Browse files
committed
add min snr loss weight
1 parent 0a4e941 commit 77e4ced

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,11 @@ sampled_images.shape # (4, 3, 128, 128)
141141
copyright = {Creative Commons Attribution 4.0 International}
142142
}
143143
```
144+
145+
```bibtex
146+
@inproceedings{Hang2023EfficientDT,
147+
title = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
148+
author = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
149+
year = {2023}
150+
}
151+
```

rin_pytorch/rin_pytorch.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ def __init__(
568568
objective = 'v',
569569
schedule_kwargs: dict = dict(),
570570
time_difference = 0.,
571+
min_snr_loss_weight = True,
572+
min_snr_gamma = 5,
571573
train_prob_self_cond = 0.9,
572574
scale = 1. # this will be set to < 1. for better convergence when training on higher resolution images
573575
):
@@ -611,6 +613,11 @@ def __init__(
611613

612614
self.train_prob_self_cond = train_prob_self_cond
613615

616+
# min snr loss weight
617+
618+
self.min_snr_loss_weight = min_snr_loss_weight
619+
self.min_snr_gamma = min_snr_gamma
620+
614621
@property
615622
def device(self):
616623
return next(self.model.parameters()).device
@@ -811,16 +818,36 @@ def forward(self, img, *args, **kwargs):
811818

812819
pred = self.model(noised_img, times, self_cond, self_latents)
813820

814-
if self.objective == 'x0':
815-
target = img
816-
817-
elif self.objective == 'eps':
821+
if self.objective == 'eps':
818822
target = noise
819823

824+
elif self.objective == 'x0':
825+
target = img
826+
820827
elif self.objective == 'v':
821828
target = alpha * noise - sigma * img
822829

823-
return F.mse_loss(pred, target)
830+
loss = F.mse_loss(pred, target, reduction = 'none')
831+
loss = reduce(loss, 'b ... -> b', 'mean')
832+
833+
# min snr loss weight
834+
835+
snr = (alpha * alpha) / (sigma * sigma)
836+
maybe_clipped_snr = snr.clone()
837+
838+
if self.min_snr_loss_weight:
839+
maybe_clipped_snr.clamp_(min = self.min_snr_gamma)
840+
841+
if self.objective == 'eps':
842+
loss_weight = maybe_clipped_snr / snr
843+
844+
elif self.objective == 'x0':
845+
loss_weight = maybe_clipped_snr
846+
847+
elif self.objective == 'v':
848+
loss_weight = maybe_clipped_snr / (snr + 1)
849+
850+
return (loss * loss_weight).mean()
824851

825852
# dataset classes
826853

@@ -872,7 +899,7 @@ def __init__(
872899
train_num_steps = 100000,
873900
ema_update_every = 10,
874901
ema_decay = 0.995,
875-
adam_betas = (0.9, 0.99),
902+
betas = (0.9, 0.99),
876903
save_and_sample_every = 1000,
877904
num_samples = 25,
878905
results_folder = './results',
@@ -912,7 +939,7 @@ def __init__(
912939

913940
# optimizer
914941

915-
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
942+
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = betas)
916943

917944
# for logging results in a folder periodically
918945

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.6.1',
6+
version = '0.7.1',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)