diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index bfba059e..e29b8284 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -13,9 +13,8 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Generator, Optional, Set, Tuple +from typing import Dict, Generator, Optional, Set, Tuple -import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, @@ -27,10 +26,6 @@ from tqdm import tqdm -if TYPE_CHECKING: - from compressed_tensors.quantization import QuantizationScheme - - __all__ = ["BaseSparseCompressor"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -205,16 +200,3 @@ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> b return ( name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets ) - - def decompress_module_from_state_dict( - self, - prefix: str, - state_dict: Dict[str, torch.Tensor], - scheme: "QuantizationScheme", - ) -> Dict[str, torch.Tensor]: - """ - This function is implemented as a workaround because of how - `ModelCompressor.quantization_compressor` can be set to either - an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`. - """ - return state_dict.copy() diff --git a/src/compressed_tensors/compressors/sparse_compressors/dense.py b/src/compressed_tensors/compressors/sparse_compressors/dense.py index 0ec2b5f6..d782ad27 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/dense.py +++ b/src/compressed_tensors/compressors/sparse_compressors/dense.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Generator, Tuple +from typing import TYPE_CHECKING, Dict, Generator, Tuple +import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat from torch import Tensor +if TYPE_CHECKING: + from compressed_tensors.quantization import QuantizationScheme + + @BaseCompressor.register(name=CompressionFormat.dense.value) class DenseCompressor(BaseCompressor): """ @@ -47,3 +52,16 @@ def decompress_from_state_dict( ) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]: for key, value in state_dict.items(): yield key, value + + def decompress_module_from_state_dict( + self, + prefix: str, + state_dict: Dict[str, torch.Tensor], + scheme: "QuantizationScheme", + ) -> Dict[str, torch.Tensor]: + """ + This function is implemented as a workaround because of how + `ModelCompressor.quantization_compressor` can be set to either + an instance of `BaseQuantizationCompressor` or `DenseCompressor`. + """ + return state_dict.copy()