Skip to content

Commit 5a451b0

Browse files
authored
maintain module_device (#384)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent b163bd9 commit 5a451b0

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ def compress_model(self, model: Module):
392392
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
393393

394394
if prefix in module_to_scheme or prefix in sparse_compression_targets:
395-
module_device = get_execution_device(module).type
396-
is_meta = (module_device == "meta")
395+
module_device = get_execution_device(module)
396+
is_meta = module_device.type == "meta"
397397

398398
exec_device = "meta" if is_meta else "cpu"
399399
onloading_device = "meta" if is_meta else module_device

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,13 @@ def sparse24_bitmask_compress(
178178

179179
if tensor.is_meta:
180180
num_rows, num_cols = tensor.shape
181-
compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta")
181+
compressed_values = torch.empty(
182+
(num_rows, num_cols // 2), dtype=tensor.dtype, device="meta"
183+
)
182184
packed_cols = (num_cols + 7) // 8
183-
bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta")
185+
bitmasks_packed = torch.empty(
186+
(num_rows, packed_cols), dtype=torch.uint8, device="meta"
187+
)
184188
return compressed_values, bitmasks_packed
185189

186190
bytemasks = get_24_bytemasks(tensor=tensor)

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
446446
cpu_model, s_config, q_format
447447
)
448448
# Only stores dtype because meta model does not store values
449-
expected = {
450-
k: v.dtype
451-
for k, v in reference_compressor.compress(cpu_model).items()
452-
}
449+
expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()}
453450

454451
# Load model on meta device
455452
meta_model = AutoModelForCausalLM.from_pretrained(

0 commit comments

Comments
 (0)