Skip to content

Commit 8775b29

Browse files
authored
[Bugfix] Fix saving of models dispatched by offloaded_dispatch (#357)
* add hf map Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * ensure offload device is cpu for now Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove existing hooks Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * support gpu offloading Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * harden device map Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add remove_dispatch util function Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add to export Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 0ffe3c3 commit 8775b29

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
"delete_offload_module",
8686
"offloaded_dispatch",
8787
"disable_offloading",
88+
"remove_dispatch",
8889
]
8990

9091

@@ -514,6 +515,9 @@ def offloaded_dispatch(
514515
if offload_device == "disk":
515516
raise NotImplementedError("Disk offloading is not currently supported")
516517

518+
# remove any existing hooks
519+
remove_dispatch(module)
520+
517521
# create weights map
518522
state_dict = module.state_dict()
519523
state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
@@ -535,6 +539,33 @@ def offloaded_dispatch(
535539
weights_map=weights_map,
536540
tied_params_map=tied_params_map,
537541
)
542+
543+
# when saving a model, `PretrainedModel.save_pretrained` will only
544+
# onload weights if the following requirements are met
545+
# if (
546+
# hasattr(self, "hf_device_map")
547+
# and len(set(self.hf_device_map.values())) > 1
548+
# and ("cpu" in self.hf_device_map.values()
549+
# or "disk" in self.hf_device_map.values())
550+
# ):
551+
# because this function always offloads, disregard actual devices and
552+
# always use `cpu` and `cuda:0` to guarantee this condition passes
553+
setattr(module, "hf_device_map", {"fake_offload": "cpu", "fake_exec": "cuda:0"})
554+
555+
return module
556+
557+
558+
def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module:
559+
"""
560+
Remove any existing dispatches from module
561+
562+
:param module: module which may be dispatched with hf hooks
563+
:return: module without dispatch
564+
"""
565+
remove_hook_from_module(module, recurse=True)
566+
if hasattr(module, "hf_device_map"):
567+
delattr(module, "hf_device_map")
568+
538569
return module
539570

540571

0 commit comments

Comments
 (0)