File tree Expand file tree Collapse file tree 4 files changed +22
-5
lines changed Expand file tree Collapse file tree 4 files changed +22
-5
lines changed Original file line number Diff line number Diff line change @@ -113,13 +113,17 @@ def __init__(
113
113
)
114
114
else :
115
115
self .random_erasing = None
116
- self .is_cuda = torch .cuda .is_available () and device .type == 'cuda'
116
+ self .is_cuda = device .type == 'cuda' and torch .cuda .is_available ()
117
+ self .is_npu = device .type == 'npu' and torch .npu .is_available ()
117
118
118
119
def __iter__ (self ):
119
120
first = True
120
121
if self .is_cuda :
121
122
stream = torch .cuda .Stream ()
122
123
stream_context = partial (torch .cuda .stream , stream = stream )
124
+ elif self .is_npu :
125
+ stream = torch .npu .Stream ()
126
+ stream_context = partial (torch .npu .stream , stream = stream )
123
127
else :
124
128
stream = None
125
129
stream_context = suppress
@@ -139,7 +143,10 @@ def __iter__(self):
139
143
first = False
140
144
141
145
if stream is not None :
142
- torch .cuda .current_stream ().wait_stream (stream )
146
+ if self .is_cuda :
147
+ torch .cuda .current_stream ().wait_stream (stream )
148
+ elif self .is_npu :
149
+ torch .npu .current_stream ().wait_stream (stream )
143
150
144
151
input = next_input
145
152
target = next_target
Original file line number Diff line number Diff line change @@ -116,6 +116,7 @@ def init_distributed_device_so(
116
116
"xpu" : "ccl" ,
117
117
"hpu" : "hccl" ,
118
118
"cuda" : "nccl" ,
119
+ "npu" : "hccl" ,
119
120
}
120
121
dist_backend = dist_backends .get (device_type , 'gloo' )
121
122
dist_url = dist_url or 'env://'
@@ -159,6 +160,8 @@ def init_distributed_device_so(
159
160
160
161
if device_type == 'cuda' :
161
162
assert torch .cuda .is_available (), f'CUDA is not available but { device } was specified.'
163
+ if device_type == 'npu' :
164
+ assert torch .npu .is_available (), f'Ascend NPU is not available but { device } was specified.'
162
165
163
166
if distributed and device != 'cpu' :
164
167
# Ignore manually specified device index in distributed mode and
Original file line number Diff line number Diff line change @@ -1042,8 +1042,11 @@ def _backward(_loss):
1042
1042
if model_ema is not None :
1043
1043
model_ema .update (model , step = num_updates )
1044
1044
1045
- if args .synchronize_step and device .type == 'cuda' :
1046
- torch .cuda .synchronize ()
1045
+ if args .synchronize_step :
1046
+ if device .type == 'cuda' :
1047
+ torch .cuda .synchronize ()
1048
+ elif device .type == 'npu' :
1049
+ torch .npu .synchronize ()
1047
1050
time_now = time .time ()
1048
1051
update_time_m .update (time .time () - update_start_time )
1049
1052
update_start_time = time_now
@@ -1143,6 +1146,8 @@ def validate(
1143
1146
1144
1147
if device .type == 'cuda' :
1145
1148
torch .cuda .synchronize ()
1149
+ elif device .type == "npu" :
1150
+ torch .npu .synchronize ()
1146
1151
1147
1152
losses_m .update (reduced_loss .item (), input .size (0 ))
1148
1153
top1_m .update (acc1 .item (), output .size (0 ))
Original file line number Diff line number Diff line change @@ -387,8 +387,10 @@ def _try_run(args, initial_batch_size):
387
387
while batch_size :
388
388
args .batch_size = batch_size * args .num_gpu # multiply by num-gpu for DataParallel case
389
389
try :
390
- if torch . cuda . is_available () and 'cuda' in args .device :
390
+ if 'cuda' in args .device and torch . cuda . is_available () :
391
391
torch .cuda .empty_cache ()
392
+ elif "npu" in args .device and torch .npu .is_available ():
393
+ torch .npu .empty_cache ()
392
394
results = validate (args )
393
395
return results
394
396
except RuntimeError as e :
You can’t perform that action at this time.
0 commit comments