Skip to content

[Bugfix] [Tests] Perform explicit garbage collection in between tests #1503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def save_pretrained_compressed(save_pretrained_method):
model_class = model_ref().__class__
del save_pretrained_method

# hotfix: create a weak reference to the model to avoid circular dep
# TODO: determine why circular dep is not collected and how to clean up this fn
model_ref = weakref.ref(model)

@wraps(original_save_pretrained)
def save_pretrained_wrapper(
save_directory: str,
Expand Down Expand Up @@ -95,11 +99,11 @@ def save_pretrained_wrapper(
state_dict = kwargs.pop("state_dict", None)
if state_dict is None:
logger.info("Fetching state_dict - this may take some time")
state_dict = get_state_dict_offloaded_model(model)
state_dict = get_state_dict_offloaded_model(model_ref())

logger.info("Fetching compressor")
compressor = get_model_compressor(
model=model,
model=model_ref(),
sparsity_config=sparsity_config,
quantization_format=quantization_format,
save_compressed=save_compressed,
Expand All @@ -111,7 +115,7 @@ def save_pretrained_wrapper(
if compressor is None:
# model is not compressed or quantized, save as normal
original_save_pretrained_func = original_save_pretrained.__get__(
model, model_class
model_ref(), model_class
)
original_save_pretrained_func(
save_directory, state_dict=state_dict, **kwargs
Expand All @@ -121,10 +125,10 @@ def save_pretrained_wrapper(
# make sure we're on the main process when saving
if state_dict is not None and len(state_dict) > 0:
compressed_state_dict = compressor.compress(
model, state_dict, show_progress=True
model_ref(), state_dict, show_progress=True
)
logger.info("Saving compressed model to disk")
original_save_pretrained.__get__(model, model_class)(
original_save_pretrained.__get__(model_ref(), model_class)(
save_directory,
state_dict=compressed_state_dict,
safe_serialization=safe_serialization,
Expand All @@ -133,10 +137,10 @@ def save_pretrained_wrapper(
compressor.update_config(save_directory)

# update existing recipe
update_and_save_recipe(model.name_or_path, save_directory)
update_and_save_recipe(model_ref().name_or_path, save_directory)

# copy python files from cache dir to save_path if any
copy_python_files_from_model_cache(model, save_directory)
copy_python_files_from_model_cache(model_ref(), save_directory)

save_pretrained_wrapper._overridden = True
return save_pretrained_wrapper
Expand Down