Skip to content

[Performance] Add memory compression and decompression pathways #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
iter_named_leaf_modules,
)
from compressed_tensors.utils import (
align_module_device,
delete_offload_parameter,
get_execution_device,
get_safetensors_folder,
has_offloaded_params,
merge_names,
Expand Down Expand Up @@ -98,6 +101,9 @@ class ModelCompressor:
:param quantization_config: config specifying quantization compression parameters
"""

sparsity_config: Optional[SparsityCompressionConfig] = None
quantization_config: Optional[QuantizationConfig] = None

@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -261,6 +267,8 @@ def __init__(
quantization_config.format, config=quantization_config
)

# ----- used by hf quantizer ----- #

def get_missing_module_keys(self, model: Module) -> List[str]:
"""
Identifies the expected missing weight keys in the compressed state_dict.
Expand All @@ -270,7 +278,6 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
This function determines which weight keys are missing based on the
applied compression techniques.


:param model: The PyTorch model to check for missing keys.
:return: A list of missing keys expected in the compressed state_dict.
"""
Expand Down Expand Up @@ -362,8 +369,124 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:

return list(unexpected_keys)

# ----- model memory compression/decompression pathways ----- #

def compress_model(self, model: Module):
"""
Compress a model in memory. Because the model structure is modified in place,
this method is more memory-efficient than `self.compress`

:param model: model containing parameters to compress
"""
module_to_scheme = map_module_to_scheme(model)
sparse_compression_targets: Set[str] = expand_target_names(
model=model,
targets=self.sparsity_config.targets if self.sparsity_config else [],
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
)

for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
if prefix in module_to_scheme or prefix in sparse_compression_targets:
# in the future, support compression on same device
with align_module_device(module, execution_device="cpu"):
state_dict = module.state_dict(prefix=f"{prefix}.")

# quantization first
if prefix in module_to_scheme:
state_dict = self.quantization_compressor.compress(
state_dict,
names_to_scheme=module_to_scheme,
show_progress=False,
)

# sparsity second
if prefix in sparse_compression_targets:
state_dict = self.sparsity_compressor.compress(
state_dict,
compression_targets=sparse_compression_targets,
show_progress=False,
)

# remove any existing parameters
device = get_execution_device(module)
for name, _ in list(module.named_parameters()):
delattr(module, name)

# replace with compressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param)

module.quantization_status = QuantizationStatus.COMPRESSED

def decompress_model(self, model: Module):
"""
Decompress a model in memory. Because the model structure is modified in place,
this method does not require loading some compression parameters from disk

:param model: model containing parameters to compress
"""
module_to_scheme = map_module_to_scheme(model)
sparse_compression_targets: Set[str] = expand_target_names(
model=model,
targets=self.sparsity_config.targets if self.sparsity_config else [],
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
)

for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
if prefix in module_to_scheme or prefix in sparse_compression_targets:
# in the future, support decompression on same device
with align_module_device(module, execution_device="cpu"):
state_dict = module.state_dict(prefix=f"{prefix}.")

# sparsity first
if prefix in sparse_compression_targets:
# sparse_compression_targets are automatically inferred by this fn
generator = self.sparsity_compressor.decompress_from_state_dict(
state_dict,
)
# generates (param_path, param_val)
# of compressed and unused params
state_dict = {key: value for key, value in generator}

# quantization second
if prefix in module_to_scheme:
generator = self.quantization_compressor.decompress_from_state_dict(
state_dict,
names_to_scheme=module_to_scheme,
)
# generates (mod_path, {param_name, param_val})
# of compressed params and used params, but not unused params
# some used params are removed by get_unexpected_file_keys
state_dict = {
merge_names(module_path, param_name): param_value
for module_path, compressed_data in generator
for param_name, param_value in compressed_data.items()
}

# remove any existing parameters
device = get_execution_device(module)
for name, _ in list(module.named_parameters()):
delete_offload_parameter(module, name)

# replace with decompressed parameters
for name, value in state_dict.items():
name = name.removeprefix(f"{prefix}.")
value = value.to(device)
param = torch.nn.Parameter(value, requires_grad=False)
register_offload_parameter(module, name, param)

module.quantization_status = QuantizationStatus.FROZEN

# ----- state dict compression pathways ----- #

def compress(
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
self,
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
show_progress: bool = False,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict or model with sparsity and/or quantization
Expand All @@ -379,7 +502,9 @@ def compress(
if self.quantization_compressor is not None:
module_to_scheme = map_module_to_scheme(model)
state_dict = self.quantization_compressor.compress(
state_dict, names_to_scheme=module_to_scheme
state_dict,
names_to_scheme=module_to_scheme,
show_progress=show_progress,
)

# TODO: consider sparse compression to also be compression
Expand All @@ -397,6 +522,7 @@ def compress(
state_dict = self.sparsity_compressor.compress(
state_dict,
compression_targets=sparse_compression_targets,
show_progress=show_progress,
)

# HACK: Override the dtype_byte_size function in transformers to
Expand All @@ -406,6 +532,8 @@ def compress(

return state_dict

# ----- disk decompression pathways ----- #

def decompress(self, model_path: str, model: Module):
"""
Overwrites the weights in model with weights decompressed from model_path
Expand Down
21 changes: 14 additions & 7 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
merge_names,
remove_suffix,
)
from safetensors import safe_open
from torch import Tensor
Expand Down Expand Up @@ -71,6 +70,7 @@ def compress(
self,
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationScheme],
show_progress: bool = False,
**kwargs,
) -> Dict[str, Tensor]:
"""
Expand All @@ -79,18 +79,21 @@ def compress(
:param model_state: state dict of uncompressed model
:param names_to_scheme: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
:param show_progress: whether to show tqdm progress
:return: compressed state dict
"""
uncompressed_names = list(model_state.keys())
compressed_dict = {}
save_device = "cpu"

uncompressed_names = list(model_state.keys())
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
# compress values
desc = "Compressing with quantization"
for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)):
value = model_state[name]

# compress weights
if name.endswith("weight"):
prefix = remove_suffix(name, "weight")
prefix = name.removesuffix("weight")

# gather qparams
scale = model_state.get(prefix + "weight_scale", None)
Expand Down Expand Up @@ -182,7 +185,7 @@ def decompress(
)

else:
yield from self._decompress_from_state_dict(
yield from self.decompress_from_state_dict(
path_to_model_or_tensors, names_to_scheme
)

Expand All @@ -209,7 +212,11 @@ def _decompress_from_path(
weight_data["weight"] = decompressed
yield module_path, weight_data

def _decompress_from_state_dict(self, state_dict, names_to_scheme):
def decompress_from_state_dict(
self,
state_dict: Dict[str, torch.Tensor],
names_to_scheme: Dict[str, QuantizationScheme],
) -> Generator[Tuple[str, Dict[str, torch.Tensor]], None, None]:
weight_mappings = get_nested_mappings_from_state_dict(
state_dict, self.compression_param_names
)
Expand All @@ -219,7 +226,7 @@ def _decompress_from_state_dict(self, state_dict, names_to_scheme):
weight_data[param_name] = param_value

if "weight_scale" in weight_data:
quant_args = names_to_scheme[module_path]
quant_args = names_to_scheme[module_path].weights
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
Expand Down
50 changes: 44 additions & 6 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from typing import Dict, Generator, Optional, Set, Tuple

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
merge_names,
)
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm
Expand Down Expand Up @@ -63,6 +67,7 @@ def compress(
self,
model_state: Dict[str, Tensor],
compression_targets: Optional[Set[str]] = None,
show_progress: bool = False,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression
Expand All @@ -76,7 +81,11 @@ def compress(
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
for name, value in tqdm(
model_state.items(),
desc="Compressing with sparsity",
disable=(not show_progress),
):
if not self.should_compress(name, compression_targets):
compressed_dict[name] = value
continue
Expand Down Expand Up @@ -124,15 +133,15 @@ def decompress(
self.compression_param_names,
return_unmatched_params=True,
)
for weight_name in weight_mappings.keys():
for module_path in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
for param_name, safe_path in weight_mappings[module_path].items():
full_name = merge_names(module_path, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)

decompressed = self.decompress_weight(weight_data)
yield merge_names(weight_name, "weight"), decompressed
yield merge_names(module_path, "weight"), decompressed

for ignored_param_name, safe_path in ignored_params.items():
should_skip = False
Expand All @@ -146,6 +155,35 @@ def decompress(
value = f.get_tensor(ignored_param_name)
yield ignored_param_name, value

def decompress_from_state_dict(
self,
state_dict: Dict[str, Tensor],
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
"""
Decompress the state dict of a module (or model)

Unlike `self.decompress`, this function does not need to explicitly skip params
via params_to_skip_load because it is more convenient for its only caller
(ModelCompressor.decompress_model) to retrieve all unused param keys

:param state_dict: state dict containing parameters to decompress
:return: Generator of (param_path, param_val)
"""
weight_mappings, ignored_params = get_nested_mappings_from_state_dict(
state_dict, self.compression_param_names, return_unmatched_params=True
)

for module_path in weight_mappings.keys():
weight_data = {}
for param_name, param_value in weight_mappings[module_path].items():
weight_data[param_name] = param_value

decompressed = self.decompress_weight(weight_data)
yield merge_names(module_path, "weight"), decompressed

for ignored_param_path, ignored_param_value in ignored_params.items():
yield ignored_param_path, ignored_param_value

@staticmethod
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ def decompress(
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
) -> Generator[Tuple[str, Tensor], None, None]:
return iter([])

def decompress_from_state_dict(
self,
state_dict: Dict[str, Tensor],
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
for key, value in state_dict.items():
yield key, value
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
from typing import Dict, Generator, List, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
Expand Down Expand Up @@ -202,11 +202,7 @@ def sparse24_bitmask_decompress(
decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
decompressed_tensor = decompressed_tensor.to(values.device)
values = values.flatten()
if decompressed_tensor.dtype == FP8_DTYPE:
decompressed_tensor[bytemasks_unpacked] = values
decompressed_tensor = decompressed_tensor.cuda()
else:
decompressed_tensor[bytemasks_unpacked] = values
decompressed_tensor[bytemasks_unpacked] = values
return decompressed_tensor


Expand Down
Loading