Skip to content

Added support for compression on meta device #376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 9, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:

# ----- model memory compression/decompression pathways ----- #

def compress_model(self, model: Module):
def compress_model(self, model: Module, is_meta: bool = False):
"""
Compress a model in memory. Because the model structure is modified in place,
this method is more memory-efficient than `self.compress`
Expand All @@ -394,8 +394,11 @@ def compress_model(self, model: Module):

for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
if prefix in module_to_scheme or prefix in sparse_compression_targets:
exec_device = "meta" if is_meta else "cpu"
onloading_device = "meta" if is_meta else get_execution_device(module)

# in the future, support compression on same device
with align_module_device(module, execution_device="cpu"):
with align_module_device(module, execution_device=exec_device):
state_dict = module.state_dict(prefix=f"{prefix}.")

# quantization first
Expand All @@ -404,6 +407,7 @@ def compress_model(self, model: Module):
state_dict,
names_to_scheme=module_to_scheme,
show_progress=False,
save_device=exec_device,
)

# sparsity second
Expand All @@ -412,18 +416,18 @@ def compress_model(self, model: Module):
state_dict,
compression_targets=sparse_compression_targets,
show_progress=False,
module=module,
)

# remove any existing parameters
exec_device = get_execution_device(module)
offload_device = get_offloaded_device(module)
for name, _ in list(module.named_parameters()):
delete_offload_parameter(module, name)

# replace with compressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(exec_device)
value = value.to(onloading_device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param, offload_device)

Expand Down Expand Up @@ -485,7 +489,9 @@ def decompress_model(self, model: Module):
# replace with decompressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(exec_device)
# skipping save if we're just registering the model on meta device
if exec_device != "meta":
value = value.to(exec_device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param, offload_device)

Expand Down Expand Up @@ -694,6 +700,7 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
device = "cpu" if has_offloaded_params(module) else params_device
delattr(module, param_name)
requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)

param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
register_offload_parameter(module, param_name, param)

Expand All @@ -711,7 +718,6 @@ def _replace_weights(self, dense_weight_generator, model: Module):
'data' is the updated param data
:param model: The model whose weights are to be updated.
"""

for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"):
module = operator.attrgetter(mod_path)(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def compress(
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationScheme],
show_progress: bool = False,
save_device: str = "cpu",
**kwargs,
) -> Dict[str, Tensor]:
"""
Expand All @@ -85,7 +86,6 @@ def compress(
"""
uncompressed_names = list(model_state.keys())
compressed_dict = {}
save_device = "cpu"

# compress values
desc = "Compressing with quantization"
Expand All @@ -107,7 +107,7 @@ def compress(
compressed_dict[name] = value.to(save_device)
continue

# compress values on cpu (memory movement too expensive)
# compress values on meta if loading from meta otherwise on cpu (memory movement too expensive)
module_path = prefix[:-1] if prefix.endswith(".") else prefix
quant_args = names_to_scheme[module_path].weights
compressed_values = self.compress_weight(
Expand All @@ -117,7 +117,7 @@ def compress(
global_scale=global_scale,
g_idx=g_idx,
quantization_args=quant_args,
device="cpu",
device=save_device,
)

# update state dict
Expand All @@ -133,7 +133,6 @@ def compress(
# TODO: does this case actually occur?
elif name.endswith("g_idx") and torch.any(value <= -1):
continue

compressed_dict[name] = value.to(save_device)

return compressed_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,30 +220,36 @@ def pack_to_int32(
if num_bits < 1:
raise ValueError(f"num_bits must be at least 1, got {num_bits}")

# convert to unsigned for packing
# Convert to unsigned range for packing, matching quantization offset
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
value = value.cpu().numpy().astype(np.uint32)
device = value.device

pack_factor = 32 // num_bits

# pad input tensor and initialize packed output
packed_size = math.ceil(value.shape[packed_dim] / pack_factor)
padding = packed_size * pack_factor - value.shape[packed_dim]
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
if packed_dim == 0:
value = value.transpose(0, 1)

# pack values
if packed_dim == 1:
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
for i in range(pack_factor):
packed |= value[:, i::pack_factor] << num_bits * i
else:
packed = np.zeros((packed_size, value.shape[1]), dtype=np.uint32)
for i in range(pack_factor):
packed |= value[i::pack_factor, :] << num_bits * i
rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols

if pad_len > 0:
value = torch.nn.functional.pad(value, (0, pad_len))

num_groups = padded_cols // pack_factor

# Use int32 here
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)

bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits

packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)

if packed_dim == 0:
packed = packed.transpose(0, 1)

# convert back to signed and torch
packed = np.ascontiguousarray(packed).view(np.int32)
return torch.from_numpy(packed)
return packed


def unpack_from_int32(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from torch import Tensor
from tqdm import tqdm

from torch.nn import Module


__all__ = ["BaseSparseCompressor"]

Expand Down Expand Up @@ -68,6 +70,7 @@ def compress(
model_state: Dict[str, Tensor],
compression_targets: Optional[Set[str]] = None,
show_progress: bool = False,
module: Optional[Module] = None,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression
Expand All @@ -93,7 +96,10 @@ def compress(
if prefix.endswith(".weight"):
prefix = prefix[: -(len(".weight"))]

compression_data = self.compress_weight(prefix, value)
compression_data = self.compress_weight(
prefix, value, module=module
)

for key in compression_data.keys():
if key in compressed_dict:
_LOGGER.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,22 @@ def compression_param_names(self) -> Tuple[str]:
"bitmask",
)

def compress_weight(self, name, value):
def compress_weight(self, name, value, *, module=None):
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
value, self.config.sparsity_structure
)
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
return bitmask_dict
if value.device.type == "meta":
if module is None:
raise ValueError("compress_weight requires module argument when is_meta=True")
# Create empty parameter matching compressed shape
empty_weight = torch.empty_like(bitmask_tensor.compressed, device="meta")
module.weight = torch.nn.Parameter(empty_weight, requires_grad=False)

# Normal flow: return compression dict
return bitmask_tensor.dict(
name_prefix=name,
device="meta" if value.device.type == "meta" else "cpu",
)

def decompress_weight(self, weight_data):
data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
Expand Down Expand Up @@ -90,9 +100,14 @@ def from_dense(
:return: instantiated compressed tensor
"""
shape = list(tensor.shape)
compressed, bitmask = sparse24_bitmask_compress(
tensor.cpu(), sparsity_structure=sparsity_structure
)
if tensor.device.type == "meta":
compressed, bitmask = sparse24_bitmask_compress(
tensor, sparsity_structure=sparsity_structure
)
else:
compressed, bitmask = sparse24_bitmask_compress(
tensor.cpu(), sparsity_structure=sparsity_structure
)
return Sparse24BitMaskTensor(
shape=shape,
compressed=compressed,
Expand Down Expand Up @@ -169,6 +184,13 @@ def sparse24_bitmask_compress(
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
), "Only 2:4 sparsity is supported"

if tensor.device.type == "meta":
num_rows, num_cols = tensor.shape
compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta")
packed_cols = (num_cols + 7) // 8
bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta")
return compressed_values, bitmasks_packed

bytemasks = get_24_bytemasks(tensor=tensor)

if tensor.dtype == FP8_DTYPE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def compression_param_names(self) -> Tuple[str]:
"""
return ("shape", "compressed", "bitmask", "row_offsets")

def compress_weight(self, name, value):
def compress_weight(self, name, value, **kwargs):
bitmask_tensor = BitmaskTensor.from_dense(value)
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
return bitmask_dict
Expand Down