diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 4832e3b7f..9c28a3166 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -11,7 +11,7 @@ CompressionFormat, ModelCompressor, SparsityCompressionConfig, - is_module_offloaded, + has_offloaded_params, update_offload_parameter, ) from loguru import logger @@ -162,7 +162,7 @@ def patch_tied_tensors_bug(model: torch.nn.Module): if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight): for module in (input_embed, output_embed): - if not is_module_offloaded(module): + if not has_offloaded_params(module): # create new storage ptr for onloaded weight untied_data = module.weight.data.clone() module.weight.data = untied_data