Skip to content

Commit 5f563ca

Browse files
committed
fix save_checkpoint bug with native amp
1 parent d98967e commit 5f563ca

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def main():
544544
save_metric = eval_metrics[eval_metric]
545545
best_metric, best_epoch = saver.save_checkpoint(
546546
model, optimizer, args,
547-
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
547+
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp)
548548

549549
except KeyboardInterrupt:
550550
pass
@@ -647,8 +647,9 @@ def train_epoch(
647647

648648
if saver is not None and args.recovery_interval and (
649649
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
650+
650651
saver.save_recovery(
651-
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
652+
model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx)
652653

653654
if lr_scheduler is not None:
654655
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

0 commit comments

Comments
 (0)