File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -514,8 +514,14 @@ def main():
514
514
if utils .is_primary (args ):
515
515
_logger .info ('Using NVIDIA APEX AMP. Training in mixed precision.' )
516
516
elif use_amp == 'native' :
517
- amp_autocast = partial (torch .autocast , device_type = device .type , dtype = amp_dtype )
518
- if device .type == 'cuda' :
517
+ try :
518
+ amp_autocast = partial (torch .autocast , device_type = device .type , dtype = amp_dtype )
519
+ except (AttributeError , TypeError ):
520
+ # fallback to CUDA only AMP for PyTorch < 1.10
521
+ assert device .type == 'cuda'
522
+ amp_autocast = torch .cuda .amp .autocast
523
+ if device .type == 'cuda' and amp_dtype == torch .float16 :
524
+ # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
519
525
loss_scaler = NativeScaler ()
520
526
if utils .is_primary (args ):
521
527
_logger .info ('Using native Torch AMP. Training in mixed precision.' )
You can’t perform that action at this time.
0 commit comments