-
Notifications
You must be signed in to change notification settings - Fork 76
Description
Hi there, thanks for this package, it's really helpful!
On a cluster with multiple GPUs, I have my model on device cuda:1
.
When calculating FID with a passed gen
function, new samples are generated during FID calculation. To that end, a model_fn(x)
function is defined here:
clean-fid/cleanfid/features.py
Lines 23 to 25 in bd44693
if use_dataparallel: | |
model = torch.nn.DataParallel(model) | |
def model_fn(x): return model(x) |
and if use_dataparallel=True
, the model will be wrapped with model = torch.nn.DataParallel(model)
.
Problem: DataParallel
has a kwarg device_ids=None
which defaults to all the available devices and then selects the first device as the "source" device, i.e., cuda:0
. Later it asserts that all parameters and buffers of the model are on that device.
Now, if device_ids is not passed, this will result in an error because my model device is different from cuda:0
.
I am wondering why DataParallel
just hard codes everything to the first of all available devices, but there is a solution on the cleanfid
side for this problem.
Solution: pass device_ids with the device of the model:
if use_dataparallel:
device_ids = [torch.cuda.current_device()] # or use next(model.parameters()).device
model = torch.nn.DataParallel(model, device_ids=device_ids)
def model_fn(x): return model(x)
I would be happy to make a PR fixing this. Unless I am missing something?
Cheers,
Jan