|
47 | 47 | iter_named_leaf_modules,
|
48 | 48 | )
|
49 | 49 | from compressed_tensors.utils import (
|
| 50 | + get_execution_device, |
50 | 51 | get_safetensors_folder,
|
51 | 52 | has_offloaded_params,
|
52 | 53 | merge_names,
|
@@ -403,12 +404,14 @@ def compress_model(self, model: Module):
|
403 | 404 | )
|
404 | 405 |
|
405 | 406 | # remove any existing parameters
|
| 407 | + device = get_execution_device(module) |
406 | 408 | for name, _ in list(module.named_parameters()):
|
407 | 409 | delattr(module, name)
|
408 | 410 |
|
409 | 411 | # replace with compressed parameters
|
410 | 412 | for name, value in state_dict.items():
|
411 | 413 | name = name.removeprefix(f"{prefix}.")
|
| 414 | + value = value.to(device) |
412 | 415 | param = torch.nn.Parameter(value, requires_grad=False)
|
413 | 416 | register_offload_parameter(module, name, param)
|
414 | 417 |
|
@@ -456,12 +459,14 @@ def decompress_model(self, model: Module):
|
456 | 459 | }
|
457 | 460 |
|
458 | 461 | # remove any existing parameters
|
| 462 | + device = get_execution_device(module) |
459 | 463 | for name, _ in list(module.named_parameters()):
|
460 | 464 | delattr(module, name)
|
461 | 465 |
|
462 | 466 | # replace with decompressed parameters
|
463 | 467 | for name, value in state_dict.items():
|
464 | 468 | name = name.removeprefix(f"{prefix}.")
|
| 469 | + value = value.to(device) |
465 | 470 | param = torch.nn.Parameter(value, requires_grad=False)
|
466 | 471 | register_offload_parameter(module, name, param)
|
467 | 472 |
|
|
0 commit comments