From 5856c277a90b0231e3c93bdbb9f881286c3efc41 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Jun 2025 16:43:05 -0400 Subject: [PATCH 1/4] support models Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index ac156a63..814209c0 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -123,14 +123,24 @@ def is_module_offloaded(module: torch.nn.Module) -> bool: def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ + Get the offload device of this module. If this module has multiple offloaded + parameters, return the first one. If this module is not offloaded, return the + device of the first parameter + :param module: module to check - :return: device module is offloaded to onto after forward pass + :return: offload device of module """ - if has_offloaded_params(module): - first_key = list(module._hf_hook.weights_map.keys())[0] - prefix_dataset = module._hf_hook.weights_map.dataset - return prefix_dataset[first_key].device - return next(module.parameters()).device + for submodule in module.modules(): + name, param = next(submodule.named_parameters(recurse=False), (None, None)) + if has_offloaded_params(submodule): + assert name is not None + return submodule._hf_hook.weights_map[name].device + + elif param is not None: + return param.device + + warnings.warn(f"Cannot get offload device of {module}, falling back to CPU") + return torch.device("cpu") @check_accelerate(fallback=None) From f4d8cfef1469eb04f90b0847dd011b6a2e1e4b0e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Jun 2025 16:58:31 -0400 Subject: [PATCH 2/4] clarify docstring Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 814209c0..c26d69c7 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -123,9 +123,11 @@ def is_module_offloaded(module: torch.nn.Module) -> bool: def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ - Get the offload device of this module. If this module has multiple offloaded - parameters, return the first one. If this module is not offloaded, return the - device of the first parameter + Get the offload device of this module. This is determined through analysis of the + first module with parameters returned by `module.modules()`. + + If this module is offloaded, return its offload device. If this module not + offloaded, return the parameter's device. :param module: module to check :return: offload device of module From 0e77936396a9038783f75832e93b9deba9582348 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Jun 2025 17:38:03 -0400 Subject: [PATCH 3/4] add tests Signed-off-by: Kyle Sayers --- tests/test_utils/test_offload.py | 70 ++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 1fce49b3..782124ea 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -21,6 +21,7 @@ disable_hf_hook, disable_offloading, get_execution_device, + get_offloaded_device, has_offloaded_params, offloaded_dispatch, register_offload_module, @@ -121,6 +122,75 @@ def forward(self, x): assert get_execution_device(model) == torch.device("cuda:0") +def test_get_offloaded_device(): + from accelerate import init_empty_weights + + # no offloading + module = ExampleModule() + assert get_offloaded_device(module) == torch.device("cpu") + + # with offloading + offloaded_dispatch( + module, + execution_device=torch.device("cpu"), + offload_device=torch.device("cuda:0"), + ) + assert get_offloaded_device(module) == torch.device("cuda:0") + + # in meta context + with torch.device("meta"): + module = ExampleModule() + assert get_offloaded_device(module) == torch.device("meta") + + # offloaded in meta context + module = ExampleModule() + offloaded_dispatch( + module, + execution_device=torch.device("cpu"), + offload_device=torch.device("cuda:0"), + ) + with torch.device("meta"): + assert get_offloaded_device(module) == torch.device("cuda:0") + + # in empty weights context + with init_empty_weights(): + module = ExampleModule() + assert get_offloaded_device(module) == torch.device("meta") + + # offloaded in empty weights context + module = ExampleModule() + offloaded_dispatch( + module, + execution_device=torch.device("cpu"), + offload_device=torch.device("cuda:0"), + ) + with init_empty_weights(): + assert get_offloaded_device(module) == torch.device("cuda:0") + + +@requires_gpu +@requires_accelerate() +def test_get_execution_device_model(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Linear(1, 2) + self.b = torch.nn.Linear(2, 2, device="cuda:0") + + def forward(self, x): + return self.b(self.a(x).to("cuda:0")) + + model = Model() + assert get_offloaded_device(model) == torch.device("cpu") + + offloaded_dispatch( + model.a, + execution_device=torch.device("cpu"), + offload_device=torch.device("cuda:0"), + ) + assert get_offloaded_device(model) == torch.device("cuda:0") + + @requires_accelerate() def test_register_offload_parameter(): from accelerate import init_empty_weights From b6f77f34dab37c96b5869bd08b4aafb1cb4d6429 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Jun 2025 18:00:19 -0400 Subject: [PATCH 4/4] fix testing marks Signed-off-by: Kyle Sayers --- tests/test_utils/test_offload.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 782124ea..d3dcc65c 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -122,6 +122,8 @@ def forward(self, x): assert get_execution_device(model) == torch.device("cuda:0") +@requires_gpu +@requires_accelerate() def test_get_offloaded_device(): from accelerate import init_empty_weights