|
47 | 47 | iter_named_leaf_modules, |
48 | 48 | ) |
49 | 49 | from compressed_tensors.utils import ( |
| 50 | + get_execution_device, |
50 | 51 | get_safetensors_folder, |
51 | 52 | has_offloaded_params, |
52 | 53 | merge_names, |
@@ -403,12 +404,14 @@ def compress_model(self, model: Module): |
403 | 404 | ) |
404 | 405 |
|
405 | 406 | # remove any existing parameters |
| 407 | + device = get_execution_device(module) |
406 | 408 | for name, _ in list(module.named_parameters()): |
407 | 409 | delattr(module, name) |
408 | 410 |
|
409 | 411 | # replace with compressed parameters |
410 | 412 | for name, value in state_dict.items(): |
411 | 413 | name = name.removeprefix(f"{prefix}.") |
| 414 | + value = value.to(device) |
412 | 415 | param = torch.nn.Parameter(value, requires_grad=False) |
413 | 416 | register_offload_parameter(module, name, param) |
414 | 417 |
|
@@ -456,12 +459,14 @@ def decompress_model(self, model: Module): |
456 | 459 | } |
457 | 460 |
|
458 | 461 | # remove any existing parameters |
| 462 | + device = get_execution_device(module) |
459 | 463 | for name, _ in list(module.named_parameters()): |
460 | 464 | delattr(module, name) |
461 | 465 |
|
462 | 466 | # replace with decompressed parameters |
463 | 467 | for name, value in state_dict.items(): |
464 | 468 | name = name.removeprefix(f"{prefix}.") |
| 469 | + value = value.to(device) |
465 | 470 | param = torch.nn.Parameter(value, requires_grad=False) |
466 | 471 | register_offload_parameter(module, name, param) |
467 | 472 |
|
|
0 commit comments