Skip to content

cuda device mismatch in DataParallel when not using cuda:0 #60

@janfb

Description

@janfb

Hi there, thanks for this package, it's really helpful!

On a cluster with multiple GPUs, I have my model on device cuda:1.

When calculating FID with a passed gen function, new samples are generated during FID calculation. To that end, a model_fn(x) function is defined here:

if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x): return model(x)

and if use_dataparallel=True, the model will be wrapped with model = torch.nn.DataParallel(model).

Problem: DataParallel has a kwarg device_ids=None which defaults to all the available devices and then selects the first device as the "source" device, i.e., cuda:0. Later it asserts that all parameters and buffers of the model are on that device.
Now, if device_ids is not passed, this will result in an error because my model device is different from cuda:0.
I am wondering why DataParallel just hard codes everything to the first of all available devices, but there is a solution on the cleanfid side for this problem.

Solution: pass device_ids with the device of the model:

        if use_dataparallel:
            device_ids = [torch.cuda.current_device()]  # or use next(model.parameters()).device
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        def model_fn(x): return model(x)

I would be happy to make a PR fixing this. Unless I am missing something?

Cheers,
Jan

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions