diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 4059ae8d..700c1769 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -68,6 +68,10 @@ from transformers.file_utils import CONFIG_NAME +if TYPE_CHECKING: + from compressed_tensors.compressors import BaseQuantizationCompressor + + __all__ = ["ModelCompressor", "map_module_to_scheme"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -257,7 +261,9 @@ def __init__( self.sparsity_config = sparsity_config self.quantization_config = quantization_config self.sparsity_compressor = None - self.quantization_compressor = None + self.quantization_compressor: Optional[ + Union[BaseQuantizationCompressor, DenseCompressor] + ] = None if sparsity_config is not None: self.sparsity_compressor = BaseCompressor.load_from_registry( diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index e29b8284..bfba059e 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,8 +13,9 @@ # limitations under the License. import logging -from typing import Dict, Generator, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Generator, Optional, Set, Tuple +import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, @@ -26,6 +27,10 @@ from tqdm import tqdm +if TYPE_CHECKING: + from compressed_tensors.quantization import QuantizationScheme + + __all__ = ["BaseSparseCompressor"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -200,3 +205,16 @@ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> b return ( name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets ) + + def decompress_module_from_state_dict( + self, + prefix: str, + state_dict: Dict[str, torch.Tensor], + scheme: "QuantizationScheme", + ) -> Dict[str, torch.Tensor]: + """ + This function is implemented as a workaround because of how + `ModelCompressor.quantization_compressor` can be set to either + an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`. + """ + return state_dict.copy() diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index c7b4cc05..c1a46e3c 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -87,12 +87,15 @@ def decorator(func: Callable[[Any], Any]): if not _has_accelerate: if fallback == "error": + @wraps(func) def fallback_fn(*args, **kwargs): raise ValueError( "Please install `accelerate` in order to use this function" ) + else: + @wraps(func) def fallback_fn(*args, **kwargs): return fallback