|
47 | 47 | iter_named_leaf_modules, |
48 | 48 | ) |
49 | 49 | from compressed_tensors.utils import ( |
| 50 | + align_module_device, |
| 51 | + delete_offload_parameter, |
50 | 52 | get_execution_device, |
51 | 53 | get_safetensors_folder, |
52 | 54 | has_offloaded_params, |
@@ -386,7 +388,10 @@ def compress_model(self, model: Module): |
386 | 388 |
|
387 | 389 | for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): |
388 | 390 | if prefix in module_to_scheme or prefix in sparse_compression_targets: |
389 | | - state_dict = module.state_dict(prefix=f"{prefix}.") |
| 391 | + # in the future, support compression on same device |
| 392 | + with align_module_device(module, execution_device="cpu"): |
| 393 | + state_dict = module.state_dict(prefix=f"{prefix}.") |
| 394 | + |
390 | 395 | # quantization first |
391 | 396 | if prefix in module_to_scheme: |
392 | 397 | state_dict = self.quantization_compressor.compress( |
@@ -433,7 +438,10 @@ def decompress_model(self, model: Module): |
433 | 438 |
|
434 | 439 | for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"): |
435 | 440 | if prefix in module_to_scheme or prefix in sparse_compression_targets: |
436 | | - state_dict = module.state_dict(prefix=f"{prefix}.") |
| 441 | + # in the future, support decompression on same device |
| 442 | + with align_module_device(module, execution_device="cpu"): |
| 443 | + state_dict = module.state_dict(prefix=f"{prefix}.") |
| 444 | + |
437 | 445 | # sparsity first |
438 | 446 | if prefix in sparse_compression_targets: |
439 | 447 | # sparse_compression_targets are automatically inferred by this fn |
@@ -461,7 +469,7 @@ def decompress_model(self, model: Module): |
461 | 469 | # remove any existing parameters |
462 | 470 | device = get_execution_device(module) |
463 | 471 | for name, _ in list(module.named_parameters()): |
464 | | - delattr(module, name) |
| 472 | + delete_offload_parameter(module, name) |
465 | 473 |
|
466 | 474 | # replace with decompressed parameters |
467 | 475 | for name, value in state_dict.items(): |
|
0 commit comments