Skip to content

Commit 81b59fa

Browse files
committed
Merge branch 'npu_support' of github.com:MengqingCao/pytorch-image-models into MengqingCao-npu_support
2 parents a852318 + 37c731c commit 81b59fa

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

timm/data/loader.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,17 @@ def __init__(
113113
)
114114
else:
115115
self.random_erasing = None
116-
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
116+
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
117+
self.is_npu = device.type == 'npu' and torch.npu.is_available()
117118

118119
def __iter__(self):
119120
first = True
120121
if self.is_cuda:
121122
stream = torch.cuda.Stream()
122123
stream_context = partial(torch.cuda.stream, stream=stream)
124+
elif self.is_npu:
125+
stream = torch.npu.Stream()
126+
stream_context = partial(torch.npu.stream, stream=stream)
123127
else:
124128
stream = None
125129
stream_context = suppress
@@ -139,7 +143,10 @@ def __iter__(self):
139143
first = False
140144

141145
if stream is not None:
142-
torch.cuda.current_stream().wait_stream(stream)
146+
if self.is_cuda:
147+
torch.cuda.current_stream().wait_stream(stream)
148+
elif self.is_npu:
149+
torch.npu.current_stream().wait_stream(stream)
143150

144151
input = next_input
145152
target = next_target

timm/utils/distributed.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def init_distributed_device_so(
116116
"xpu": "ccl",
117117
"hpu": "hccl",
118118
"cuda": "nccl",
119+
"npu": "hccl",
119120
}
120121
dist_backend = dist_backends.get(device_type, 'gloo')
121122
dist_url = dist_url or 'env://'
@@ -159,6 +160,8 @@ def init_distributed_device_so(
159160

160161
if device_type == 'cuda':
161162
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
163+
if device_type == 'npu':
164+
assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.'
162165

163166
if distributed and device != 'cpu':
164167
# Ignore manually specified device index in distributed mode and

train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,8 +1054,11 @@ def _backward(_loss):
10541054
if model_ema is not None:
10551055
model_ema.update(model, step=num_updates)
10561056

1057-
if args.synchronize_step and device.type == 'cuda':
1058-
torch.cuda.synchronize()
1057+
if args.synchronize_step:
1058+
if device.type == 'cuda':
1059+
torch.cuda.synchronize()
1060+
elif device.type == 'npu':
1061+
torch.npu.synchronize()
10591062
time_now = time.time()
10601063
update_time_m.update(time.time() - update_start_time)
10611064
update_start_time = time_now
@@ -1155,6 +1158,8 @@ def validate(
11551158

11561159
if device.type == 'cuda':
11571160
torch.cuda.synchronize()
1161+
elif device.type == "npu":
1162+
torch.npu.synchronize()
11581163

11591164
losses_m.update(reduced_loss.item(), input.size(0))
11601165
top1_m.update(acc1.item(), output.size(0))

validate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,10 @@ def _try_run(args, initial_batch_size):
395395
while batch_size:
396396
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
397397
try:
398-
if torch.cuda.is_available() and 'cuda' in args.device:
398+
if 'cuda' in args.device and torch.cuda.is_available():
399399
torch.cuda.empty_cache()
400+
elif "npu" in args.device and torch.npu.is_available():
401+
torch.npu.empty_cache()
400402
results = validate(args)
401403
return results
402404
except RuntimeError as e:

0 commit comments

Comments
 (0)