Skip to content

Commit 17a746c

Browse files
authored
[Bugfix] Safeguard against submodule parameter deletion in decompress_model (#347)
* safeguard against removing parameters not immediately on the module Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add safeguard to compress method Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * safeguard against state dict creation Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 7e0dc32 commit 17a746c

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,10 @@ def compress_model(self, model: Module):
400400

401401
# in the future, support compression on same device
402402
with align_module_device(module, execution_device=exec_device):
403-
state_dict = module.state_dict(prefix=f"{prefix}.")
403+
state_dict = {
404+
f"{prefix}.{name}": param
405+
for name, param in module.named_parameters(recurse=False)
406+
}
404407

405408
# quantization first
406409
if prefix in module_to_scheme:
@@ -421,7 +424,7 @@ def compress_model(self, model: Module):
421424

422425
# remove any existing parameters
423426
offload_device = get_offloaded_device(module)
424-
for name, _ in list(module.named_parameters()):
427+
for name, _ in list(module.named_parameters(recurse=False)):
425428
delete_offload_parameter(module, name)
426429

427430
# replace with compressed parameters
@@ -458,7 +461,10 @@ def decompress_model(self, model: Module):
458461
if prefix in module_to_scheme or prefix in sparse_compression_targets:
459462
# in the future, support decompression on same device
460463
with align_module_device(module, execution_device="cpu"):
461-
state_dict = module.state_dict(prefix=f"{prefix}.")
464+
state_dict = {
465+
f"{prefix}.{name}": param
466+
for name, param in module.named_parameters(recurse=False)
467+
}
462468

463469
# sparsity first
464470
if prefix in sparse_compression_targets:
@@ -483,7 +489,7 @@ def decompress_model(self, model: Module):
483489
# remove any existing parameters
484490
exec_device = get_execution_device(module)
485491
offload_device = get_offloaded_device(module)
486-
for name, _ in list(module.named_parameters()):
492+
for name, _ in list(module.named_parameters(recurse=False)):
487493
delete_offload_parameter(module, name)
488494

489495
# replace with decompressed parameters

0 commit comments

Comments
 (0)