diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index b74a9a81..74bd1429 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -78,7 +78,6 @@ "delete_offload_parameter", "has_offloaded_params", "disable_hf_hook", - "disable_offload", "align_modules", "align_module_device", "register_offload_module", @@ -386,23 +385,6 @@ def delete_from_weights_map( ) -@check_accelerate(fallback=contextlib.nullcontext()) -@contextlib.contextmanager -def disable_offload(module: torch.nn.Module): - """ - Context manager to disable module onloading and offloading. Parameters will stay on - their current device - - :param module: module to disable offloading for - """ - if has_offloaded_params(module): - module._hf_hook.offload = False - yield - module._hf_hook.offload = True - else: - yield - - @check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager def align_modules( @@ -419,9 +401,9 @@ def align_modules( modules = (modules,) if isinstance(modules, torch.nn.Module) else modules with contextlib.ExitStack() as stack: + stack.enter_context(disable_offloading()) for module in modules: stack.enter_context(align_module_device(module, execution_device)) - stack.enter_context(disable_offload(module)) # disable redundant onloading yield diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index d69eccc8..a9cec10e 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -305,6 +305,7 @@ def test_align_modules(): module0 = ExampleModule() module1 = ExampleModule() module2 = ExampleModule() + new_data = torch.tensor(1.0) model = torch.nn.Sequential(module0, torch.nn.Sequential(module1, module2)) attach_align_device_hook( model, @@ -318,13 +319,16 @@ def test_align_modules(): assert module2.a.device == torch.device("meta") with align_modules((module0, module1)): - assert module0.a.device != torch.device("meta") - assert module1.a.device != torch.device("meta") + assert module0.a.device == torch.device("cpu") + assert module1.a.device == torch.device("cpu") assert module2.a.device == torch.device("meta") + update_offload_parameter(module0, "a", new_data) + assert module0.a == new_data assert module0.a.device == torch.device("meta") assert module1.a.device == torch.device("meta") assert module2.a.device == torch.device("meta") + assert module0._hf_hook.weights_map["a"] == new_data @requires_accelerate()