Skip to content

Revert "Enable module state_dict compression, simplify compression lo… #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -36,12 +36,12 @@
from compressed_tensors.quantization import (
DEFAULT_QUANTIZATION_METHOD,
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
apply_quantization_config,
load_pretrained_quantization_parameters,
)
from compressed_tensors.quantization.lifecycle import expand_target_names
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import (
is_module_quantized,
iter_named_leaf_modules,
Expand All @@ -64,7 +64,7 @@
from transformers.file_utils import CONFIG_NAME


__all__ = ["ModelCompressor", "map_module_to_scheme"]
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -372,17 +372,20 @@ def compress(
:param state_dict: optional uncompressed state_dict to insert into model
:return: compressed state dict
"""

if state_dict is None:
state_dict = model.state_dict()

compressed_state_dict = state_dict

quantized_modules_to_args: Dict[
str, QuantizationArgs
] = map_modules_to_quant_args(model)

if self.quantization_compressor is not None:
module_to_scheme = map_module_to_scheme(model)
state_dict = self.quantization_compressor.compress(
state_dict, names_to_scheme=module_to_scheme
compressed_state_dict = self.quantization_compressor.compress(
state_dict, names_to_scheme=quantized_modules_to_args
)

# TODO: consider sparse compression to also be compression
if self.quantization_config.format != CompressionFormat.dense.value:
self.quantization_config.quantization_status = (
QuantizationStatus.COMPRESSED
Expand All @@ -394,8 +397,8 @@ def compress(
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
state_dict = self.sparsity_compressor.compress(
state_dict,
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict,
compression_targets=sparse_compression_targets,
)

Expand All @@ -404,7 +407,7 @@ def compress(
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size

return state_dict
return compressed_state_dict

def decompress(self, model_path: str, model: Module):
"""
Expand Down Expand Up @@ -602,15 +605,30 @@ def _replace_weights(self, dense_weight_generator, model: Module):
update_parameter_data(module, param_data, param_name)


def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
def map_modules_to_quant_args(
model: Module,
) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]:
"""
Returns a dictionary which maps quantized module names to their quantization schemes
Given a pytorch model, map out the submodule name (usually linear layers)
to the weight QuantizationArgs. If running input activation quantization, will also
map to the input QuantizationArgs in a tuple.

:param model: pytorch model
"""
return {
fix_fsdp_module_name(name): module.quantization_scheme
for name, module in iter_named_leaf_modules(model)
if is_module_quantized(module)
}
quantized_modules_to_args = {}
for name, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
if submodule.quantization_scheme.weights is not None:
name = fix_fsdp_module_name(name)
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
if submodule.quantization_scheme.input_activations is not None:
weight_args = quantized_modules_to_args.get(name)
quantized_modules_to_args[name] = (
weight_args,
submodule.quantization_scheme.input_activations,
)

return quantized_modules_to_args


# HACK: Override the dtype_byte_size function in transformers to support float8 types
Expand Down
161 changes: 77 additions & 84 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@

import logging
from pathlib import Path
from typing import Any, Dict, Generator, Tuple, Union
from typing import Any, Dict, Generator, Optional, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
merge_names,
remove_suffix,
)
from safetensors import safe_open
from torch import Tensor
Expand Down Expand Up @@ -70,7 +69,7 @@ class BaseQuantizationCompressor(BaseCompressor):
def compress(
self,
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationScheme],
names_to_scheme: Dict[str, QuantizationArgs],
**kwargs,
) -> Dict[str, Tensor]:
"""
Expand All @@ -82,87 +81,87 @@ def compress(
:return: compressed state dict
"""
compressed_dict = {}
save_device = "cpu"

uncompressed_names = list(model_state.keys())
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
value = model_state[name]

# compress weights
if name.endswith("weight"):
prefix = remove_suffix(name, "weight")

# gather qparams
scale = model_state.get(prefix + "weight_scale", None)
g_idx = model_state.get(prefix + "weight_g_idx", None)
zp = model_state.get(prefix + "weight_zero_point", None)

# is scale does not exist, then weight cannot be compressed
if scale is None:
compressed_dict[name] = value.to(save_device)
continue

# compress values on cpu (memory movement too expensive)
module_path = prefix[:-1] if prefix.endswith(".") else prefix
quant_args = names_to_scheme[module_path].weights
compressed_values = self.compress_weight(
weight=value,
scale=scale,
zero_point=zp,
g_idx=g_idx,
quantization_args=quant_args,
device="cpu",
)

# update state dict
for key, value in compressed_values.items():
compressed_dict[prefix + key] = value.to(save_device)
weight_suffix = ".weight"
input_zp_suffix = ".input_zero_point"
weight_zp_suffix = ".weight_zero_point"
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
# check if the parameter we're compressing is the weight zp
# or the input zp
is_weight_zp = name.endswith(weight_zp_suffix)
is_input_zp = name.endswith(input_zp_suffix)

# if we're saving the weight zp, fetch weight quant args
if is_weight_zp:
quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))])
if isinstance(quant_args_zp, tuple):
# If tuple, first value is weight args, second is input args
quant_args_zp = quant_args_zp[0]

# if we're saving the input zp, fetch input quant args
if is_input_zp:
input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))])
if isinstance(input_args_zp, tuple):
# If tuple, first value is weight args, second is input args
input_args_zp = input_args_zp[-1]

if name.endswith(weight_suffix):
prefix = name[: -(len(weight_suffix))]
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
if scale is not None:
# weight is quantized, compress it
if isinstance(names_to_scheme[prefix], tuple):
quant_args = names_to_scheme[prefix][0]
else:
quant_args = names_to_scheme[prefix]

compressed_data = self.compress_weight(
weight=value,
scale=scale,
zero_point=zp,
g_idx=g_idx,
quantization_args=quant_args,
device="cpu",
)
for key, value in compressed_data.items():
compressed_dict[merge_names(prefix, key)] = value
else:
compressed_dict[name] = value.to("cpu")
# only save zp if asym and not packed zp
elif is_weight_zp and (
quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args)
):
continue
# only save if asym
elif is_input_zp and input_args_zp.symmetric:
continue
elif name.endswith("g_idx") and torch.any(value <= -1):
continue
else:
# omit saving zero points for symmetric or packed quantization
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
continue

# omit saving for g_idx if uninitialized
# TODO: does this case actually occur?
elif name.endswith("g_idx") and torch.any(value <= -1):
continue

compressed_dict[name] = value.to(save_device)
compressed_dict[name] = value.to("cpu")

return compressed_dict

def _skip_zp(
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
) -> bool:
def _check_if_zp_pack_quantized(self, quant_args):
from compressed_tensors.compressors import PackedQuantizationCompressor

module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
scheme = names_to_scheme[module_name]

if zp_name == "weight_zero_point":
args = scheme.weights
if zp_name == "input_zero_point":
args = scheme.input_activations
if zp_name == "output_zero_point":
args = scheme.output_activations

symmetric = args.symmetric
packable_strategies = [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]
packed = (
isinstance(self, PackedQuantizationCompressor)
and args.strategy in packable_strategies
)

return symmetric or packed
if isinstance(self, PackedQuantizationCompressor):
if not quant_args.symmetric and quant_args.strategy in [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]:
return True
return False

def decompress(
self,
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
names_to_scheme: Dict[str, QuantizationScheme],
names_to_scheme: Dict[str, QuantizationArgs],
device: str = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Expand All @@ -171,9 +170,8 @@ def decompress(
dense state dict
:param path_to_model_or_tensors: path to compressed safetensors model (directory
with one or more safetensors files) or compressed tensors file
:param names_to_scheme: quantization scheme for each quantized weight
:param device: optional device to load intermediate weights into (must be `str`,
not `torch.device`)
:param names_to_scheme: quantization args for each quantized weight
:param device: optional device to load intermediate weights into
:return: compressed state dict
"""
if isinstance(path_to_model_or_tensors, (str, Path)):
Expand All @@ -186,12 +184,7 @@ def decompress(
path_to_model_or_tensors, names_to_scheme
)

def _decompress_from_path(
self,
path_to_model: Union[str, Path, Dict[str, Any]],
names_to_scheme: Dict[str, QuantizationScheme],
device: str,
):
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
weight_mappings = get_nested_weight_mappings(
path_to_model, self.compression_param_names
)
Expand All @@ -202,7 +195,7 @@ def _decompress_from_path(
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name].weights
quant_args = names_to_scheme[weight_name]
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
Expand Down
Loading
Loading