85
85
"delete_offload_module" ,
86
86
"offloaded_dispatch" ,
87
87
"disable_offloading" ,
88
+ "remove_dispatch" ,
88
89
]
89
90
90
91
@@ -514,6 +515,9 @@ def offloaded_dispatch(
514
515
if offload_device == "disk" :
515
516
raise NotImplementedError ("Disk offloading is not currently supported" )
516
517
518
+ # remove any existing hooks
519
+ remove_dispatch (module )
520
+
517
521
# create weights map
518
522
state_dict = module .state_dict ()
519
523
state_dict = {key : val .to (offload_device ) for key , val in state_dict .items ()}
@@ -535,6 +539,33 @@ def offloaded_dispatch(
535
539
weights_map = weights_map ,
536
540
tied_params_map = tied_params_map ,
537
541
)
542
+
543
+ # when saving a model, `PretrainedModel.save_pretrained` will only
544
+ # onload weights if the following requirements are met
545
+ # if (
546
+ # hasattr(self, "hf_device_map")
547
+ # and len(set(self.hf_device_map.values())) > 1
548
+ # and ("cpu" in self.hf_device_map.values()
549
+ # or "disk" in self.hf_device_map.values())
550
+ # ):
551
+ # because this function always offloads, disregard actual devices and
552
+ # always use `cpu` and `cuda:0` to guarantee this condition passes
553
+ setattr (module , "hf_device_map" , {"fake_offload" : "cpu" , "fake_exec" : "cuda:0" })
554
+
555
+ return module
556
+
557
+
558
+ def remove_dispatch (module : torch .nn .Module ) -> torch .nn .Module :
559
+ """
560
+ Remove any existing dispatches from module
561
+
562
+ :param module: module which may be dispatched with hf hooks
563
+ :return: module without dispatch
564
+ """
565
+ remove_hook_from_module (module , recurse = True )
566
+ if hasattr (module , "hf_device_map" ):
567
+ delattr (module , "hf_device_map" )
568
+
538
569
return module
539
570
540
571
0 commit comments