diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index e1376957..f79d31c4 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -94,22 +94,6 @@ def is_module_offloaded(module: torch.nn.Module) -> bool: return has_offloaded_params(module) -def get_execution_device(module: torch.nn.Module) -> torch.device: - """ - :param module: module to check - :return: device module is loaded onto during forward pass - """ - if has_offloaded_params(module): - return module._hf_hook.execution_device - device = next(module.parameters()).device - - # offload only gets set for leaf modules, fallback to checking for device type - if device.type == "meta": - return module._hf_hook.execution_device - - return device - - def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ :param module: module to check @@ -158,6 +142,26 @@ def update_parameter_data( """ Candidates for Upstreaming """ +def get_execution_device(module: torch.nn.Module) -> torch.device: + """ + Get the device which inputs should be moved to before module execution + + :param module: module to check, may be offloaded + :return: onload device of module + """ + if has_offloaded_params(module): + return module._hf_hook.execution_device + + first_param = next(module.parameters(), None) + if first_param is None: + warnings.warn( + f"Unable able to infer execution device of {module}, falling back to CPU" + ) + return torch.device("cpu") + + return first_param.device + + def register_offload_parameter( module: torch.nn.Module, name: str, diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index e013b058..7dc1d563 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -17,12 +17,13 @@ align_module_device, delete_offload_parameter, disable_hf_hook, + get_execution_device, has_offloaded_params, register_offload_parameter, update_offload_parameter, ) from compressed_tensors.utils.offload import offload_to_weights_map -from tests.testing_utils import requires_accelerate +from tests.testing_utils import requires_accelerate, requires_gpu class ExampleModule(torch.nn.Module): @@ -55,8 +56,46 @@ def test_has_offloaded_params(): assert has_offloaded_params(module) +@requires_gpu +@requires_accelerate() +def test_get_execution_device(): + from accelerate import init_empty_weights + from accelerate.big_modeling import attach_align_device_hook + + # no offloading + module = ExampleModule() + assert get_execution_device(module) == torch.device("cpu") + + # with offloading + attach_align_device_hook(module, torch.device("cuda:0")) + assert get_execution_device(module) == torch.device("cuda:0") + + # in meta context + with torch.device("meta"): + module = ExampleModule() + assert get_execution_device(module) == torch.device("meta") + + # offloaded in meta context + module = ExampleModule() + attach_align_device_hook(module, torch.device("cuda:0")) + with torch.device("meta"): + assert get_execution_device(module) == torch.device("cuda:0") + + # in empty weights context + with init_empty_weights(): + module = ExampleModule() + assert get_execution_device(module) == torch.device("meta") + + # offloaded in empty weights context + module = ExampleModule() + attach_align_device_hook(module, torch.device("cuda:0")) + with init_empty_weights(): + assert get_execution_device(module) == torch.device("cuda:0") + + @requires_accelerate() def test_register_offload_parameter(): + from accelerate import init_empty_weights from accelerate.hooks import attach_align_device_hook module = ExampleModule() @@ -94,6 +133,12 @@ def test_register_offload_parameter(): assert module.f.device == torch.device("cpu") assert module._hf_hook.weights_map["f"].device == torch.device("cpu") + # parameters registered in the empty init context are still empty + with init_empty_weights(): + module = ExampleModule() + register_offload_parameter(module, "c", parameter) + assert module.a.device == module.b.device == module.c.device == torch.device("meta") + @requires_accelerate() def test_update_offload_parameter():