Skip to content

Commit 1cd21fc

Browse files
author
um1
committed
ema later
1 parent e0c6dae commit 1cd21fc

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
254254
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
255255
# print('-' * 10)
256256

257-
if opt.wa and wa_flag and epoch >= num_epochs*0.5:
257+
if opt.wa and wa_flag and epoch >= num_epochs*0.8:
258258
wa_flag = False
259259
swa_model = swa_utils.AveragedModel(model)
260260
swa_model.avg_fn = swa_utils.get_ema_avg_fn(decay=0.996)
@@ -356,7 +356,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
356356

357357
del inputs
358358
# use extra DG Dataset (https://github.com/NVlabs/DG-Net#dg-market)
359-
if opt.DG and phase == 'train' and epoch > num_epochs*0.5:
359+
if opt.DG and phase == 'train' and epoch > num_epochs*0.8:
360360
# print("DG-Market is involved. It will double the training time.")
361361
try:
362362
_, batch = DGloader_iter.__next__()
@@ -436,7 +436,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
436436
pbar.set_postfix(ordered_dict=ordered_dict)
437437
pbar.close()
438438

439-
if phase == 'train' and opt.wa and epoch >= num_epochs*0.5:
439+
if phase == 'train' and opt.wa and epoch >= num_epochs*0.8:
440440
swa_model.update_parameters(model)
441441
swa_utils.update_bn(dataloaders['train'], swa_model, device='cuda:0')
442442

0 commit comments

Comments
 (0)