|
47 | 47 | iter_named_leaf_modules,
|
48 | 48 | )
|
49 | 49 | from compressed_tensors.utils import (
|
| 50 | + align_module_device, |
| 51 | + delete_offload_parameter, |
50 | 52 | get_execution_device,
|
51 | 53 | get_safetensors_folder,
|
52 | 54 | has_offloaded_params,
|
@@ -386,7 +388,10 @@ def compress_model(self, model: Module):
|
386 | 388 |
|
387 | 389 | for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
|
388 | 390 | if prefix in module_to_scheme or prefix in sparse_compression_targets:
|
389 |
| - state_dict = module.state_dict(prefix=f"{prefix}.") |
| 391 | + # in the future, support compression on same device |
| 392 | + with align_module_device(module, execution_device="cpu"): |
| 393 | + state_dict = module.state_dict(prefix=f"{prefix}.") |
| 394 | + |
390 | 395 | # quantization first
|
391 | 396 | if prefix in module_to_scheme:
|
392 | 397 | state_dict = self.quantization_compressor.compress(
|
@@ -433,7 +438,10 @@ def decompress_model(self, model: Module):
|
433 | 438 |
|
434 | 439 | for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
|
435 | 440 | if prefix in module_to_scheme or prefix in sparse_compression_targets:
|
436 |
| - state_dict = module.state_dict(prefix=f"{prefix}.") |
| 441 | + # in the future, support decompression on same device |
| 442 | + with align_module_device(module, execution_device="cpu"): |
| 443 | + state_dict = module.state_dict(prefix=f"{prefix}.") |
| 444 | + |
437 | 445 | # sparsity first
|
438 | 446 | if prefix in sparse_compression_targets:
|
439 | 447 | # sparse_compression_targets are automatically inferred by this fn
|
@@ -461,7 +469,7 @@ def decompress_model(self, model: Module):
|
461 | 469 | # remove any existing parameters
|
462 | 470 | device = get_execution_device(module)
|
463 | 471 | for name, _ in list(module.named_parameters()):
|
464 |
| - delattr(module, name) |
| 472 | + delete_offload_parameter(module, name) |
465 | 473 |
|
466 | 474 | # replace with decompressed parameters
|
467 | 475 | for name, value in state_dict.items():
|
|
0 commit comments