Skip to content

Commit b3e89a2

Browse files
authored
[Hotfix] Implement quantization compressor methods on dense compressor (#344)
* hotfix dense compressor Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix import structure Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix import structure part 2 Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix type hint Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix typehint Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent f5dbfc3 commit b3e89a2

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@
6868
from transformers.file_utils import CONFIG_NAME
6969

7070

71+
if TYPE_CHECKING:
72+
from compressed_tensors.compressors import BaseQuantizationCompressor
73+
74+
7175
__all__ = ["ModelCompressor", "map_module_to_scheme"]
7276

7377
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -257,7 +261,9 @@ def __init__(
257261
self.sparsity_config = sparsity_config
258262
self.quantization_config = quantization_config
259263
self.sparsity_compressor = None
260-
self.quantization_compressor = None
264+
self.quantization_compressor: Optional[
265+
Union[BaseQuantizationCompressor, DenseCompressor]
266+
] = None
261267

262268
if sparsity_config is not None:
263269
self.sparsity_compressor = BaseCompressor.load_from_registry(

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Dict, Generator, Optional, Set, Tuple
16+
from typing import TYPE_CHECKING, Dict, Generator, Optional, Set, Tuple
1717

18+
import torch
1819
from compressed_tensors.compressors.base import BaseCompressor
1920
from compressed_tensors.utils import (
2021
get_nested_mappings_from_state_dict,
@@ -26,6 +27,10 @@
2627
from tqdm import tqdm
2728

2829

30+
if TYPE_CHECKING:
31+
from compressed_tensors.quantization import QuantizationScheme
32+
33+
2934
__all__ = ["BaseSparseCompressor"]
3035

3136
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -200,3 +205,16 @@ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> b
200205
return (
201206
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
202207
)
208+
209+
def decompress_module_from_state_dict(
210+
self,
211+
prefix: str,
212+
state_dict: Dict[str, torch.Tensor],
213+
scheme: "QuantizationScheme",
214+
) -> Dict[str, torch.Tensor]:
215+
"""
216+
This function is implemented as a workaround because of how
217+
`ModelCompressor.quantization_compressor` can be set to either
218+
an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`.
219+
"""
220+
return state_dict.copy()

src/compressed_tensors/utils/offload.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,15 @@ def decorator(func: Callable[[Any], Any]):
8787
if not _has_accelerate:
8888

8989
if fallback == "error":
90+
9091
@wraps(func)
9192
def fallback_fn(*args, **kwargs):
9293
raise ValueError(
9394
"Please install `accelerate` in order to use this function"
9495
)
96+
9597
else:
98+
9699
@wraps(func)
97100
def fallback_fn(*args, **kwargs):
98101
return fallback

0 commit comments

Comments
 (0)