18
18
import time
19
19
import yaml
20
20
from datetime import datetime
21
+ from contextlib import suppress
21
22
22
- try :
23
- from apex import amp
24
- from apex .parallel import DistributedDataParallel as DDP
25
- from apex .parallel import convert_syncbn_model
26
- has_apex = True
27
- except ImportError :
28
- from torch .cuda import amp
29
- from torch .nn .parallel import DistributedDataParallel as DDP
30
- has_apex = False
31
-
32
-
23
+ import torch
24
+ import torch .nn as nn
25
+ import torchvision .utils
26
+ from torch .nn .parallel import DistributedDataParallel as NativeDDP
33
27
34
28
from timm .data import Dataset , create_loader , resolve_data_config , Mixup , FastCollateMixup , AugMixDataset
35
29
from timm .models import create_model , resume_checkpoint , convert_splitbn_model
38
32
from timm .optim import create_optimizer
39
33
from timm .scheduler import create_scheduler
40
34
41
- import torch
42
- import torch .nn as nn
43
- import torchvision .utils
35
+ try :
36
+ from apex import amp
37
+ from apex .parallel import DistributedDataParallel as ApexDDP
38
+ from apex .parallel import convert_syncbn_model
39
+ has_apex = True
40
+ except ImportError :
41
+ has_apex = False
42
+
43
+ has_native_amp = False
44
+ try :
45
+ if getattr (torch .cuda .amp , 'autocast' ) is not None :
46
+ has_native_amp = True
47
+ except AttributeError :
48
+ pass
44
49
45
50
torch .backends .cudnn .benchmark = True
46
51
_logger = logging .getLogger ('train' )
47
52
48
-
49
53
# The first arg parser parses out only the --config argument, this argument is used to
50
54
# load a yaml file containing key-values that override the defaults for the main parser below
51
55
config_parser = parser = argparse .ArgumentParser (description = 'Training Config' , add_help = False )
221
225
parser .add_argument ('--save-images' , action = 'store_true' , default = False ,
222
226
help = 'save images of input bathes every log interval for debugging' )
223
227
parser .add_argument ('--amp' , action = 'store_true' , default = False ,
224
- help = 'use NVIDIA amp for mixed precision training' )
228
+ help = 'use NVIDIA Apex AMP or Native AMP for mixed precision training' )
229
+ parser .add_argument ('--apex-amp' , action = 'store_true' , default = False ,
230
+ help = 'Use NVIDIA Apex AMP mixed precision' )
231
+ parser .add_argument ('--native-amp' , action = 'store_true' , default = False ,
232
+ help = 'Use Native Torch AMP mixed precision' )
233
+ parser .add_argument ('--channels-last' , action = 'store_true' , default = False ,
234
+ help = 'Use channels_last memory layout' )
225
235
parser .add_argument ('--pin-mem' , action = 'store_true' , default = False ,
226
236
help = 'Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.' )
227
237
parser .add_argument ('--no-prefetcher' , action = 'store_true' , default = False ,
@@ -254,6 +264,23 @@ def _parse_args():
254
264
return args , args_text
255
265
256
266
267
+ class ApexScaler :
268
+ def __call__ (self , loss , optimizer ):
269
+ with amp .scale_loss (loss , optimizer ) as scaled_loss :
270
+ scaled_loss .backward ()
271
+ optimizer .step ()
272
+
273
+
274
+ class NativeScaler :
275
+ def __init__ (self ):
276
+ self ._scaler = torch .cuda .amp .GradScaler ()
277
+
278
+ def __call__ (self , loss , optimizer ):
279
+ self ._scaler .scale (loss ).backward ()
280
+ self ._scaler .step (optimizer )
281
+ self ._scaler .update ()
282
+
283
+
257
284
def main ():
258
285
setup_default_logging ()
259
286
args , args_text = _parse_args ()
@@ -263,7 +290,8 @@ def main():
263
290
if 'WORLD_SIZE' in os .environ :
264
291
args .distributed = int (os .environ ['WORLD_SIZE' ]) > 1
265
292
if args .distributed and args .num_gpu > 1 :
266
- _logger .warning ('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' )
293
+ _logger .warning (
294
+ 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.' )
267
295
args .num_gpu = 1
268
296
269
297
args .device = 'cuda:0'
@@ -315,28 +343,50 @@ def main():
315
343
assert num_aug_splits > 1 or args .resplit
316
344
model = convert_splitbn_model (model , max (num_aug_splits , 2 ))
317
345
346
+ use_amp = None
347
+ if args .amp :
348
+ # for backwards compat, `--amp` arg tries apex before native amp
349
+ if has_apex :
350
+ args .apex_amp = True
351
+ elif has_native_amp :
352
+ args .native_amp = True
353
+ if args .apex_amp and has_apex :
354
+ use_amp = 'apex'
355
+ elif args .native_amp and has_native_amp :
356
+ use_amp = 'native'
357
+ elif args .apex_amp or args .native_amp :
358
+ _logger .warning ("Neither APEX or native Torch AMP is available, using float32. "
359
+ "Install NVIDA apex or upgrade to PyTorch 1.6" )
360
+
318
361
if args .num_gpu > 1 :
319
- if args . amp :
362
+ if use_amp == 'apex' :
320
363
_logger .warning (
321
- 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' )
322
- args . amp = False
364
+ 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.' )
365
+ use_amp = None
323
366
model = nn .DataParallel (model , device_ids = list (range (args .num_gpu ))).cuda ()
367
+ assert not args .channels_last , "Channels last not supported with DP, use DDP."
324
368
else :
325
369
model .cuda ()
370
+ if args .channels_last :
371
+ model = model .to (memory_format = torch .channels_last )
326
372
327
373
optimizer = create_optimizer (args , model )
328
374
329
- use_amp = False
330
- if has_apex and args .amp :
375
+ amp_autocast = suppress # do nothing
376
+ loss_scaler = None
377
+ if use_amp == 'apex' :
331
378
model , optimizer = amp .initialize (model , optimizer , opt_level = 'O1' )
332
- use_amp = True
333
- elif args .amp :
334
- _logger .info ('Using torch AMP. Install NVIDIA Apex for Apex AMP.' )
335
- scaler = torch .cuda .amp .GradScaler ()
336
- use_amp = True
337
- if args .local_rank == 0 :
338
- _logger .info ('NVIDIA APEX {}. AMP {}.' .format (
339
- 'installed' if has_apex else 'not installed' , 'on' if use_amp else 'off' ))
379
+ loss_scaler = ApexScaler ()
380
+ if args .local_rank == 0 :
381
+ _logger .info ('Using NVIDIA APEX AMP. Training in mixed precision.' )
382
+ elif use_amp == 'native' :
383
+ amp_autocast = torch .cuda .amp .autocast
384
+ loss_scaler = NativeScaler ()
385
+ if args .local_rank == 0 :
386
+ _logger .info ('Using native Torch AMP. Training in mixed precision.' )
387
+ else :
388
+ if args .local_rank == 0 :
389
+ _logger .info ('AMP not enabled. Training in float32.' )
340
390
341
391
# optionally resume from a checkpoint
342
392
resume_state = {}
@@ -346,7 +396,7 @@ def main():
346
396
if resume_state and not args .no_resume_opt :
347
397
if 'optimizer' in resume_state :
348
398
if args .local_rank == 0 :
349
- _logger .info ('Restoring Optimizer state from checkpoint' )
399
+ _logger .info ('Restoring optimizer state from checkpoint' )
350
400
optimizer .load_state_dict (resume_state ['optimizer' ])
351
401
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp .__dict__ :
352
402
if args .local_rank == 0 :
@@ -367,7 +417,8 @@ def main():
367
417
if args .sync_bn :
368
418
assert not args .split_bn
369
419
try :
370
- if has_apex :
420
+ if has_apex and use_amp != 'native' :
421
+ # Apex SyncBN preferred unless native amp is activated
371
422
model = convert_syncbn_model (model )
372
423
else :
373
424
model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -377,12 +428,15 @@ def main():
377
428
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' )
378
429
except Exception as e :
379
430
_logger .error ('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' )
380
- if has_apex :
381
- model = DDP (model , delay_allreduce = True )
431
+ if has_apex and use_amp != 'native' :
432
+ # Apex DDP preferred unless native amp is activated
433
+ if args .local_rank == 0 :
434
+ _logger .info ("Using NVIDIA APEX DistributedDataParallel." )
435
+ model = ApexDDP (model , delay_allreduce = True )
382
436
else :
383
437
if args .local_rank == 0 :
384
- _logger .info ("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP ." )
385
- model = DDP (model , device_ids = [args .local_rank ]) # can use device str in Torch >= 1.1
438
+ _logger .info ("Using native Torch DistributedDataParallel ." )
439
+ model = NativeDDP (model , device_ids = [args .local_rank ]) # can use device str in Torch >= 1.1
386
440
# NOTE: EMA model does not need to be wrapped by DDP
387
441
388
442
lr_scheduler , num_epochs = create_scheduler (args , optimizer )
@@ -501,7 +555,7 @@ def main():
501
555
])
502
556
output_dir = get_outdir (output_base , 'train' , exp_name )
503
557
decreasing = True if eval_metric == 'loss' else False
504
- saver = CheckpointSaver (checkpoint_dir = output_dir , decreasing = decreasing )
558
+ saver = CheckpointSaver (checkpoint_dir = output_dir , decreasing = decreasing , save_amp = use_amp == 'apex' )
505
559
with open (os .path .join (output_dir , 'args.yaml' ), 'w' ) as f :
506
560
f .write (args_text )
507
561
@@ -513,22 +567,20 @@ def main():
513
567
train_metrics = train_epoch (
514
568
epoch , model , loader_train , optimizer , train_loss_fn , args ,
515
569
lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir ,
516
- use_amp = use_amp , has_apex = has_apex , scaler = scaler ,
517
- model_ema = model_ema , mixup_fn = mixup_fn )
570
+ amp_autocast = amp_autocast , loss_scaler = loss_scaler , model_ema = model_ema , mixup_fn = mixup_fn )
518
571
519
572
if args .distributed and args .dist_bn in ('broadcast' , 'reduce' ):
520
573
if args .local_rank == 0 :
521
574
_logger .info ("Distributing BatchNorm running means and vars" )
522
575
distribute_bn (model , args .world_size , args .dist_bn == 'reduce' )
523
576
524
- eval_metrics = validate (model , loader_eval , validate_loss_fn , args )
577
+ eval_metrics = validate (model , loader_eval , validate_loss_fn , args , amp_autocast = amp_autocast )
525
578
526
579
if model_ema is not None and not args .model_ema_force_cpu :
527
580
if args .distributed and args .dist_bn in ('broadcast' , 'reduce' ):
528
581
distribute_bn (model_ema , args .world_size , args .dist_bn == 'reduce' )
529
-
530
582
ema_eval_metrics = validate (
531
- model_ema .ema , loader_eval , validate_loss_fn , args , log_suffix = ' (EMA)' )
583
+ model_ema .ema , loader_eval , validate_loss_fn , args , amp_autocast = amp_autocast , log_suffix = ' (EMA)' )
532
584
eval_metrics = ema_eval_metrics
533
585
534
586
if lr_scheduler is not None :
@@ -543,8 +595,7 @@ def main():
543
595
# save proper checkpoint with eval metric
544
596
save_metric = eval_metrics [eval_metric ]
545
597
best_metric , best_epoch = saver .save_checkpoint (
546
- model , optimizer , args ,
547
- epoch = epoch , model_ema = model_ema , metric = save_metric , use_amp = has_apex & use_amp )
598
+ model , optimizer , args , epoch = epoch , model_ema = model_ema , metric = save_metric )
548
599
549
600
except KeyboardInterrupt :
550
601
pass
@@ -554,8 +605,8 @@ def main():
554
605
555
606
def train_epoch (
556
607
epoch , model , loader , optimizer , loss_fn , args ,
557
- lr_scheduler = None , saver = None , output_dir = '' , use_amp = False ,
558
- has_apex = False , scaler = None , model_ema = None , mixup_fn = None ):
608
+ lr_scheduler = None , saver = None , output_dir = '' , amp_autocast = suppress ,
609
+ loss_scaler = None , model_ema = None , mixup_fn = None ):
559
610
560
611
if args .mixup_off_epoch and epoch >= args .mixup_off_epoch :
561
612
if args .prefetcher and loader .mixup_enabled :
@@ -579,31 +630,21 @@ def train_epoch(
579
630
input , target = input .cuda (), target .cuda ()
580
631
if mixup_fn is not None :
581
632
input , target = mixup_fn (input , target )
582
- if not has_apex and use_amp :
583
- with torch .cuda .amp .autocast ():
584
- output = model (input )
585
- loss = loss_fn (output , target )
586
- else :
633
+ if args .channels_last :
634
+ input = input .contiguous (memory_format = torch .channels_last )
635
+
636
+ with amp_autocast ():
587
637
output = model (input )
588
638
loss = loss_fn (output , target )
589
-
639
+
590
640
if not args .distributed :
591
641
losses_m .update (loss .item (), input .size (0 ))
592
642
593
643
optimizer .zero_grad ()
594
- if use_amp :
595
- if has_apex :
596
- with amp .scale_loss (loss , optimizer ) as scaled_loss :
597
- scaled_loss .backward ()
598
- else :
599
- scaler .scale (loss ).backward ()
600
-
644
+ if loss_scaler is not None :
645
+ loss_scaler (loss , optimizer )
601
646
else :
602
647
loss .backward ()
603
- if not has_apex and use_amp :
604
- scaler .step (optimizer )
605
- scaler .update ()
606
- else :
607
648
optimizer .step ()
608
649
609
650
torch .cuda .synchronize ()
@@ -648,8 +689,7 @@ def train_epoch(
648
689
if saver is not None and args .recovery_interval and (
649
690
last_batch or (batch_idx + 1 ) % args .recovery_interval == 0 ):
650
691
651
- saver .save_recovery (
652
- model , optimizer , args , epoch , model_ema = model_ema , use_amp = has_apex & use_amp , batch_idx = batch_idx )
692
+ saver .save_recovery (model , optimizer , args , epoch , model_ema = model_ema , batch_idx = batch_idx )
653
693
654
694
if lr_scheduler is not None :
655
695
lr_scheduler .step_update (num_updates = num_updates , metric = losses_m .avg )
@@ -663,7 +703,7 @@ def train_epoch(
663
703
return OrderedDict ([('loss' , losses_m .avg )])
664
704
665
705
666
- def validate (model , loader , loss_fn , args , log_suffix = '' ):
706
+ def validate (model , loader , loss_fn , args , amp_autocast = suppress , log_suffix = '' ):
667
707
batch_time_m = AverageMeter ()
668
708
losses_m = AverageMeter ()
669
709
top1_m = AverageMeter ()
@@ -679,8 +719,11 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
679
719
if not args .prefetcher :
680
720
input = input .cuda ()
681
721
target = target .cuda ()
722
+ if args .channels_last :
723
+ input = input .contiguous (memory_format = torch .channels_last )
682
724
683
- output = model (input )
725
+ with amp_autocast ():
726
+ output = model (input )
684
727
if isinstance (output , (tuple , list )):
685
728
output = output [0 ]
686
729
0 commit comments