Skip to content

Commit cbb258b

Browse files
committed
ensure correct device placement
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 82dfe9d commit cbb258b

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
iter_named_leaf_modules,
4848
)
4949
from compressed_tensors.utils import (
50+
get_execution_device,
5051
get_safetensors_folder,
5152
has_offloaded_params,
5253
merge_names,
@@ -403,12 +404,14 @@ def compress_model(self, model: Module):
403404
)
404405

405406
# remove any existing parameters
407+
device = get_execution_device(module)
406408
for name, _ in list(module.named_parameters()):
407409
delattr(module, name)
408410

409411
# replace with compressed parameters
410412
for name, value in state_dict.items():
411413
name = name.removeprefix(f"{prefix}.")
414+
value = value.to(device)
412415
param = torch.nn.Parameter(value, requires_grad=False)
413416
register_offload_parameter(module, name, param)
414417

@@ -456,12 +459,14 @@ def decompress_model(self, model: Module):
456459
}
457460

458461
# remove any existing parameters
462+
device = get_execution_device(module)
459463
for name, _ in list(module.named_parameters()):
460464
delattr(module, name)
461465

462466
# replace with decompressed parameters
463467
for name, value in state_dict.items():
464468
name = name.removeprefix(f"{prefix}.")
469+
value = value.to(device)
465470
param = torch.nn.Parameter(value, requires_grad=False)
466471
register_offload_parameter(module, name, param)
467472

0 commit comments

Comments
 (0)