-
Notifications
You must be signed in to change notification settings - Fork 176
Open
Description
def multi_gpu(model):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device != torch.device('cpu') and torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
print(f'Using {torch.cuda.device_count()} GPUs!')
elif device == torch.device('cuda'):
print(f'Using 1 GPU!')
else:
print('Using CPU!')
model.to(device)
return model, device
I do by this.
Metadata
Metadata
Assignees
Labels
No labels