diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index ac156a63..c26d69c7 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -123,14 +123,26 @@ def is_module_offloaded(module: torch.nn.Module) -> bool: def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ + Get the offload device of this module. This is determined through analysis of the + first module with parameters returned by `module.modules()`. + + If this module is offloaded, return its offload device. If this module not + offloaded, return the parameter's device. + :param module: module to check - :return: device module is offloaded to onto after forward pass + :return: offload device of module """ - if has_offloaded_params(module): - first_key = list(module._hf_hook.weights_map.keys())[0] - prefix_dataset = module._hf_hook.weights_map.dataset - return prefix_dataset[first_key].device - return next(module.parameters()).device + for submodule in module.modules(): + name, param = next(submodule.named_parameters(recurse=False), (None, None)) + if has_offloaded_params(submodule): + assert name is not None + return submodule._hf_hook.weights_map[name].device + + elif param is not None: + return param.device + + warnings.warn(f"Cannot get offload device of {module}, falling back to CPU") + return torch.device("cpu") @check_accelerate(fallback=None) diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 1fce49b3..d3dcc65c 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -21,6 +21,7 @@ disable_hf_hook, disable_offloading, get_execution_device, + get_offloaded_device, has_offloaded_params, offloaded_dispatch, register_offload_module, @@ -121,6 +122,77 @@ def forward(self, x): assert get_execution_device(model) == torch.device("cuda:0") +@requires_gpu +@requires_accelerate() +def test_get_offloaded_device(): + from accelerate import init_empty_weights + + # no offloading + module = ExampleModule() + assert get_offloaded_device(module) == torch.device("cpu") + + # with offloading + offloaded_dispatch( + module, + execution_device=torch.device("cpu"), + offload_device=torch.device("cuda:0"), + ) + assert get_offloaded_device(module) == torch.device("cuda:0") + + # in meta context + with torch.device("meta"): + module = ExampleModule() + assert get_offloaded_device(module) == torch.device("meta") + + # offloaded in meta context + module = ExampleModule() + offloaded_dispatch( + module, + execution_device=torch.device("cpu"), + offload_device=torch.device("cuda:0"), + ) + with torch.device("meta"): + assert get_offloaded_device(module) == torch.device("cuda:0") + + # in empty weights context + with init_empty_weights(): + module = ExampleModule() + assert get_offloaded_device(module) == torch.device("meta") + + # offloaded in empty weights context + module = ExampleModule() + offloaded_dispatch( + module, + execution_device=torch.device("cpu"), + offload_device=torch.device("cuda:0"), + ) + with init_empty_weights(): + assert get_offloaded_device(module) == torch.device("cuda:0") + + +@requires_gpu +@requires_accelerate() +def test_get_execution_device_model(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Linear(1, 2) + self.b = torch.nn.Linear(2, 2, device="cuda:0") + + def forward(self, x): + return self.b(self.a(x).to("cuda:0")) + + model = Model() + assert get_offloaded_device(model) == torch.device("cpu") + + offloaded_dispatch( + model.a, + execution_device=torch.device("cpu"), + offload_device=torch.device("cuda:0"), + ) + assert get_offloaded_device(model) == torch.device("cuda:0") + + @requires_accelerate() def test_register_offload_parameter(): from accelerate import init_empty_weights