Skip to content

Commit 8f67b97

Browse files
Added support for compression on meta device (#376)
* added support for compression on meta device Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * remove breakpoint Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * remove comment Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * address reviewed issues Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * fix style Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * new line Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * Update src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> * Update src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> * Update src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> * Added docstring Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * removed is_meta input Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * added test Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> --------- Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent 40fa3c5 commit 8f67b97

File tree

5 files changed

+117
-33
lines changed

5 files changed

+117
-33
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,16 @@ def compress_model(self, model: Module):
390390
)
391391

392392
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
393+
393394
if prefix in module_to_scheme or prefix in sparse_compression_targets:
395+
module_device = get_execution_device(module).type
396+
is_meta = (module_device == "meta")
397+
398+
exec_device = "meta" if is_meta else "cpu"
399+
onloading_device = "meta" if is_meta else module_device
400+
394401
# in the future, support compression on same device
395-
with align_module_device(module, execution_device="cpu"):
402+
with align_module_device(module, execution_device=exec_device):
396403
state_dict = module.state_dict(prefix=f"{prefix}.")
397404

398405
# quantization first
@@ -401,6 +408,7 @@ def compress_model(self, model: Module):
401408
state_dict,
402409
names_to_scheme=module_to_scheme,
403410
show_progress=False,
411+
compression_device=exec_device,
404412
)
405413

406414
# sparsity second
@@ -412,15 +420,14 @@ def compress_model(self, model: Module):
412420
)
413421

414422
# remove any existing parameters
415-
exec_device = get_execution_device(module)
416423
offload_device = get_offloaded_device(module)
417424
for name, _ in list(module.named_parameters()):
418425
delete_offload_parameter(module, name)
419426

420427
# replace with compressed parameters
421428
for name, value in state_dict.items():
422429
name = name.removeprefix(f"{prefix}.")
423-
value = value.to(exec_device)
430+
value = value.to(onloading_device)
424431
param = torch.nn.Parameter(value, requires_grad=False)
425432
register_offload_parameter(module, name, param, offload_device)
426433

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def compress(
7272
model_state: Dict[str, Tensor],
7373
names_to_scheme: Dict[str, QuantizationScheme],
7474
show_progress: bool = False,
75+
compression_device: str = "cpu",
7576
**kwargs,
7677
) -> Dict[str, Tensor]:
7778
"""
@@ -85,7 +86,6 @@ def compress(
8586
"""
8687
uncompressed_names = list(model_state.keys())
8788
compressed_dict = {}
88-
save_device = "cpu"
8989

9090
# compress values
9191
desc = "Compressing with quantization"
@@ -104,10 +104,10 @@ def compress(
104104

105105
# is scale does not exist, then weight cannot be compressed
106106
if scale is None:
107-
compressed_dict[name] = value.to(save_device)
107+
compressed_dict[name] = value.to(compression_device)
108108
continue
109109

110-
# compress values on cpu (memory movement too expensive)
110+
# compress values on meta if loading from meta otherwise on cpu (memory movement too expensive)
111111
module_path = prefix[:-1] if prefix.endswith(".") else prefix
112112
quant_args = names_to_scheme[module_path].weights
113113
compressed_values = self.compress_weight(
@@ -117,12 +117,12 @@ def compress(
117117
global_scale=global_scale,
118118
g_idx=g_idx,
119119
quantization_args=quant_args,
120-
device="cpu",
120+
device=compression_device,
121121
)
122122

123123
# update state dict
124124
for key, value in compressed_values.items():
125-
compressed_dict[prefix + key] = value.to(save_device)
125+
compressed_dict[prefix + key] = value.to(compression_device)
126126

127127
else:
128128
# omit saving zero points for symmetric or packed quantization
@@ -133,8 +133,7 @@ def compress(
133133
# TODO: does this case actually occur?
134134
elif name.endswith("g_idx") and torch.any(value <= -1):
135135
continue
136-
137-
compressed_dict[name] = value.to(save_device)
136+
compressed_dict[name] = value.to(compression_device)
138137

139138
return compressed_dict
140139

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -220,30 +220,34 @@ def pack_to_int32(
220220
if num_bits < 1:
221221
raise ValueError(f"num_bits must be at least 1, got {num_bits}")
222222

223-
# convert to unsigned for packing
223+
# Convert to unsigned range for packing, matching quantization offset
224224
offset = 1 << (num_bits - 1)
225225
value = (value + offset).to(torch.uint8)
226-
value = value.cpu().numpy().astype(np.uint32)
226+
device = value.device
227+
227228
pack_factor = 32 // num_bits
228229

229-
# pad input tensor and initialize packed output
230-
packed_size = math.ceil(value.shape[packed_dim] / pack_factor)
231-
padding = packed_size * pack_factor - value.shape[packed_dim]
232-
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
230+
if packed_dim == 0:
231+
value = value.transpose(0, 1)
233232

234-
# pack values
235-
if packed_dim == 1:
236-
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
237-
for i in range(pack_factor):
238-
packed |= value[:, i::pack_factor] << num_bits * i
239-
else:
240-
packed = np.zeros((packed_size, value.shape[1]), dtype=np.uint32)
241-
for i in range(pack_factor):
242-
packed |= value[i::pack_factor, :] << num_bits * i
233+
rows, cols = value.shape
234+
padded_cols = math.ceil(cols / pack_factor) * pack_factor
235+
pad_len = padded_cols - cols
236+
237+
if pad_len > 0:
238+
value = torch.nn.functional.pad(value, (0, pad_len))
239+
240+
num_groups = padded_cols // pack_factor
241+
242+
# Use int32 here
243+
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
244+
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
245+
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
246+
247+
if packed_dim == 0:
248+
packed = packed.transpose(0, 1)
243249

244-
# convert back to signed and torch
245-
packed = np.ascontiguousarray(packed).view(np.int32)
246-
return torch.from_numpy(packed)
250+
return packed
247251

248252

249253
def unpack_from_int32(

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ def compress_weight(self, name, value):
5656
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
5757
value, self.config.sparsity_structure
5858
)
59-
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
60-
return bitmask_dict
59+
return bitmask_tensor.dict(
60+
name_prefix=name,
61+
device="meta" if value.is_meta else "cpu",
62+
)
6163

6264
def decompress_weight(self, weight_data):
6365
data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
@@ -90,9 +92,14 @@ def from_dense(
9092
:return: instantiated compressed tensor
9193
"""
9294
shape = list(tensor.shape)
93-
compressed, bitmask = sparse24_bitmask_compress(
94-
tensor.cpu(), sparsity_structure=sparsity_structure
95-
)
95+
if tensor.is_meta:
96+
compressed, bitmask = sparse24_bitmask_compress(
97+
tensor, sparsity_structure=sparsity_structure
98+
)
99+
else:
100+
compressed, bitmask = sparse24_bitmask_compress(
101+
tensor.cpu(), sparsity_structure=sparsity_structure
102+
)
96103
return Sparse24BitMaskTensor(
97104
shape=shape,
98105
compressed=compressed,
@@ -169,6 +176,13 @@ def sparse24_bitmask_compress(
169176
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
170177
), "Only 2:4 sparsity is supported"
171178

179+
if tensor.is_meta:
180+
num_rows, num_cols = tensor.shape
181+
compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta")
182+
packed_cols = (num_cols + 7) // 8
183+
bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta")
184+
return compressed_values, bitmasks_packed
185+
172186
bytemasks = get_24_bytemasks(tensor=tensor)
173187

174188
if tensor.dtype == FP8_DTYPE:

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,66 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
412412
assert torch.all(compressed[key] == true_compressed[key]), f"{key}"
413413

414414

415+
@pytest.mark.parametrize(
416+
"model_stub,q_format,s_config",
417+
[
418+
(
419+
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
420+
"float-quantized",
421+
None,
422+
),
423+
(
424+
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
425+
None,
426+
"sparse-24-bitmask",
427+
),
428+
(
429+
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
430+
"float-quantized",
431+
"sparse-24-bitmask",
432+
),
433+
(
434+
"nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed",
435+
"pack-quantized",
436+
None,
437+
),
438+
],
439+
)
440+
def test_compress_model_meta(model_stub, q_format, s_config):
441+
# Load model on CPU to get expected compressed state_dict
442+
cpu_model = AutoModelForCausalLM.from_pretrained(
443+
model_stub, torch_dtype=torch.float32
444+
)
445+
reference_compressor = ModelCompressor.from_pretrained_model(
446+
cpu_model, s_config, q_format
447+
)
448+
# Only stores dtype because meta model does not store values
449+
expected = {
450+
k: v.dtype
451+
for k, v in reference_compressor.compress(cpu_model).items()
452+
}
453+
454+
# Load model on meta device
455+
meta_model = AutoModelForCausalLM.from_pretrained(
456+
model_stub,
457+
torch_dtype=torch.float32,
458+
low_cpu_mem_usage=True,
459+
)
460+
for module in meta_model.modules():
461+
if hasattr(module, "to_empty"):
462+
module.to_empty(device="meta")
463+
464+
# Compress in-place on meta model
465+
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format)
466+
compressor.compress_model(meta_model)
467+
468+
# Compare keys and dtypes
469+
compressed = dict(meta_model.state_dict())
470+
assert set(compressed.keys()) == set(expected.keys())
471+
for key, dtype in expected.items():
472+
assert compressed[key].dtype == dtype, f"{key} has incorrect dtype"
473+
474+
415475
@pytest.mark.parametrize(
416476
"model_stub,comp_stub",
417477
[

0 commit comments

Comments
 (0)