Skip to content

Commit cef3b60

Browse files
authored
[Hotfix] Implement method on dense compressor (#345)
* swap to dense Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent b3e89a2 commit cef3b60

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import TYPE_CHECKING, Dict, Generator, Optional, Set, Tuple
16+
from typing import Dict, Generator, Optional, Set, Tuple
1717

18-
import torch
1918
from compressed_tensors.compressors.base import BaseCompressor
2019
from compressed_tensors.utils import (
2120
get_nested_mappings_from_state_dict,
@@ -27,10 +26,6 @@
2726
from tqdm import tqdm
2827

2928

30-
if TYPE_CHECKING:
31-
from compressed_tensors.quantization import QuantizationScheme
32-
33-
3429
__all__ = ["BaseSparseCompressor"]
3530

3631
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -205,16 +200,3 @@ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> b
205200
return (
206201
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
207202
)
208-
209-
def decompress_module_from_state_dict(
210-
self,
211-
prefix: str,
212-
state_dict: Dict[str, torch.Tensor],
213-
scheme: "QuantizationScheme",
214-
) -> Dict[str, torch.Tensor]:
215-
"""
216-
This function is implemented as a workaround because of how
217-
`ModelCompressor.quantization_compressor` can be set to either
218-
an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`.
219-
"""
220-
return state_dict.copy()

src/compressed_tensors/compressors/sparse_compressors/dense.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, Generator, Tuple
15+
from typing import TYPE_CHECKING, Dict, Generator, Tuple
1616

17+
import torch
1718
from compressed_tensors.compressors.base import BaseCompressor
1819
from compressed_tensors.config import CompressionFormat
1920
from torch import Tensor
2021

2122

23+
if TYPE_CHECKING:
24+
from compressed_tensors.quantization import QuantizationScheme
25+
26+
2227
@BaseCompressor.register(name=CompressionFormat.dense.value)
2328
class DenseCompressor(BaseCompressor):
2429
"""
@@ -47,3 +52,16 @@ def decompress_from_state_dict(
4752
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
4853
for key, value in state_dict.items():
4954
yield key, value
55+
56+
def decompress_module_from_state_dict(
57+
self,
58+
prefix: str,
59+
state_dict: Dict[str, torch.Tensor],
60+
scheme: "QuantizationScheme",
61+
) -> Dict[str, torch.Tensor]:
62+
"""
63+
This function is implemented as a workaround because of how
64+
`ModelCompressor.quantization_compressor` can be set to either
65+
an instance of `BaseQuantizationCompressor` or `DenseCompressor`.
66+
"""
67+
return state_dict.copy()

0 commit comments

Comments
 (0)