Skip to content

Commit ff3323b

Browse files
committed
perform ops on cpu, move back to module device
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 1a70148 commit ff3323b

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
iter_named_leaf_modules,
4848
)
4949
from compressed_tensors.utils import (
50+
align_module_device,
51+
delete_offload_parameter,
5052
get_execution_device,
5153
get_safetensors_folder,
5254
has_offloaded_params,
@@ -386,7 +388,10 @@ def compress_model(self, model: Module):
386388

387389
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
388390
if prefix in module_to_scheme or prefix in sparse_compression_targets:
389-
state_dict = module.state_dict(prefix=f"{prefix}.")
391+
# in the future, support compression on same device
392+
with align_module_device(module, execution_device="cpu"):
393+
state_dict = module.state_dict(prefix=f"{prefix}.")
394+
390395
# quantization first
391396
if prefix in module_to_scheme:
392397
state_dict = self.quantization_compressor.compress(
@@ -433,7 +438,10 @@ def decompress_model(self, model: Module):
433438

434439
for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
435440
if prefix in module_to_scheme or prefix in sparse_compression_targets:
436-
state_dict = module.state_dict(prefix=f"{prefix}.")
441+
# in the future, support decompression on same device
442+
with align_module_device(module, execution_device="cpu"):
443+
state_dict = module.state_dict(prefix=f"{prefix}.")
444+
437445
# sparsity first
438446
if prefix in sparse_compression_targets:
439447
# sparse_compression_targets are automatically inferred by this fn
@@ -461,7 +469,7 @@ def decompress_model(self, model: Module):
461469
# remove any existing parameters
462470
device = get_execution_device(module)
463471
for name, _ in list(module.named_parameters()):
464-
delattr(module, name)
472+
delete_offload_parameter(module, name)
465473

466474
# replace with decompressed parameters
467475
for name, value in state_dict.items():

0 commit comments

Comments
 (0)