From e8748e92161a8e2a0f6f34d35dbdcc17b461ee4a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Jun 2025 16:46:40 -0400 Subject: [PATCH 1/2] swap to dense Signed-off-by: Kyle Sayers --- .../compressors/sparse_compressors/base.py | 20 +------------------ .../compressors/sparse_compressors/dense.py | 20 ++++++++++++++++++- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index bfba059e..e29b8284 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,9 +13,8 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Generator, Optional, Set, Tuple +from typing import Dict, Generator, Optional, Set, Tuple -import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, @@ -27,10 +26,6 @@ from tqdm import tqdm -if TYPE_CHECKING: - from compressed_tensors.quantization import QuantizationScheme - - __all__ = ["BaseSparseCompressor"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -205,16 +200,3 @@ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> b return ( name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets ) - - def decompress_module_from_state_dict( - self, - prefix: str, - state_dict: Dict[str, torch.Tensor], - scheme: "QuantizationScheme", - ) -> Dict[str, torch.Tensor]: - """ - This function is implemented as a workaround because of how - `ModelCompressor.quantization_compressor` can be set to either - an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`. - """ - return state_dict.copy() diff --git a/src/compressed_tensors/compressors/sparse_compressors/dense.py b/src/compressed_tensors/compressors/sparse_compressors/dense.py index 0ec2b5f6..2625a5a3 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/dense.py +++ b/src/compressed_tensors/compressors/sparse_compressors/dense.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Generator, Tuple +from typing import TYPE_CHECKING, Dict, Generator, Tuple +import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat from torch import Tensor +if TYPE_CHECKING: + from compressed_tensors.quantization import QuantizationScheme + + @BaseCompressor.register(name=CompressionFormat.dense.value) class DenseCompressor(BaseCompressor): """ @@ -47,3 +52,16 @@ def decompress_from_state_dict( ) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]: for key, value in state_dict.items(): yield key, value + + def decompress_module_from_state_dict( + self, + prefix: str, + state_dict: Dict[str, torch.Tensor], + scheme: "QuantizationScheme", + ) -> Dict[str, torch.Tensor]: + """ + This function is implemented as a workaround because of how + `ModelCompressor.quantization_compressor` can be set to either + an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`. + """ + return state_dict.copy() From c7b83fa2a8f51c9fa6a932a81bafc4e8cab8aa40 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Jun 2025 16:55:57 -0400 Subject: [PATCH 2/2] update docstring Signed-off-by: Kyle Sayers --- src/compressed_tensors/compressors/sparse_compressors/dense.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/dense.py b/src/compressed_tensors/compressors/sparse_compressors/dense.py index 2625a5a3..d782ad27 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/dense.py +++ b/src/compressed_tensors/compressors/sparse_compressors/dense.py @@ -62,6 +62,6 @@ def decompress_module_from_state_dict( """ This function is implemented as a workaround because of how `ModelCompressor.quantization_compressor` can be set to either - an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`. + an instance of `BaseQuantizationCompressor` or `DenseCompressor`. """ return state_dict.copy()