@@ -56,10 +56,6 @@ def save_pretrained_compressed(save_pretrained_method):
56
56
model_class = model_ref ().__class__
57
57
del save_pretrained_method
58
58
59
- # hotfix: create a weak reference to the model to avoid circular dep
60
- # TODO: determine why circular dep is not collected and how to clean up this fn
61
- model_ref = weakref .ref (model )
62
-
63
59
@wraps (original_save_pretrained )
64
60
def save_pretrained_wrapper (
65
61
save_directory : str ,
@@ -99,11 +95,11 @@ def save_pretrained_wrapper(
99
95
state_dict = kwargs .pop ("state_dict" , None )
100
96
if state_dict is None :
101
97
logger .info ("Fetching state_dict - this may take some time" )
102
- state_dict = get_state_dict_offloaded_model (model_ref () )
98
+ state_dict = get_state_dict_offloaded_model (model )
103
99
104
100
logger .info ("Fetching compressor" )
105
101
compressor = get_model_compressor (
106
- model = model_ref () ,
102
+ model = model ,
107
103
sparsity_config = sparsity_config ,
108
104
quantization_format = quantization_format ,
109
105
save_compressed = save_compressed ,
@@ -115,7 +111,7 @@ def save_pretrained_wrapper(
115
111
if compressor is None :
116
112
# model is not compressed or quantized, save as normal
117
113
original_save_pretrained_func = original_save_pretrained .__get__ (
118
- model_ref () , model_class
114
+ model , model_class
119
115
)
120
116
original_save_pretrained_func (
121
117
save_directory , state_dict = state_dict , ** kwargs
@@ -125,10 +121,10 @@ def save_pretrained_wrapper(
125
121
# make sure we're on the main process when saving
126
122
if state_dict is not None and len (state_dict ) > 0 :
127
123
compressed_state_dict = compressor .compress (
128
- model_ref () , state_dict , show_progress = True
124
+ model , state_dict , show_progress = True
129
125
)
130
126
logger .info ("Saving compressed model to disk" )
131
- original_save_pretrained .__get__ (model_ref () , model_class )(
127
+ original_save_pretrained .__get__ (model , model_class )(
132
128
save_directory ,
133
129
state_dict = compressed_state_dict ,
134
130
safe_serialization = safe_serialization ,
@@ -137,10 +133,10 @@ def save_pretrained_wrapper(
137
133
compressor .update_config (save_directory )
138
134
139
135
# update existing recipe
140
- update_and_save_recipe (model_ref () .name_or_path , save_directory )
136
+ update_and_save_recipe (model .name_or_path , save_directory )
141
137
142
138
# copy python files from cache dir to save_path if any
143
- copy_python_files_from_model_cache (model_ref () , save_directory )
139
+ copy_python_files_from_model_cache (model , save_directory )
144
140
145
141
save_pretrained_wrapper ._overridden = True
146
142
return save_pretrained_wrapper
0 commit comments