Skip to content

Enable module state_dict compression, simplify compression logic #307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 2, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -36,12 +36,12 @@
from compressed_tensors.quantization import (
DEFAULT_QUANTIZATION_METHOD,
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
apply_quantization_config,
load_pretrained_quantization_parameters,
)
from compressed_tensors.quantization.lifecycle import expand_target_names
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
iter_named_leaf_modules,
Expand All @@ -64,7 +64,7 @@
from transformers.file_utils import CONFIG_NAME


__all__ = ["ModelCompressor", "map_modules_to_quant_args"]
__all__ = ["ModelCompressor", "map_module_to_scheme"]

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -372,20 +372,17 @@ def compress(
:param state_dict: optional uncompressed state_dict to insert into model
:return: compressed state dict
"""

if state_dict is None:
state_dict = model.state_dict()

compressed_state_dict = state_dict

quantized_modules_to_args: Dict[
str, QuantizationArgs
] = map_modules_to_quant_args(model)

if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
state_dict, names_to_scheme=quantized_modules_to_args
module_to_scheme = map_module_to_scheme(model)
state_dict = self.quantization_compressor.compress(
state_dict, names_to_scheme=module_to_scheme
)

# TODO: consider sparse compression to also be compression
if self.quantization_config.format != CompressionFormat.dense.value:
self.quantization_config.quantization_status = (
QuantizationStatus.COMPRESSED
Expand All @@ -397,8 +394,8 @@ def compress(
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict,
state_dict = self.sparsity_compressor.compress(
state_dict,
compression_targets=sparse_compression_targets,
)

Expand All @@ -407,7 +404,7 @@ def compress(
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size

return compressed_state_dict
return state_dict

def decompress(self, model_path: str, model: Module):
"""
Expand Down Expand Up @@ -605,30 +602,15 @@ def _replace_weights(self, dense_weight_generator, model: Module):
update_parameter_data(module, param_data, param_name)


def map_modules_to_quant_args(
model: Module,
) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]:
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
"""
Given a pytorch model, map out the submodule name (usually linear layers)
to the weight QuantizationArgs. If running input activation quantization, will also
map to the input QuantizationArgs in a tuple.

:param model: pytorch model
Returns a dictionary which maps quantized module names to their quantization schemes
"""
quantized_modules_to_args = {}
for name, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
if submodule.quantization_scheme.weights is not None:
name = fix_fsdp_module_name(name)
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
if submodule.quantization_scheme.input_activations is not None:
weight_args = quantized_modules_to_args.get(name)
quantized_modules_to_args[name] = (
weight_args,
submodule.quantization_scheme.input_activations,
)

return quantized_modules_to_args
return {
fix_fsdp_module_name(name): module.quantization_scheme
for name, module in iter_named_leaf_modules(model)
if is_module_quantized(module)
}


# HACK: Override the dtype_byte_size function in transformers to support float8 types
Expand Down
161 changes: 84 additions & 77 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

import logging
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple, Union
from typing import Any, Dict, Generator, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
merge_names,
remove_suffix,
)
from safetensors import safe_open
from torch import Tensor
Expand Down Expand Up @@ -69,7 +70,7 @@ class BaseQuantizationCompressor(BaseCompressor):
def compress(
self,
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationArgs],
names_to_scheme: Dict[str, QuantizationScheme],
**kwargs,
) -> Dict[str, Tensor]:
"""
Expand All @@ -81,87 +82,87 @@ def compress(
:return: compressed state dict
"""
compressed_dict = {}
weight_suffix = ".weight"
input_zp_suffix = ".input_zero_point"
weight_zp_suffix = ".weight_zero_point"
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
save_device = "cpu"

uncompressed_names = list(model_state.keys())
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
value = model_state[name]

# compress weights
if name.endswith("weight"):
prefix = remove_suffix(name, "weight")

# gather qparams
scale = model_state.get(prefix + "weight_scale", None)
g_idx = model_state.get(prefix + "weight_g_idx", None)
zp = model_state.get(prefix + "weight_zero_point", None)

# is scale does not exist, then weight cannot be compressed
if scale is None:
compressed_dict[name] = value.to(save_device)
continue

# compress values on cpu (memory movement too expensive)
module_path = prefix[:-1] if prefix.endswith(".") else prefix
quant_args = names_to_scheme[module_path].weights
compressed_values = self.compress_weight(
weight=value,
scale=scale,
zero_point=zp,
g_idx=g_idx,
quantization_args=quant_args,
device="cpu",
)

# update state dict
for key, value in compressed_values.items():
compressed_dict[prefix + key] = value.to(save_device)

for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
# check if the parameter we're compressing is the weight zp
# or the input zp
is_weight_zp = name.endswith(weight_zp_suffix)
is_input_zp = name.endswith(input_zp_suffix)

# if we're saving the weight zp, fetch weight quant args
if is_weight_zp:
quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))])
if isinstance(quant_args_zp, tuple):
# If tuple, first value is weight args, second is input args
quant_args_zp = quant_args_zp[0]

# if we're saving the input zp, fetch input quant args
if is_input_zp:
input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))])
if isinstance(input_args_zp, tuple):
# If tuple, first value is weight args, second is input args
input_args_zp = input_args_zp[-1]

if name.endswith(weight_suffix):
prefix = name[: -(len(weight_suffix))]
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
if scale is not None:
# weight is quantized, compress it
if isinstance(names_to_scheme[prefix], tuple):
quant_args = names_to_scheme[prefix][0]
else:
quant_args = names_to_scheme[prefix]

compressed_data = self.compress_weight(
weight=value,
scale=scale,
zero_point=zp,
g_idx=g_idx,
quantization_args=quant_args,
device="cpu",
)
for key, value in compressed_data.items():
compressed_dict[merge_names(prefix, key)] = value
else:
compressed_dict[name] = value.to("cpu")
# only save zp if asym and not packed zp
elif is_weight_zp and (
quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args)
):
continue
# only save if asym
elif is_input_zp and input_args_zp.symmetric:
continue
elif name.endswith("g_idx") and torch.any(value <= -1):
continue
else:
compressed_dict[name] = value.to("cpu")
# omit saving zero points for symmetric or packed quantization
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
continue

# omit saving for g_idx if uninitialized
# TODO: does this case actually occur?
elif name.endswith("g_idx") and torch.any(value <= -1):
continue

compressed_dict[name] = value.to(save_device)

return compressed_dict

def _check_if_zp_pack_quantized(self, quant_args):
def _skip_zp(
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
) -> bool:
from compressed_tensors.compressors import PackedQuantizationCompressor

if isinstance(self, PackedQuantizationCompressor):
if not quant_args.symmetric and quant_args.strategy in [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]:
return True
return False
module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
scheme = names_to_scheme[module_name]

if zp_name == "weight_zero_point":
args = scheme.weights
if zp_name == "input_zero_point":
args = scheme.input_activations
if zp_name == "output_zero_point":
args = scheme.output_activations

symmetric = args.symmetric
packable_strategies = [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]
packed = (
isinstance(self, PackedQuantizationCompressor)
and args.strategy in packable_strategies
)

return symmetric or packed

def decompress(
self,
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
names_to_scheme: Dict[str, QuantizationArgs],
names_to_scheme: Dict[str, QuantizationScheme],
device: str = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Expand All @@ -170,8 +171,9 @@ def decompress(
dense state dict
:param path_to_model_or_tensors: path to compressed safetensors model (directory
with one or more safetensors files) or compressed tensors file
:param names_to_scheme: quantization args for each quantized weight
:param device: optional device to load intermediate weights into
:param names_to_scheme: quantization scheme for each quantized weight
:param device: optional device to load intermediate weights into (must be `str`,
not `torch.device`)
:return: compressed state dict
"""
if isinstance(path_to_model_or_tensors, (str, Path)):
Expand All @@ -184,7 +186,12 @@ def decompress(
path_to_model_or_tensors, names_to_scheme
)

def _decompress_from_path(self, path_to_model, names_to_scheme, device):
def _decompress_from_path(
self,
path_to_model: Union[str, Path, Dict[str, Any]],
names_to_scheme: Dict[str, QuantizationScheme],
device: str,
):
weight_mappings = get_nested_weight_mappings(
path_to_model, self.compression_param_names
)
Expand All @@ -195,7 +202,7 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
quant_args = names_to_scheme[weight_name].weights
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
Expand Down
Loading