diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 13a9f375..1e509a8b 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -33,6 +33,7 @@ from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.sparse_compressors import DenseCompressor from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig +from compressed_tensors.linear.compressed_linear import CompressedLinear from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, QuantizationConfig, @@ -40,6 +41,7 @@ QuantizationStatus, apply_quantization_config, load_pretrained_quantization_parameters, + unwrap_module_forward_quantized, ) from compressed_tensors.quantization.lifecycle import expand_target_names from compressed_tensors.quantization.utils import ( @@ -50,6 +52,7 @@ get_safetensors_folder, has_offloaded_params, merge_names, + module_map_replace, register_offload_parameter, update_parameter_data, ) @@ -57,6 +60,7 @@ fix_fsdp_module_name, is_compressed_tensors_config, ) +from compressed_tensors.utils.offload import disable_hf_hook, update_offload_parameter from torch import Tensor from torch.nn import Module from tqdm import tqdm @@ -98,6 +102,9 @@ class ModelCompressor: :param quantization_config: config specifying quantization compression parameters """ + sparsity_config: Optional[SparsityCompressionConfig] = None + quantization_config: Optional[QuantizationConfig] = None + @classmethod def from_pretrained( cls, @@ -362,8 +369,50 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: return list(unexpected_keys) + def apply_compression_status(self, model: Module): + # sparsity compression + if self.quantization_config is None: + for module in model.modules(): + module.quantization_status = QuantizationStatus.COMPRESSED + + # hack: compress state dict upfront, since CompressedLinear doesn't have + # support for sparsified models + model_state_dict = self.compress(model) + + def state_dict_hook(module, prefix, keep_vars): + return model_state_dict if prefix == "" else {} + + model.register_state_dict_pre_hook(state_dict_hook) + + return + + def replace_with_compressed(module: Module) -> Module: + scheme = getattr(module, "quantization_scheme", None) + if isinstance(module, torch.nn.Linear) and scheme is not None: + # TODO: in the future, implement this with CompressedLinear + state_dict = self.compress(module, show_progress=False) + + # remove any exist parameters + for name, _ in list(module.named_parameters()): + delattr(module, name) + + # replace with compressed parameters + for name, value in state_dict.items(): + param = torch.nn.Parameter(value, requires_grad=False) + register_offload_parameter(module, name, param) + + module.quantization_status = QuantizationStatus.COMPRESSED + + return module + + progress = tqdm(desc="Compressing modules", total=len(list(model.modules()))) + module_map_replace(model, replace_with_compressed, progress=progress) + def compress( - self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None + self, + model: Module, + state_dict: Optional[Dict[str, Tensor]] = None, + show_progress: bool = False, ) -> Dict[str, Tensor]: """ Compresses a dense state dict or model with sparsity and/or quantization @@ -379,7 +428,9 @@ def compress( if self.quantization_compressor is not None: module_to_scheme = map_module_to_scheme(model) state_dict = self.quantization_compressor.compress( - state_dict, names_to_scheme=module_to_scheme + state_dict, + names_to_scheme=module_to_scheme, + show_progress=show_progress, ) # TODO: consider sparse compression to also be compression @@ -397,6 +448,7 @@ def compress( state_dict = self.sparsity_compressor.compress( state_dict, compression_targets=sparse_compression_targets, + show_progress=show_progress, ) # HACK: Override the dtype_byte_size function in transformers to diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 832cc4c0..5b108d63 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -71,6 +71,7 @@ def compress( self, model_state: Dict[str, Tensor], names_to_scheme: Dict[str, QuantizationScheme], + show_progress: bool = False, **kwargs, ) -> Dict[str, Tensor]: """ @@ -79,13 +80,16 @@ def compress( :param model_state: state dict of uncompressed model :param names_to_scheme: quantization args for each quantized weight, needed for quantize function to calculate bit depth + :param show_progress: whether to show tqdm progress :return: compressed state dict """ + uncompressed_names = list(model_state.keys()) compressed_dict = {} save_device = "cpu" - uncompressed_names = list(model_state.keys()) - for name in tqdm(uncompressed_names, desc="Compressing with quantization"): + # compress values + desc = "Compressing with quantization" + for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)): value = model_state[name] # compress weights diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 33075819..ed11cbe7 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -63,6 +63,7 @@ def compress( self, model_state: Dict[str, Tensor], compression_targets: Optional[Set[str]] = None, + show_progress: bool = False, ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression @@ -76,7 +77,11 @@ def compress( _LOGGER.debug( f"Compressing model with {len(model_state)} parameterized layers..." ) - for name, value in tqdm(model_state.items(), desc="Compressing model"): + for name, value in tqdm( + model_state.items(), + desc="Compressing with sparsity", + disable=(not show_progress), + ): if not self.should_compress(name, compression_targets): compressed_dict[name] = value continue diff --git a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py index a8487cb7..7b8fea02 100644 --- a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +++ b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py @@ -125,6 +125,7 @@ def compress( self, model_state: Dict[str, Tensor], names_to_scheme: Dict[str, QuantizationScheme], + show_progress: bool = False, **kwargs, ) -> Dict[str, Tensor]: """ @@ -134,6 +135,7 @@ def compress( :param model_state: state dict of uncompressed model :param names_to_scheme: quantization scheme for each quantized weight, needed for quantize function to calculate bit depth + :param show_progress: whether to show tqdm progress :return: compressed state dict """ self.validate_quant_compatability(names_to_scheme) @@ -144,7 +146,9 @@ def compress( f"Compressing model with {len(model_state)} parameterized layers..." ) - for name, value in tqdm(model_state.items(), desc="Compressing model"): + for name, value in tqdm( + model_state.items(), desc="Compressing model", disable=(not show_progress) + ): if name.endswith(weight_suffix): prefix = name[: -(len(weight_suffix))] scale = model_state.get(merge_names(prefix, "weight_scale"), None) diff --git a/src/compressed_tensors/linear/compressed_linear.py b/src/compressed_tensors/linear/compressed_linear.py index 65f992b6..d24df2fc 100644 --- a/src/compressed_tensors/linear/compressed_linear.py +++ b/src/compressed_tensors/linear/compressed_linear.py @@ -23,6 +23,7 @@ initialize_module_for_quantization, ) from compressed_tensors.utils import register_offload_parameter +from compressed_tensors.utils.offload import get_execution_device from torch import Tensor from torch.nn import Parameter from torch.nn.functional import linear @@ -60,7 +61,7 @@ def from_linear( """ module.__class__ = CompressedLinear module.compressor = BaseCompressor.load_from_registry(quantization_format) - device = next(module.parameters()).device + init_device = get_execution_device(module) # this will initialize all the scales and zero points initialize_module_for_quantization( @@ -79,7 +80,7 @@ def from_linear( # populate compressed weights and quantization parameters for name, (shape, dtype) in compression_params.items(): param = Parameter( - torch.empty(shape, device=device, dtype=dtype), requires_grad=False + torch.empty(shape, device=init_device, dtype=dtype), requires_grad=False ) register_offload_parameter(module, name, param) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index f4f93f27..4a3af6fa 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -37,6 +37,7 @@ "dequantize", "fake_quantize", "wrap_module_forward_quantized", + "unwrap_module_forward_quantized", "forward_quantize", ] @@ -312,6 +313,10 @@ def wrapped_forward(self, *args, **kwargs): setattr(module, "forward", bound_wrapped_forward) +def unwrap_module_forward_quantized(module: Module): + delattr(module, "forward") # revert to class implementation + + def forward_quantize( module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs" ) -> torch.Tensor: diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index af0b9159..7b0303ca 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -14,10 +14,11 @@ import warnings from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import numpy import torch +import tqdm from transformers import AutoConfig @@ -39,6 +40,7 @@ "pack_bitmasks", "unpack_bitmasks", "remove_suffix", + "module_map_replace", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -335,3 +337,42 @@ def remove_suffix(value: str, suffix: str) -> str: # can replace with str.removesuffix in python3.9+ assert value.endswith(suffix) return value[: -len(suffix)] + + +def module_map_replace( + module: torch.nn.Module, + func: Callable[[torch.nn.Module], torch.nn.Module], + progress: Union[bool, tqdm.tqdm] = False, + pre: bool = True, +) -> torch.nn.Module: + """ + Replaces modules in a given `torch.nn.Module` recursively using a provided function. + + This function traverses the module hierarchy and applies the `func` transformation + either before (`pre=True`) or after (`pre=False`) recursing into children modules. + Optionally displays progress using tqdm. + + :param module: root module to replace + :param func: module mapping function + :param progress: if True, display a tqdm progress bar. + If a `tqdm.tqdm` instance is provided, the instance will be updated + :param pre: if True, apply with pre-order, post-order otherwise + :return: the modified module after applying the function to all submodules + """ + if progress is True: + total = len(list(module.modules())) + progress = tqdm.tqdm(total=total) + + if pre: + module = func(module) + + for name, child in list(module.named_children()): + module.add_module(name, module_map_replace(child, func, pre, progress)) + + if not pre: + module = func(module) + + if isinstance(progress, tqdm.tqdm): + progress.update(1) + + return module diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index e0049f27..e586fdea 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -21,9 +21,11 @@ import torch.nn as nn from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config import SparsityCompressionConfig -from compressed_tensors.quantization import QuantizationConfig +from compressed_tensors.linear.compressed_linear import CompressedLinear +from compressed_tensors.quantization import QuantizationConfig, QuantizationStatus from safetensors.torch import save_file from tests.testing_utils import induce_sparsity, requires_hf_quantizer +from transformers import AutoModelForCausalLM def sparsity_config(): @@ -365,3 +367,54 @@ def _get_combined_config(s_config, q_config): combined["sparsity_config"] = s_config return combined + + +@pytest.mark.parametrize( + "model_stub,q_format,s_format", + [ + ( + "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed", + "float-quantized", + None, + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed", + None, + "sparse-24-bitmask", + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed", + "float-quantized", + "sparse-24-bitmask", + ), + ], +) +def test_apply_compression_status(model_stub, q_format, s_format): + model = AutoModelForCausalLM.from_pretrained(model_stub) + compressor = ModelCompressor.from_pretrained_model(model, s_format, q_format) + original_compressed_state_dict = dict(compressor.compress(model)) + original_compressed_state_dict = { + key: value.clone() for key, value in original_compressed_state_dict.items() + } + + compressor.apply_compression_status(model) + + for module in model.modules(): + # scheme <=> CompressedLinear + has_scheme = hasattr(module, "quantization_scheme") + is_compressed = ( + getattr(module, "quantization_status", None) + == QuantizationStatus.COMPRESSED + ) + # assert has_scheme == is_compressed + + # equivalent to eagerly compressing state dict + compressed_state_dict = dict(model.state_dict()) + assert compressed_state_dict.keys() == original_compressed_state_dict.keys() + for key in compressed_state_dict.keys(): + assert torch.all( + compressed_state_dict[key] == original_compressed_state_dict[key] + ), f"{key}" + + # can run to completion + # model(**model.dummy_inputs)