From 4a19f2c24149d61247a59dc77ddbac8baa8fe62e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Jun 2025 14:14:34 -0400 Subject: [PATCH 1/5] hotfix dense compressor Signed-off-by: Kyle Sayers --- .../model_compressors/model_compressor.py | 12 ++++++++--- .../compressors/sparse_compressors/base.py | 20 ++++++++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 4059ae8d..ba028f47 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -30,8 +30,12 @@ QUANTIZATION_METHOD_NAME, SPARSITY_CONFIG_NAME, ) -from compressed_tensors.compressors.base import BaseCompressor -from compressed_tensors.compressors.sparse_compressors import DenseCompressor +from compressed_tensors.compressors import ( + BaseCompressor, + BaseQuantizationCompressor, + BaseSparseCompressor, + DenseCompressor, +) from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, @@ -257,7 +261,9 @@ def __init__( self.sparsity_config = sparsity_config self.quantization_config = quantization_config self.sparsity_compressor = None - self.quantization_compressor = None + self.quantization_compressor: Optional[ + Union[BaseQuantizationCompressor, BaseSparseCompressor] + ] = None if sparsity_config is not None: self.sparsity_compressor = BaseCompressor.load_from_registry( diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index e29b8284..b87de647 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,8 +13,9 @@ # limitations under the License. import logging -from typing import Dict, Generator, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Generator, Optional, Set, Tuple +import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, @@ -26,6 +27,10 @@ from tqdm import tqdm +if TYPE_CHECKING: + from compressed_tensors.quantization import QuantizationScheme + + __all__ = ["BaseSparseCompressor"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -200,3 +205,16 @@ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> b return ( name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets ) + + def decompress_module_from_state_dict( + self, + prefix: str, + state_dict: Dict[str, torch.Tensor], + scheme: QuantizationScheme, + ) -> Dict[str, torch.Tensor]: + """ + This function is implemented as a workaround because of how + `ModelCompressor.quantization_compressor` can be set to either + an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`. + """ + return state_dict.copy() From 20634fc2885f8eb2b746b46502b0e373db49761a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Jun 2025 14:30:24 -0400 Subject: [PATCH 2/5] fix import structure Signed-off-by: Kyle Sayers --- .../model_compressors/model_compressor.py | 14 ++++++++------ src/compressed_tensors/utils/offload.py | 3 +++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index ba028f47..3e8830d3 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -30,12 +30,7 @@ QUANTIZATION_METHOD_NAME, SPARSITY_CONFIG_NAME, ) -from compressed_tensors.compressors import ( - BaseCompressor, - BaseQuantizationCompressor, - BaseSparseCompressor, - DenseCompressor, -) +from compressed_tensors.compressors import BaseCompressor, DenseCompressor from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, @@ -72,6 +67,13 @@ from transformers.file_utils import CONFIG_NAME +if TYPE_CHECKING: + from compressed_tensors.compressors import ( + BaseQuantizationCompressor, + BaseSparseCompressor, + ) + + __all__ = ["ModelCompressor", "map_module_to_scheme"] _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index c7b4cc05..c1a46e3c 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -87,12 +87,15 @@ def decorator(func: Callable[[Any], Any]): if not _has_accelerate: if fallback == "error": + @wraps(func) def fallback_fn(*args, **kwargs): raise ValueError( "Please install `accelerate` in order to use this function" ) + else: + @wraps(func) def fallback_fn(*args, **kwargs): return fallback From 4540f26fa912d7605c8480a6f44c11fd3e07a761 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Jun 2025 14:34:18 -0400 Subject: [PATCH 3/5] fix import structure part 2 Signed-off-by: Kyle Sayers --- .../compressors/model_compressors/model_compressor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 3e8830d3..f20f939e 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -30,7 +30,8 @@ QUANTIZATION_METHOD_NAME, SPARSITY_CONFIG_NAME, ) -from compressed_tensors.compressors import BaseCompressor, DenseCompressor +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.sparse_compressors import DenseCompressor from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, From e04ab1660a844d8953b970e671bd472f3e115c73 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Jun 2025 14:50:29 -0400 Subject: [PATCH 4/5] fix type hint Signed-off-by: Kyle Sayers --- .../compressors/model_compressors/model_compressor.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index f20f939e..700c1769 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -69,10 +69,7 @@ if TYPE_CHECKING: - from compressed_tensors.compressors import ( - BaseQuantizationCompressor, - BaseSparseCompressor, - ) + from compressed_tensors.compressors import BaseQuantizationCompressor __all__ = ["ModelCompressor", "map_module_to_scheme"] @@ -265,7 +262,7 @@ def __init__( self.quantization_config = quantization_config self.sparsity_compressor = None self.quantization_compressor: Optional[ - Union[BaseQuantizationCompressor, BaseSparseCompressor] + Union[BaseQuantizationCompressor, DenseCompressor] ] = None if sparsity_config is not None: From 6400912073dedacb3d2c4c5bfebb866c89a9a95a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Jun 2025 14:58:37 -0400 Subject: [PATCH 5/5] fix typehint Signed-off-by: Kyle Sayers --- src/compressed_tensors/compressors/sparse_compressors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index b87de647..bfba059e 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -210,7 +210,7 @@ def decompress_module_from_state_dict( self, prefix: str, state_dict: Dict[str, torch.Tensor], - scheme: QuantizationScheme, + scheme: "QuantizationScheme", ) -> Dict[str, torch.Tensor]: """ This function is implemented as a workaround because of how