@@ -254,7 +254,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
254
254
print ('Epoch {}/{}' .format (epoch , num_epochs - 1 ))
255
255
# print('-' * 10)
256
256
257
- if opt .wa and wa_flag and epoch >= num_epochs * 0.1 :
257
+ if opt .wa and wa_flag and epoch >= num_epochs * 0.5 :
258
258
wa_flag = False
259
259
swa_model = swa_utils .AveragedModel (model )
260
260
swa_model .avg_fn = swa_utils .get_ema_avg_fn (decay = 0.996 )
@@ -356,7 +356,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
356
356
357
357
del inputs
358
358
# use extra DG Dataset (https://github.com/NVlabs/DG-Net#dg-market)
359
- if opt .DG and phase == 'train' and epoch > num_epochs * 0.1 :
359
+ if opt .DG and phase == 'train' and epoch > num_epochs * 0.5 :
360
360
# print("DG-Market is involved. It will double the training time.")
361
361
try :
362
362
_ , batch = DGloader_iter .__next__ ()
@@ -436,7 +436,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
436
436
pbar .set_postfix (ordered_dict = ordered_dict )
437
437
pbar .close ()
438
438
439
- if phase == 'train' and opt .wa and epoch >= num_epochs * 0.1 :
439
+ if phase == 'train' and opt .wa and epoch >= num_epochs * 0.5 :
440
440
swa_model .update_parameters (model )
441
441
swa_utils .update_bn (dataloaders ['train' ], swa_model , device = 'cuda:0' )
442
442
0 commit comments