@@ -568,6 +568,8 @@ def __init__(
568
568
objective = 'v' ,
569
569
schedule_kwargs : dict = dict (),
570
570
time_difference = 0. ,
571
+ min_snr_loss_weight = True ,
572
+ min_snr_gamma = 5 ,
571
573
train_prob_self_cond = 0.9 ,
572
574
scale = 1. # this will be set to < 1. for better convergence when training on higher resolution images
573
575
):
@@ -611,6 +613,11 @@ def __init__(
611
613
612
614
self .train_prob_self_cond = train_prob_self_cond
613
615
616
+ # min snr loss weight
617
+
618
+ self .min_snr_loss_weight = min_snr_loss_weight
619
+ self .min_snr_gamma = min_snr_gamma
620
+
614
621
@property
615
622
def device (self ):
616
623
return next (self .model .parameters ()).device
@@ -811,16 +818,36 @@ def forward(self, img, *args, **kwargs):
811
818
812
819
pred = self .model (noised_img , times , self_cond , self_latents )
813
820
814
- if self .objective == 'x0' :
815
- target = img
816
-
817
- elif self .objective == 'eps' :
821
+ if self .objective == 'eps' :
818
822
target = noise
819
823
824
+ elif self .objective == 'x0' :
825
+ target = img
826
+
820
827
elif self .objective == 'v' :
821
828
target = alpha * noise - sigma * img
822
829
823
- return F .mse_loss (pred , target )
830
+ loss = F .mse_loss (pred , target , reduction = 'none' )
831
+ loss = reduce (loss , 'b ... -> b' , 'mean' )
832
+
833
+ # min snr loss weight
834
+
835
+ snr = (alpha * alpha ) / (sigma * sigma )
836
+ maybe_clipped_snr = snr .clone ()
837
+
838
+ if self .min_snr_loss_weight :
839
+ maybe_clipped_snr .clamp_ (min = self .min_snr_gamma )
840
+
841
+ if self .objective == 'eps' :
842
+ loss_weight = maybe_clipped_snr / snr
843
+
844
+ elif self .objective == 'x0' :
845
+ loss_weight = maybe_clipped_snr
846
+
847
+ elif self .objective == 'v' :
848
+ loss_weight = maybe_clipped_snr / (snr + 1 )
849
+
850
+ return (loss * loss_weight ).mean ()
824
851
825
852
# dataset classes
826
853
@@ -872,7 +899,7 @@ def __init__(
872
899
train_num_steps = 100000 ,
873
900
ema_update_every = 10 ,
874
901
ema_decay = 0.995 ,
875
- adam_betas = (0.9 , 0.99 ),
902
+ betas = (0.9 , 0.99 ),
876
903
save_and_sample_every = 1000 ,
877
904
num_samples = 25 ,
878
905
results_folder = './results' ,
@@ -912,7 +939,7 @@ def __init__(
912
939
913
940
# optimizer
914
941
915
- self .opt = Adam (diffusion_model .parameters (), lr = train_lr , betas = adam_betas )
942
+ self .opt = Adam (diffusion_model .parameters (), lr = train_lr , betas = betas )
916
943
917
944
# for logging results in a folder periodically
918
945
0 commit comments