Skip to content

Commit 37c731c

Browse files
committed
fix device check
1 parent 234f975 commit 37c731c

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

timm/data/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def __init__(
113113
)
114114
else:
115115
self.random_erasing = None
116-
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
117-
self.is_npu = torch.npu.is_available() and device.type == 'npu'
116+
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
117+
self.is_npu = device.type == 'npu' and torch.npu.is_available()
118118

119119
def __iter__(self):
120120
first = True

validate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,9 @@ def _try_run(args, initial_batch_size):
395395
while batch_size:
396396
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
397397
try:
398-
if torch.cuda.is_available() and 'cuda' in args.device:
398+
if 'cuda' in args.device and torch.cuda.is_available():
399399
torch.cuda.empty_cache()
400-
elif torch.npu.is_available() and "npu" in args.device:
400+
elif "npu" in args.device and torch.npu.is_available():
401401
torch.npu.empty_cache()
402402
results = validate(args)
403403
return results

0 commit comments

Comments
 (0)