25
25
from apex .parallel import convert_syncbn_model
26
26
has_apex = True
27
27
except ImportError :
28
+ from torch .cuda import amp
28
29
from torch .nn .parallel import DistributedDataParallel as DDP
29
30
has_apex = False
31
+
32
+
30
33
31
34
from timm .data import Dataset , create_loader , resolve_data_config , Mixup , FastCollateMixup , AugMixDataset
32
35
from timm .models import create_model , resume_checkpoint , convert_splitbn_model
@@ -327,6 +330,10 @@ def main():
327
330
if has_apex and args .amp :
328
331
model , optimizer = amp .initialize (model , optimizer , opt_level = 'O1' )
329
332
use_amp = True
333
+ elif args .amp :
334
+ _logger .info ('Using torch AMP. Install NVIDIA Apex for Apex AMP.' )
335
+ scaler = torch .cuda .amp .GradScaler ()
336
+ use_amp = True
330
337
if args .local_rank == 0 :
331
338
_logger .info ('NVIDIA APEX {}. AMP {}.' .format (
332
339
'installed' if has_apex else 'not installed' , 'on' if use_amp else 'off' ))
@@ -506,7 +513,8 @@ def main():
506
513
train_metrics = train_epoch (
507
514
epoch , model , loader_train , optimizer , train_loss_fn , args ,
508
515
lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir ,
509
- use_amp = use_amp , model_ema = model_ema , mixup_fn = mixup_fn )
516
+ use_amp = use_amp , has_apex = has_apex , scaler = scaler ,
517
+ model_ema = model_ema , mixup_fn = mixup_fn )
510
518
511
519
if args .distributed and args .dist_bn in ('broadcast' , 'reduce' ):
512
520
if args .local_rank == 0 :
@@ -546,7 +554,8 @@ def main():
546
554
547
555
def train_epoch (
548
556
epoch , model , loader , optimizer , loss_fn , args ,
549
- lr_scheduler = None , saver = None , output_dir = '' , use_amp = False , model_ema = None , mixup_fn = None ):
557
+ lr_scheduler = None , saver = None , output_dir = '' , use_amp = False ,
558
+ has_apex = False , scaler = None , model_ema = None , mixup_fn = None ):
550
559
551
560
if args .mixup_off_epoch and epoch >= args .mixup_off_epoch :
552
561
if args .prefetcher and loader .mixup_enabled :
@@ -570,20 +579,32 @@ def train_epoch(
570
579
input , target = input .cuda (), target .cuda ()
571
580
if mixup_fn is not None :
572
581
input , target = mixup_fn (input , target )
573
-
574
- output = model (input )
575
-
576
- loss = loss_fn (output , target )
582
+ if not has_apex and use_amp :
583
+ with torch .cuda .amp .autocast ():
584
+ output = model (input )
585
+ loss = loss_fn (output , target )
586
+ else :
587
+ output = model (input )
588
+ loss = loss_fn (output , target )
589
+
577
590
if not args .distributed :
578
591
losses_m .update (loss .item (), input .size (0 ))
579
592
580
593
optimizer .zero_grad ()
581
594
if use_amp :
582
- with amp .scale_loss (loss , optimizer ) as scaled_loss :
583
- scaled_loss .backward ()
595
+ if has_apex :
596
+ with amp .scale_loss (loss , optimizer ) as scaled_loss :
597
+ scaled_loss .backward ()
598
+ else :
599
+ scaler .scale (loss ).backward ()
600
+
584
601
else :
585
602
loss .backward ()
586
- optimizer .step ()
603
+ if not has_apex and use_amp :
604
+ scaler .step (optimizer )
605
+ scaler .update ()
606
+ else :
607
+ optimizer .step ()
587
608
588
609
torch .cuda .synchronize ()
589
610
if model_ema is not None :
0 commit comments