|
50 | 50 | align_module_device,
|
51 | 51 | delete_offload_parameter,
|
52 | 52 | get_execution_device,
|
| 53 | + get_offloaded_device, |
53 | 54 | get_safetensors_folder,
|
54 | 55 | has_offloaded_params,
|
55 | 56 | merge_names,
|
@@ -408,16 +409,17 @@ def compress_model(self, model: Module):
|
408 | 409 | )
|
409 | 410 |
|
410 | 411 | # remove any existing parameters
|
411 |
| - device = get_execution_device(module) |
| 412 | + exec_device = get_execution_device(module) |
| 413 | + offload_device = get_offloaded_device(module) |
412 | 414 | for name, _ in list(module.named_parameters()):
|
413 |
| - delattr(module, name) |
| 415 | + delete_offload_parameter(module, name) |
414 | 416 |
|
415 | 417 | # replace with compressed parameters
|
416 | 418 | for name, value in state_dict.items():
|
417 | 419 | name = name.removeprefix(f"{prefix}.")
|
418 |
| - value = value.to(device) |
| 420 | + value = value.to(exec_device) |
419 | 421 | param = torch.nn.Parameter(value, requires_grad=False)
|
420 |
| - register_offload_parameter(module, name, param) |
| 422 | + register_offload_parameter(module, name, param, offload_device) |
421 | 423 |
|
422 | 424 | module.quantization_status = QuantizationStatus.COMPRESSED
|
423 | 425 |
|
@@ -474,16 +476,17 @@ def decompress_model(self, model: Module):
|
474 | 476 | }
|
475 | 477 |
|
476 | 478 | # remove any existing parameters
|
477 |
| - device = get_execution_device(module) |
| 479 | + exec_device = get_execution_device(module) |
| 480 | + offload_device = get_offloaded_device(module) |
478 | 481 | for name, _ in list(module.named_parameters()):
|
479 | 482 | delete_offload_parameter(module, name)
|
480 | 483 |
|
481 | 484 | # replace with decompressed parameters
|
482 | 485 | for name, value in state_dict.items():
|
483 | 486 | name = name.removeprefix(f"{prefix}.")
|
484 |
| - value = value.to(device) |
| 487 | + value = value.to(exec_device) |
485 | 488 | param = torch.nn.Parameter(value, requires_grad=False)
|
486 |
| - register_offload_parameter(module, name, param) |
| 489 | + register_offload_parameter(module, name, param, offload_device) |
487 | 490 |
|
488 | 491 | module.quantization_status = QuantizationStatus.FROZEN
|
489 | 492 |
|
|
0 commit comments