Skip to content

Commit 5081b53

Browse files
authored
Merge pull request #2308 from huggingface/device_amp_cleanup
Cleanup some amp related behaviour to better support different (non-cuda) devices
2 parents a852318 + c3992d5 commit 5081b53

File tree

8 files changed

+61
-52
lines changed

8 files changed

+61
-52
lines changed

benchmark.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@
3232
except ImportError:
3333
pass
3434

35-
has_native_amp = False
36-
try:
37-
if getattr(torch.cuda.amp, 'autocast') is not None:
38-
has_native_amp = True
39-
except AttributeError:
40-
pass
41-
4235
try:
4336
from deepspeed.profiling.flops_profiler import get_model_profile
4437
has_deepspeed_profiling = True
@@ -242,7 +235,7 @@ def __init__(
242235
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
243236
self.channels_last = kwargs.pop('channels_last', False)
244237
if self.amp_dtype is not None:
245-
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
238+
self.amp_autocast = partial(torch.amp.autocast, device_type=device, dtype=self.amp_dtype)
246239
else:
247240
self.amp_autocast = suppress
248241

inference.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,6 @@
2828
except ImportError:
2929
has_apex = False
3030

31-
has_native_amp = False
32-
try:
33-
if getattr(torch.cuda.amp, 'autocast') is not None:
34-
has_native_amp = True
35-
except AttributeError:
36-
pass
37-
3831
try:
3932
from functorch.compile import memory_efficient_fusion
4033
has_functorch = True
@@ -170,7 +163,6 @@ def main():
170163
# resolve AMP arguments based on PyTorch / Apex availability
171164
amp_autocast = suppress
172165
if args.amp:
173-
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
174166
assert args.amp_dtype in ('float16', 'bfloat16')
175167
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
176168
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)

timm/data/loader.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,17 @@ def __init__(
113113
)
114114
else:
115115
self.random_erasing = None
116-
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
116+
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
117+
self.is_npu = device.type == 'npu' and torch.npu.is_available()
117118

118119
def __iter__(self):
119120
first = True
120121
if self.is_cuda:
121122
stream = torch.cuda.Stream()
122123
stream_context = partial(torch.cuda.stream, stream=stream)
124+
elif self.is_npu:
125+
stream = torch.npu.Stream()
126+
stream_context = partial(torch.npu.stream, stream=stream)
123127
else:
124128
stream = None
125129
stream_context = suppress
@@ -139,7 +143,10 @@ def __iter__(self):
139143
first = False
140144

141145
if stream is not None:
142-
torch.cuda.current_stream().wait_stream(stream)
146+
if self.is_cuda:
147+
torch.cuda.current_stream().wait_stream(stream)
148+
elif self.is_npu:
149+
torch.npu.current_stream().wait_stream(stream)
143150

144151
input = next_input
145152
target = next_target

timm/layers/fast_norm.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,30 @@
2828
_USE_FAST_NORM = False # defaulting to False for now
2929

3030

31+
def get_autocast_dtype(device: str = 'cuda'):
32+
try:
33+
return torch.get_autocast_dtype(device)
34+
except (AttributeError, TypeError):
35+
# dispatch to older device specific fns, only covering cuda/cpu devices here
36+
if device == 'cpu':
37+
return torch.get_autocast_cpu_dtype()
38+
else:
39+
assert device == 'cuda'
40+
return torch.get_autocast_gpu_dtype()
41+
42+
43+
def is_autocast_enabled(device: str = 'cuda'):
44+
try:
45+
return torch.is_autocast_enabled(device)
46+
except TypeError:
47+
# dispatch to older device specific fns, only covering cuda/cpu devices here
48+
if device == 'cpu':
49+
return torch.is_autocast_cpu_enabled()
50+
else:
51+
assert device == 'cuda'
52+
return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch)
53+
54+
3155
def is_fast_norm():
3256
return _USE_FAST_NORM
3357

@@ -48,14 +72,14 @@ def fast_group_norm(
4872
# currently cannot use is_autocast_enabled within torchscript
4973
return F.group_norm(x, num_groups, weight, bias, eps)
5074

51-
if torch.is_autocast_enabled():
75+
if is_autocast_enabled(x.device.type):
5276
# normally native AMP casts GN inputs to float32
5377
# here we use the low precision autocast dtype
5478
# FIXME what to do re CPU autocast?
55-
dt = torch.get_autocast_gpu_dtype()
79+
dt = get_autocast_dtype(x.device.type)
5680
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
5781

58-
with torch.cuda.amp.autocast(enabled=False):
82+
with torch.amp.autocast(device_type=x.device.type, enabled=False):
5983
return F.group_norm(x, num_groups, weight, bias, eps)
6084

6185

@@ -73,14 +97,14 @@ def fast_layer_norm(
7397
if has_apex:
7498
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
7599

76-
if torch.is_autocast_enabled():
100+
if is_autocast_enabled(x.device.type):
77101
# normally native AMP casts LN inputs to float32
78102
# apex LN does not, this is behaving like Apex
79-
dt = torch.get_autocast_gpu_dtype()
103+
dt = get_autocast_dtype(x.device.type)
80104
# FIXME what to do re CPU autocast?
81105
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
82106

83-
with torch.cuda.amp.autocast(enabled=False):
107+
with torch.amp.autocast(device_type=x.device.type, enabled=False):
84108
return F.layer_norm(x, normalized_shape, weight, bias, eps)
85109

86110

timm/utils/cuda.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ def load_state_dict(self, state_dict):
4646
class NativeScaler:
4747
state_dict_key = "amp_scaler"
4848

49-
def __init__(self):
50-
self._scaler = torch.cuda.amp.GradScaler()
49+
def __init__(self, device='cuda'):
50+
try:
51+
self._scaler = torch.amp.GradScaler(device=device)
52+
except (AttributeError, TypeError) as e:
53+
self._scaler = torch.cuda.amp.GradScaler()
5154

5255
def __call__(
5356
self,

timm/utils/distributed.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def init_distributed_device_so(
116116
"xpu": "ccl",
117117
"hpu": "hccl",
118118
"cuda": "nccl",
119+
"npu": "hccl",
119120
}
120121
dist_backend = dist_backends.get(device_type, 'gloo')
121122
dist_url = dist_url or 'env://'
@@ -159,6 +160,8 @@ def init_distributed_device_so(
159160

160161
if device_type == 'cuda':
161162
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
163+
if device_type == 'npu':
164+
assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.'
162165

163166
if distributed and device != 'cpu':
164167
# Ignore manually specified device index in distributed mode and

train.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,6 @@
4848
except ImportError:
4949
has_apex = False
5050

51-
has_native_amp = False
52-
try:
53-
if getattr(torch.cuda.amp, 'autocast') is not None:
54-
has_native_amp = True
55-
except AttributeError:
56-
pass
5751

5852
try:
5953
import wandb
@@ -442,7 +436,6 @@ def main():
442436
use_amp = 'apex'
443437
assert args.amp_dtype == 'float16'
444438
else:
445-
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
446439
use_amp = 'native'
447440
assert args.amp_dtype in ('float16', 'bfloat16')
448441
if args.amp_dtype == 'bfloat16':
@@ -572,15 +565,10 @@ def main():
572565
if utils.is_primary(args):
573566
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
574567
elif use_amp == 'native':
575-
try:
576-
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
577-
except (AttributeError, TypeError):
578-
# fallback to CUDA only AMP for PyTorch < 1.10
579-
assert device.type == 'cuda'
580-
amp_autocast = torch.cuda.amp.autocast
581-
if device.type == 'cuda' and amp_dtype == torch.float16:
568+
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
569+
if device.type in ('cuda',) and amp_dtype == torch.float16:
582570
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
583-
loss_scaler = NativeScaler()
571+
loss_scaler = NativeScaler(device=device.type)
584572
if utils.is_primary(args):
585573
_logger.info('Using native Torch AMP. Training in mixed precision.')
586574
else:
@@ -1054,8 +1042,11 @@ def _backward(_loss):
10541042
if model_ema is not None:
10551043
model_ema.update(model, step=num_updates)
10561044

1057-
if args.synchronize_step and device.type == 'cuda':
1058-
torch.cuda.synchronize()
1045+
if args.synchronize_step:
1046+
if device.type == 'cuda':
1047+
torch.cuda.synchronize()
1048+
elif device.type == 'npu':
1049+
torch.npu.synchronize()
10591050
time_now = time.time()
10601051
update_time_m.update(time.time() - update_start_time)
10611052
update_start_time = time_now
@@ -1155,6 +1146,8 @@ def validate(
11551146

11561147
if device.type == 'cuda':
11571148
torch.cuda.synchronize()
1149+
elif device.type == "npu":
1150+
torch.npu.synchronize()
11581151

11591152
losses_m.update(reduced_loss.item(), input.size(0))
11601153
top1_m.update(acc1.item(), output.size(0))

validate.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@
3434
except ImportError:
3535
has_apex = False
3636

37-
has_native_amp = False
38-
try:
39-
if getattr(torch.cuda.amp, 'autocast') is not None:
40-
has_native_amp = True
41-
except AttributeError:
42-
pass
43-
4437
try:
4538
from functorch.compile import memory_efficient_fusion
4639
has_functorch = True
@@ -183,7 +176,6 @@ def validate(args):
183176
use_amp = 'apex'
184177
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
185178
else:
186-
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
187179
assert args.amp_dtype in ('float16', 'bfloat16')
188180
use_amp = 'native'
189181
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
@@ -395,8 +387,10 @@ def _try_run(args, initial_batch_size):
395387
while batch_size:
396388
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
397389
try:
398-
if torch.cuda.is_available() and 'cuda' in args.device:
390+
if 'cuda' in args.device and torch.cuda.is_available():
399391
torch.cuda.empty_cache()
392+
elif "npu" in args.device and torch.npu.is_available():
393+
torch.npu.empty_cache()
400394
results = validate(args)
401395
return results
402396
except RuntimeError as e:

0 commit comments

Comments
 (0)