Skip to content

Added support for compression on meta device #376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 9, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,16 @@ def compress_model(self, model: Module):
)

for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):

if prefix in module_to_scheme or prefix in sparse_compression_targets:
module_device = get_execution_device(module).type
is_meta = (module_device == "meta")

exec_device = "meta" if is_meta else "cpu"
onloading_device = "meta" if is_meta else module_device

# in the future, support compression on same device
with align_module_device(module, execution_device="cpu"):
with align_module_device(module, execution_device=exec_device):
state_dict = module.state_dict(prefix=f"{prefix}.")

# quantization first
Expand All @@ -404,6 +411,7 @@ def compress_model(self, model: Module):
state_dict,
names_to_scheme=module_to_scheme,
show_progress=False,
compression_device=exec_device,
)

# sparsity second
Expand All @@ -415,15 +423,14 @@ def compress_model(self, model: Module):
)

# remove any existing parameters
exec_device = get_execution_device(module)
offload_device = get_offloaded_device(module)
for name, _ in list(module.named_parameters()):
delete_offload_parameter(module, name)

# replace with compressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(exec_device)
value = value.to(onloading_device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param, offload_device)

Expand Down
13 changes: 6 additions & 7 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def compress(
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationScheme],
show_progress: bool = False,
compression_device: str = "cpu",
**kwargs,
) -> Dict[str, Tensor]:
"""
Expand All @@ -85,7 +86,6 @@ def compress(
"""
uncompressed_names = list(model_state.keys())
compressed_dict = {}
save_device = "cpu"

# compress values
desc = "Compressing with quantization"
Expand All @@ -104,10 +104,10 @@ def compress(

# is scale does not exist, then weight cannot be compressed
if scale is None:
compressed_dict[name] = value.to(save_device)
compressed_dict[name] = value.to(compression_device)
continue

# compress values on cpu (memory movement too expensive)
# compress values on meta if loading from meta otherwise on cpu (memory movement too expensive)
module_path = prefix[:-1] if prefix.endswith(".") else prefix
quant_args = names_to_scheme[module_path].weights
compressed_values = self.compress_weight(
Expand All @@ -117,12 +117,12 @@ def compress(
global_scale=global_scale,
g_idx=g_idx,
quantization_args=quant_args,
device="cpu",
device=compression_device,
)

# update state dict
for key, value in compressed_values.items():
compressed_dict[prefix + key] = value.to(save_device)
compressed_dict[prefix + key] = value.to(compression_device)

else:
# omit saving zero points for symmetric or packed quantization
Expand All @@ -133,8 +133,7 @@ def compress(
# TODO: does this case actually occur?
elif name.endswith("g_idx") and torch.any(value <= -1):
continue

compressed_dict[name] = value.to(save_device)
compressed_dict[name] = value.to(compression_device)

return compressed_dict

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,30 +220,34 @@ def pack_to_int32(
if num_bits < 1:
raise ValueError(f"num_bits must be at least 1, got {num_bits}")

# convert to unsigned for packing
# Convert to unsigned range for packing, matching quantization offset
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
value = value.cpu().numpy().astype(np.uint32)
device = value.device

pack_factor = 32 // num_bits

# pad input tensor and initialize packed output
packed_size = math.ceil(value.shape[packed_dim] / pack_factor)
padding = packed_size * pack_factor - value.shape[packed_dim]
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
if packed_dim == 0:
value = value.transpose(0, 1)

# pack values
if packed_dim == 1:
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
for i in range(pack_factor):
packed |= value[:, i::pack_factor] << num_bits * i
else:
packed = np.zeros((packed_size, value.shape[1]), dtype=np.uint32)
for i in range(pack_factor):
packed |= value[i::pack_factor, :] << num_bits * i
rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols

if pad_len > 0:
value = torch.nn.functional.pad(value, (0, pad_len))

num_groups = padded_cols // pack_factor

# Use int32 here
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)

if packed_dim == 0:
packed = packed.transpose(0, 1)

# convert back to signed and torch
packed = np.ascontiguousarray(packed).view(np.int32)
return torch.from_numpy(packed)
return packed


def unpack_from_int32(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def compress_weight(self, name, value):
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
value, self.config.sparsity_structure
)
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
return bitmask_dict
return bitmask_tensor.dict(
name_prefix=name,
device="meta" if value.is_meta else "cpu",
)

def decompress_weight(self, weight_data):
data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
Expand Down Expand Up @@ -90,9 +92,14 @@ def from_dense(
:return: instantiated compressed tensor
"""
shape = list(tensor.shape)
compressed, bitmask = sparse24_bitmask_compress(
tensor.cpu(), sparsity_structure=sparsity_structure
)
if tensor.is_meta:
compressed, bitmask = sparse24_bitmask_compress(
tensor, sparsity_structure=sparsity_structure
)
else:
compressed, bitmask = sparse24_bitmask_compress(
tensor.cpu(), sparsity_structure=sparsity_structure
)
return Sparse24BitMaskTensor(
shape=shape,
compressed=compressed,
Expand Down Expand Up @@ -169,6 +176,13 @@ def sparse24_bitmask_compress(
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
), "Only 2:4 sparsity is supported"

if tensor.is_meta:
num_rows, num_cols = tensor.shape
compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta")
packed_cols = (num_cols + 7) // 8
bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta")
return compressed_values, bitmasks_packed

bytemasks = get_24_bytemasks(tensor=tensor)

if tensor.dtype == FP8_DTYPE:
Expand Down
60 changes: 60 additions & 0 deletions tests/test_compressors/model_compressors/test_model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,66 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
assert torch.all(compressed[key] == true_compressed[key]), f"{key}"


@pytest.mark.parametrize(
"model_stub,q_format,s_config",
[
(
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
"float-quantized",
None,
),
(
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
None,
"sparse-24-bitmask",
),
(
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
"float-quantized",
"sparse-24-bitmask",
),
(
"nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed",
"pack-quantized",
None,
),
],
)
def test_compress_model_meta(model_stub, q_format, s_config):
# Load model on CPU to get expected compressed state_dict
cpu_model = AutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.float32
)
reference_compressor = ModelCompressor.from_pretrained_model(
cpu_model, s_config, q_format
)
# Only stores dtype because meta model does not store values
expected = {
k: v.dtype
for k, v in reference_compressor.compress(cpu_model).items()
}

# Load model on meta device
meta_model = AutoModelForCausalLM.from_pretrained(
model_stub,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
for module in meta_model.modules():
if hasattr(module, "to_empty"):
module.to_empty(device="meta")

# Compress in-place on meta model
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format)
compressor.compress_model(meta_model)

# Compare keys and dtypes
compressed = dict(meta_model.state_dict())
assert set(compressed.keys()) == set(expected.keys())
for key, dtype in expected.items():
assert compressed[key].dtype == dtype, f"{key} has incorrect dtype"


@pytest.mark.parametrize(
"model_stub,comp_stub",
[
Expand Down