diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index c5c10845..3637ecd3 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -400,7 +400,10 @@ def compress_model(self, model: Module): # in the future, support compression on same device with align_module_device(module, execution_device=exec_device): - state_dict = module.state_dict(prefix=f"{prefix}.") + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } # quantization first if prefix in module_to_scheme: @@ -421,7 +424,7 @@ def compress_model(self, model: Module): # remove any existing parameters offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters()): + for name, _ in list(module.named_parameters(recurse=False)): delete_offload_parameter(module, name) # replace with compressed parameters @@ -458,7 +461,10 @@ def decompress_model(self, model: Module): if prefix in module_to_scheme or prefix in sparse_compression_targets: # in the future, support decompression on same device with align_module_device(module, execution_device="cpu"): - state_dict = module.state_dict(prefix=f"{prefix}.") + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } # sparsity first if prefix in sparse_compression_targets: @@ -483,7 +489,7 @@ def decompress_model(self, model: Module): # remove any existing parameters exec_device = get_execution_device(module) offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters()): + for name, _ in list(module.named_parameters(recurse=False)): delete_offload_parameter(module, name) # replace with decompressed parameters