Skip to content

Commit 1d34a0a

Browse files
committed
Merge branch 'master' of https://github.com/tgisaturday/pytorch-image-models into torchamp
2 parents 6d158ad + 5f563ca commit 1d34a0a

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

train.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@
2525
from apex.parallel import convert_syncbn_model
2626
has_apex = True
2727
except ImportError:
28+
from torch.cuda import amp
2829
from torch.nn.parallel import DistributedDataParallel as DDP
2930
has_apex = False
31+
32+
3033

3134
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
3235
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
@@ -327,6 +330,10 @@ def main():
327330
if has_apex and args.amp:
328331
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
329332
use_amp = True
333+
elif args.amp:
334+
_logger.info('Using torch AMP. Install NVIDIA Apex for Apex AMP.')
335+
scaler = torch.cuda.amp.GradScaler()
336+
use_amp = True
330337
if args.local_rank == 0:
331338
_logger.info('NVIDIA APEX {}. AMP {}.'.format(
332339
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
@@ -506,7 +513,8 @@ def main():
506513
train_metrics = train_epoch(
507514
epoch, model, loader_train, optimizer, train_loss_fn, args,
508515
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
509-
use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn)
516+
use_amp=use_amp, has_apex=has_apex, scaler = scaler,
517+
model_ema=model_ema, mixup_fn=mixup_fn)
510518

511519
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
512520
if args.local_rank == 0:
@@ -536,7 +544,7 @@ def main():
536544
save_metric = eval_metrics[eval_metric]
537545
best_metric, best_epoch = saver.save_checkpoint(
538546
model, optimizer, args,
539-
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
547+
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp)
540548

541549
except KeyboardInterrupt:
542550
pass
@@ -546,7 +554,8 @@ def main():
546554

547555
def train_epoch(
548556
epoch, model, loader, optimizer, loss_fn, args,
549-
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None):
557+
lr_scheduler=None, saver=None, output_dir='', use_amp=False,
558+
has_apex=False, scaler = None, model_ema=None, mixup_fn=None):
550559

551560
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
552561
if args.prefetcher and loader.mixup_enabled:
@@ -570,20 +579,32 @@ def train_epoch(
570579
input, target = input.cuda(), target.cuda()
571580
if mixup_fn is not None:
572581
input, target = mixup_fn(input, target)
573-
574-
output = model(input)
575-
576-
loss = loss_fn(output, target)
582+
if not has_apex and use_amp:
583+
with torch.cuda.amp.autocast():
584+
output = model(input)
585+
loss = loss_fn(output, target)
586+
else:
587+
output = model(input)
588+
loss = loss_fn(output, target)
589+
577590
if not args.distributed:
578591
losses_m.update(loss.item(), input.size(0))
579592

580593
optimizer.zero_grad()
581594
if use_amp:
582-
with amp.scale_loss(loss, optimizer) as scaled_loss:
583-
scaled_loss.backward()
595+
if has_apex:
596+
with amp.scale_loss(loss, optimizer) as scaled_loss:
597+
scaled_loss.backward()
598+
else:
599+
scaler.scale(loss).backward()
600+
584601
else:
585602
loss.backward()
586-
optimizer.step()
603+
if not has_apex and use_amp:
604+
scaler.step(optimizer)
605+
scaler.update()
606+
else:
607+
optimizer.step()
587608

588609
torch.cuda.synchronize()
589610
if model_ema is not None:
@@ -626,8 +647,9 @@ def train_epoch(
626647

627648
if saver is not None and args.recovery_interval and (
628649
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
650+
629651
saver.save_recovery(
630-
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
652+
model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx)
631653

632654
if lr_scheduler is not None:
633655
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

0 commit comments

Comments
 (0)