Skip to content

Commit 487885c

Browse files
committed
use weakref to model
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent db1ebc7 commit 487885c

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def save_pretrained_compressed(save_pretrained_method):
5656
model_class = model_ref().__class__
5757
del save_pretrained_method
5858

59+
# hotfix: create a weak reference to the model to avoid circular dep
60+
# TODO: determine why circular dep is not collected and how to clean up this fn
61+
model_ref = weakref.ref(model)
62+
5963
@wraps(original_save_pretrained)
6064
def save_pretrained_wrapper(
6165
save_directory: str,
@@ -95,11 +99,11 @@ def save_pretrained_wrapper(
9599
state_dict = kwargs.pop("state_dict", None)
96100
if state_dict is None:
97101
logger.info("Fetching state_dict - this may take some time")
98-
state_dict = get_state_dict_offloaded_model(model)
102+
state_dict = get_state_dict_offloaded_model(model_ref())
99103

100104
logger.info("Fetching compressor")
101105
compressor = get_model_compressor(
102-
model=model,
106+
model=model_ref(),
103107
sparsity_config=sparsity_config,
104108
quantization_format=quantization_format,
105109
save_compressed=save_compressed,
@@ -111,7 +115,7 @@ def save_pretrained_wrapper(
111115
if compressor is None:
112116
# model is not compressed or quantized, save as normal
113117
original_save_pretrained_func = original_save_pretrained.__get__(
114-
model, model_class
118+
model_ref(), model_class
115119
)
116120
original_save_pretrained_func(
117121
save_directory, state_dict=state_dict, **kwargs
@@ -121,10 +125,10 @@ def save_pretrained_wrapper(
121125
# make sure we're on the main process when saving
122126
if state_dict is not None and len(state_dict) > 0:
123127
compressed_state_dict = compressor.compress(
124-
model, state_dict, show_progress=True
128+
model_ref(), state_dict, show_progress=True
125129
)
126130
logger.info("Saving compressed model to disk")
127-
original_save_pretrained.__get__(model, model_class)(
131+
original_save_pretrained.__get__(model_ref(), model_class)(
128132
save_directory,
129133
state_dict=compressed_state_dict,
130134
safe_serialization=safe_serialization,
@@ -133,10 +137,10 @@ def save_pretrained_wrapper(
133137
compressor.update_config(save_directory)
134138

135139
# update existing recipe
136-
update_and_save_recipe(model.name_or_path, save_directory)
140+
update_and_save_recipe(model_ref().name_or_path, save_directory)
137141

138142
# copy python files from cache dir to save_path if any
139-
copy_python_files_from_model_cache(model, save_directory)
143+
copy_python_files_from_model_cache(model_ref(), save_directory)
140144

141145
save_pretrained_wrapper._overridden = True
142146
return save_pretrained_wrapper

0 commit comments

Comments
 (0)