Skip to content

Added support for compression on meta device #376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 9, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:

# ----- model memory compression/decompression pathways ----- #

def compress_model(self, model: Module):
def compress_model(self, model: Module, is_meta: bool = False):
"""
Compress a model in memory. Because the model structure is modified in place,
this method is more memory-efficient than `self.compress`
Expand All @@ -394,8 +394,11 @@ def compress_model(self, model: Module):

for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
if prefix in module_to_scheme or prefix in sparse_compression_targets:
exec_device = "meta" if is_meta else "cpu"
onloading_device = "meta" if is_meta else get_execution_device(module)

# in the future, support compression on same device
with align_module_device(module, execution_device="cpu"):
with align_module_device(module, execution_device=exec_device):
state_dict = module.state_dict(prefix=f"{prefix}.")

# quantization first
Expand All @@ -404,6 +407,7 @@ def compress_model(self, model: Module):
state_dict,
names_to_scheme=module_to_scheme,
show_progress=False,
compression_device=exec_device,
)

# sparsity second
Expand All @@ -415,15 +419,14 @@ def compress_model(self, model: Module):
)

# remove any existing parameters
exec_device = get_execution_device(module)
offload_device = get_offloaded_device(module)
for name, _ in list(module.named_parameters()):
delete_offload_parameter(module, name)

# replace with compressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(exec_device)
value = value.to(onloading_device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param, offload_device)

Expand Down Expand Up @@ -485,7 +488,9 @@ def decompress_model(self, model: Module):
# replace with decompressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(exec_device)
# skipping save if we're just registering the model on meta device
if exec_device != "meta":
value = value.to(exec_device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param, offload_device)

Expand Down
13 changes: 6 additions & 7 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def compress(
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationScheme],
show_progress: bool = False,
compression_device: str = "cpu",
**kwargs,
) -> Dict[str, Tensor]:
"""
Expand All @@ -85,7 +86,6 @@ def compress(
"""
uncompressed_names = list(model_state.keys())
compressed_dict = {}
save_device = "cpu"

# compress values
desc = "Compressing with quantization"
Expand All @@ -104,10 +104,10 @@ def compress(

# is scale does not exist, then weight cannot be compressed
if scale is None:
compressed_dict[name] = value.to(save_device)
compressed_dict[name] = value.to(compression_device)
continue

# compress values on cpu (memory movement too expensive)
# compress values on meta if loading from meta otherwise on cpu (memory movement too expensive)
module_path = prefix[:-1] if prefix.endswith(".") else prefix
quant_args = names_to_scheme[module_path].weights
compressed_values = self.compress_weight(
Expand All @@ -117,12 +117,12 @@ def compress(
global_scale=global_scale,
g_idx=g_idx,
quantization_args=quant_args,
device="cpu",
device=compression_device,
)

# update state dict
for key, value in compressed_values.items():
compressed_dict[prefix + key] = value.to(save_device)
compressed_dict[prefix + key] = value.to(compression_device)

else:
# omit saving zero points for symmetric or packed quantization
Expand All @@ -133,8 +133,7 @@ def compress(
# TODO: does this case actually occur?
elif name.endswith("g_idx") and torch.any(value <= -1):
continue

compressed_dict[name] = value.to(save_device)
compressed_dict[name] = value.to(compression_device)

return compressed_dict

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,30 +220,34 @@ def pack_to_int32(
if num_bits < 1:
raise ValueError(f"num_bits must be at least 1, got {num_bits}")

# convert to unsigned for packing
# Convert to unsigned range for packing, matching quantization offset
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
value = value.cpu().numpy().astype(np.uint32)
device = value.device

pack_factor = 32 // num_bits

# pad input tensor and initialize packed output
packed_size = math.ceil(value.shape[packed_dim] / pack_factor)
padding = packed_size * pack_factor - value.shape[packed_dim]
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
if packed_dim == 0:
value = value.transpose(0, 1)

# pack values
if packed_dim == 1:
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
for i in range(pack_factor):
packed |= value[:, i::pack_factor] << num_bits * i
else:
packed = np.zeros((packed_size, value.shape[1]), dtype=np.uint32)
for i in range(pack_factor):
packed |= value[i::pack_factor, :] << num_bits * i
rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols

if pad_len > 0:
value = torch.nn.functional.pad(value, (0, pad_len))

num_groups = padded_cols // pack_factor

# Use int32 here
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)

if packed_dim == 0:
packed = packed.transpose(0, 1)

# convert back to signed and torch
packed = np.ascontiguousarray(packed).view(np.int32)
return torch.from_numpy(packed)
return packed


def unpack_from_int32(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def compress_weight(self, name, value):
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
value, self.config.sparsity_structure
)
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
return bitmask_dict
return bitmask_tensor.dict(
name_prefix=name,
device="meta" if value.device.type == "meta" else "cpu",
)

def decompress_weight(self, weight_data):
data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
Expand Down Expand Up @@ -90,9 +92,14 @@ def from_dense(
:return: instantiated compressed tensor
"""
shape = list(tensor.shape)
compressed, bitmask = sparse24_bitmask_compress(
tensor.cpu(), sparsity_structure=sparsity_structure
)
if tensor.device.type == "meta":
compressed, bitmask = sparse24_bitmask_compress(
tensor, sparsity_structure=sparsity_structure
)
else:
compressed, bitmask = sparse24_bitmask_compress(
tensor.cpu(), sparsity_structure=sparsity_structure
)
return Sparse24BitMaskTensor(
shape=shape,
compressed=compressed,
Expand Down Expand Up @@ -169,6 +176,13 @@ def sparse24_bitmask_compress(
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
), "Only 2:4 sparsity is supported"

if tensor.device.type == "meta":
num_rows, num_cols = tensor.shape
compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta")
packed_cols = (num_cols + 7) // 8
bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta")
return compressed_values, bitmasks_packed

bytemasks = get_24_bytemasks(tensor=tensor)

if tensor.dtype == FP8_DTYPE:
Expand Down