Skip to content

Commit 71820e1

Browse files
committed
explicitly call gc collect
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 487885c commit 71820e1

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ def save_pretrained_compressed(save_pretrained_method):
5656
model_class = model_ref().__class__
5757
del save_pretrained_method
5858

59-
# hotfix: create a weak reference to the model to avoid circular dep
60-
# TODO: determine why circular dep is not collected and how to clean up this fn
61-
model_ref = weakref.ref(model)
62-
6359
@wraps(original_save_pretrained)
6460
def save_pretrained_wrapper(
6561
save_directory: str,
@@ -99,11 +95,11 @@ def save_pretrained_wrapper(
9995
state_dict = kwargs.pop("state_dict", None)
10096
if state_dict is None:
10197
logger.info("Fetching state_dict - this may take some time")
102-
state_dict = get_state_dict_offloaded_model(model_ref())
98+
state_dict = get_state_dict_offloaded_model(model)
10399

104100
logger.info("Fetching compressor")
105101
compressor = get_model_compressor(
106-
model=model_ref(),
102+
model=model,
107103
sparsity_config=sparsity_config,
108104
quantization_format=quantization_format,
109105
save_compressed=save_compressed,
@@ -115,7 +111,7 @@ def save_pretrained_wrapper(
115111
if compressor is None:
116112
# model is not compressed or quantized, save as normal
117113
original_save_pretrained_func = original_save_pretrained.__get__(
118-
model_ref(), model_class
114+
model, model_class
119115
)
120116
original_save_pretrained_func(
121117
save_directory, state_dict=state_dict, **kwargs
@@ -125,10 +121,10 @@ def save_pretrained_wrapper(
125121
# make sure we're on the main process when saving
126122
if state_dict is not None and len(state_dict) > 0:
127123
compressed_state_dict = compressor.compress(
128-
model_ref(), state_dict, show_progress=True
124+
model, state_dict, show_progress=True
129125
)
130126
logger.info("Saving compressed model to disk")
131-
original_save_pretrained.__get__(model_ref(), model_class)(
127+
original_save_pretrained.__get__(model, model_class)(
132128
save_directory,
133129
state_dict=compressed_state_dict,
134130
safe_serialization=safe_serialization,
@@ -137,10 +133,10 @@ def save_pretrained_wrapper(
137133
compressor.update_config(save_directory)
138134

139135
# update existing recipe
140-
update_and_save_recipe(model_ref().name_or_path, save_directory)
136+
update_and_save_recipe(model.name_or_path, save_directory)
141137

142138
# copy python files from cache dir to save_path if any
143-
copy_python_files_from_model_cache(model_ref(), save_directory)
139+
copy_python_files_from_model_cache(model, save_directory)
144140

145141
save_pretrained_wrapper._overridden = True
146142
return save_pretrained_wrapper

tests/llmcompressor/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import os
23
import shutil
34
import tempfile
@@ -80,7 +81,7 @@ def check_for_created_files():
8081

8182

8283
@pytest.fixture(autouse=True, scope="function")
83-
def setup_fresh_session():
84+
def llm_compressor_setup_teardown():
8485
"""
8586
setup any state tied to the execution of the given method in a
8687
class. setup_method is invoked for every test method of a class.
@@ -92,3 +93,6 @@ def setup_fresh_session():
9293
yield
9394
# reset the session after each test
9495
reset_session()
96+
# explictly collect memory to catch memory bugs,
97+
# see https://github.com/vllm-project/llm-compressor/pull/1503
98+
gc.collect()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
cadence: "nightly"
2+
test_type: "regression"
3+
model: "meta-llama/Llama-2-7b-hf"
4+
dataset: open_platypus
5+
recipe: "tests/llmcompressor/transformers/obcq/recipes/sparse.yaml"
6+
sparsity: 0.3
7+
device: "cuda:0"

0 commit comments

Comments
 (0)