Skip to content

Commit 43e6143

Browse files
committed
Fix #1712 broken support for AMP w/ PyTorch < 1.10. Disable loss scaler for bfloat16
1 parent 3a636ee commit 43e6143

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,14 @@ def main():
514514
if utils.is_primary(args):
515515
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
516516
elif use_amp == 'native':
517-
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
518-
if device.type == 'cuda':
517+
try:
518+
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
519+
except (AttributeError, TypeError):
520+
# fallback to CUDA only AMP for PyTorch < 1.10
521+
assert device.type == 'cuda'
522+
amp_autocast = torch.cuda.amp.autocast
523+
if device.type == 'cuda' and amp_dtype == torch.float16:
524+
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
519525
loss_scaler = NativeScaler()
520526
if utils.is_primary(args):
521527
_logger.info('Using native Torch AMP. Training in mixed precision.')

0 commit comments

Comments
 (0)