@@ -56,6 +56,10 @@ def save_pretrained_compressed(save_pretrained_method):
56
56
model_class = model_ref ().__class__
57
57
del save_pretrained_method
58
58
59
+ # hotfix: create a weak reference to the model to avoid circular dep
60
+ # TODO: determine why circular dep is not collected and how to clean up this fn
61
+ model_ref = weakref .ref (model )
62
+
59
63
@wraps (original_save_pretrained )
60
64
def save_pretrained_wrapper (
61
65
save_directory : str ,
@@ -95,11 +99,11 @@ def save_pretrained_wrapper(
95
99
state_dict = kwargs .pop ("state_dict" , None )
96
100
if state_dict is None :
97
101
logger .info ("Fetching state_dict - this may take some time" )
98
- state_dict = get_state_dict_offloaded_model (model )
102
+ state_dict = get_state_dict_offloaded_model (model_ref () )
99
103
100
104
logger .info ("Fetching compressor" )
101
105
compressor = get_model_compressor (
102
- model = model ,
106
+ model = model_ref () ,
103
107
sparsity_config = sparsity_config ,
104
108
quantization_format = quantization_format ,
105
109
save_compressed = save_compressed ,
@@ -111,7 +115,7 @@ def save_pretrained_wrapper(
111
115
if compressor is None :
112
116
# model is not compressed or quantized, save as normal
113
117
original_save_pretrained_func = original_save_pretrained .__get__ (
114
- model , model_class
118
+ model_ref () , model_class
115
119
)
116
120
original_save_pretrained_func (
117
121
save_directory , state_dict = state_dict , ** kwargs
@@ -121,10 +125,10 @@ def save_pretrained_wrapper(
121
125
# make sure we're on the main process when saving
122
126
if state_dict is not None and len (state_dict ) > 0 :
123
127
compressed_state_dict = compressor .compress (
124
- model , state_dict , show_progress = True
128
+ model_ref () , state_dict , show_progress = True
125
129
)
126
130
logger .info ("Saving compressed model to disk" )
127
- original_save_pretrained .__get__ (model , model_class )(
131
+ original_save_pretrained .__get__ (model_ref () , model_class )(
128
132
save_directory ,
129
133
state_dict = compressed_state_dict ,
130
134
safe_serialization = safe_serialization ,
@@ -133,10 +137,10 @@ def save_pretrained_wrapper(
133
137
compressor .update_config (save_directory )
134
138
135
139
# update existing recipe
136
- update_and_save_recipe (model .name_or_path , save_directory )
140
+ update_and_save_recipe (model_ref () .name_or_path , save_directory )
137
141
138
142
# copy python files from cache dir to save_path if any
139
- copy_python_files_from_model_cache (model , save_directory )
143
+ copy_python_files_from_model_cache (model_ref () , save_directory )
140
144
141
145
save_pretrained_wrapper ._overridden = True
142
146
return save_pretrained_wrapper
0 commit comments