Skip to content

Commit 56cf39c

Browse files
authored
fix deleting offload (#319)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 7b5a7a4 commit 56cf39c

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
align_module_device,
5151
delete_offload_parameter,
5252
get_execution_device,
53+
get_offloaded_device,
5354
get_safetensors_folder,
5455
has_offloaded_params,
5556
merge_names,
@@ -408,16 +409,17 @@ def compress_model(self, model: Module):
408409
)
409410

410411
# remove any existing parameters
411-
device = get_execution_device(module)
412+
exec_device = get_execution_device(module)
413+
offload_device = get_offloaded_device(module)
412414
for name, _ in list(module.named_parameters()):
413-
delattr(module, name)
415+
delete_offload_parameter(module, name)
414416

415417
# replace with compressed parameters
416418
for name, value in state_dict.items():
417419
name = name.removeprefix(f"{prefix}.")
418-
value = value.to(device)
420+
value = value.to(exec_device)
419421
param = torch.nn.Parameter(value, requires_grad=False)
420-
register_offload_parameter(module, name, param)
422+
register_offload_parameter(module, name, param, offload_device)
421423

422424
module.quantization_status = QuantizationStatus.COMPRESSED
423425

@@ -474,16 +476,17 @@ def decompress_model(self, model: Module):
474476
}
475477

476478
# remove any existing parameters
477-
device = get_execution_device(module)
479+
exec_device = get_execution_device(module)
480+
offload_device = get_offloaded_device(module)
478481
for name, _ in list(module.named_parameters()):
479482
delete_offload_parameter(module, name)
480483

481484
# replace with decompressed parameters
482485
for name, value in state_dict.items():
483486
name = name.removeprefix(f"{prefix}.")
484-
value = value.to(device)
487+
value = value.to(exec_device)
485488
param = torch.nn.Parameter(value, requires_grad=False)
486-
register_offload_parameter(module, name, param)
489+
register_offload_parameter(module, name, param, offload_device)
487490

488491
module.quantization_status = QuantizationStatus.FROZEN
489492

0 commit comments

Comments
 (0)